package brn.analysis.capacity;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import brn.util.NavigableSet;
import java.util.Set;
import brn.util.TreeSet;

import jist.swans.Constants;
import jist.swans.misc.Util;

import org.apache.commons.lang.ArrayUtils;

import brn.swans.radio.TransmissionMode;

/**
 * Java impl. of the capacity approximation
 * 
 * TODO exclude from 1.5 build
 *
 * @author Kurth
 */
public class CapacityApprox {

  public static final double csThreshold_mW = Util.fromDB(-91.44 -.15);
  public static final double noiseFloor_mW = Util.fromDB(-92.965);
  public static final double factor = 4;

  private ConflictGraph cT;
  private ConflictGraph cR;
  private double[] q;
  private Double[] condQ;
  private int src;

  public ConflictGraph calcTxConflictGraph(double[][] signalPower_mW) {
    return this.calcTxConflictGraph(signalPower_mW, csThreshold_mW);
  }

  public ConflictGraph calcRxConflictGraph(int src, double[][] signalPower_mW) {
    return this.calcRxConflictGraph(src, signalPower_mW, noiseFloor_mW, factor);
  }

  private ConflictGraph calcTxConflictGraph(double[][] signalPower_mW, double csThreshold_mW) {
    int nodes = signalPower_mW.length;
    int[][] adjacencyMatrix = new int[nodes][nodes];
    
    for (int src=0; src < nodes; src++) {
      for (int dst=0; dst < nodes; dst++) {
        // TODO use preamble rx prob
        if (signalPower_mW[src][dst] >= csThreshold_mW)
          adjacencyMatrix[src][dst] = 1;
      }
    }
    
    return new ConflictGraph(adjacencyMatrix);
  }

  private ConflictGraph calcRxConflictGraph(int src, double[][] signalPower_mW, 
      double noiseFloor_mW, double factor) {
    int nodes = signalPower_mW.length;
    int[][] adjacencyMatrix = new int[nodes][nodes];
    Map thresholds = new HashMap();
    thresholds.put(Constants.BANDWIDTH_6Mbps, 5.4);
    TransmissionMode.setReceiverLoss(1000, .1, thresholds);
    TransmissionMode mode = TransmissionMode.get80211gMode(
        (int)Constants.CHANNEL_WIDTH_80211g, Constants.BANDWIDTH_6Mbps);
    int nbits = 1524 * 8;
    //double pdrThreshold = .5;

    for (int dst=0; dst < nodes; dst++) {
      double snr_mW = signalPower_mW[src][dst] / noiseFloor_mW;
      double pdrThreshold = mode.getChunkSuccessRate(snr_mW, nbits) / factor;
      if (pdrThreshold <= .0)
        continue;

      for (int other=0; other < nodes; other++) {
        if (other == dst)
          continue;
        snr_mW = signalPower_mW[src][dst] / 
          (noiseFloor_mW + signalPower_mW[other][dst]);
        
        double pdr = mode.getChunkSuccessRate(snr_mW, nbits); 
        if (pdr < pdrThreshold)
          adjacencyMatrix[dst][other] = 1;
      }
    }

    return new ConflictGraph(adjacencyMatrix);
  }
  
  public double calcPacketSuccessProb(int src, int[] dst,
      double[][] signalPower_mW, double[] q, double[][] p) {
    ConflictGraph cT = calcTxConflictGraph(signalPower_mW, csThreshold_mW);
    ConflictGraph cR = calcRxConflictGraph(src, signalPower_mW, noiseFloor_mW, factor);
    
    return calcPacketSuccessProb(src, dst, cT, cR, q, p);
  }

  /**
   * Calculate the packet success probability for SRC -> DST.
   * aka Prob4
   *
   * @param src the sender
   * @param dst the receiver(s)
   * @param cT transmitter (carrier sense) conflict graph
   * @param cR receiver conflict graph (SRC specific)
   * @param q normalized input rates
   * @param p (channel) packet success probability
   * @return
   */
  public double calcPacketSuccessProb(int src, int[] dst,
      ConflictGraph cT, ConflictGraph cR,
      double[] q, double[][] p) {
    this.cT = cT;
    this.cR = cR;
    this.q  = q;
    this.condQ = new Double[q.length];
    this.src = src;
    
    NavigableSet<Integer> lstJ = new TreeSet<Integer>(Arrays.asList(ArrayUtils.toObject(dst)));
    NavigableSet<Integer> lstC = new TreeSet<Integer>();

//    try {
      return prob4(lstJ,lstC, p);
//    } catch (Throwable e) {
//      //System.err.println(e.getMessage());
//      return 0;
//    }
  }
  
