package brn.swans.radio;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import jist.runtime.Main;
import jist.swans.Constants;
import jist.swans.misc.Util;

import org.apache.commons.math.MathException;
import org.apache.commons.math.special.Erf;

import brn.sim.data.Line;
import brn.sim.data.XplotSerializer;

/**
 * TODO as in yans and ns-3, the BER for 9Mbps is higher compared to 12Mbps
 * (for all SNRs). what is wrong?
 * It seems not: D. Qiao, S. Choi, and K.G. Shin. 'Goodput analysis and link
 * adaptation for IEEE 802. 11 a wireless LANs.' IEEE Transactions on Mobile
 * Computing,
 *
 * TODO do we also have this bug? http://www.dei.unipd.it/wdyn/?IDsezione=5519
 */
public abstract class TransmissionMode {

  protected class SuccessRateCache {

    protected static final boolean PLOT = false;

    private double[][] chunkSuccessRates = new double[2500][];

    /** track first 1000 bits in detail, then only groups a 1 bits */
    private int bit2idx(int nbits) {
      // cutoff at 10 bits, because QAM with high redundancy is unaccurate at low SNR
      //return Math.max(nbits <= 1000 ? (int)nbits : (int)(nbits/10 + 901), 10);
      return nbits <= 1000 ? (int)nbits : (int)(nbits/10 + 901);
    }

    private int idx2bit(int idx) {
      return (idx <= 1000 ? idx : (idx-901) * 10);
    }

    /** from -14 dBm to ... in .1 dB steps
     * note that we asserted that every idx below 0 is not useful
     */
    private int snr2idx(double snr) {
      return (int)Math.round(Math.max( (Util.toDB(snr) + 14.) * 10., .0));
    }

    private double idx2snr(int idx) {
      return Util.fromDB(idx / 10. - 14.);
    }

    public double getChunkSuccessRate (double snr, int nbits) {
      int bitIdx = bit2idx(nbits);
      int snirIdx = snr2idx(snr);

      try {
        if (snirIdx >= chunkSuccessRates[bitIdx].length)
          return 1.;

        return chunkSuccessRates[bitIdx][snirIdx];
      }
      catch (ArrayIndexOutOfBoundsException e) {
        // check for array expansion
        if (Main.ASSERT)
          Util.assertion (bitIdx >= chunkSuccessRates.length);

        int newSize = bitIdx + 25;
        double[][] newlist = new double[newSize][];
        System.arraycopy(chunkSuccessRates, 0, newlist, 0, chunkSuccessRates.length);
        chunkSuccessRates = newlist;

        return getChunkSuccessRate (snr, nbits);
      }
      catch (NullPointerException e) {
        if (Main.ASSERT)
          Util.assertion(nbits - idx2bit(bitIdx) <= 10);
        if (0 >= nbits) {
          chunkSuccessRates[bitIdx] = new double[0];
          return getChunkSuccessRate (snr, nbits);
        }
        // calculate on demand...
        double csr = .0;
        Line line = new Line("SNR-PER curve");
        double csrThreshold = .999;
        // calculate BER with 1e-6 accuracy!
        if (nbits <= 1)
          csrThreshold = .999999;
        for (int idx = 0; csr < csrThreshold; idx += 1) {
          double snr_tmp = idx2snr(idx);
          if (Main.ASSERT)
            Util.assertion(snr2idx(snr_tmp) == idx);
          csr = TransmissionMode.this.getChunkSuccessRate(snr_tmp, (int)nbits);
          // x values are of interest only if we plot the whole thing...
          line.add( (PLOT ? Util.toDB(snr_tmp) : snr_tmp), csr);
        }

        chunkSuccessRates[bitIdx] = line.getY();

        // QAM produces relative low bers at verly low snr, so skip first 10 bits
        if (Main.ASSERT)
          Util.assertion(bitIdx < 10 || chunkSuccessRates[bitIdx][0] < 0.01);
        if (Main.ASSERT)
          Util.assertion(chunkSuccessRates[bitIdx][chunkSuccessRates[bitIdx].length-1] > 0.99);
        if (Main.ASSERT)
          Util.assertion(chunkSuccessRates[bitIdx].length < 1000);

        if (PLOT) {
//          System.out.println("mode=" + TransmissionMode.this
//              + "   bits=" + nbits + "   " + chunkSuccessRates[bitIdx].length);
          try {
            XplotSerializer seri = new XplotSerializer(line.getTitle(), "SNR (db)", "PSR");
            seri.add(line, "yellow");
            seri.saveToFile("PER-" + TransmissionMode.this.getDataRate() + "-"
                + nbits + ".xpl");
          } catch (IOException e1) {
            e1.printStackTrace();
          }
        }
        return getChunkSuccessRate (snr, nbits);
      }
    }
  }

