package brn.swans.radio;

import java.io.FileInputStream;
import java.io.ObjectInputStream;

import jist.runtime.JistAPI;
import jist.runtime.Main;
import jist.swans.Constants;
import jist.swans.mac.MacInterface;
import jist.swans.misc.Message;
import jist.swans.misc.MessageAnno;
import jist.swans.misc.Util;
import jist.swans.radio.RadioInterface;

import org.apache.log4j.Logger;

/**
 * Connects to mac entities using an empirical propagation model based on PER 
 * measurement data.
 * 
 * @author kurth
 *
 */
public class RadioEmpirical {

  public static final Logger log = Logger.getLogger(RadioEmpirical.class.getName());

  public static final int RECV_STATUS_OK = Constants.RADIO_RECV_STATUS_OK.byteValue();
  public static final int RECV_STATUS_CRC = Constants.RADIO_RECV_STATUS_CRC.byteValue();
  public static final int RECV_STATUS_PHY = Constants.RADIO_RECV_STATUS_PHY.byteValue();

  //////////////////////////////////////////////////
  // locals
  //
  
  /**
   * Array of rate, status, 0, time 
   *          rate, status, 1, packet arrived(1/0
   */
  private long[][][][] timeRateStatus;

  /**
   * Time interval over which the channel is assumed to be stationary 
   * (in trace time units).
   */
  private long stationarityTime;
  
  /**
   * Time offset (in trace time units) to start reading traces.
   */
  private long startTime;

  /**
   * Determines how many trace time units correspond to a simulation time unit.
   */
  private long timeScaleFactor;

  /**
   * Whether to pass corrupted packets up the stack.
   */
  private boolean receiveCorrupted;
  
  /**
   * Array of moving averages per rate and status.
   */
  private MovingAverage[][] averages;
  
  /**
   * Array of contained radios (i.e. network stacks).
   */
  private Radio[] radio;
  
  //////////////////////////////////////////////////
  // initialize 
  //

  public RadioEmpirical(String traceFileName, long stationarityTime, 
      long startTime, long timeScaleFactor, boolean receiveCorrupted) throws Exception {
    this.stationarityTime = stationarityTime;
    this.timeScaleFactor = timeScaleFactor;
    this.receiveCorrupted = receiveCorrupted;
    this.startTime = startTime;
    
    FileInputStream fin = new FileInputStream(traceFileName);
    ObjectInputStream oin = new ObjectInputStream(fin);
    timeRateStatus = (long[][][][]) oin.readObject();
    
    averages = new MovingAverage[timeRateStatus.length][];
    for (int rate=0; rate < timeRateStatus.length; rate++) {
      averages[rate] = new MovingAverage[timeRateStatus[rate].length];
      for (int status=0; status < timeRateStatus[rate].length; status++) {
        if (0 != status && !receiveCorrupted)
          continue;
        averages[rate][status] = new MovingAverage(timeRateStatus[rate][status][0],
            timeRateStatus[rate][status][1], stationarityTime);
      }
    }

    radio = new Radio[2];
    radio[0] = new Radio(0);
    radio[1] = new Radio(1);
    
    if (receiveCorrupted) {
      radio[0].setUseAnnotations(true);
      radio[1].setUseAnnotations(true);
    }
  }
  
  //////////////////////////////////////////////////
  // entity hookups
  //
  
  public void setMacEntity(MacInterface macA, MacInterface macB) {
    radio[0].setMacEntity(macA);
    radio[1].setMacEntity(macB);
  }
  
  public RadioInterface[] getProxy() {
    RadioInterface[] ret = new RadioInterface[2];
    ret[0] = radio[0].getProxy();
    ret[1] = radio[1].getProxy();
    return (ret);
  }
  
  public long getStationarityTime() {
    return stationarityTime;
  }

  public long[][][][] getTimeRateStatus() {
    return timeRateStatus;
  }


  //////////////////////////////////////////////////
  // implementation
  //
  
  private long simToTraceTime(long simTime) {
    return (timeScaleFactor*simTime + startTime);
  }
  
