package brn.analysis.dump;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import jist.swans.mac.MacDcfMessage;
import jist.swans.misc.MessageBytes;
import jist.swans.misc.Pickle;

import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;

import brn.analysis.Statistics;
import brn.sim.data.Line;
import brn.sim.data.XplotSerializer;
import brn.sim.data.dump.WiresharkDump;
import brn.sim.data.dump.WiresharkMessage;

/**
 * Correlates the bit-errors of a packet received at one receiver to the errors
 * of a packet received at a different receiver.
 *
 * @author kurth
 *
 */
public class BitErrorCrossCorrelator {

  public static final Logger log = Logger
      .getLogger(BitErrorCrossCorrelator.class.getName());

  protected static final String[] rateString = { "1Mbps", "2Mbps", "5Mbps",
      "11Mbps", "6Mbps", "9Mbps", "12Mbps", "18Mbps", "24Mbps", "36Mbps",
      "48Mbps", "54Mbps" };

  protected static final String[] rateColor = { "red", // 1
      "blue", // 2
      "yellow", // 5
      "green", // 11
      "white", // 6
      "magenta", // 9
      "orange", // 12
      "purple", // 18
      "red", // 24
      "blue", // 36
      "yellow", // 48
      "green", // 54
  };

  protected static final String[] rateShape = { "dot", // 1
      "dot", // 2
      "dot", // 5
      "dot", // 11
      "x", // 6
      "x", // 9
      "x", // 12
      "x", // 18
      "x", // 24
      "x", // 36
      "x", // 48
      "x", // 54
  };

  protected static final int MTU = 4000;

  protected static final int DATA_OFFSET = AthdescHeader.HEADER_SIZE
      + MacDcfMessage.Data.getHeaderSize();;

  protected static final int AUTOCORRELLATION_SIZE = 400;

  protected static final String[] statusString = { "ok", "crc", "phy", "all" };

  public static class IEEE80211Header {
    public static short getSeqId(byte[] msg, int offset) {
      if (offset + 24 > msg.length)
        return -1;
      return (short) (Pickle.arrayToUShort(msg, offset + 22) >> 4);
    }
  }

  private static class BaseCtx {
  }

  private static class LinkCtx extends BaseCtx {
    private WiresharkDump dump = null;

    private String fileName = null;

    Map mapTimestampToPacket = new HashMap();

    List pktArrivedCorrect[] = new ArrayList[rateString.length];

    List pktArrived[] = new ArrayList[rateString.length];

    double[] pktCount = new double[rateString.length];

    double[][] bitsRecvCOR = new double[rateString.length][AUTOCORRELLATION_SIZE];

    // double[][] bitsAutoCOR = new
    // double[rateString.length][AUTOCORRELLATION_SIZE];
    double[][] allanCoV = new double[rateString.length][14];

    public double[][] timeCOR = new double[rateString.length][14];

    public LinkCtx() {
      for (int i = 0; i < pktArrivedCorrect.length; i++) {
        pktArrivedCorrect[i] = new ArrayList();
        pktArrived[i] = new ArrayList();
      }
    }
  }

  private static class ReferenceCtx {
    private WiresharkDump dump = null;

    private String fileName = null;
  }

  private LinkCtx link1, link2;

  private ReferenceCtx ref;

  /**
   * Constructor.
   *
   * @param fileName
   *          file name of the reference stream.
   * @throws IOException
   */
  public BitErrorCrossCorrelator(String fileName) throws IOException {
    this.link1 = null;
    this.ref = new ReferenceCtx();
    ref.fileName = fileName;
  }

  private LinkCtx analyze(String fileName) throws IOException {
    log.info("Analyzing " + fileName);

    LinkCtx link = new LinkCtx();

    link.fileName = fileName;
    link.dump = new WiresharkDump();
    link.dump.open(fileName);

    // Iterate through the dump, put all packets in a hash
    for (WiresharkMessage msg = link.dump.read(); null != msg; msg = link.dump
        .read()) {
      analyzePacket(link, msg);
    }

    return link;
  }

  /**
   * Analyze an input dump file.
   *
   * @throws IOException
   */
  public void analyze(String fileName1, String fileName2) throws IOException {

    link1 = analyze(fileName1);
    link2 = analyze(fileName2);

    // compare against reference
    analyzeRef(ref.fileName);

    save();

    link1 = link2 = null;
  }

