package edu.uulm.scbayes.inference.incremental

import edu.uulm.scbayes.factorgraph.messages.DiscreteMessage
import util.Random
import edu.uulm.scbayes.factorgraph.{DiscreteFactor, DiscreteFactorGraph}
import edu.uulm.scbayes.inference.bp.BeliefPropagation
import edu.uulm.scbayes.probabilities.{DiscreteVariable, DiscreteMarginals}
import edu.uulm.scbayes.inference.SteppingInferer

/**
 * @param messageThreshold Function returning true if the two given messages are differing enough to wake
 *   the adjacent node.
 */
class SteppingEFBP(val messageThreshold: (DiscreteMessage, DiscreteMessage) => Boolean) extends SteppingInferer {
 /* Stuff for unifying the two BP directions (variable -> factor; factor -> variable) */
  private trait FGDirection[TSrc,TDest]{
    def getSourceNodes: Set[TSrc]
    def getDestinationNodes: Set[TDest]
    def neighboursOfSrc(n: TSrc): Set[TDest]
    def neighboursOfDest(n: TDest): Set[TSrc]
    def computeMessages(incoming: Map[(TDest,TSrc),DiscreteMessage],
                        nodes: Set[TSrc]): Map[(TSrc,TDest),DiscreteMessage]
  }

  private implicit def evFactor2Variable[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V,F]) =
    new FGDirection[F,V] {
      def computeMessages(incoming: Map[(V, F), DiscreteMessage],
                          nodes: Set[F]): Map[(F, V), DiscreteMessage] =
        BeliefPropagation.computeFactorMessages(graph, incoming, nodes)
      def neighboursOfDest(n: V): Set[F] = graph.factorsOf(n)
      def neighboursOfSrc(n: F): Set[V] = graph.variablesOf(n)
      def getDestinationNodes: Set[V] = graph.variables
      def getSourceNodes: Set[F] = graph.factors
    }

  private implicit def evVariable2Factor[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V,F]) =
    new FGDirection[V,F] {
      def computeMessages(incoming: Map[(F, V), DiscreteMessage],
                          nodes: Set[V]): Map[(V, F), DiscreteMessage] =
        BeliefPropagation.computeVariableMessages(graph, incoming, nodes)
      def neighboursOfDest(n: F): Set[V] = graph.variablesOf(n)
      def neighboursOfSrc(n: V): Set[F] = graph.factorsOf(n)
      def getDestinationNodes: Set[F] = graph.factors
      def getSourceNodes: Set[V] = graph.variables
    }
  /* freaky type stuff ends here */

  type TState[V <: DiscreteVariable, F <: DiscreteFactor[V]] = (DiscreteFactorGraph[V,F],Set[V],Set[F],Map[(F,V),DiscreteMessage],Map[(V,F),DiscreteMessage])

  def marginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: SteppingEFBP#TState[V, F]): DiscreteMarginals[V] =
    BeliefPropagation.marginals(state._1,state._4)

  /**
   * Perform one step of EFBP. This includes sending messages from variables to factors and then
   * from factors to variables.
   * Also compute which nodes will get activated.
   */
  def advanceState[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: SteppingEFBP#TState[V, F]): SteppingEFBP#TState[V, F] = {
    val (graph, activeVars, activeFactors, old_f2vMessages, old_v2fMessages) = state

    //compute messages originating from active variables
    val (activatedFactors, new_v2fMessages) = computeStep[V,F](
      graph,
      activeVars,
      activeFactors,
      old_f2vMessages,
      old_v2fMessages)

    val newActiveFactors: Set[F] = activeFactors ++ activatedFactors
    val newV2fMessages: Map[(V, F), DiscreteMessage] = old_v2fMessages ++ new_v2fMessages

    val (activatedVariables, new_f2vMessages) = computeStep[F,V](
      graph,
      newActiveFactors,
      activeVars,
      newV2fMessages,
      old_f2vMessages)

    val newActiveVariables  = activeVars ++ activatedVariables
    val newF2vMessages = old_f2vMessages ++ new_f2vMessages

    (graph, newActiveVariables, newActiveFactors, newF2vMessages, newV2fMessages)
  }

  /**
   * Creates an initial set of messages. The f->v messages are random messages and the v->f messages are computed
   * from those using one step of the flooding protocol.
   */
  def createInitialState[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                        query: Set[V],
                                                                        random: Random): SteppingEFBP#TState[V, F] = {

    val (v2f,f2v) = BeliefPropagation.createRandomMessages(graph, random)
    (graph,graph.variables.toSet,graph.factors.toSet,f2v,v2f)
  }

  /**
   * This method assumes, that state contains converged messages.
   */
  def incrementState[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                    activeVarNodes: Set[V],
                                                                    activeFactorNodes: Set[F],
                                                                    initialFactorMessages: Map[(F,V),DiscreteMessage],
                                                                    initialVariableMessages: Map[(V,F),DiscreteMessage]): SteppingEFBP#TState[V, F] = {
    (graph, activeVarNodes, activeFactorNodes, initialFactorMessages, initialVariableMessages)
  }

  /**
   * This method generalizes over doing an EFBP step with messages from factor to variable nodes
   * and from variable to factor nodes.
   *
   * What basically happens is this:
   *  * recompute messages originating from the source type
   *  * compute which nodes of the destination type get activated
   *  * return those newly activated nodes and the relevant computed messages.
   *    Those are the messages, that go between now active nodes.
   *
   * @tparam TSrc Type of the nodes that produce the messages that are to be computed.
   *     Either a variable or factor type.
   * @tparam TDest Type of the nodes that are send the computed messages.
   *     The dual type to TSrc, e.g. variables if source are factors.
   *
   */
  private def computeStep[TSrc,TDest](uniGraph: FGDirection[TSrc,TDest],
                                      activeSources: Set[TSrc],
                                      activeDestinations: Set[TDest],
                                      incomingMessages: Map[(TDest,TSrc),DiscreteMessage],
                                      referenceMessages: Map[(TSrc,TDest),DiscreteMessage]): (Set[TDest], Map[(TSrc,TDest),DiscreteMessage]) = {
    //compute messages originating from active source nodes
    val newMessages = uniGraph.computeMessages(incomingMessages, activeSources)

    //activate additional destination nodes
    //those can only be found as neighbours of the active source nodes, that are not active, yet
    val activationCandidates = for(
      activeSrc <- activeSources;
      neighbour <- uniGraph.neighboursOfSrc(activeSrc) if !activeDestinations(neighbour)
    ) yield neighbour

    //activate those candidates for which a source node exists that sends a message that differs enough
    //this can only be the case for active source nodes
    val activatedNodes = activationCandidates.filter(
      candidate => uniGraph
        .neighboursOfDest(candidate)
        .toSet
        .intersect(activeSources)
        .exists(n => messageThreshold(newMessages((n,candidate)), referenceMessages((n,candidate))))
    )

    //compute the messages that we want to keep. These do not include messages to factors, that remained inactive
    val relevantMessages = newMessages.filter{case ((srcNode,destNode),msg) => activeDestinations(destNode) || activatedNodes(destNode)}

    (activatedNodes, relevantMessages)
  }
}



