package edu.uulm.scbayes.inference.sampling

import util.Random
import edu.uulm.scbayes.logic._
import edu.uulm.scbayes.logic.cnf._
import edu.uulm.scbayes.mln.factorgraph._
import sampling.{SampleSAT, SATSampler}
import edu.uulm.scbayes.probabilities.DiscreteVariable
import edu.uulm.scbayes.factorgraph.{DiscreteFactorGraph, DiscreteFactor}
import annotation.tailrec

/**
 * Implementation of MC-Sat sampling algorithm for MLNs.
 *
 *
 * Date: 28.03.11
 */
class MutableMCSatInferer(val satSampler: SATSampler = new SampleSAT(0.5, 0.1, 100000),
                          val useUnitPropagation: Boolean = true,
                          val useSplitting: Boolean = false)
extends MutableMCMCStepper {
  type SubRepr[V,F] = (AtomBaseAssignment, Set[WeightedFormulaFactor], Set[WeightedFormulaFactor], MLNFactorGraph)

  protected def stepStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: MutableMCSatInferer#SubRepr[V, F], graph: DiscreteFactorGraph[V, F], random: Random): MutableMCSatInferer#SubRepr[V, F] = {
    //MC-SAT first computes a set of formulas, that must be satisfied
    val (assignment, soft_factors, hard_factors, graph) = state

    val satisfiedFactors = soft_factors.filter(f => f.formula.evaluate(assignment))
    val requiredSoftFactors = satisfiedFactors.filter {
      factor =>
        //all factors from satisfiedFactors are soft
        (random.nextDouble() * math.exp(factor.weight)) > 1.0
    }

    val requiredFactors: Set[WeightedFormulaFactor] = requiredSoftFactors ++ hard_factors

    val sampleResult: Option[TruthAssignment] =
      if(useSplitting){
        val cnfs: Seq[GroundCNF] = buildCNFs(requiredFactors, graph)
        sampleInterpretationSplit(graph, cnfs, random)
      } else {
        val cnf = buildCNF(requiredFactors)
        sampleInterpretation(graph, cnf, random)
      }

    val newAssignment = sampleResult match {
      case Some(i) => i
      case None => {
        assignment
      }
    }

    (newAssignment, soft_factors, hard_factors, graph)
  }

  protected def valueOfVariable[V <: DiscreteVariable, F <: DiscreteFactor[V]](v: V, state: SubRepr[V, F]): Int = {
    val ln = v.asInstanceOf[LogicNode]
    ln.domain2Int(state._1.baseValue(ln.atomBase).get)
  }

  protected def createInitialStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F], query: Set[V], random: Random): MutableMCSatInferer#SubRepr[V, F] = {
    //todo casting sucks
    val mlnGraph = graph.asInstanceOf[MLNFactorGraph].allFormulasPositive
    val (hard_factors, soft_factors) =
      mlnGraph.factors.partition(factor => factor.asInstanceOf[WeightedFormulaFactor].isHard)

    //find an interpretation satisfying the hard clauses
    val hardCNF: GroundCNF = buildCNF(hard_factors.map(_.asInstanceOf[WeightedFormulaFactor]))
    val sat = sampleInterpretation(mlnGraph, hardCNF, random)

    assert(sat.isDefined, "could not find interpretation satisfying hard clauses")

    (sat.get, soft_factors.toSet, hard_factors.toSet, mlnGraph).asInstanceOf[SubRepr[LogicNode,WeightedFormulaFactor]]
  }

  /**
   * Use the SATSampler to find an interpretation for the given clauses. */
  private def sampleInterpretation(graph: MLNFactorGraph, requiredClauses: GroundCNF, random: Random): Option[TruthAssignment] =
    satSampler.sampleComplete(requiredClauses, graph.signature, graph.variables.toSet.map((a: LogicNode) => a.atomBase), random, useUnitPropagation)

  private def sampleInterpretationSplit(graph: MLNFactorGraph, requiredClauses: Seq[GroundCNF], random: Random): Option[TruthAssignment] = {
    val partialAssignment = requiredClauses
      .map(satSampler.samplePartial(_, graph.signature, random, useUnitPropagation))
      .foldLeft(Some(TruthAssignment.empty): Option[TruthAssignment]){
        case (Some(ta1), Some(ta2)) => Some(ta1.orElse(ta2))
        case _ => None
      }

    partialAssignment.map(SATSampler.completeRandomized(_, graph.signature, graph.variables.toSet.map((a: LogicNode) => a.atomBase), random))
  }

  /**
   * Construct a GroundCNF from a set of factors by simply appending the clauses.
   */
  def buildCNF(factors: Iterable[WeightedFormulaFactor]): GroundCNF = {
    //simply flatten out the clauses
    val clauses: Seq[GroundClause] = factors.flatMap(_.formula.clauses)(collection.breakOut)

    GroundCNF.fromClauses(clauses.toSeq)
  }

  private def buildCNFs(factors: Set[WeightedFormulaFactor], graph: MLNFactorGraph): Seq[GroundCNF] = {
    @tailrec
    def findComponentOf(acc: Set[WeightedFormulaFactor]): Set[WeightedFormulaFactor] = {
      val neighbours: Set[WeightedFormulaFactor] = acc.flatMap(graph.factorAdjacency(_) intersect factors)
      if(neighbours.size == acc.size) acc
      else findComponentOf(neighbours)
    }

    @tailrec
    def splitIntoComponents(_factors: Set[WeightedFormulaFactor],
                            acc: List[Set[WeightedFormulaFactor]]): List[Set[WeightedFormulaFactor]] = {
      if (_factors.isEmpty) acc
      else {
        val nextComponent = findComponentOf(Set(_factors.head))
        val rest: Set[WeightedFormulaFactor] = _factors -- nextComponent
        splitIntoComponents(rest, nextComponent :: acc)
      }
    }

    val splitFactors: List[Set[WeightedFormulaFactor]] = splitIntoComponents(factors, Nil)
    splitFactors.map(buildCNF(_))
  }
}