package test.sim.scenario.mac;

import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.log4j.Logger;

import jist.runtime.Main;
import jist.swans.Constants;
import jist.swans.mac.AbstractMac;
import jist.swans.mac.Mac802_11Message;
import jist.swans.mac.MacDcfMessage;
import jist.swans.misc.Event;
import jist.swans.misc.Util;
import jist.swans.net.AbstractNet;
import jist.swans.net.NetMessage;

public class MacTestHandler
{
  public static final Logger   LOG      = Logger.getLogger(MacTestHandler.class
                                            .getName());

  public List                  backoffs = new ArrayList();

  /** mac msg id --> mac msg */
  public Hashtable             macPkts  = new Hashtable();

  /** net msg id --> net msg */
  public Hashtable             netMsgs  = new Hashtable();

  private MacTransmissionTimes txTimes  = null;

  /**
   * Add event handlers to certain events. 
   */
  public void registerHandlers()
  {

    Event.addHandler(AbstractNet.SendToMacEvent.class, new Event.Handler() {
      public void handle(Event event)
      {
        AbstractNet.SendToMacEvent ev = (AbstractNet.SendToMacEvent) event;

        if (ev.getData() instanceof NetMessage.Ip) {
          NetMessage.Ip msg = (NetMessage.Ip) ev.getData();
          Integer netMsgId = new Integer(msg.getId());
          NetMsg nm = getNetMsg(netMsgId);
          nm.sendNodeId = ev.nodeId;
          nm.netSentAt = ev.time;
          nm.size = msg.getSize();
        } else {
          if (LOG.isInfoEnabled())
            LOG.info("SendToMacEvent: msg class is not IP but "
                + ev.getData().getClass());
        }
      }
    });

    Event.addHandler(AbstractNet.SendToMacFinishEvent.class,
        new Event.Handler() {
          public void handle(Event event)
          {
            AbstractNet.SendToMacFinishEvent ev = (AbstractNet.SendToMacFinishEvent) event;

            if (ev.getData() instanceof NetMessage.Ip) {
              NetMessage.Ip msg = (NetMessage.Ip) ev.getData();
              Integer netMsgId = new Integer(msg.getId());
              NetMsg nm = getNetMsg(netMsgId);
              nm.netEndSend = ev.time;

              /**
               * At the end of the transmission (on recpt. of the ACK in MAC)
               * verify that it worked correctly.
               */
              if (Main.ASSERT) Util.assertion(verifyTransmission(nm));

            } else {
              if (LOG.isInfoEnabled())
                LOG.info("SendToMacEvent: msg class is not IP but "
                    + ev.getData().getClass());
            }
          }
        });

    Event.addHandler(AbstractMac.SendEvent.class, new Event.Handler() {
      public void handle(Event event)
      {
        AbstractMac.SendEvent ev = (AbstractMac.SendEvent) event;

        if (ev.getData() instanceof MacDcfMessage) {
          MacDcfMessage msg = (MacDcfMessage) ev.getData();
          Integer macPktId = new Integer(msg.getId());

          if (msg.getType() == MacDcfMessage.TYPE_DATA) {
            /*
             * add the id of this DATA pkt to the net msg with the contained id
             */
            NetMessage.Ip ip = (NetMessage.Ip) ((MacDcfMessage.Data) msg)
                .getBody();
            getNetMsg(new Integer(ip.getId())).macPktIds.add(macPktId);
          } else if (msg.getType() == MacDcfMessage.TYPE_ACK
              || msg.getType() == Mac802_11Message.TYPE_CTS) {
            /*
             * add the id of an ACK to the same net msg as the prev DATA pkt
             * add the id of an CTS to the same net msg as the prev RTS pkt
             */
            addMacReplyPktIdToNetMsg(macPktId);
          } else if (msg.getType() == Mac802_11Message.TYPE_RTS) {
            /*
             * add the id of an RTS to the currently sent net msg of this node
             */
            getCurrentNetMsgForNode(ev.nodeId).macPktIds.add(macPktId);
          }

          MacPkt ps = getMacPkt(macPktId);
          ps.macSentAt = ev.time;
          ps.macTxBitrate = ev.phyBitRate;
          ps.macRcvdAt = -1;
          ps.macSize = msg.getSize();
          ps.duration = ev.duration;
          ps.retries += ev.retry ? 1 : 0;

        } else {
          if (LOG.isInfoEnabled())
            LOG.info("SendToMacEvent: msg class is not IP but "
                + ev.getData().getClass());
        }
      }
    });

    Event.addHandler(AbstractMac.ReceiveEvent.class, new Event.Handler() {
      public void handle(Event event)
      {
        AbstractMac.ReceiveEvent ev = (AbstractMac.ReceiveEvent) event;

        if (ev.getData() instanceof MacDcfMessage) {
          MacDcfMessage msg = (MacDcfMessage) ev.getData();
          Integer macPktId = new Integer(msg.getId());
          MacPkt ps = getMacPkt(macPktId);
          ps.macRcvdAt = ev.time;
        } else {
          if (LOG.isInfoEnabled())
            LOG.info("SendToMacEvent: msg class is not IP but "
                + ev.getData().getClass());
        }
      }
    });

    Event.addHandler(AbstractMac.BackoffEvent.class, new Event.Handler() {
      public void handle(Event event)
      {
        AbstractMac.BackoffEvent ev = (AbstractMac.BackoffEvent) event;

        Backoff bo = new Backoff(ev.nodeId, ev.waitedBo, ev.time);
        backoffs.add(bo);
      }
    });
  } // registerHandlers()

