package click.swans.net;

import java.util.HashMap;
import java.util.Map;

import org.apache.log4j.Logger;

import jist.runtime.JistAPI;
import jist.swans.Constants;
import jist.swans.misc.MessageAnno;
import jist.swans.net.NetAddress;
import jist.swans.net.NetMessage;
import jist.swans.trans.TransTcp.TcpMessage;
import jist.swans.trans.TransUdp.UdpMessage;

/**
 * Preserves flow/packet ids and tx times when packet is tunneled through click.
 * The idea is to store the information using ip layer keys (src,dst,id),
 * hoping that click will not change them!
 *
 * TODO integrate with netcoding:
 * methods for tracing packets through click
 * it works as follows:
 *
 * 1. The Trans application registers a flow between a source and destination
 * Unfortunately there can only be one flow per pair of nodes, but nothing will
 * crash if there are more registrations. It will only look a bit strange later.
 *
 * 2. Trans application sends packets
 *
 * 3. ClickRouter determines the flow id from source and destination and
 * registers a packet when ever it gets to send one.
 *
 * 4. When the packet arrives at some click node, the MAC addresses of src, dst,
 * next and last are recorded by TraceCollector (in package netcoding) and later
 * fetched by ClickRouteTracer via read handlers.
 *
 * 5. ClickRouteTracer knows all ClickRouter instances and can derive the Net
 * Addresses from the MAC addresses and thus can in turn resolve the flow ids and
 * produce "forward", "dupe", "discard", whatever events.
 *
 * 6. When the packet is received by the Trans application on the other side, the
 * tx time and thus the latency can be determined.
 *
 * @author kurth
 */
public class ClickFlowStore {

  public static final Logger log = Logger.getLogger(ClickFlowStore.class.getName());

  protected static class FlowEntry {
    /** ip packet source address. */
    public NetAddress src;
    /** packet source address. */
    public int srcPort;
    /** ip packet destination address. */
    public NetAddress dst;
    /** packet destination address. */
    public int dstPort;
    /** ip packet protocol, such as TCP, UDP, etc. */
    public short protocol;