  public static int getBitrate(int idx) {
    switch (idx) {
    case 0:
      return Constants.BANDWIDTH_1Mbps;
    case 1:
      return Constants.BANDWIDTH_2Mbps;
    case 2:
      return Constants.BANDWIDTH_5_5Mbps;
    case 3:
      return Constants.BANDWIDTH_11Mbps;
    case 4:
      return Constants.BANDWIDTH_6Mbps;
    case 5:
      return Constants.BANDWIDTH_9Mbps;
    case 6:
      return Constants.BANDWIDTH_12Mbps;
    case 7:
      return Constants.BANDWIDTH_18Mbps;
    case 8:
      return Constants.BANDWIDTH_24Mbps;
    case 9:
      return Constants.BANDWIDTH_36Mbps;
    case 10:
      return Constants.BANDWIDTH_48Mbps;
    case 11:
      return Constants.BANDWIDTH_54Mbps;
    }
    throw new RuntimeException("illegal bit-rate");
  }

  public static int getBitrateIdx(int bitrate) {
    switch (bitrate) {
    case Constants.BANDWIDTH_1Mbps:
      return 0;
    case Constants.BANDWIDTH_2Mbps:
      return 1;
    case Constants.BANDWIDTH_5_5Mbps:
      return 2;
    case Constants.BANDWIDTH_11Mbps:
      return 3;
    case Constants.BANDWIDTH_6Mbps:
      return 4;
    case Constants.BANDWIDTH_9Mbps:
      return 5;
    case Constants.BANDWIDTH_12Mbps:
      return 6;
    case Constants.BANDWIDTH_18Mbps:
      return 7;
    case Constants.BANDWIDTH_24Mbps:
      return 8;
    case Constants.BANDWIDTH_36Mbps:
      return 9;
    case Constants.BANDWIDTH_48Mbps:
      return 10;
    case Constants.BANDWIDTH_54Mbps:
      return 11;
    }
    throw new RuntimeException("illegal bit-rate");
  }

  /**
   * Retrieve the probabilities for packet reception for the given simulation 
   * time.
   * 
   * @param simTime
   * @return double[3] with prob. of correct, crc and phy corrupt reception
   */
  public double[] getLinkQuality(long simTime, int bitrate) {
    long traceTime = simToTraceTime(simTime);
    int idxRate = getBitrateIdx(bitrate);
    double[] ret = new double[3];
    
    ret[RECV_STATUS_OK] = averages[idxRate][RECV_STATUS_OK].getAverage(traceTime); 
    if (receiveCorrupted) {
      ret[RECV_STATUS_CRC] = averages[idxRate][RECV_STATUS_CRC].getAverage(traceTime); 
      ret[RECV_STATUS_PHY] = averages[idxRate][RECV_STATUS_PHY].getAverage(traceTime);
    }
    else {
      ret[RECV_STATUS_CRC] = .0; 
      ret[RECV_STATUS_PHY] = .0;
    }
    
    return (ret);
  }

  /**
   * TODO entity call or not?
   * 
   * @param id
   * @param msg
   * @param duration
   * @param rate 
   */
  public void transmit(int id, Message msg, long duration, int rate) {
    radio[(id+1)%2].receive(msg, new Double(rate), new Long(duration));
  }

  //////////////////////////////////////////////////
  // sub-class radio
  //

  protected class Radio implements RadioInterface {
    
    /**
     * Identification of this radio
     */
    protected int id;
    
    /**
     * radio mode: IDLE, SENSING, RECEIVING, SENDING, SLEEP.
     */
    protected byte mode;

    /**
     * message being received.
     */
    protected Message signalBuffer;
    
    /**
     * messsage annotations
     */
    protected MessageAnno signalAnno;

    /**
     * whether to use and generate annotations.
     */
    protected boolean useAnnotations;

    /**
     * end of transmission time.
     */
    protected long signalFinish;
    

    // entity hookup

    /**
     * self-referencing radio entity reference.
     */
    protected RadioInterface self;

    /**
     * mac entity upcall reference.
     */
    protected MacInterface macEntity;

    
    //////////////////////////////////////////////////
    // initialize 
    //

    public Radio(int id) {
      this.id = id;
      this.mode = Constants.RADIO_MODE_IDLE;
      this.signalBuffer = null;
      this.signalAnno = null;
      this.useAnnotations = false;

      this.self = (RadioInterface)JistAPI.proxy(new RadioInterface.Dlg(this), 
          RadioInterface.class);
    }

    //////////////////////////////////////////////////
    // entity hookups
    //
  
    /**
     * Return self-referencing radio entity reference.
     *
     * @return self-referencing radio entity reference
     */
    public RadioInterface getProxy()
    {
      return this.self;
    }
  
