package edu.uulm.scbayes.mln.factorgraph

import edu.uulm.scbayes.logic._
import edu.uulm.scbayes.logic.cnf._
import edu.uulm.scbayes.factorgraph.DiscreteFactorGraph
import edu.uulm.scbayes.probabilities.{DiscreteVariable, DiscreteMarginals}
import edu.uulm.scbayes.mln.{MarkovLogicNetwork, MLNWeightedFormula}
import scalaz._
import Scalaz._

/**
 * A factor graph implementation that resembles a MLN.
 * Date: 3/14/11
 */
class MLNFactorGraph( factors: Set[WeightedFormulaFactor],
                      val signature: Signature)
  extends DiscreteFactorGraph(factors.flatMap(_.variables), factors) {

  def allFormulasPositive: MLNFactorGraph = {
    val (negativeFactors, positiveFactors) = factors.partition(_.oWeight.exists(_ < 0))
    assert(negativeFactors.forall(_.weight < 0), "wrongly partitioned")

    val positivizedFactors = negativeFactors.map(_.flip)
    new MLNFactorGraph(positiveFactors ++ positivizedFactors, signature)
  }

  def filterFactors(p: WeightedFormulaFactor => Boolean) = new MLNFactorGraph(factors.filter(p), signature)

  /** Constructs a new MLN that contains additional hints for the given atoms.
    * The hint is computed from the marginals of these atoms and consists of an additional weighted formula
    * of the form (w,a) with w = log(p(a)/p(!a)).
    */
  def augment(atoms: Set[LogicNode], marginals: DiscreteMarginals[LogicNode]) = {
    val augmentationFactors: Set[WeightedFormulaFactor] = (atoms.intersect(variables)).flatMap{
      case rv: PredicateNode => {
        val marginal = marginals.marginal(rv, rv.domain2Int(true))
        Seq(WeightedFormulaFactor(
          if(marginal > 0.5)
            MLNWeightedFormula(GroundPredicate(rv.atomBase.name,rv.atomBase.arguments),Some(math.log(marginal/(1-marginal))))
          else
            MLNWeightedFormula(Negation(GroundPredicate(rv.atomBase.name,rv.atomBase.arguments)),Some(math.log((1-marginal)/marginal))),
          signature
        ))
      }
      case rv: FunctionValuedNode => {
        for(value <- rv.getRange) yield {
          val marginal = marginals.marginal(rv, value)
          WeightedFormulaFactor(
            MLNWeightedFormula(GroundPredicate(rv.atomBase.name,rv.atomBase.arguments :+ rv.int2Domain(value)), Some(math.log(marginal/(1-marginal)))),
            signature
          )
        }
      }
    }
    new MLNFactorGraph(factors ++ augmentationFactors, signature)
  }

  def reduceWithEvidence(evidence: TruthAssignment): MLNFactorGraph = {
    val reducedFactors: Set[WeightedFormulaFactor] = factors.map(_.reduceFactor(evidence, signature)).flatten
    new MLNFactorGraph(
      reducedFactors,
      signature
    )
  }

  def toParseableString: String = {
    signature.toParseableString + "\n\n//formulas\n" + factors.map(_.toParseableString).mkString("\n")
  }

  def toDOT: String = {
    val varMap: Map[Int, LogicNode] = variables.zipWithIndex.map(_.swap).toMap
    val varMapRev = varMap.map(_.swap)
    val facMap: Map[Int, WeightedFormulaFactor] = factors.zipWithIndex.map(_.swap).toMap

    val facToVarEdges: Seq[(Int, Int)] = for{
      (fi,f) <- facMap.toSeq
      vi <- f.variables.map(varMapRev)
    } yield (fi, vi)

    val nodes = varMap.toSeq.map{case (i,v) => f"""v$i [label="$v"];"""}.mkString("\n")
    val factorNodes = facMap.toSeq.map{case (i,f) => f"""f$i [label="$f", shape=box];"""}.mkString("\n")
    val edges = facToVarEdges.map{case (fi,vi) => f"f$fi -- v$vi;"}.mkString("\n")

    f"""graph MLN {
      |$nodes
      |$factorNodes
      |$edges
      |}
    """.stripMargin
  }

  /** @return A tuple of a uai problem (first) and a mapping from MLN random variables to uai variables (second). */
  def toUAI = {
    val varSeq = variables.toSeq
    val varMap: Map[LogicNode, Int] = varSeq.zipWithIndex.toMap
    val factorSeq = factors.toSeq
    val factorMap: Map[WeightedFormulaFactor, Int] = factorSeq.zipWithIndex.toMap
    val factorDomains = for{
      f <- factorSeq
    } yield f.variables.size + " " + f.variables.map(varMap).mkString(" ")
    val factorValues: Seq[String] = for{
      f <- factorSeq
    } yield "" + f.logValues.size + "\n" + f.logValues.map(math.exp).mkString(" ")
    val uai =
      f"""MARKOV
        |${varSeq.size}
        |${varSeq.map(_.domainSize).mkString(" ")}
        |${factorSeq.size}
        |${factorDomains.mkString("\n")}
        |${factorValues.mkString("\n")}
      """.stripMargin
    (uai,varMap.map(_.swap))
  }
}