  private static Map modes;

  /** reference payload size for adaption of curve to a receiver */
  private static int rxLossTxBytes;
  /** reference PER level for adaption of curve to a receiver */
  private static double rxLossPer;
  /** SNR [dB] thresholds for adaption of curve to a receiver */
  private static Map /*bitrate -> threashold*/ rxLossThresholds_dB = null;
  /** offset  to match the current curve with the theoretical one (linear scale) */
  protected double rxLossOffset = 1.;

  private SuccessRateCache cache = new SuccessRateCache();


  private void setRxLossOffset(double rxLossOffset) {
    this.rxLossOffset = rxLossOffset;
  }

  /**
   * Adjust curves to match the given thresholds at the given PER when
   * transmitting a frame of the given size. Basically, the curve is shifted
   * along the x-axis until it matches the point (threshold,per).
   *
   * @param txBytes
   * @param per
   * @param thresholds
   */
  public static void setReceiverLoss(int txBytes, double per, Map thresholds) {
    // calculation is done during creation of a mode, so set in advance!
    if (Main.ASSERT)
      Util.assertion(null == modes || rxLossThresholds_dB.equals(thresholds));

    TransmissionMode.rxLossTxBytes = txBytes;
    TransmissionMode.rxLossPer = per;
    TransmissionMode.rxLossThresholds_dB = thresholds;
  }

  /**
   * Access singleton instance representing an IEEE 802.11ag bitrate
   */
  public static TransmissionMode get80211gMode(int bandwidth, int bitrate) {
    if (modes == null) {
      Util.assertion(bandwidth == 20e6);

      modes = new HashMap();
      modes.put(Integer.valueOf(Constants.BANDWIDTH_6Mbps), new TransmissionMode.FecBpskMode(
          bandwidth, Constants.BANDWIDTH_6Mbps, Constants.BANDWIDTH_12Mbps, 0.5, 10, 11));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_9Mbps), new TransmissionMode.FecBpskMode(
          bandwidth, Constants.BANDWIDTH_9Mbps, Constants.BANDWIDTH_12Mbps, 0.75,  5, 8));