    /**
     * Set downcall mac entity reference.
     *
     * @param macEntity downcall mac entity reference
     */
    public void setMacEntity(MacInterface macEntity)
    {
      if(!JistAPI.isEntity(macEntity)) throw new IllegalArgumentException("entity expected");
      this.macEntity = macEntity;
    }
  
    //////////////////////////////////////////////////
    // accessors
    //

    public boolean isUseAnnotations() {
      return useAnnotations;
    }

    public void setUseAnnotations(boolean useAnnotations) {
      this.useAnnotations = useAnnotations;
    }

    /**
     * Set radio mode. Also notifies mac entity.
     *
     * @param mode radio mode
     */
    public void setMode(byte mode)
    {
      if(this.mode!=mode)
      {
        this.mode = mode;
        this.macEntity.setRadioMode(mode);
      }
    }

    //////////////////////////////////////////////////
    // signal acquisition
    //

    /**
     * Lock onto current packet signal.
     *
     * @param msg packet currently on the air
     * @param duration time to EOT (units: simtime)
     * @param status
     * @param bitrate 
     */
    protected void lockSignal(Message msg, long duration, Byte status, int bitrate)
    {
      signalBuffer = msg;
      signalFinish = JistAPI.getTime() + duration;
      signalAnno = null;
      if (useAnnotations) {
        signalAnno = new MessageAnno();
        signalAnno.put(MessageAnno.ANNO_RADIO_RECV_STATUS, status);
        signalAnno.put(MessageAnno.ANNO_MAC_BITRATE, new Integer(bitrate));
      }
      this.macEntity.peek(msg, signalAnno);
    }
    
    /**
     * Unlock from current packet signal.
     */
    protected void unlockSignal()
    {
      signalBuffer = null;
      signalAnno = null;
      signalFinish = -1;
    }

    //////////////////////////////////////////////////
    // reception
    //
  
    public void receive(Message msg, Double powerObj_mW, Long durationObj) {
      // HACK bit-rate is tunneled trough power
      int bitrate = powerObj_mW.intValue();
      
      // get error probability
      double[] prob = RadioEmpirical.this.getLinkQuality(JistAPI.getTime(), bitrate);
      
      // realize random process
      double realization = Constants.random.nextDouble();
      Byte status = Constants.RADIO_RECV_STATUS_LOST; 
      if (realization < prob[RECV_STATUS_OK])
        status = Constants.RADIO_RECV_STATUS_OK;
      else if (realization < prob[RECV_STATUS_OK] + prob[RECV_STATUS_CRC])
        status = Constants.RADIO_RECV_STATUS_CRC;
      else if (realization < prob[RECV_STATUS_OK] + prob[RECV_STATUS_CRC] + 
          prob[RECV_STATUS_PHY])
        status = Constants.RADIO_RECV_STATUS_PHY;
      
      // upcall, if successful
      if (null == status)
        return;
      
      final long duration = durationObj.longValue();
      switch(mode)
      {
        case Constants.RADIO_MODE_IDLE:
        case Constants.RADIO_MODE_SENSING:
          setMode(Constants.RADIO_MODE_RECEIVING);
          lockSignal(msg, duration, status, bitrate);
          break;
        case Constants.RADIO_MODE_TRANSMITTING:
        case Constants.RADIO_MODE_SLEEP:
          break;
        case Constants.RADIO_MODE_RECEIVING:
        default:
          throw new RuntimeException("invalid radio mode: "+mode);
      }
      // schedule an endReceive
      JistAPI.sleep(duration); 
      self.endReceive(null, null, null, null);
    }
  
    public void endReceive(Message msg, Double power, RFChannel channel, Object event) {
      if(mode==Constants.RADIO_MODE_SLEEP) 
        return;
      
      if(mode==Constants.RADIO_MODE_RECEIVING)
      {
        if(signalBuffer!=null && JistAPI.getTime()==signalFinish)
        {
          // signalAnno is created during lockSignal
          this.macEntity.receive(signalBuffer, signalAnno);
          unlockSignal();
        }
      }
      else if (mode==Constants.RADIO_MODE_TRANSMITTING) {
        return;
      }
      if(signalBuffer==null) 
        setMode(Constants.RADIO_MODE_IDLE);
    }
  
    //////////////////////////////////////////////////
    // transmission
    //
    public void transmit(Message msg, long delay, long duration, MessageAnno anno) {
      
    }
    
