package edu.uulm.scbayes.mln

import edu.uulm.scbayes.probabilities.DiscreteMarginals

import edu.uulm.scbayes.logic._
import TemporalSignature._

import factorgraph._
import scalaz.Monoid

/**
 * A MarkovLogicNetwork is a blueprint for instantiating a markov network (or factor graph).
 *
 *
 * Date: 20.05.11
 */

case class MarkovLogicNetwork(formulas: Seq[MLNWeightedFormula], signature: Signature){

  /** Constructs a new MLN that contains additional hints for the given atoms.
    * The hint is computed from the marginals of these atoms and consists of an additional weighted formula
    * of the form (w,a) with w = log(p(a)/p(!a)).
    */
  def augment(atoms: Seq[AtomBase], marginals: DiscreteMarginals[LogicNode]) = {
    val hints = atoms.flatMap{
      case pab: PredicateAtomBase => {
        val rv: PredicateNode = PredicateNode(pab)
        val marginal = marginals.marginal(rv, rv.domain2Int(true))
        Seq(MLNWeightedFormula(GroundPredicate(pab.name,pab.arguments),Some(math.log(marginal/(1-marginal)))))
      }
      case fab: FunctionalAtomBase => {
        val range = signature.constants(fab.name.targetSignature)
        val rv: FunctionValuedNode = FunctionValuedNode(fab,range)
        for(value <- range) yield {
          val marginal = marginals.marginal(rv, rv.domain2Int(value))
          MLNWeightedFormula(GroundPredicate(fab.name,fab.arguments :+ value), Some(math.log(marginal/(1-marginal))))
        }
      }
    }
    MarkovLogicNetwork(formulas ++ hints, signature)
  }

  /**Filter out formulas, that contain (time) constants that are not inside the signature's time range. */
  def withoutUninstantiatedTimeConstants: MarkovLogicNetwork = {
    val isInstantiated: Constant => Boolean = this.signature.constants(timeSort).contains _
    this.filterFormulas(_.formula.extractTerms.collect{case c: Constant if(c.sort.isTimeSort) => c}.forall(isInstantiated))
  }

  def filterFormulas(p: MLNWeightedFormula => Boolean) = this.copy(formulas = formulas.filter(p))

  def timeShift(begin: Int, end: Int): MarkovLogicNetwork = this.copy(signature = signature.setTimeRange(begin, end))
  def toGraph(evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph = MLNFactorGraph.fromFormulas(formulas, signature, evidence)

  /** Builds a ground MLNFactorGraph, that contains the intra-slice formulas at time and the inter-slice formulas
    * between time - 1 and time.
    */
  def buildSliceDynamic(time: Int, evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph =
    buildSlice( time,
    { formula =>
      val usedTimes: Set[Int] = formula.formula.extractTerms.collect{
        case Constant(name, sort) if sort == TemporalSignature.timeSort => name.toInt
      }(collection.breakOut)

      usedTimes.contains(time)
    },
    evidence)

  /** @return A slice that also includes all static atoms. */
  def buildSliceWithStatic(time: Int, evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph =
    buildSlice( time,
    { formula =>
      val usedTimes: Set[Int] = formula.formula.extractTerms.collect{
        case Constant(name, sort) if sort == TemporalSignature.timeSort => name.toInt
      }(collection.breakOut)

      usedTimes.contains(time) || usedTimes.isEmpty
    },
    evidence)

  /** Produces a two step [[edu.uulm.scbayes.mln.factorgraph.MLNFactorGraph that spans time steps `t - 1` and `t`.
   *
   * @param t The time index of the newer time step.
   * @param p Predicate function that decides whether a given formula shall be included in the result.
   *          All arguments are ground formulas.
   * @param evidence The observations to be used during grounding.
   *
   * @return A two-step factor graph. */
  def buildSlice( t: Int,
                  p: MLNWeightedFormula => Boolean,
                  evidence: TruthAssignment = TruthAssignment.empty): MLNFactorGraph = {
    val theMLN = this.timeShift(t - 1, t).withoutUninstantiatedTimeConstants
    val groundFormulas = MLNFactorGraph.groundWithoutSplit(theMLN.formulas, theMLN.signature)
    //filter out the formulas that are intra-time for (time - 1)
    val legal2StepFormulas = groundFormulas.filter(p)
    MLNFactorGraph.fromFormulas(legal2StepFormulas, theMLN.signature, evidence)
  }

  /** Makes all formulas positively weighted by negating negative formulas.
    *
    * @return A MarkovLogicNetwork that contains no formulas with negative weights. */
  def makePositive: MarkovLogicNetwork = {
    this.copy(formulas = formulas.map{
      case wf@MLNWeightedFormula(_, None, _) => wf
      case wf@MLNWeightedFormula(_,Some(w), _) if (w >= 0) => wf
      case wf@MLNWeightedFormula(f,Some(w), _) if (w < 0) => wf.copy(formula = Negation(f),weight = Some(-w))
    })
  }

  def toParseableString: String =
    signature.toParseableString + "\n\n//formulas\n" + formulas.map(_.toParseableString).mkString("\n")

  def temporalSequence(start: Int,
                       end: Int,
                       evidence: TruthAssignment = TruthAssignment.empty): Iterable[MLNFactorGraph] = {
    require(start < end, "cannot build a slice of size 1 or 0")
    new Iterable[MLNFactorGraph]{
      def iterator = ((start+1) to end).iterator.map(t =>
        if(t==start+1)
          buildSliceWithStatic(t,evidence)
        else
          buildSliceDynamic(t,evidence))
    }
  }

  /** @return A copy of this MLN with the constants added to the signature.*/
  def addConstants(cs: Iterable[Constant]): MarkovLogicNetwork = copy(signature = signature.addConstants(cs))
}

object MarkovLogicNetwork {
  implicit val mlnAsMonoid: Monoid[MarkovLogicNetwork] = new Monoid[MarkovLogicNetwork]{
    import scalaz.Scalaz._
    def append(s1: MarkovLogicNetwork, s2: => MarkovLogicNetwork): MarkovLogicNetwork =
      MarkovLogicNetwork(s1.formulas ++ s2.formulas, s1.signature |+| s2.signature)
    val zero: MarkovLogicNetwork = MarkovLogicNetwork(Seq(),Signature.empty)
  }
}

