package edu.uulm.scbayes.inference.incremental

import util.Random

import edu.uulm.scbayes.logic._

import edu.uulm.scbayes.probabilities.{JointMarginals, DiscreteMarginals}

import edu.uulm.scbayes.mln._
import factorgraph._

import edu.uulm.scbayes.inference.NInferer

/**
 * On increment, create a new MLN of the new nodes and the nodes on the border.
 * Augment the border nodes with factors that contain the evidence of the old MLN (by using the marginals).
 *
 *
 * Date: 17.06.11
 */

class NSlide2S(inferer: NInferer) extends IncrementalMLNInferer {
  def marginals[V](state: NSlide2S#TIncRes): DiscreteMarginals[LogicNode] =
    new JointMarginals(state._3, new TruthAssignmentMarginals(state._4))

  /** Add some new information to the network. */
  def createIncrementalResult(oldState: NSlide2S#TIncRes,
                newConstants: Set[Constant],
                newEvidence: TruthAssignment): NSlide2S#TIncRes = {
    val (mln, nodes, oldMarginals, oldEvidence) = oldState
    val combinedEvidence = oldEvidence.orElse(newEvidence)
    val MarkovLogicNetwork(formulas, oldSignature) = mln

    //todo remove dirty hack
    assert(newConstants.size == 1)

    val newSignature: Signature = oldSignature.addConstants(newConstants)
    val newMLN = MarkovLogicNetwork(formulas, newSignature)
    val newGraph: MLNFactorGraph = MLNFactorGraph.fromMLN(newMLN, combinedEvidence)

    val newNodes = newGraph.variables.toSet -- nodes
    //the nodes that are adjacent to the new nodes
    val borderNodes: Set[LogicNode] =
      newNodes
        .flatMap(newGraph.factorsOf)
        .toSet
        .flatMap(newGraph.variablesOf)
        .toSet
        .--(newNodes)

    //take only factors, that have a variable inside the new variables
    val newFactors: Set[WeightedFormulaFactor] = newGraph.factors.filter(f => f.variables.exists(newNodes))

    val sliceGraph: MLNFactorGraph = new MLNFactorGraph(
      //take all new variables and all variables that are adjacent to a used factor
      newFactors,
      newSignature
    )

    val augmentedGraph = sliceGraph.augment(borderNodes, oldMarginals)
    val newMarginals: DiscreteMarginals[LogicNode] = inferer.infer(augmentedGraph)

    (newMLN, newGraph.variables.toSet, new JointMarginals(newMarginals.filter(newNodes), oldMarginals), combinedEvidence)
  }

  def createInitialResult(mln: MarkovLogicNetwork,
                          evidence: TruthAssignment,
                          random: Random = null): NSlide2S#TIncRes = {
    val graph = MLNFactorGraph.fromMLN(mln, evidence)
    val variableNodes = graph.variables.toSet
    val marginals = inferer.infer(graph)

    (mln, variableNodes, marginals, evidence)
  }

  type TIncRes = (MarkovLogicNetwork, Set[LogicNode], DiscreteMarginals[LogicNode], TruthAssignment)
}