  /**
   * Set the reference times object to compare transmission actual times to.
   * 
   * @param times Transmission times to compare to
   */
  public void setMacTxTimes(MacTransmissionTimes times)
  {
    txTimes = times;
  }

  /**
   * Verify the transmission times for the given netMsg. Call this after
   * transmission is completed.
   * 
   * @param netMsg Net message to verify
   * @return True if transmission times are correct in respect to the member
   *         {@link txTimes}, false if not correct.
   */
  public boolean verifyTransmission(NetMsg netMsg)
  {
    if (txTimes == null) {
//      LOG.warn("Cannot verify transmission times without reference times!");
      // throw new RuntimeException(
      // "Cannot verify transmission times without reference times!");
      return true;
    }

    // First data pkt on MAC layer trying to transmit the NET layer message
    MacPkt data = (MacPkt) getMacPkt((Integer) netMsg.macPktIds.iterator()
        .next());

    // Get the backoffs performed by the STA which sent this net msg
    List bos = getBackoffsDuringMsg(netMsg.sendNodeId, netMsg);

    // Calculate the total time spent in backoff during tx of this net msg
    long totalBackoff = 0;
    Iterator it = bos.iterator();
    while (it.hasNext())
      totalBackoff += ((Backoff) it.next()).length;

    if (LOG.isDebugEnabled())
      LOG.debug("Total backoff during tx of net msg #" + netMsg.id() + ": "
          + totalBackoff + " (slots:" + totalBackoff / Constants.SLOT_TIME_DSSS
          + ")");

    // Check if first transmission succeeded (w/o backoff)
    long txTimeTheoNoBo = txTimes.get(data.macTxBitrate, netMsg.size);
    long expTxTime = txTimeTheoNoBo + totalBackoff;
    long txTime = netMsg.netEndSend - netMsg.netSentAt;
    if (LOG.isInfoEnabled())
      LOG.info("net msg: " + netMsg.id + " - real " + txTime + " vs. exp. "
          + txTimeTheoNoBo + "+" + totalBackoff + " (bo)");

    /* do the check */
    if (netMsg.equalsTxTime(expTxTime, Constants.MICRO_SECOND)) {
      if (LOG.isDebugEnabled())
        LOG.debug("First tx SUCCEEDED for net msg #" + netMsg.id());
      return true;
    } else {
      // First transmission failed
      LOG.error("First tx FAILED for net msg #" + netMsg.id() + "net msg:" + 
          netMsg.id + " - real " + txTime + " vs. exp. " + expTxTime +
          ", rate:" + data.macTxBitrate + " size:" + netMsg.size);
      return false;
    }
  }

  /**
   * Adds the given mac pkt id to the appropriate net msg (or rather to its mac
   * pkt id list). Only works correctly for response packets, e.g. CTS and ACK.
   * 
   * @param macPktId The mac RESPONSE packet id to add (only CTS or ACK!!!)
   */
  private void addMacReplyPktIdToNetMsg(Integer macPktId)
  {
    NetMsg netMsg = null;
    Iterator it = netMsgs.values().iterator();
    while (it.hasNext()) {
      Object o = it.next();
      // if (LOG.isDebugEnabled())
      // LOG.debug("Hashtable entry: " + o + " # class: " + o.getClass());
      netMsg = (NetMsg) o; // it.next();
      if (netMsg.macPktIds.contains(new Integer(macPktId.intValue() - 1)))
        netMsg.macPktIds.add(macPktId);
    }
  }

  /**
   * @return A list of backoffs that occured between sending the given net msg
   *         by the net layer and receiving the ack for that message
   */
  private List getBackoffsDuringMsg(int nodeId, NetMsg netMsg)
  {
    List l = new ArrayList();
    Iterator it = backoffs.iterator();
    if (LOG.isDebugEnabled())
      LOG.debug("Backoff list has " + backoffs.size() + " elements");
    Backoff bo = null;
    while (it.hasNext()) {
      bo = (Backoff) it.next();
      // if (LOG.isDebugEnabled())
      // LOG.debug("next backoff: bo.start:" + bo.start() + "\nNet sent:" +
      // netMsg.netSentAt
      // + "\nbo. end:" + bo.endTime
      // + "\nNet end:" + netMsg.netEndSend);
      if (bo.nodeId == nodeId && bo.start() > netMsg.netSentAt
          && bo.endTime < netMsg.netEndSend) l.add(bo);
    }
    // if (LOG.isDebugEnabled())
    // LOG.debug("Backoff list for msg has " + l.size() + " elements");
    return l;
  }