    public FlowEntry(NetAddress src, int srcPort, NetAddress dst, int dstPort,
        short protocol) {
      super();
      this.src = src;
      this.srcPort = srcPort;
      this.dst = dst;
      this.dstPort = dstPort;
      this.protocol = protocol;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#hashCode()
     */
    @Override
    public int hashCode() {
      final int prime = 31;
      int result = 1;
      result = prime * result + ((dst == null) ? 0 : dst.hashCode());
      result = prime * result + dstPort;
      result = prime * result + protocol;
      result = prime * result + ((src == null) ? 0 : src.hashCode());
      result = prime * result + srcPort;
      return result;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#equals(java.lang.Object)
     */
    @Override
    public boolean equals(Object obj) {
      if (this == obj)
        return true;
      if (obj == null)
        return false;
      if (getClass() != obj.getClass())
        return false;
      FlowEntry other = (FlowEntry) obj;
      if (dst == null) {
        if (other.dst != null)
          return false;
      } else if (!dst.equals(other.dst))
        return false;
      if (dstPort != other.dstPort)
        return false;
      if (protocol != other.protocol)
        return false;
      if (src == null) {
        if (other.src != null)
          return false;
      } else if (!src.equals(other.src))
        return false;
      if (srcPort != other.srcPort)
        return false;
      return true;
    }
  }

  protected static class PacketEntry {
    /** ip packet identification. */
    public short id;

    /** the flow id */
    public Integer flowId;
    /** the packet id*/
    public Integer packetId;
    /** the tx time */
    public Long txTime;

    public PacketEntry(short id, Integer flowId, Integer packetId, Long txTime) {
      super();
      this.id = id;
      this.flowId = flowId;
      this.packetId = packetId;
      this.txTime = txTime;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#hashCode()
     */
    @Override
    public int hashCode() {
      final int prime = 31;
      int result = 1;
      result = prime * result + id;
      return result;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#equals(java.lang.Object)
     */
    @Override
    public boolean equals(Object obj) {
      if (this == obj)
        return true;
      if (obj == null)
        return false;
      if (getClass() != obj.getClass())
        return false;
      PacketEntry other = (PacketEntry) obj;
      if (id != other.id)
        return false;
      return true;
    }

  }

  protected static Map<FlowEntry, Map<Short, PacketEntry>> flowMap;

  public ClickFlowStore() {
    flowMap = new HashMap<FlowEntry, Map<Short,PacketEntry>>();
  }

  protected int getSrcPort(NetMessage.Ip msg) {
    int srcPort;
    if (Constants.NET_PROTOCOL_UDP == msg.getProtocol()) {
      srcPort = ((UdpMessage)msg.getPayload()).getSrcPort();
    } else if (Constants.NET_PROTOCOL_TCP == msg.getProtocol()) {
      srcPort = ((TcpMessage)msg.getPayload()).getSrcPort();
    } else {
      // TODO error handling
      throw new RuntimeException();
    }
    return srcPort;
  }

  protected int getDstPort(NetMessage.Ip msg) {
    int dstPort;
    if (Constants.NET_PROTOCOL_UDP == msg.getProtocol()) {
      dstPort = ((UdpMessage)msg.getPayload()).getDstPort();
    } else if (Constants.NET_PROTOCOL_TCP == msg.getProtocol()) {
      dstPort = ((TcpMessage)msg.getPayload()).getDstPort();
    } else {
      // TODO error handling
      throw new RuntimeException();
    }
    return dstPort;
  }

  /**
   * Put information into store
   *
   * @param msg the message to send
   * @param anno annotation containing flow/packet id, tx time
   */
  public void put(NetMessage.Ip msg, MessageAnno anno) {
    if (null == anno || null == msg)
      return;
    if (Constants.NET_PROTOCOL_UDP != msg.getProtocol()
        && Constants.NET_PROTOCOL_TCP != msg.getProtocol()) {
      log.error(this + "(" + JistAPI.getTime() + "): Unable to handle protocol "+
          msg.getProtocol());
      return;
    }

    Integer flowId = (Integer) anno.get(MessageAnno.ANNO_RTG_FLOWID);
    Integer packetId = (Integer) anno.get(MessageAnno.ANNO_RTG_PACKETID);
    Long txTime = (Long) anno.get(MessageAnno.ANNO_TRANS_TXTIME);

    FlowEntry fentry = new FlowEntry(msg.getSrc(), getSrcPort(msg),
        msg.getDst(), getDstPort(msg), msg.getProtocol());

    Map<Short, PacketEntry> packetMap = flowMap.get(fentry);
    if (null == packetMap) {
      packetMap = new HashMap<Short, PacketEntry>();
      flowMap.put(fentry, packetMap);
    }

    PacketEntry pentry = new PacketEntry(msg.getId(), flowId, packetId, txTime);
    packetMap.put(msg.getId(), pentry);
  }

  /** Search the store for the given packet and store information in annos.
   *
   * @param msg the message to lookup (coming from click)
   * @param anno the target annotations.
   */
  public void get(NetMessage.Ip msg, MessageAnno anno) {
    if (null == anno || null == msg)
      return;

    FlowEntry fentry = new FlowEntry(msg.getSrc(), getSrcPort(msg),
        msg.getDst(), getDstPort(msg), msg.getProtocol());
    Map<Short, PacketEntry> packetMap = flowMap.get(fentry);
    if (null == packetMap)
      return;

    PacketEntry pentry = packetMap.get(msg.getId());
    if (null == pentry)
      return;

    anno.put(MessageAnno.ANNO_RTG_FLOWID, pentry.flowId);
    anno.put(MessageAnno.ANNO_RTG_PACKETID, pentry.packetId);
    anno.put(MessageAnno.ANNO_TRANS_TXTIME, pentry.txTime);
  }

//  /** highest flow id in use */
//  protected static int lastFlowId = 0;
//
//  /** source->destination->flowId */
//  protected static Map flowIds = new HashMap();
//
//  /** flowId->packetId->sendTime */
//  protected static Map sendTimes = new HashMap();
//
//  static {
//    sendTimes.put(new Integer(-1), new HashMap());
//  }
//  /**
//   *
//   * @param src
//   * @param dst
//   * @return
//   */
//  public static int determineFlowId(NetAddress src, NetAddress dst) {
//    HashMap dstToFlow = (HashMap) flowIds.get(src);
//    if (dstToFlow == null) {
//      dstToFlow = new HashMap();
//      flowIds.put(src, dstToFlow);
//    }
//
//    Integer id = (Integer) dstToFlow.get(dst);
//    if (id == null) {
//      return -1;
//    } else {
//      return id.intValue();
//    }
//  }
//
//  public static void registerFlow(NetAddress src, NetAddress dst,
//      int flowId) {
//    HashMap dstToFlow = (HashMap) flowIds.get(src);
//    if (dstToFlow == null) {
//      dstToFlow = new HashMap();
//      flowIds.put(src, dstToFlow);
//    }
//    dstToFlow.put(dst, Integer.valueOf(flowId));
//    HashMap idToSend = (HashMap) sendTimes.get(flowId);
//    if (idToSend == null) {
//      idToSend = new HashMap();
//      sendTimes.put(flowId, idToSend);
//    }
//    if (flowId >= lastFlowId)
//      lastFlowId = flowId;
//  }
//
//  public static void registerPacket(int in_flow, short nextId) {
//    Integer flowId = Integer.valueOf(in_flow);
//    Short ipId = Short.valueOf(nextId);
//    HashMap idToSend = (HashMap) sendTimes.get(flowId);
//    idToSend.put(ipId, Long.valueOf(JistAPI.getTime()));
//  }
//
//  protected static long determineTxTime(int flowId, short packetId) {
//    long now = JistAPI.getTime();
//    HashMap packetToSend = (HashMap) sendTimes.get(Integer
//        .valueOf(flowId));
//    if (packetToSend == null) {
//      return now;
//    } else {
//      Long sendTime = (Long) packetToSend
//          .get(Short.valueOf(packetId));
//      if (sendTime != null)
//        return sendTime.longValue();
//      else {
//        //log.warn("no sendTime found");
//        return now;
//      }
//    }
//  }

}