    public void transmit(Message msg, MessageAnno anno, long predelay,
        long duration, long postdelay) {
      // radio in sleep mode
      if(mode==Constants.RADIO_MODE_SLEEP) return;
      // ensure not currently transmitting
      if(mode==Constants.RADIO_MODE_TRANSMITTING) 
        throw new RuntimeException("radio already transmitting");
      // use default delay, if necessary
      if(predelay==Constants.RADIO_NOUSER_DELAY) predelay = Constants.RX_TX_TURNAROUND__802_11bg;
      // set mode to transmitting
      setMode(Constants.RADIO_MODE_TRANSMITTING);
      // schedule message propagation delay
      JistAPI.sleep(predelay);

      int rate = ((Integer) anno.get(MessageAnno.ANNO_MAC_BITRATE)).intValue();
      RadioEmpirical.this.transmit(id, msg, duration, rate);
      
      // schedule end of transmission
      JistAPI.sleep(duration+postdelay);
      self.endTransmit();
    }
  
    public void endTransmit() {
      // radio in sleep mode
      if(mode==Constants.RADIO_MODE_SLEEP) return;
      // check that we are currently transmitting
      if(mode!=Constants.RADIO_MODE_TRANSMITTING) 
        throw new RuntimeException("radio is not transmitting");
      // set mode
      setMode(Constants.RADIO_MODE_IDLE);
    }
  
    public void setSleepMode(boolean sleep) {
      setMode(sleep ? Constants.RADIO_MODE_SLEEP : Constants.RADIO_MODE_IDLE);
    }

    public double getNoise_mW() throws JistAPI.Continuation {
      return 0;  //To change body of implemented methods use File | Settings | File Templates.
    }

    public void setChannel(RFChannel channel, long delay) {
      //To change body of implemented methods use File | Settings | File Templates.
    }

    public void endSetChannel(RFChannel channel) {
      //To change body of implemented methods use File | Settings | File Templates.
    }

    public RFChannel getChannel() throws JistAPI.Continuation {
      return null;  //To change body of implemented methods use File | Settings | File Templates.
    }

  }

  //////////////////////////////////////////////////
  // moving average
  //

  /**
   * Calculates a floating moving average over a determined time interval.
   */
  protected static class MovingAverage {
    
    /** array of time values */
    private final long[] time;
    
    /** start time stamp */
    private final long startTime;
    
    /** array of values to aggregate */
    private final long[] value;
    
    /** aggregation interval */
    private final long aggregationInt;

    
    /** current sum of values */
    private double sumValues;
    
    /** current number of values in sumValues */
    //private long noValues;
    
    /** start index (incl) for sumValues */
    private int idxStart;
    
    /** end index (excl) for sumValues */
    private int idxEnd;
    
    
    /**
     * Creates an moving average object
     * 
     * @param time array of time values
     * @param value array of values to aggregate
     * @param aggregationInt aggregation interval
     */
    public MovingAverage(long[] time, long[] value, long aggregationInt) {
      this.time = time;
      this.value = value;
      this.startTime = time[0];
      this.aggregationInt = aggregationInt;
      
      if (Main.ASSERT)
        Util.assertion(time.length == value.length);
      
      this.idxStart = 0;
      this.sumValues = 0;
      this.idxEnd = 0;
    }
    
    /**
     * Get the moving average for a particular point in time. Assumption: 
     * Subsequent calls do not have decreasing time values (i.e. time is only
     * advanced). 
     *  
     */
    public double getAverage(long _currTime) {
      // get the current start time relative to the file time
      long currTime = _currTime + startTime;
      
      if (Main.ASSERT)
        Util.assertion(currTime >= time[idxStart]);

      try {
        // advance start index until current time is reached
        while (time[idxStart+1] < currTime) {
          sumValues -= value[idxStart];
          idxStart++;
        } // post-cond: time[idxStart] >= currTime
        
        // if start index > end index, reinit end index and sum
        if (idxStart > idxEnd) {
          idxEnd = idxStart;
          sumValues = 0;
        }
        
        // advance end index until aggregation interval is reached
        while (time[idxEnd] - time[idxStart] < aggregationInt) {
          sumValues += value[idxEnd];
          idxEnd++;
        }
      }
      catch (ArrayIndexOutOfBoundsException e) {
        if (0 == idxStart)
          throw e;
        
        // if we reached the end of the time line, wrap around
        log.warn("wrapping around at curr time " + JistAPI.getTime());
        idxStart = 0;
        sumValues = 0;
        idxEnd = 0;
        return getAverage(currTime);
      }
      
      // calculate the average
      return sumValues / (double)(idxEnd-idxStart);
    }
  }

}