  public double[] getCondQ(int src, ConflictGraph cT, double[] q) {
    this.cT = cT;
    this.q  = q;
    this.condQ = new Double[q.length];
    this.src = src;

    // get neighbors (carrier sensing)
    NavigableSet<Integer> stNeighbors = cT.getDepartures(src);
    for (int i = 0; i < q.length; i++)
      this.getCondQ(i, stNeighbors);
    double[] ret = new double[q.length];
    for (int i = 0; i < condQ.length; i++)
      ret[i] = this.condQ[i];
    return ret;
  }

  /**
   * aka prob4
   *
   * @param lstJ
   * @param lstC
   * @param p 
   * @return
   */
  private double prob4(NavigableSet<Integer> lstJ, NavigableSet<Integer> lstK, double[][] p) {
    // generate new K+
    NavigableSet<Integer> lstKPlus = new TreeSet<Integer>(lstK);
    lstKPlus.add(lstJ.first());

    // base case: only one left
    if (lstJ.size() == 1)
      return p[src][lstJ.first()] * qFunc(lstKPlus) * gFunc(lstKPlus);

    // generate J- 
    NavigableSet<Integer> lstJMinus =  lstJ.tailSet(lstJ.first(), false);
    
    // Iterate over J and K
    return prob4(lstJMinus,lstK,p) 
      + p[src][lstJ.first()]
        * (qFunc(lstKPlus) * gFunc(lstKPlus) - prob4(lstJMinus,lstKPlus,p));
  }

  /**
   * aka g
   * @param lstJ
   * @return
   */
  private double gFunc(NavigableSet<Integer> lstJ) {
    // get neighbors of neighbors
    Set<Integer> stHidden = cR.getDepartures(lstJ);
    // remore direct neighbors
    stHidden.removeAll(cT.getDepartures(src));
    // remore sender -> hidden nodes remain
    stHidden.remove(src);

    // medium release rate (=1/tx time), not used in calculation
    double mu = 470;

    // determine aggregated medium access rate of hidden nodes
    double lambda = .0;
    for (Integer node : stHidden) {
      lambda += this.q[node] * mu;
    }

    // assume exponential inter-arrival at hidden nodes
    return Math.exp( - lambda / mu );
  }

  /**
   * aka qq
   * @param lstJ
   * @return
   */
  private double qFunc(NavigableSet<Integer> lstJ) {
    // get neighbors of neighbors
    NavigableSet<Integer> stHidden = cR.getDepartures(lstJ);
    // get neighbors (carrier sensing)
    NavigableSet<Integer> stNeighbors = cT.getDepartures(src);
    // remore direct neighbors
    stHidden.removeAll(stNeighbors);
    // remore sender -> hidden nodes remain
    stHidden.remove(src);

    return 1-hidden3(stHidden,stNeighbors);
  }

  /**
   * Determine a priori transmission probabilty for hidden nodes, under the
   * constraint that the neighbors remain silent.
   * aka hidden3
   *
   * @param stHidden hidden nodes
   * @param stSilent neighbors (carrier sense)
   * @return
   */
  private double hidden3(NavigableSet<Integer> stHidden, NavigableSet<Integer> stSilent) {
    return hidden3a(stHidden, stSilent) 
        + hidden3b(stHidden, stSilent) 
        + hidden3d(stHidden, stSilent);
  }

  /**
   * aka hidden3a
   *
   * @param stHidden
   * @param stSilent
   * @return
   */
  private double hidden3a(NavigableSet<Integer> stHidden, Set<Integer> stSilent) {
    double ret = 1.;
    for (Integer node : stHidden) {
      ret *= 1-getCondQ(node,stSilent);
    }
    return 1 - ret;
  }

  /**
   * aka hidden3b
   *
   * @param stHidden
   * @param stSilent
   * @return
   */
  private double hidden3b(NavigableSet<Integer> stHidden, Set<Integer> stSilent) {
    double ret = .0;
    for (Integer node : stHidden) {
      // n2(J[1],CT) minus S
      NavigableSet<Integer> stNeighbors = cT.getDepartures(node);
      stNeighbors.removeAll(stSilent);
      
      ret += getCondQ(node,stSilent) / 2. * hidden3a(stNeighbors,stSilent)
        + hidden3c(stNeighbors,stSilent);
    }
    return ret;
  }

