package edu.uulm.scbayes.mln.factorgraph

import edu.uulm.scbayes.logic._
import edu.uulm.scbayes.logic.cnf._
import collection.immutable.BitSet
import edu.uulm.scbayes.mln.MLNWeightedFormula
import edu.uulm.scbayes.factorgraph.{IndexedFactor, DiscreteFactor}
import java.util.Locale
import edu.uulm.scbayes.util.CrossProductIndexer

/**
 * A factor node that is defined by a formula and a (possibly infinite) weight.
 *
 *
 * Date: 3/14/11
 */

class WeightedFormulaFactor(
                     val formula: GroundCNF,
                     val weight: Double,
                     val isHard: Boolean,
                     override val variables: Set[LogicNode]
                     ) extends DiscreteFactor[LogicNode] with IndexedFactor {
  val assignmentVariables: IndexedSeq[LogicNode] = variables.toIndexedSeq

  def oWeight: Option[Double] = if(isHard) None else Some(weight)

  override def toString = isHard match {
    case true => "%s.".format(formula)
    case false => "%g %s".formatLocal(Locale.US,weight, formula)
  }

  def toParseableString: String = this.toString

  /**
   * Memoizes the evaluation of the formula in a Bitset.
   */
  private object LookUp {
    val cpi = new CrossProductIndexer(assignmentVariables.map(_.domainSize))

    val evalLookup: BitSet = {
      val builder = BitSet.newBuilder

      //loop through the cpi and evaluate every assignment and store it at its index
      (0 until cpi.size) foreach { idx =>
        val assign = cpi.index2Seq(idx)
        if(evaluateFormula(assign))
          builder += idx
      }

      builder.result()
    }

    def lookUpEval(seq: IndexedSeq[Int]): Boolean = evalLookup(cpi.seq2Index(seq))
  }

  private def logFactorIfEvaluatesTo(b: Boolean): Double = if (b)
      if(isHard) 0d else weight
    else
      if(isHard) Double.NegativeInfinity else 0d

  /** Evaluate the formula given an interpretation as (Int-) values for the DiscreteVariables. */
  private def evaluateFormula(assign: IndexedSeq[Int]): Boolean = {
    val pairing = assignmentVariables.zip(assign)
    val gpMap: Map[PredicateAtomBase, Boolean] = pairing.collect {
      case (pn: PredicateNode, v) => pn.atomBase -> pn.int2Domain(v)
    }(collection.breakOut)

    val fMap: Map[FunctionalAtomBase, Constant] = pairing.collect {
      case (fn: FunctionValuedNode, v) => fn.atomBase -> fn.int2Domain(v)
    }(collection.breakOut)

    formula.evaluate(TruthAssignment(gpMap, fMap))
  }

  override def logFactor(assign: IndexedSeq[Int]): Double = logFactorIfEvaluatesTo(LookUp.lookUpEval(assign))

  def logFactorByIndex(index: Int): Double = logFactorIfEvaluatesTo(LookUp.evalLookup(index))

  def assignmentToIndex(assign: IndexedSeq[Int]): Int = LookUp.cpi.seq2Index(assign)

  def index2Assignment(index: Int): IndexedSeq[Int] = LookUp.cpi.index2Seq(index)

  def reduceFactor(evidence: TruthAssignment, signature: Signature): Option[WeightedFormulaFactor] = {
    val reducedFormula = CNFOps.reduceWithEvidence(formula, evidence)

    reducedFormula.unconditionalValue match {
      case None => Some(new WeightedFormulaFactor(reducedFormula,
        (this).weight,
        (this).isHard,
        WeightedFormulaFactor.logicNodesFromFormula(reducedFormula,signature)))
      case Some(false) if (this).isHard => throw new RuntimeException("hard formula (%s) is violated by evidence".format(formula))
      case Some(_) => None
    }
  }
  
  def flip: WeightedFormulaFactor = {
    assert(!this.isHard, "cannot flip a hard factor")
    new WeightedFormulaFactor(GroundCNF.formulaToGroundCNF(Negation(formula.toFormula)),-this.weight,this.isHard,this.variables)
  }

  override def equals(p1: Any): Boolean = p1 match {
    case wff: WeightedFormulaFactor => this.formula == wff.formula && this.isHard == wff.isHard && (this.isHard || this.weight == wff.weight)
    case _ => false
  }

  override def hashCode: Int = formula.hashCode ^ (if(this.isHard) 0 else (this.weight)).hashCode

  def logValues = LookUp.cpi.map(x => logFactor(x.toIndexedSeq))
}

object WeightedFormulaFactor {
  def logicNodesFromFormula(f: GroundCNF, signature: Signature): Set[LogicNode] =
    f.atoms.map(_.base).map(LogicNode(_,signature)) (collection.breakOut)

  def apply(wf: MLNWeightedFormula, signature: Signature) = {
    val groundFormula = GroundCNF.formulaToGroundCNF(wf.formula)

    new WeightedFormulaFactor(
      groundFormula,
      wf.weight.getOrElse(0d),
      wf.weight.isEmpty,
      logicNodesFromFormula(groundFormula, signature)
    )
  }
}
