package brn.swans.route;

import jist.runtime.JistAPI;
import jist.swans.Constants;
import jist.swans.radio.RadioInterface;
import jist.swans.misc.Message;
import jist.swans.net.NetAddress;
import org.apache.log4j.Logger;

import brn.swans.route.metric.RouteMetricInterface;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Hashtable;
import java.util.List;

/**
 * Naive flooding for McExOR. 
 *
 * @author Zubow
 */
public class RouteMcExORFlood {

  /**
   * logger for BRP events.
   */
  private static Logger log = Logger.getLogger(RouteMcExORFlood.class.getName());

  // (msecs)
  protected static final int DEFER_FORWARD = 20;
  final static int NB_MIN_METRIC = 5000;

  protected NetAddress me;

  protected int period; //in msec

  protected int offset; //in msec

  protected int nextSeqNum;

  protected int adsChIndex;

  protected RouteMetricInterface routeMetric;

  /**
   * list of available channels
   */
  protected List /* int */ channelsInNBhood;

  /**
   * my home channel
   */
  protected int homeChannel;

  /**
   * I know the home channel of each node in the network
   */
  protected Hashtable nodeHomeChannels;

  /**
   * Indicates the maximum number of floodings.
   */
  private long maxnofloodings;

  //

  /**
   * Flooding packet
   */
  public static class McFloodingMsg implements Message {

    /**
     * Net address of node issuing the flooding message.
     */
    protected NetAddress src;

    /**
     * Sequence number of node issuing flooding message.
     */
    protected int seqNum;

    /**
     * period of this node's probe broadcasts, in msecs.
     */
    protected int period;

    /**
     * number of linkEntry entries following
     */
    protected int numLinks;

    protected List /** {@link LinkEntry} */
            linkEntries;

    /**
     * Inform all nodes about my home channel.
     */
    protected int homeChannel;

    /**
     * Channel used by this probe packet. Only annotation.
     */
    public int nextChannel;

    /**
     * Constructs new Flooding Message object.
     *
     * @param src    net address of this node
     * @param seqNum sequence number of this node
     */
    public McFloodingMsg(NetAddress src, int period, int seqNum) {
      this.src = src;
      this.period = period;
      this.seqNum = seqNum;
      this.linkEntries = new ArrayList();
    }

    /**
     * Constructs new Multi channel Flooding Message object.
     *
     * @param src    net address of this node
     * @param seqNum sequence number of this node
     */
    public McFloodingMsg(NetAddress src, int period, int seqNum, int homeChannel) {
      this(src, period, seqNum);
      this.homeChannel = homeChannel;
    }

    public void setHomeChannel(int homeChannel) {
      this.homeChannel = homeChannel;
    }


    /**
     * Constructs new Multi channel Flooding Message object.
     *
     * @param flood       a non multi channel flooding packet
     * @param homeChannel my home channel
     */
    public McFloodingMsg(McFloodingMsg flood, int homeChannel) {
      this(flood.src, flood.period, flood.seqNum);
      this.linkEntries = flood.linkEntries;
      this.homeChannel = homeChannel;
    }

    public void setNextChannel(int ch) {
      this.nextChannel = ch;
    }

    /**
     * Return size of packet.
     *
     * @return size of packet
     */
    public int getSize() {

      assert(numLinks == linkEntries.size());

      int byte_size = 4 /* IP src address */
              + /*Integer.SIZE*/32 / 8  /* seqNum */
              + /*Integer.SIZE*/32 / 8  /* period */
              + /*Integer.SIZE*/32 / 8;  /* numLinks */

      byte_size += numLinks * (4 /* IP src address */
              + /*Integer.SIZE*/32 / 8  /* metric */);

      log.debug("FloodingPacket with size = " + byte_size);

      byte_size += /*Integer.SIZE*/32 / 8; /* homeChannel */

      log.debug("McFloodingPacket with size = " + byte_size);

      return byte_size;
    }

    public void getBytes(byte[] msg, int offset) {
      throw new RuntimeException("Not yet implemented.");
    }

    /**
     * Returns message ip field.
     *
     * @return message ip field
     */
    public NetAddress getSrc() {
      return src;
    }

    public int getSeqNum() {
      return seqNum;
    }