  /**
   * Put all packets into a hash map.
   *
   * @param curr
   */
  private void analyzePacket(LinkCtx link, WiresharkMessage curr) {
    byte[] currBytes = ((MessageBytes) curr.getPayload()).getBytes();

    if (currBytes.length <= DATA_OFFSET + 8)
      return;

    long currTime = Long.MIN_VALUE;
    currTime = Pickle.arrayToInteger(currBytes, DATA_OFFSET);
    currTime = currTime << Integer.SIZE;
    currTime += Pickle.arrayToInteger(currBytes, DATA_OFFSET + 4);
    link.mapTimestampToPacket.put(new Long(currTime), curr);
  }

  /**
   * Analyze the reference file.
   *
   * @param fileName
   * @throws IOException
   */
  private void analyzeRef(String fileName) throws IOException {
    log.info("Analyzing " + fileName);

    ref = new ReferenceCtx();

    ref.fileName = fileName;
    ref.dump = new WiresharkDump();
    ref.dump.open(fileName);

    // Iterate through the dump
    // iterate through all ref packets
    WiresharkMessage refMsg = null;
    for (refMsg = ref.dump.read(); refMsg != null; refMsg = ref.dump.read()) {
      analyzeRefPacket(refMsg);
    }
  }

  /**
   * Compares the reference against the received packet.
   *
   */
  private void analyzeRefPacket(WiresharkMessage msg) {
    byte[] refBytes = ((MessageBytes) msg.getPayload()).getBytes();
    int idx = AthdescHeader.getExtTxRateIdx(refBytes, 0);

    // compare in-packet timestamps
    long refTime = Pickle.arrayToInteger(refBytes, DATA_OFFSET);
    refTime = refTime << Integer.SIZE;
    refTime += Pickle.arrayToInteger(refBytes, DATA_OFFSET + 4);

    // look for packet
    WiresharkMessage curr1 = (WiresharkMessage) link1.mapTimestampToPacket
        .get(new Long(refTime));
    WiresharkMessage curr2 = (WiresharkMessage) link2.mapTimestampToPacket
        .get(new Long(refTime));

    link1.pktArrived[idx].add(new Double(null == curr1 ? 0 : 1));
    link2.pktArrived[idx].add(new Double(null == curr2 ? 0 : 1));

    // get error string
    double[] errors1 = null;
    if (null != curr1)
      errors1 = getBitError(refBytes, curr1);
    link1.pktArrivedCorrect[idx].add(new Double((null != curr1 && null == errors1) ? 1 : 0));

    double[] errors2 = null;
    if (null != curr2)
      errors2 = getBitError(refBytes, curr2);
    link2.pktArrivedCorrect[idx].add(new Double((null != curr2 && null == errors2) ? 1 : 0));

    if (null == errors1 || null == errors2)
      return;

    // TODO test remove - begin
    // assert(Arrays.equals(errors1, errors2));
    // for (int i = 0; i < errors1.length; i++)
    // System.out.print(errors1[i] == 0 ? "0" : "1");
    // System.out.println();
    // for (int i = 0; i < errors2.length; i++)
    // System.out.print(errors2[i] == 0 ? "0" : "1");
    // TODO test remove - end

    // calculate the auto correlation and allan deviation
    double[] bitsRecvCOR = new double[AUTOCORRELLATION_SIZE];
    link1.pktCount[idx]++;
    Statistics.correlate(errors1, errors2, bitsRecvCOR, 0.01);
    for (int i = 0; i < bitsRecvCOR.length; i++)
      link1.bitsRecvCOR[idx][i] += bitsRecvCOR[i];

    // calculate time correlation
    double[] timeCOR = Statistics.timeCorrelation(errors1, errors2);
    for (int i = 0; i < timeCOR.length; i++)
      link1.timeCOR[idx][i] += timeCOR[i];

    // // TODO test remove - begin
    // double[] bitsRecvCOR1 = new double[AUTOCORRELLATION_SIZE];
    // Statistics.autoCorrelate(errors1, bitsRecvCOR1, 0.01);
    // for (int i = 0; i < bitsRecvCOR.length; i++)
    // link1.bitsAutoCOR[idx][i] += bitsRecvCOR1[i];
    //
    // for (int i = 0; i < bitsRecvCOR.length; i++)
    // if (Math.abs(bitsRecvCOR1[i] - bitsRecvCOR[i]) > 0.001)
    // System.out.println("Error at " + i);
    // // TODO test remove - end

    // Calc allan covariance
    double[] allanCoV = Statistics.allanCovariance(errors1, errors2);
    for (int i = 0; i < allanCoV.length; i++)
      link1.allanCoV[idx][i] += allanCoV[i];
  }

