package edu.uulm.scbayes.inference.mlnsampling

import util.Random
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.logic.TruthAssignment
import edu.uulm.scbayes.inference.sampling.MutableMCSatInferer
import edu.uulm.scbayes.logic.cnf.sampling.SampleSAT
import edu.uulm.scbayes.probabilities.statistics.MarginalVarianceTester
import edu.uulm.scbayes.mln.factorgraph.MLNFactorGraph

/**
 * This class allows to use the MCSat MCMC algorithm to draw a sample from a MLN.
 *
 *
 * Date: 19.05.11
 */

class MCSatSampler(val graph: MLNFactorGraph,
                   val numChains: Int,
                   val maxVariance: Double,
                   val maxSteps: Int,
                   private val rand: Random) {

  /** @return A sample from the given graph and the number of MCSAT steps required. */
  def getSampleAndSteps: (TruthAssignment, Int) = {
    var samplers = rand.split(numChains)
      //.par
      .map {
        rand =>
        val inferer = new MutableMCSatInferer(new SampleSAT(0.5, 0.1, 10000), true, false)
        (inferer, inferer.createInitialState(graph,graph.variables,rand))
      }

    //otherwise variance will be 0
    def stepState(steps: Int) {
      for(i <- 1 to steps) {
        samplers = samplers.map {
          case (inferer, state) => (inferer, inferer.advanceState(state))
        }
      }
    }

    stepState(5)
    var stepsTaken = 5

    def currentVariance = MarginalVarianceTester.maxVariance(samplers.map(t => t._1.marginals(t._2)).seq, graph.variables)
    while (stepsTaken < maxSteps && (currentVariance > maxVariance || numChains == 1)) {
      stepState(1)
      stepsTaken += 1
    }

    //output the truth assignment of the first sampler
    (samplers.head._2._4._1.asInstanceOf[TruthAssignment], stepsTaken)
  }

  def getSample: TruthAssignment = getSampleAndSteps._1
}