    public String toString() {
      StringBuffer str = new StringBuffer();
      str.append(", size ").append(getSize());

      str.append(", sender ").append(src);
      str.append(", seqNum ").append(seqNum);

      str.append(", period ").append(period);
      str.append(", numLinks ").append(numLinks);
      for (int i = 0; i < linkEntries.size(); i++)
        str.append(",, numLinksItem ").append(linkEntries.get(i));
      return str.toString();
    }

    public boolean equals(Object o) {
      if (this == o) return true;
      if (o == null || getClass() != o.getClass()) return false;

      final McFloodingMsg that = (McFloodingMsg) o;

      if (homeChannel != that.homeChannel) return false;
      if (nextChannel != that.nextChannel) return false;
      if (numLinks != that.numLinks) return false;
      if (period != that.period) return false;
      if (seqNum != that.seqNum) return false;
      if (linkEntries != null ? !linkEntries.equals(that.linkEntries) : that.linkEntries != null) return false;
      if (src != null ? !src.equals(that.src) : that.src != null) return false;

      return true;
    }

    public int hashCode() {
      int result;
      result = (src != null ? src.hashCode() : 0);
      result = 29 * result + seqNum;
      result = 29 * result + period;
      result = 29 * result + numLinks;
      result = 29 * result + (linkEntries != null ? linkEntries.hashCode() : 0);
      result = 29 * result + homeChannel;
      result = 29 * result + nextChannel;
      return result;
    }
  }

  /**
   * NB
   */
  protected static class LinkEntry {

    protected NetAddress ip;
    protected int metric;

    public LinkEntry(NetAddress ip, int metric) {
      this.ip = ip;
      this.metric = metric;
    }

    public String toString() {
      StringBuffer str = new StringBuffer();
      str.append("nb ").append(ip);
      str.append("metric ").append(metric);

      return str.toString();
    }

    public boolean equals(Object o) {
      if (this == o) return true;
      if (o == null || getClass() != o.getClass()) return false;

      final LinkEntry linkEntry = (LinkEntry) o;

      if (metric != linkEntry.metric) return false;
      if (ip != null ? !ip.equals(linkEntry.ip) : linkEntry.ip != null) return false;

      return true;
    }

    public int hashCode() {
      int result;
      result = (ip != null ? ip.hashCode() : 0);
      result = 29 * result + metric;
      return result;
    }
  }

  /**
   * Create new Flooding Helper Class.
   */
  public RouteMcExORFlood(NetAddress me, int period, int offset, 
      RouteMetricInterface routeMetric) {
    this(me, period, offset, routeMetric, routeMetric.getHomeChannel(me));
  }

  /**
   * Create new Flooding Helper Class.
   */
  public RouteMcExORFlood(NetAddress me, int period, int offset, 
      RouteMetricInterface routeMetric, RadioInterface.RFChannel homeChannel) {
    this.me = me;
    this.period = period;
    this.offset = offset;
    this.routeMetric = routeMetric;
    this.adsChIndex = 0;
    this.channelsInNBhood = new ArrayList();
    this.homeChannel = homeChannel.getChannel();
    this.nodeHomeChannels = new Hashtable();
  }

  public void setMaxNoFloodings(long maxnofloodings) {
    this.maxnofloodings = maxnofloodings;
  }

  public int getHomeChannel() {
    return homeChannel;
  }

  public void setHomeChannel(int homeChannel) {
    this.homeChannel = homeChannel;
  }

  public int getNumberOfChannels() {
    return channelsInNBhood.size();
  }

  public RadioInterface.RFChannel getNodeHomeChannel(NetAddress node) {
    return routeMetric.getHomeChannel(node);
  }

  protected List estimateChannelsInNBhood() {
    // estimate number of channels used by my neighbors
    List neighbors = routeMetric.getNeighbors(me);

    for (int i = 0; i < neighbors.size(); i++) {
      NetAddress nb = (NetAddress) neighbors.get(i);

      // only clear neighbors are considered
      int nbMetric = routeMetric.getLinkMetricOld(me, nb);

      // skip too bad neighbors
      if (nbMetric > NB_MIN_METRIC) {
        //log.debug("calculateCandidateSet() skip nb " + nb + " (MAX_METRIC_FOR_NB exceeded)");
        continue;
      }

      RadioInterface.RFChannel nbHomeChannel = routeMetric.getHomeChannel(nb);
      if (!channelsInNBhood.contains(nbHomeChannel))
        channelsInNBhood.add(nbHomeChannel);
    }

    // TODO: HACK
    Collections.sort(channelsInNBhood);
    return channelsInNBhood;
  }

