package edu.uulm.scbayes.inference.sampling

import org.specs2.mutable._

import scala.util.Random

import edu.uulm.scbayes._
import inference.TestUtils._
import inference.exact.CountingInferer
import experiments.examples.ExampleGraphs
import mln.factorgraph._

/**
 * Unit tests for gibbs sampling.
 *
 * Date: 24.03.11
 */

class GibbsSamplerTest extends Specification with Tags {
  def gibbsSample(factorGraph: MLNFactorGraph) = {
    val gibbsIterator = Iterator.iterate(
      MutableGibbsStepper.createInitialState(factorGraph, factorGraph.variables.toSet, new Random(42))
    )(
      MutableGibbsStepper.advanceState
    )

    MutableGibbsStepper.marginals(gibbsIterator.drop(10000).next())
  }

  "a simple MLN (Smokers)" should {
    tag("expensive")

    val factorGraph = ExampleGraphs.smokers

    "have all marginals exact upto 0.02 after 10000 iterations" in {
      val sampleMarginals = gibbsSample(factorGraph)

      val exact = new CountingInferer(factorGraph)

      sampleMarginals must beCloseMarginals(exact,0.015, factorGraph.variables.collect{case p: PredicateNode => p})
    }
  }

  "smokers MLN with evidence smokes(PETER)" should {
    tag("expensive")

    val factorGraph = ExampleGraphs.smokersWithEvidence

    "have all marginals exact upto 0.02" in {
      val sampleMarginals = gibbsSample(factorGraph)

      val exact = new CountingInferer(factorGraph)

      sampleMarginals must beCloseMarginals(exact, 0.02, factorGraph.variables)
    }
  }

  "hard smokers MLN" should {
    tag("expensive")

    val factorGraph = ExampleGraphs.hardSmokers

    "have all marginals exact upto 0.02 after 10000 iterations" in {
      val sampleMarginals = gibbsSample(factorGraph)

      val exact = new CountingInferer(factorGraph)

      sampleMarginals must beCloseMarginals(exact,0.02, factorGraph.variables.collect{case p: PredicateNode => p})
    }
  }
}