  /**
   * aka hidden3c
   *
   * @param stNeighbors
   * @param stNeighbors2
   * @param stSilent
   * @return
   */
  private double hidden3c(NavigableSet<Integer> stNeighbors,
//      NavigableSet<Integer> stNeighbors2, 
      Set<Integer> stSilent) {
    double ret = .0;
    for (Integer neighbor : stNeighbors)  {
      NavigableSet<Integer> stNeighborsOf1 = cT.getDepartures(neighbor);
      
      for (Integer neighbor2 : stNeighbors.tailSet(neighbor, false)) {
        
        if (neighbor == neighbor2
            || cT.isConnected(neighbor, neighbor2))
          continue;
          
        // n2(nn,CT) intersect n2(M[1],CT)
        NavigableSet<Integer> stCommonNeighbors = cT.getDepartures(neighbor2);
        stCommonNeighbors.retainAll(stNeighborsOf1);

        // add(y,y=map(x->q[x], stCommonNeighbors))
        double cumlQ = .0;
        for(Integer node : stCommonNeighbors) {
          cumlQ += q[node];
        }
        
        ret += getCondQ(neighbor,stSilent) * getCondQ(neighbor2,stSilent)
          / (double)stCommonNeighbors.size() * (1. - 1./(1.-cumlQ));
      }
    }
    return ret;
//    // N and M are equal
//    if (stNeighbors1.size() <= 1)
//      return 0;
//
//    // N[2..nops(N)]
//    NavigableSet<Integer> stNewNeighbors1 =
//      stNeighbors1.subSet(stNeighbors1.first(), false, stNeighbors1.last(), true);
//
//    // M[2..nops(M)]
//    NavigableSet<Integer> stNewNeighbors2 =
//      stNeighbors2.subSet(stNeighbors2.first(), false, stNeighbors2.last(), true);
//
//    return  hidden3c(stNewNeighbors1,stNewNeighbors2,stSilent)
//      + hidden3cX(stNeighbors1.first(),stNewNeighbors2,stSilent);
  }

//  /**
//   * aka hidden3cX
//   *
//   * @param first
//   * @param stNeighbors2
//   * @param stSilent
//   * @return
//   */
//  private double hidden3cX(int neighbor, NavigableSet<Integer> stNeighbors2,
//      Set<Integer> stSilent) {
//    double ret = .0;
//    NavigableSet<Integer> stNeighborsOf1 = cT.getDepartures(neighbor);
//    for (Integer neighbor2 : stNeighbors2) {
//      
//      if (neighbor == neighbor2
//          || cT.isConnected(neighbor, neighbor2))
//        continue;
//        
//      // n2(nn,CT) intersect n2(M[1],CT)
//      NavigableSet<Integer> stCommonNeighbors = cT.getDepartures(neighbor2);
//      stCommonNeighbors.retainAll(stNeighborsOf1);
//
//      // add(y,y=map(x->q[x], stCommonNeighbors))
//      double cumlQ = .0;
//      for(Integer node : stCommonNeighbors) {
//        cumlQ += q[node];
//      }
//      
//      ret += q_cond(neighbor,stSilent) * q_cond(neighbor2,stSilent)
//        / (double)stCommonNeighbors.size() * (1. - 1./(1.-cumlQ));
//    }
//    return ret;
//    
//    if (stNeighbors2.size() == 0)
//      return 0;
//
//    // M[2..nops(M)]
//    NavigableSet<Integer> stNewNeighbors2 =
//      stNeighbors2.subSet(stNeighbors2.first(), false, stNeighbors2.last(), true);
//
//    // nn=M[1] or connected(nn,M[1],CT)
//    if (neighbor == stNeighbors2.first()
//        || cT.isConnected(neighbor, stNeighbors2.first()))
//      return hidden3cX(neighbor,stNewNeighbors2,stSilent);
//
//    // n2(nn,CT) intersect n2(M[1],CT)
//    NavigableSet<Integer> stCommonNeighbors = cT.getDepartures(neighbor);
//    stCommonNeighbors.retainAll(cT.getDepartures(stNeighbors2.first()));
//
//    // add(y,y=map(x->q[x], stCommonNeighbors))
//    double cumlQ = .0;
//    for(Integer node : stCommonNeighbors) {
//      cumlQ += this.q[node];
//    }
//
//    return hidden3cX(neighbor,stNewNeighbors2,stSilent)
//      + (1./(double)stCommonNeighbors.size())
//        * q_cond(neighbor,stSilent) * q_cond(stNeighbors2.first(),stSilent)
//        * (1 - 1/(1-cumlQ));
//  }