  /** @return The last net msg (highest id) sent by the node with the given id. */
  private NetMsg getCurrentNetMsgForNode(final int nodeId)
  {
    SortedSet keysInOrder = new TreeSet(netMsgs.keySet());
    int msgNodeId = Integer.MIN_VALUE;
    NetMsg msg = null;
    while (nodeId != msgNodeId && !keysInOrder.isEmpty()) {
      Object o = keysInOrder.last();
      msg = (NetMsg) netMsgs.get(o);
      LOG.debug("LOOKING AT (key:" + o + "):\n\t" + msg);
      msgNodeId = msg.sendNodeId;
      keysInOrder.remove(o);
    }
    if (nodeId != msgNodeId) return null;
    return msg;
  }

  /**
   * @param macPktId Id of the MAC packet to return
   * @return The MAC pkt with the given id from the local list
   */
  private MacPkt getMacPkt(Integer macPktId)
  {
    MacPkt macPkt = (MacPkt) macPkts.get(macPktId);
    if (macPkt == null) {
      macPkt = new MacPkt(macPktId);
      macPkts.put(macPktId, macPkt);
    }
    return macPkt;
  }

  /**
   * @param netMsgId Id of the NET message to return
   * @return The NET message with the given id from the local list
   */
  private NetMsg getNetMsg(Integer netMsgId)
  {
    NetMsg netMsg = (NetMsg) netMsgs.get(netMsgId);
    if (netMsg == null) {
      netMsg = new NetMsg(netMsgId);
      netMsgs.put(netMsgId, netMsg);
    }
    return netMsg;
  }

  ///////////////////////////////////////////////////////////
  // local classes
  //
  
  private static class Backoff
  {
    public long endTime;

    public long length;

    public int  nodeId;

    public Backoff(int nodeId, long length, long endTime)
    {
      this.nodeId = nodeId;
      this.length = length;
      this.endTime = endTime;
    }

    public long start()
    {
      return endTime - length;
    }

    public String toString()
    {
      StringBuffer sb = new StringBuffer();
      sb.append(" node:");
      sb.append(nodeId);
      sb.append(" bo_stt:");
      sb.append(endTime - length);
      sb.append(" bo_end:");
      sb.append(endTime);
      sb.append(" bo_len:");
      sb.append(length);
      return sb.toString();
    }
  }

  private static class MacPkt
  {
    public long     duration;

    public long     macRcvdAt;

    public long     macSentAt;

    public int      macSize;

    public int      macTxBitrate;

    public short    retries = 0;

    private Integer id;

    public MacPkt(Integer macPktId)
    {
      id = macPktId;
    }

    public Integer id()
    {
      return id;
    }

    public String toString()
    {
      StringBuffer sb = new StringBuffer();
      sb.append("m_id:");
      sb.append(id);
      sb.append(" m_snt:");
      sb.append(macSentAt);
      sb.append(" m_rcv:");
      sb.append(macRcvdAt);
      if (macRcvdAt > -1) {
        sb.append(" m_tx:");
        sb.append(macRcvdAt - macSentAt);
      }
      sb.append(" m_sz:");
      sb.append(macSize);
      sb.append(" rate:");
      sb.append(macTxBitrate / Constants.BANDWIDTH_1Mbps);
      sb.append(" dur:");
      sb.append(duration);
      sb.append(" retr:");
      sb.append(retries);
      return sb.toString();
    }
  }

  private static class NetMsg
  {
    public long     netEndSend;

    public long     netSentAt;

    public int      recvNodeId;

    public int      sendNodeId;

    public int      size;

    private Integer id;

    private Set     macPktIds = new TreeSet();

    public NetMsg(Integer netMsgId)
    {
      id = netMsgId;
    }

    /**
     * @param timeToEqual Expected transmission time for this message (in ns)
     * @param tolerance Maximum difference between expected and actual tx time
     * @return True if the given time matches the actual transmission time with
     *         a tolerance of the given amount, false otherwise.
     */
    public boolean equalsTxTime(long timeToEqual, long tolerance)
    {
      long txTime = netEndSend - netSentAt;
      LOG.debug("net msg:" + id + " - real " + txTime + " vs. exp. "
          + timeToEqual);
      return Math.abs(timeToEqual - txTime) <= tolerance;
    }

    public Integer id()
    {
      return id;
    }

    public String toString()
    {
      StringBuffer sb = new StringBuffer();
      sb.append("n_id:");
      sb.append(id);
      sb.append(" from:");
      sb.append(sendNodeId);
      sb.append(" to:");
      sb.append(recvNodeId);
      sb.append(" n_snt:");
      sb.append(netSentAt);
      sb.append(" n_end:");
      sb.append(netEndSend);
      sb.append(" n_tx:");
      sb.append(netEndSend - netSentAt);
      sb.append(" mac pkt ids: [");
      Iterator it = macPktIds.iterator();
      while (it.hasNext()) {
        Integer id = (Integer) it.next();
        sb.append(id);
        sb.append(",");
      }
      int i = -1;
      if ((i = sb.lastIndexOf(",")) >= 0) sb.deleteCharAt(i);
      sb.append("]");
      return sb.toString();
    }
  }

}