  private double[] getBitError(byte[] refBytes, WiresharkMessage curr) {
    int bitErrors = 0;
    byte[] currBytes = ((MessageBytes) curr.getPayload()).getBytes();
    double[] errors = new double[currBytes.length * 8];
    Arrays.fill(errors, 0);

    // generate error mask (using xor), ignore athdesc header
    for (int i = AthdescHeader.HEADER_SIZE; i < refBytes.length; i++) {
      if (i == AthdescHeader.HEADER_SIZE + 22
          || i == AthdescHeader.HEADER_SIZE + 23)
        continue; // skip seq id, it is not set in tx feedback

      if (i < currBytes.length) {
        byte xor = (byte) (refBytes[i] ^ currBytes[i]);

        if (0 != xor) {
          for (int j = 0; j < Byte.SIZE; j++)
            if (0 != (xor & (1 << j))) {
              errors[j + i * Byte.SIZE] = 1;
              bitErrors++;
            }
        }
      } else { // missing part
        // TODO missing
      }
    }
    return bitErrors == 0 ? null : errors;
  }

  /**
   * save all results.
   *
   * @throws IOException
   */
  private void save() throws IOException {
    log.info("Writing results");
    File file = new File(this.link1.fileName);
    String fileName = file.getName();
    file = new File(this.link2.fileName);
    fileName += "-" + file.getName();

    // generate the packet error time series
    int max_size = link1.pktArrived[0].size();
    for (int i = 1; i < link1.pktArrived.length; i++)
      max_size = Math.max(max_size, link1.pktArrived[i].size());
    for (int i = 0; i < link2.pktArrived.length; i++)
      max_size = Math.max(max_size, link2.pktArrived[i].size());
    double pktArrivedCorrect[][][] = new double[2][rateString.length][max_size];
    double pktArrived[][][] = new double[2][rateString.length][max_size];

    for (int rate = 0; rate < rateString.length; rate++) {
      for (int i = 0; i < link1.pktArrivedCorrect[rate].size(); i++) {
        pktArrivedCorrect[0][rate][i] = ((Double)link1.pktArrivedCorrect[rate].get(i)).doubleValue();
        pktArrived[0][rate][i] = ((Double)link1.pktArrived[rate].get(i)).doubleValue();
      }
      for (int i = 0; i < link2.pktArrivedCorrect[rate].size(); i++) {
        pktArrivedCorrect[1][rate][i] = ((Double)link2.pktArrivedCorrect[rate].get(i)).doubleValue();
        pktArrived[1][rate][i] = ((Double)link2.pktArrived[rate].get(i)).doubleValue();
      }
    }

    // save packet loss time correlation between different nodes (b/g separated, arrived vs. correct)
    XplotSerializer seri = new XplotSerializer(fileName
        + " time-COR pkt loss", "log2(sample period)", "COR");
    for (int rate = 0; rate < rateString.length; rate++) {
      double[] pktLossTimeCOR =
        Statistics.timeCorrelation(pktArrivedCorrect[0][rate], pktArrivedCorrect[1][rate]);
      Line timeCOR = new Line("time-COR pkt loss correct-" + rateString[rate]);
      for (int i = 0; i < pktLossTimeCOR.length; i++) {
        timeCOR.add(i, pktLossTimeCOR[i]);
      }
//      // TODO remove test code -- begin
//      System.out.println();
//      for (int i = 0; i < pktArrived[0][rate].length; i++)
//        System.out.print(pktArrived[0][rate][i] == .0 ? "0" : "1");
//      System.out.println();
//      for (int i = 0; i < pktArrivedCorrect[0][rate].length; i++)
//        System.out.print(pktArrivedCorrect[0][rate][i] == .0 ? "0" : "1");
//      System.out.println();
//      for (int i = 0; i < pktArrived[1][rate].length; i++)
//        System.out.print(pktArrived[1][rate][i] == .0 ? "0" : "1");
//      System.out.println();
//      for (int i = 0; i < pktArrivedCorrect[1][rate].length; i++)
//        System.out.print(pktArrivedCorrect[1][rate][i] == .0 ? "0" : "1");
//      System.out.println();
//      System.out.println();
//      // TODO remove test code -- end

      seri.addLine(timeCOR, rateColor[rate]);
      if (rate < 4) // 802.11b rates
        seri.addPoints(timeCOR, rateColor[rate], "x");

      pktLossTimeCOR =
        Statistics.timeCorrelation(pktArrived[0][rate], pktArrived[1][rate]);
      timeCOR = new Line("time-COR pkt loss arrived-" + rateString[rate]);
      for (int i = 0; i < pktLossTimeCOR.length; i++) {
        timeCOR.add(i, pktLossTimeCOR[i]);
      }

      seri.addLine(timeCOR, rateColor[rate]);
      if (rate < 4) // 802.11b rates
        seri.addPoints(timeCOR, rateColor[rate], "diamond");
      else
        seri.addPoints(timeCOR, rateColor[rate], "box");
    }
    seri.saveToFile(fileName + "-pktTimeCrossCOR.xpl");

    // TODO save packet loss allan cross-correlation (b/g separated, arrived vs. correct)

    // save bit error time correlation
    seri = new XplotSerializer(fileName
        + " time-COR bit errors", "log2(sample period)", "COR");
    for (int rate = 0; rate < rateString.length; rate++) {
      if (0 == link1.pktCount[rate])
        continue;

      Line timeCOR = new Line("time-COR bit errors-" + rateString[rate]);
      for (int i = 0; i < link1.timeCOR[rate].length; i++) {
        timeCOR.add(i, link1.timeCOR[rate][i] / link1.pktCount[rate]);
      }

      seri.addLine(timeCOR, rateColor[rate]);
      if (rate < 4) // 802.11b rates
        seri.addPoints(timeCOR, rateColor[rate], "x");
    }
    seri.saveToFile(fileName + "-bitsTimeCrossCOR.xpl");

    // save packet loss allan deviation
    seri = new XplotSerializer(fileName
        + " allan CoDEV bit errors", "log2(sample period)", "log2(allan CoDEV)");
    for (int rate = 0; rate < rateString.length; rate++) {
      if (0 == link1.pktCount[rate])
        continue;

      Line allanDev = new Line("allan CoDEV bit errors-" + rateString[rate]);
      for (int i = 0; i < link1.allanCoV[rate].length; i++) {
        if (link1.allanCoV[rate][i] <= .0)
          continue;
        double ADEV = Math.sqrt(link1.allanCoV[rate][i] / link1.pktCount[rate]);
        allanDev.add(i + 1, Math.log(ADEV) / Math.log(2));
      }

      seri.addLine(allanDev, rateColor[rate]);
      if (rate < 4) // 802.11b rates
        seri.addPoints(allanDev, rateColor[rate], "x");
    }
    seri.saveToFile(fileName + "-bitsAllanCoDev.xpl");

    // save cross-correlation bit-errors
    seri = new XplotSerializer(fileName + " CoDEV bit errors", "lag",
        "CoDEV bit errors");
    for (int rate = 0; rate < rateString.length; rate++) {
      if (0 == link1.pktCount[rate])
        continue;

      Line coDev = new Line("CoDEV bit errors-" + rateString[rate]);
      for (int i = 0; i < link1.bitsRecvCOR[rate].length; i++) {
        double avgRecvCOR = link1.bitsRecvCOR[rate][i] / link1.pktCount[rate];
        coDev.add(i, avgRecvCOR);

        // // TODO test remove
        // double avgAutoCOR = link1.bitsAutoCOR[rate][i] /
        // link1.pktCount[rate];
        // if (Math.abs(avgRecvCOR - avgAutoCOR) > 0.001)
        // System.out.println("Error at " + i);
      }

      seri.addLine(coDev, rateColor[rate]);
      if (rate < 4) // 802.11b rates
        seri.addPoints(coDev, rateColor[rate], "x");
    }
    seri.saveToFile(fileName + "-bitsCoDev.xpl");

    // StringBuilder b = new StringBuilder("node1\tnode2\trate\tcorr\n");
    // for (int i = 0; i < rateString.length; i++) {
    // b.append(link1.fileName); b.append("\t");
    // b.append(link2.fileName); b.append("\t");
    // b.append(rateString[i]); b.append("\t");
    // if (0 != link1.bitsCrossRecvCorr[i][0])
    // b.append(link1.bitsCrossRecvCorr[i][1] / link1.bitsCrossRecvCorr[i][0]);
    // else
    // b.append("-");
    // b.append("\n");
    // }
    // System.out.println(b.toString());

    // int max_size = global.pktArrived[0].size();
    // for (int i = 1; i < global.pktArrived.length; i++)
    // max_size = Math.max(max_size, global.pktArrived[i].size());
    // double pktArrivedCorrect[][] = new double[rateString.length][max_size];
    // double pktArrived[][] = new double[rateString.length][max_size];
    //
    // for (int rate = 0; rate < rateString.length; rate++) {
    // if (0 == global.pktArrivedCorrect[rate].size())
    // continue;
    //
    // for (int i = 0; i < global.pktArrivedCorrect[rate].size(); i++) {
    // pktArrivedCorrect[rate][i] =
    // ((Double)global.pktArrivedCorrect[rate].get(i)).doubleValue();
    // pktArrived[rate][i] =
    // ((Double)global.pktArrived[rate].get(i)).doubleValue();
    // }
    // }
    //
    // // save packet loss allan deviation
    // seri = new XplotSerializer(fileName + " allan dev packet loss",
    // "log2(sample period)", "log2(allan dev)");
    // for (int rate = 0; rate < rateString.length; rate++) {
    // if (0 == global.pktArrivedCorrect[rate].size())
    // continue;
    //
    // double[][] ADEV = Statistics.allanDeviation(pktArrivedCorrect[rate], 1);
    // Line allanDev = new Line("allan dev correct-" + rateString[rate]);
    // for (int i = 0; i < ADEV[1].length; i++) {
    // if (ADEV[1][i] <= .0)
    // continue;
    // allanDev.add(i+1, Math.log(ADEV[1][i]) / Math.log(2));
    // }
    //
    // seri.addLine(allanDev, rateColor[rate]);
    // if (rate < 4) // 802.11b rates
    // seri.addPoints(allanDev, rateColor[rate], "x");
    //
    // ADEV = Statistics.allanDeviation(pktArrived[rate], 1);
    // allanDev = new Line("allan dev-arrived" + rateString[rate]);
    // for (int i = 0; i < ADEV[1].length; i++) {
    // if (ADEV[1][i] <= .0)
    // continue;
    // allanDev.add(i+1, Math.log(ADEV[1][i]) / Math.log(2));
    // }
    //
    // seri.addLine(allanDev, rateColor[rate]);
    // if (rate < 4) // 802.11b rates
    // seri.addPoints(allanDev, rateColor[rate], "diamond");
    // else
    // seri.addPoints(allanDev, rateColor[rate], "box");
    // }
    // seri.saveToFile(fileName + "-pktAllanDev.xpl");
    //
    // // Correlate the arrived/correct packets againt each other considering
    // the rates
    // for (int idx1 = 0; idx1 < rateString.length; idx1++) {
    // for (int logAggLevel = 1; logAggLevel < pktArrived[idx1].length/10;
    // logAggLevel *= 2) {
    //
    // // TODO correlate against other receiver
    // double corrArrived = Statistics.correlationCoefficient(pktArrived[idx1],
    // pktArrived[idx2]);
    // double corrCorrect =
    // Statistics.correlationCoefficient(pktArrivedCorrect[idx1],
    // pktArrivedCorrect[idx2]);
    //
    // for (int idx2 = idx1; idx2 < rateString.length; idx2++) {
    // // TODO correlate against higher bit-rates at current receiver
    // }
    // }
    // }
  }

  /**
   * Main function
   *
   * @param args
   * @throws IOException
   */
  public static void main(String[] args) throws IOException {
    PatternLayout layout = new PatternLayout("%d{ISO8601} %-5p [%t] %c: %m%n");
    ConsoleAppender consoleAppender = new ConsoleAppender(layout);
    log.addAppender(consoleAppender);
    log.setLevel(Level.ALL);

    BitErrorCrossCorrelator p = new BitErrorCrossCorrelator(args[0]);
    p.analyze(args[1], args[2]);
  }

}
