package edu.uulm.scbayes.inference.incremental

import util.Random

import edu.uulm.scbayes.logic.{Signature, Constant, TruthAssignment}
import edu.uulm.scbayes.probabilities.{JointMarginals, DiscreteMarginals}

import edu.uulm.scbayes.mln._
import factorgraph._

import edu.uulm.scbayes.inference.bp.BeliefPropagation
import edu.uulm.scbayes.factorgraph.messages.DiscreteMessage
import edu.uulm.scbayes.inference.ConvergenceRunner

/**
 * @param messageThreshold Function returning true if the two given messages are differing enough to wake
 *   the adjacent node.
 */
class ExpandingFrontierBP(messageThreshold: (DiscreteMessage, DiscreteMessage) => Boolean,
                          val convergence: ConvergenceRunner[SteppingEFBP]) extends IncrementalMLNInferer {
  val efbpStepper = new SteppingEFBP(messageThreshold)

  type TIncRes = (efbpStepper.TState[LogicNode,WeightedFormulaFactor],MarkovLogicNetwork, TruthAssignment, Random)

  def marginals[V](state: ExpandingFrontierBP#TIncRes): DiscreteMarginals[LogicNode] =
    new JointMarginals(efbpStepper.marginals(state._1), new TruthAssignmentMarginals(state._3))

  def createIncrementalResult(oldState: ExpandingFrontierBP#TIncRes,
                newConstants: Set[Constant],
                newEvidence: TruthAssignment): ExpandingFrontierBP#TIncRes = {

    val (efbpState, mln, oldEvidence, random) = oldState
    val (oldGraph, _, _, old_f2vMessages, old_v2fMessages) = efbpState

    val combinedEvidence = oldEvidence.orElse(newEvidence)
    val MarkovLogicNetwork(formulas, oldSignature) = mln
    val newSignature: Signature = oldSignature.addConstants(newConstants)
    val newMLN = MarkovLogicNetwork(formulas, newSignature)
    val newGraph: MLNFactorGraph = MLNFactorGraph.fromMLN(newMLN, combinedEvidence)

    val newVariables = newGraph.variables.toSet -- oldGraph.variables.toSet
    val newFactors = newGraph.factors.toSet -- oldGraph.factors.toSet

    //also take all neighbours of the new/changed nodes

    val activeVariables = newVariables ++ newFactors.flatMap(newGraph.variablesOf)
    val activeFactors = newFactors ++ newVariables.flatMap(newGraph.factorsOf)

    val msgGenerator = BeliefPropagation.randomMessageGenerator(random)
    val newFactorMessages = BeliefPropagation.createInitialFactorMessages(newGraph, msgGenerator, activeFactors)
    val newVariableMessages = BeliefPropagation.createInitialVariableMessages(newGraph, msgGenerator, activeVariables)

    println("->%d/%d - new %d/%d".format(activeVariables.size, activeFactors.size, newVariables.size, newFactors.size))
    val incrementedState = efbpStepper.incrementState(
      newGraph,
      activeVariables,
      activeFactors,
      old_f2vMessages ++ newFactorMessages,
      old_v2fMessages ++ newVariableMessages
      )

    (runBPUntilConvergence(incrementedState), newMLN, combinedEvidence, random)
  }

  def createInitialResult(mln: MarkovLogicNetwork,
                          evidence: TruthAssignment,
                          random: Random): ExpandingFrontierBP#TIncRes = {
    val graph = MLNFactorGraph.fromMLN(mln,evidence)
    val initState = efbpStepper.createInitialState(graph, graph.variables.toSet,random)
    (runBPUntilConvergence(initState), mln, evidence, random)
  }

  def runBPUntilConvergence(state: efbpStepper.TState[LogicNode,WeightedFormulaFactor]): efbpStepper.TState[LogicNode,WeightedFormulaFactor] = {
    convergence.runUntilConvergence(efbpStepper,state)
  }
}

class DifferenceMessageThreshold(val lambda: Double) extends ((DiscreteMessage, DiscreteMessage) => Boolean){
  def apply(dm1: DiscreteMessage, dm2: DiscreteMessage): Boolean = {
    dm1.zip(dm2).exists(t => math.abs(t._1 - t._2) > lambda)
  }
}