  /**
   * Calculates the time of the next flooding event.
   *
   * @return the time of the next flooding event.
   */
  public long calcNextSendPacket() {

    if (--maxnofloodings <= 0) {
      log.warn(" max number of floodings reached.");
      return Integer.MAX_VALUE;
    }

    // estimate number of channels used by my neighbors
    estimateChannelsInNBhood();

    int channelsInNbhood = getNumberOfChannels();

    // avoid devision by zero
    if (channelsInNbhood == 0) {
      channelsInNbhood = 1;
    }

    int p = period / channelsInNbhood; // period (msecs)
    // count
    int maxJitter = p * 500;

    return (p * 500) + getJitter(maxJitter);
  }

  /**
   * Initiates a flooding request. If a node has not any neighbor this
   * method returns <code>null</code>.
   */
  public McFloodingMsg initFlooding() {

    if (log.isDebugEnabled()) {
      log.debug(me + "(" + JistAPI.getTime() + "): initFlooding ");
    }

    McFloodingMsg flood = new McFloodingMsg(me, period, getNextSeqNum());

    // estimate my neighbors
    List neighbors = routeMetric.getNeighbors(me);

    for (int i = 0; i < neighbors.size(); i++) {
      NetAddress nb = (NetAddress) neighbors.get(i);

      int metric = routeMetric.getLinkMetricOld(me, nb);

      LinkEntry entry = new LinkEntry(nb, metric);

      flood.linkEntries.add(entry);
      flood.numLinks++;
    }

    flood.setHomeChannel(homeChannel);

    // channel number to be used
    if (channelsInNBhood.size() == 0) { // node has not any neighbor
      return null;
    }

    int channel = ((Integer) channelsInNBhood.get(adsChIndex)).intValue();

    log.debug(me + "(" + JistAPI.getTime() + "): sendProbe ... via ch " + channel);

    // points to the next channel
    adsChIndex = (adsChIndex + 1) % channelsInNBhood.size();

    // annotate packet with the right channel
    flood.setNextChannel(channel);
    return flood;
  }

  /**
   * Handles an incoming flooding request.
   *
   * @param msg flooding message
   */
  public void handleMsg(Message msg) {

    McFloodingMsg flood = (McFloodingMsg) msg;

    // sender address (neighbor node)
    NetAddress src = flood.getSrc();

    assert (!src.equals(me));

    log.debug(me + "(" + JistAPI.getTime() + "): learning from incoming flooding packet: "
            + src + "/" + flood.getSeqNum());

    // peer node is using the same period, ...
    assert(flood.numLinks == flood.linkEntries.size());
    assert(flood.period == this.period);

    for (int i = 0; i < flood.linkEntries.size(); i++) {
      LinkEntry linkEntry = (LinkEntry) flood.linkEntries.get(i);
      NetAddress nb = linkEntry.ip;
      int metric = linkEntry.metric;

      // update link table
      updateLink(src, nb, metric);
    }

    log.debug(me + "(" + JistAPI.getTime() + "): learning from incoming mc flooding packet "
            + flood.src + " on home ch " + flood.homeChannel);

    // remember the node's home channel
    nodeHomeChannels.put(flood.src, new Integer(flood.homeChannel));
  }

  protected void updateLink(NetAddress from, NetAddress to, int metric) {

    if (log.isDebugEnabled()) {
      log.debug("update link called " + from + " - " + to + " (" + metric + ")");
    }

    /* update linktable */
  if (!routeMetric.updateBothLinks(from, to, metric, false)) {
      log.error("couldn't update link" + from + " " + metric + " " + to);
    }
  }

  public int deferForward() {
    int maxJitter = DEFER_FORWARD * 500;

    return (DEFER_FORWARD * 1000) + getJitter(maxJitter);
  }

  private int getNextSeqNum() {
    return nextSeqNum++;
  }

  public int getPeriod() {
    return period;
  }

  public int getOffset() {
    return offset;
  }

  public int getJitter(int maxJitter) {
    return Constants.random.nextInt(maxJitter + 1) % (maxJitter + 1);
  }

  public int hashCode() {
    return 1;
  }
}
