package edu.uulm.scbayes.inference

import junctiontree.JunctionTreeInference
import org.specs2._
import matcher.Matchers._

import edu.uulm.scbayes.probabilities.{DiscreteMarginals, DiscreteVariable}
import matcher.{MatchResult, Expectable, Matcher}
import edu.uulm.scbayes.mln.factorgraph.MLNFactorGraph
import edu.uulm.scbayes.logic.cnf.GroundCNF
import edu.uulm.scbayes.mln.parsing.{ParseHelpers, PropFormulaParser}
import edu.uulm.scbayes.logic.{TruthAssignment, Formula, Signature}

/**
 * This file holds some custom matchers for specs2.
 */
object TestUtils {

  def parsePropCNF(s: String): (GroundCNF, Signature) = {
    import PropFormulaParser._
    import scalaz._
    import Scalaz._

    val parseResult = ParseHelpers.unpack(parseAll(formula(), s))
    val result = ((f: Formula) => GroundCNF.formulaToGroundCNF(f)) <-: parseResult

    result
  }

  def parsePropAssignment(s: String): TruthAssignment = {
    import PropFormulaParser._

    val parseResult = ParseHelpers.unpack(parseAll(assignment(), s))
    parseResult._1.applyClosedWorldToEverything
  }

  def inferAsJT(graph: MLNFactorGraph, tolerance: Double): Matcher[NInferer] = {
    beCloseMarginals(JunctionTreeInference.infer(graph), tolerance, graph.variables) ^^ ((inferer: NInferer) =>
      inferer.infer(graph) aka "as infered by %s".format(inferer)
      )
  }

  /** Compare the marginals of two given MarginalInferer objects up to a given tolerance.
    */
  def beCloseMarginals[V <: DiscreteVariable](reference: DiscreteMarginals[V],
                                              tolerance: Double,
                                              variables: Set[V]): Matcher[DiscreteMarginals[V]] =
    variables.map(beCloseMarginalsOnVariable(reference, tolerance, _)).reduce(_ and _)

  def beCloseMarginalsOnVariable[V <: DiscreteVariable](reference: DiscreteMarginals[V],
                                                        tolerance: Double,
                                                        variable: V): Matcher[DiscreteMarginals[V]] =
    variable.getRange.map(beCloseMarginalsOnVariableWithValue(reference, tolerance, variable, (_: Int))).reduce(_ and _)

  def beCloseMarginalsOnVariableWithValue[V <: DiscreteVariable](reference: DiscreteMarginals[V],
                                                                 tolerance: Double,
                                                                 variable: V,
                                                                 value: Int): Matcher[DiscreteMarginals[V]] =
    beCloseTo(reference.marginal(variable, value), tolerance) ^^
      {(_: DiscreteMarginals[V]).marginal(variable, value) aka "P(%s=%s)".format(variable,variable.valueName(value))}

  def beContainedIn[T](xs: Iterable[T]): Matcher[T] = ContainedInMatcher(xs)

  case class ContainedInMatcher[T](xs: Iterable[T]) extends Matcher[T] {
    def apply[S <: T](t: Expectable[S]): MatchResult[S] = result(
    xs.toSet.contains(t.value),
      "%s is contained in %s".format(t.value, xs),
      "%s is not contained in %s".format(t.value, xs),
      t
    )
  }
}