object MLNFactorGraph {

  /**
   * Creates a factor graph from a set of formulas by converting each one to CNF. Then each CNF is split into
   * its clauses, distributing the weight amongst them. In a final step the free variables are grounded without weight
   * modification.*/
  def viaAlchemySplit(signature: Signature, formulas: Seq[MLNWeightedFormula]): MLNFactorGraph =
    fromFormulas(alchemySplit(formulas, signature), signature)

  /**
   * Creates a factor graph from a set of formulas without splitting the conjunctions into clauses.
   */

  def fromMLN(mln: MarkovLogicNetwork,
              evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph =
    fromFormulas(mln.formulas, mln.signature, evidence)

  def fromFormulas(formulas: Seq[MLNWeightedFormula],
                   signature: Signature,
                   evidence: TruthAssignment = TruthAssignment.empty) =
    fromGroundFormulas(groundWithoutSplit(formulas, signature), signature, evidence)

  def fromGroundFormulas(groundFormulas: Seq[MLNWeightedFormula],
                         signature: Signature,
                         evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph = {
    val factors: Seq[WeightedFormulaFactor] = groundFormulas.map(WeightedFormulaFactor(_, signature))

    val distinctFactors = factors.toSet

    assert(factors.size == distinctFactors.size, "we've just discarded a duplicate factor, that's probably not good")

    new MLNFactorGraph(
      distinctFactors,
      signature
    ).reduceWithEvidence(evidence)
  }

  /**
   * Grounds a set of formulas as alchemy is doing it. Formulas are converted to CNF, then they are split into
   * clauses and the weight gets distributed. Then they get grounded.
   */
  def alchemySplit(_formulas: Seq[MLNWeightedFormula], signature: Signature): Seq[MLNWeightedFormula] = {
    for {
      MLNWeightedFormula(f, w, constraints) <- _formulas
      //ground free variables (as if they were implicitly FORALL quantified)
      groundFormula <- f.groundifyFreeVariables(signature.constants,constraints)
      //ground explicit quantifiers; convert to CNF
      cnf = GroundCNF.formulaToCNF(groundFormula.groundifyQuantifiers(signature.constants))
      //drop tautologies
      valid_clauses = cnf.children.filter {
        case dj: Disjunction => !FormulaOps.isTautology(dj)
        case _ => true
      }
      //split into clauses, distribute weight equally among clauses
      clause <- valid_clauses
    } yield MLNWeightedFormula(clause, w.map(_ / valid_clauses.size))
  }

  /**
   * Simply grounds the formulas without splitting them.
   */
  def groundWithoutSplit(formulas: Seq[MLNWeightedFormula], signature: Signature): Seq[MLNWeightedFormula] = for {
    MLNWeightedFormula(f, w, constraints) <- formulas
    //ground free variables (as if they were implicitly FORALL quantified)
    groundFormula <- f.groundifyFreeVariables(signature.constants, constraints)
    //ground explicit quantifiers; convert to CNF
    cnf: ConjunctiveNormalForm = GroundCNF.formulaToCNF(groundFormula.groundifyQuantifiers(signature.constants)) if !cnf.isTautology
  } yield MLNWeightedFormula(cnf, w)

  implicit val monoid: Monoid[MLNFactorGraph] = new Monoid[MLNFactorGraph]{
    def append(s1: MLNFactorGraph, s2: => MLNFactorGraph): MLNFactorGraph =
      new MLNFactorGraph(s1.factors |+| s2.factors, s1.signature |+| s2.signature)
    val zero: MLNFactorGraph = new MLNFactorGraph(Set(), Signature.empty)
  }
}