      modes.put(Integer.valueOf(Constants.BANDWIDTH_12Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_12Mbps, Constants.BANDWIDTH_24Mbps, 0.5,   4, 10, 11, 0));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_18Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_18Mbps, Constants.BANDWIDTH_24Mbps, 0.75,  4, 5, 8, 31));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_24Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_24Mbps, Constants.BANDWIDTH_48Mbps, 0.5,   16, 10, 11, 0));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_36Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_36Mbps, Constants.BANDWIDTH_48Mbps, 0.75,  16, 5, 8, 31));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_48Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_48Mbps, Constants.BANDWIDTH_72Mbps, 0.666, 64, 6, 1, 16));
      modes.put(Integer.valueOf(Constants.BANDWIDTH_54Mbps), new TransmissionMode.FecQamMode(
          bandwidth, Constants.BANDWIDTH_54Mbps, Constants.BANDWIDTH_72Mbps, 0.75,  64, 5, 8, 31));

      Iterator iter = modes.values().iterator();
      while (null != rxLossThresholds_dB && null != iter && iter.hasNext()) {
        TransmissionMode mode = (TransmissionMode) iter.next();

        if (!rxLossThresholds_dB.containsKey(mode.getDataRate())) {
          iter.remove();
          continue;
        }

        double rxLossThreshold_dB = (Double) rxLossThresholds_dB.get(mode.getDataRate());
        double requiredSNR = mode.getChunkRequiredSNR(1.-rxLossPer, rxLossTxBytes*8);
        mode.setRxLossOffset(Util.fromDB(rxLossThreshold_dB - requiredSNR));
      }
    }
    return (TransmissionMode) modes.get(new Integer(bitrate));
  }

  /**
   * @return the number of Hz used by this signal
   */
  abstract public int getBandwidth();

  /**
   * @return the number of user bits per
   * second achieved by this transmission mode
   */
  abstract public int getDataRate();

  /**
   * @return the number of raw bits per
   * second achieved by this transmission mode.
   * the value returned by getRate and getDataRate
   * will be different only if there is some
   * FEC overhead.
   */
  abstract public int getPhyRate();

  /**
   * calculate success rate for a chunk of nbit _PHY_ bits with given SNR
   * @param snr the snr, (W/W)
   * @param nbits length of transmission (_phy_ bits)
   *
   * @return the probability that nbits be successfully transmitted.
   */
  abstract public double getChunkSuccessRate (double snr, int nbits);

  /**
   * Calculates the required SNR to meet the target chunk success rate.
   *
   * @param csr the target chunk success rate
   * @param nbits length of transmission (_phy_ bits)
   * @return the min SNR in dB scale (not linear)
   */
  private double getChunkRequiredSNR(double csr, int nbits) {
    double tol = .01;
    double snr_l = -20; //dB
    double snr_h = 40; //dB
    double currCsr = -1;

    // binary search
    while (true) {
      double snr = (snr_l + snr_h) / 2.;
      currCsr = getChunkSuccessRate(Util.fromDB(snr), nbits);

      if (Math.abs(currCsr - csr) < tol)
        return snr;
      if (Main.ASSERT)
        Util.assertion(snr_l < snr_h - tol || snr_h > snr_l + tol);

      if (currCsr < csr) {
        snr_l = snr;
      } else {
        snr_h = snr;
      }
    }
  }

  /**
   * The error rate in _data_ bits
   * @param snr
     * @return bit error rate of a data bit
   */
  public final double getBitErrorRate(double snr) {
    // note: getChunkSuccessRate works with _data_ bits, so that's it:
    return 1 - cache.getChunkSuccessRate(snr, 1);
  }

  public double getCachedSuccessRate(double snr, int nbits) {
    return cache.getChunkSuccessRate(snr, nbits);
  }

  /**
   * calculate gross number of bits transferred in specified time
   * don't account for PHY pad, signal extension time or odd symbols
   * @param delay
   * @return gross number of bits
   */
  public int getNumDataBits(long delay) {
    int rate = getDataRate();
    return (int)(rate * delay / Constants.SECOND);
  }

  /**
   * NoFEC transmissions.
   */
  abstract public static class NoFecTransmissionMode extends TransmissionMode {

    private int bandwidth;
    private int dataRate;

    public NoFecTransmissionMode (int bandwidth, int dataRate) {
      this.bandwidth = bandwidth;
      this.dataRate = dataRate;
    }

    protected double getBpskBer(double snr) {
      double EbNo = snr * bandwidth / dataRate;

      // if EbNo > 30 (15dBm), we get numerical problems...
      if (EbNo > 30)
        return 0;

      double z = Math.sqrt(EbNo);
      double ber = 0;
      try {
        ber = 0.5 * (1 - Erf.erf(z)); // old: ber = 0.5 * erfc(z)
      } catch (MathException e) {
        throw new RuntimeException(e.getMessage(), e);
      }
      return ber;
    }

    protected double getQamBer(double snr, int m) {
      double EbNo = snr * bandwidth / dataRate;
      double z = Math.sqrt ((1.5 * Util.log2 (m) * EbNo) / (m - 1.0));
      double z1 = 0 ;

      // if EbNo > 30, we get numerical problems...
      if (z > 5.5)
        return 0;

      try {
        z1 = ((1.0 - 1.0 / Math.sqrt (m)) * (1 - Erf.erf(z)));
      } catch (MathException e) {
        throw new RuntimeException(e.getMessage(), e);
      }
      double z2 = 1 - Math.pow ((1-z1), 2.0);
      double ber = z2 / Util.log2 (m);
      return ber;
    }

//    public double getSignalSpread() {
//      return bandwidth;
//    }

    abstract protected double getModulationBer(double snr);

    public int getBandwidth() {
      return this.bandwidth;
    }

    public int getDataRate() {
      return this.dataRate;
    }

    public int getPhyRate() {
      return this.dataRate;
    }

    public double getChunkSuccessRate (double snr, int nbits)
    {
      // Adjust SNR to match theoretical curve
      snr = snr / this.rxLossOffset;
      // Adjust data to phy bits TODO is this right?
//      nbits = (int)((long)nbits * (long)getPhyRate() / (long)getDataRate());

      double ber = getModulationBer(snr);
      double csr = Math.pow(1. - ber, nbits);
      return csr;
    }

  }

  public static abstract class FecTransmissionMode extends NoFecTransmissionMode {

    protected double codingRate;
    protected int phyRate;
    protected int dFree;
    protected int adFree;

    public FecTransmissionMode(int bandwidth, int dataRate, int phyRate,
        double coding_rate, int d_free, int ad_free) {
      super(bandwidth, dataRate);
      codingRate = coding_rate;
      this.phyRate = phyRate;
      dFree = d_free;
      adFree = ad_free;
      if (Main.ASSERT)
        Util.assertion(phyRate * .01 > Math.abs(this.getDataRate() - codingRate * phyRate));
    }

    public int getPhyRate() {
      return this.phyRate;
    }

    /*
     * (non-Javadoc)
     * @see brn.swans.radio.TransmissionMode.NoFecTransmissionMode#getChunkSuccessRate(double, int)
     */
    public double getChunkSuccessRate(double snr, int nbits) {
      // Adjust SNR to match theoretical curve
      snr = snr / this.rxLossOffset;
      // Adjust data to phy bits TODO is this right?
      double phybits = ((double)nbits / codingRate);

      double ber = getPhyBer (snr);
      if (ber == 0.0) {
        return 1.0;
      }

      double pms = Math.pow(1.0 - ber, phybits);
      return pms;
    }

//    /*
//     * (non-Javadoc)
//     * @see brn.swans.radio.TransmissionMode#getBitErrorRate(double)
//     */
//    @Override
//    public double getBitErrorRate(double snr) {
//      double phyBer = super.getBitErrorRate(snr);
//      if (phyBer <= .0)
//        return phyBer;
//      // translate from PHY rate to data rate by codingRate
//      return 1.0 - Math.pow(1.0 - phyBer, 1.0/codingRate);
//    }

    /**
     * obvious this rate is the rate of _fatal_ bit errors
     * on PHY level - the FEC has already been considered
     * @param snr
     * @return bit error rate of a physical bit (fec or data)
     */
    protected abstract double getPhyBer(double snr);

    protected double calculatePdOdd(double ber, int d) {
      assert ((d % 2) == 1);
      int dstart = (d + 1) / 2;
      int dend = d;
      double pd = 0;

      for (int i = dstart; i < dend; i++) {
        pd += binomial (i, ber, d);
      }
      return pd;
    }

    protected double calculatePdEven(double ber, int d) {
      assert ((d % 2) == 0);
      int dstart = d / 2 + 1;
      int dend = d;
      double pd = 0;

      for (int i = dstart; i < dend; i++){
        pd +=  binomial (i, ber, d);
      }
      pd += 0.5 * binomial (d / 2, ber, d);

      return pd;
    }

    protected double calculatePd(double ber, int d) {
      double pd;
      if ((d % 2) == 0) {
        pd = calculatePdEven (ber, d);
      } else {
        pd = calculatePdOdd (ber, d);
      }
      return pd;
    }

    private int factorial(int k) {
      int fact = 1;
      while (k > 0) {
        fact *= k;
        k--;
      }
      return fact;
    }

    private double binomial(int k, double p, int n) {
      double retval = factorial (n) / (factorial (k) * factorial (n-k)) 
        * Math.pow(p, k) * Math.pow(1-p, n-k);
      return retval;
    }
  }

  public static class NoFecBpskMode extends NoFecTransmissionMode {
    public NoFecBpskMode (int bandwidth, int dataRate, int phyRate) {
      super(bandwidth, dataRate);
    }

    @Override
    protected double getModulationBer(double snr) {
      return getBpskBer(snr);
    }
  }

  public static class FecBpskMode extends FecTransmissionMode {

    public FecBpskMode (int bandwidth, int dataRate, int phyRate,
        double coding_rate, int d_free, int ad_free) {
      super(bandwidth, dataRate, phyRate, coding_rate, d_free, ad_free);
    }

    @Override
    protected double getModulationBer(double snr) {
      return getBpskBer(snr);
    }

    protected double getPhyBer(double snr) {
      double ber = getModulationBer (snr);
      //printf ("%g\n", ber);
      if (ber <= .0) {
        return .0;
      }
      /* only the first term */
      //printf ("dfree: %d, adfree: %d\n", dFree_, adFree_);
      double pd = calculatePd(ber, dFree);
      double pmu = Math.min(1., adFree * pd);
      return pmu;
    }

    public String toString() {
      return "BPSK("+codingRate+")";
    }
  }

  public static class NoFecQamMode extends NoFecTransmissionMode {
    private int m;
    public NoFecQamMode (int bandwidth, int dataRate, int phyRate, int M) {
      super(bandwidth, dataRate);
      m = M;
    }

    @Override
    protected double getModulationBer(double snr) {
      return getQamBer(snr, m);
    }

  }

  public static class FecQamMode extends FecTransmissionMode {
    private int m;

    private int adFreePlusOne;

    public FecQamMode (int bandwidth, int dataRate, int phyRate,
        double coding_rate, int M, int d_free, int ad_free, int ad_free_plus_one) {
      super(bandwidth, dataRate, phyRate, coding_rate, d_free, ad_free);
      m = M;
      adFreePlusOne = ad_free_plus_one;
    }

    @Override
    protected double getModulationBer(double snr) {
      return getQamBer(snr, m);
    }

    protected double getPhyBer(double snr) {
      double ber = getModulationBer (snr);
      if (ber <= 0.0) {
        return .0;
      }
      /* first term */
      double pd = calculatePd (ber, dFree);
      double pmu = adFree * pd;
      // second term
      pd = calculatePd (ber, dFree + 1);
      pmu += adFreePlusOne * pd;
      return Math.min(1., pmu);
    }

    public String toString() {
      return "QAM("+m+"/"+codingRate+")";
    }
  }

}
