package edu.uulm.scbayes.inference.exact

import org.specs2.mutable._
import org.specs2.matcher.DataTables

import edu.uulm.scbayes.experiments.examples.ExampleGraphs
import edu.uulm.scbayes.mln.factorgraph._
import edu.uulm.scbayes.logic.AtomBase

/**
 * Unit tests for the exact counting inference engine.
 *
 * Date: 29.03.11
 */

class CountingInfererTest extends Specification with DataTables {

  "smoker example without evidence must be accurate as exact inference in pyMLN"      !e1 ^
  "smoker with evidence must be accurate as exact inference in pyMLN"                 !e2 ^
  "hard smoker example without evidence must be accurate as exact inference in pyMLN" !e3 ^
  "simple functional example must be accurate as exact inference in pyMLN"           !e4

  def e1 = {
    val graph = ExampleGraphs.smokers
    val inferer = new CountingInferer(graph)

    //probabilities obtained from pyMLN's exact inference
    "atom"                  || "marginal"  |>
    "cancer(PETER)"         !! 0.583313d   |
    "cancer(THOMAS)"        !! 0.583313d   |
    "smokes(PETER)"         !! 0.360569d   |
    "smokes(THOMAS)"        !! 0.360569d   |
    "friends(PETER,PETER)"  !! 0.5d        |
    "friends(THOMAS,THOMAS)"!! 0.5d        |
    "friends(PETER,THOMAS)" !! 0.412053d   |
    "friends(THOMAS,PETER)" !! 0.412053d   | { (atom, probability) =>
      val variable = graph.signature.herbrandBase.collectFirst{case ab: AtomBase if(ab.toString == atom) => LogicNode(ab,graph.signature)}.get.asInstanceOf[PredicateNode]
      inferer.padEqual(_ => true).marginal(variable, variable.domain2Int(true)) must be closeTo(probability,0.0001d)
    }
  }

  //what happens here i
  def e2 = {
    val graph = ExampleGraphs.smokersWithEvidence
    val inferer: CountingInferer[LogicNode, WeightedFormulaFactor] = new CountingInferer(graph)

    //probabilities obtained from pyMLN's exact inference
    //results:
    //0.731059  Cancer(PETER)
    //0.725032  Cancer(THOMAS)
    //0.500000  Friends(PETER,PETER)
    //1.000000  Friends(PETER,THOMAS)
    //1.000000  Friends(THOMAS,PETER)
    //0.500000  Friends(THOMAS,THOMAS)
    //1.000000  Smokes(PETER)
    //0.973919  Smokes(THOMAS)
    "atom"                  || "marginal"  |>
    "cancer(PETER)"         !! 0.731059d   |
    "cancer(THOMAS)"        !! 0.725032d   |
    "smokes(THOMAS)"        !! 0.973919d   |
    "friends(PETER,PETER)"  !! 0.5d        |
    "friends(THOMAS,THOMAS)"!! 0.5d        | { (atom, probability) =>
      val variable = graph.signature.herbrandBase.collectFirst{case ab: AtomBase if(ab.toString == atom) => LogicNode(ab,graph.signature)}.get.asInstanceOf[PredicateNode]
      inferer.padEqual(_ => true).marginal(variable, variable.domain2Int(true)) must be closeTo(probability,0.0001d)
    }
  }

  def e3 = {
    val graph = ExampleGraphs.hardSmokers
    val inferer = new CountingInferer(graph)

    //probabilities obtained from pyMLN's exact inference
    "atom"                  || "marginal"  |>
    "cancer(PETER)"         !! 0.581553d   |
    "cancer(THOMAS)"        !! 0.581553d   |
    "smokes(PETER)"         !! 0.352955d   |
    "smokes(THOMAS)"        !! 0.352955d   |
    "friends(PETER,PETER)"  !! 0.5d        |
    "friends(THOMAS,THOMAS)"!! 0.5d        |
    "friends(PETER,THOMAS)" !! 0.405520d   |
    "friends(THOMAS,PETER)" !! 0.405520d   | { (atom, probability) =>

      val variable = graph.signature.herbrandBase.collectFirst{case ab: AtomBase if(ab.toString == atom) => LogicNode(ab,graph.signature)}.get.asInstanceOf[PredicateNode]
      inferer.padEqual(_ => true).marginal(variable, variable.domain2Int(true)) must be closeTo(probability,0.0001d)
    }
  }

  def e4 = {
    val graph = ExampleGraphs.simpleFunctional
    val inferer = new CountingInferer(graph)

    //probabilities obtained from pyMLN's exact inference
    "atom"           || "value"    | "marginal" |>
    "emotion(T1)"    !! "GOOD"     ! 0.106507d  |
    "emotion(T1)"    !! "NEUTRAL"  ! 0.786986d  |
    "emotion(T1)"    !! "BAD"      ! 0.106507d  |
    "emotion(T2)"    !! "GOOD"     ! 0.106507d  |
    "emotion(T2)"    !! "NEUTRAL"  ! 0.786986d  |
    "emotion(T2)"    !! "BAD"      ! 0.106507d  | { (functionBase, a, probability) =>

      val variable = graph.variables.collect{case pn: FunctionValuedNode => pn}.find(_.toString == functionBase).get
      val value = variable.getRange.find(variable.int2Domain(_).toString == a).get
      inferer.marginal(variable,value) must be closeTo(probability,0.0001d)
    }
  }
}