  /**
   * aka hidden3d
   *
   * @param stHidden
   * @param stSilent
   * @return
   */
  private double hidden3d(
      NavigableSet<Integer> stHidden,
      NavigableSet<Integer> stSilent) {
    double ret = .0;
    for (Integer node : stSilent) {
      // n2(S[1],CT) intersect J
      NavigableSet<Integer> stNewHidden = cT.getDepartures(node);
      stNewHidden.retainAll(stHidden);
      ret += hidden3c(stNewHidden,stSilent);
    }
    return ret;
//    
//    if (stHidden.size() == 0)
//      return .0;
//
//    if (stSilent.size() == 0)
//      return .0;
//
//    // S[2..nops(S)]
//    NavigableSet<Integer> stNewSilent =
//      stSilent.subSet(stSilent.first(), false, stSilent.last(), true);
//
//    // n2(S[1],CT) intersect J
//    NavigableSet<Integer> stNewHidden = cT.getDepartures(stSilent.first());
//    stNewHidden.retainAll(stHidden);
//
//    return hidden3d(stHidden,stNewSilent)
//      + hidden3c(stNewHidden,stSilent);
  }

  /**
   * aka q_cond
   *
   * @param node
   * @param stSilent
   * @return
   */
  private double getCondQ(Integer node, Set<Integer> stSilent) {
    if (null != condQ[node])
      return condQ[node];
    
    // h2(j,CT) intersect S
    NavigableSet<Integer> stHidden = cT.getHidden(node);
    stHidden.retainAll(stSilent);

    // iterate over all hidden nodes
    double xyz = .0;
    for (Integer x : stHidden) {
      // n2(x,CT) intersect n2(j,CT)
      // common neighbors of x and node
      NavigableSet<Integer> stCommonNeighbors = cT.getDepartures(x);
      stCommonNeighbors.retainAll(cT.getDepartures(node));

      double fgh = .0;
      for (Integer commonNeighbor : stCommonNeighbors) {
        fgh += q[commonNeighbor];
      }
      assert(fgh < 1.);
//      if (fgh >= 1.)
//        fgh = 1.-1e-6;
//        throw new RuntimeException("Infeasible Input Rates");
      xyz += q[x] / (1-fgh);
    }

    // silent neighbors and hidden nodes
    // (n2(j,CT) union h2(j,CT)) intersect S
    NavigableSet<Integer> stNeighborAndHidden = cT.getNeighborsAndHidden(node);
    stNeighborAndHidden.retainAll(stSilent);

    double abc = .0;
    for (Integer silent : stNeighborAndHidden) {
      abc += q[silent];
    }

    condQ[node] = q[node] * (1 - xyz) / (1-abc); 
//    if (condQ[node] < 0 || condQ[node] > 1)
//      throw new RuntimeException("Infeasible Input Rates");
      
    return condQ[node];
  }

  
  
  public static void serialize(String name, int src, int[] dst,
      double[][] signalPower_mW, double[] q, double[][] p) throws Exception {
    FileOutputStream s_out = new FileOutputStream(name);
    ObjectOutputStream oout = new ObjectOutputStream (s_out);

    Object[] obj = new Object[] {src, dst, signalPower_mW, q, p }; 
    oout.writeObject (obj);
    s_out.close();
    oout.close();
  }
  
  public static void serialize(String name, int src, int[] dst,
      ConflictGraph cT, ConflictGraph cR,
      double[] q, double[][] p) throws Exception {
    FileOutputStream s_out = new FileOutputStream(name);
    ObjectOutputStream oout = new ObjectOutputStream (s_out);

    Object[] obj = new Object[] {src, dst, cT, cR, q, p }; 
    oout.writeObject (obj);
    s_out.close();
    oout.close();
  }

  public static void main(String[] args) throws Exception {
    FileInputStream s_out = new FileInputStream("call.bin");
    ObjectInputStream in = new ObjectInputStream(s_out);

    Object[] obj = (Object[]) in.readObject();
    in.close();
    s_out.close();
    
    int i = 0;
    int src = (Integer) obj[i++];
    int[] dst = (int[]) obj[i++];
    ConflictGraph cT = (ConflictGraph) obj[i++];
    ConflictGraph cR = (ConflictGraph) obj[i++];
//    double[][] signalPower_mW = (double[][]) obj[i++];
    double[] q = (double[]) obj[i++];
    double[][] p = (double[][]) obj[i++];
    
    CapacityApprox cap = new CapacityApprox();
    double ret = cap.calcPacketSuccessProb(src, dst, cT, cR, q, p);
//    double ret = cap.calcPacketSuccessProb(src, dst, signalPower_mW, q, p);
    System.out.println(ret);
  }
}













