package edu.uulm.scbayes.inference

import util.Random
import edu.uulm.scbayes.probabilities.{DiscreteVariable, DiscreteMarginals}
import edu.uulm.scbayes.probabilities.statistics.MarginalVarianceTester

/**
 * Implements the parallel running of several inferers.
 *
 *
 * Date: 26.05.11
 */

class ParallelInferer[V <: DiscreteVariable, TInf <: InferenceAlgorithm[V]](val numChains: Int,
                                                                            val inferer: InferenceAlgorithm[V],
                                                                            val random: Random){
  //tuples of inferer state and the according random number generator
  protected var states: Seq[(inferer.Repr,Random)] = _
  protected var steps = 0

  reset()

  def reset() {
    steps = 0

    val seed = random.nextInt()

    states = (1 to numChains)
      .map(n => new Random(seed + n))    //create the rngs
      .map(rng => (inferer.createInitialState(rng),rng))
  }

  def step() {
    //advance every chain by one
    states = states.map{case (repr, rng) => (inferer.computeStep(repr, rng),rng)}
    steps = steps + 1
  }

  /** Perform a certain number of steps. */
  def runFor(steps: Int) {for(i <- 1 to steps) {step()}}

  def getMarginals(chain: Int): DiscreteMarginals[V] = new DiscreteMarginals[V] {
    def canInfer(v: V): Boolean = inferer.canInfer(v)

    def marginal(rv: V, value: Int): Double = inferer.computeMarginals(rv,value,states.seq(chain)._1)
  }

  def getAverageMarginals: DiscreteMarginals[V] = new DiscreteMarginals[V] {
    def canInfer(v: V): Boolean = inferer.canInfer(v)

    def marginal(rv: V, value: Int): Double = states
      .map(_._1)
      .map(state => inferer.computeMarginals(rv,value,state))
      .sum.toDouble / numChains
  }


  //todo the following two methods should go somewhere else soon
  def runUntilConvergence(maxSteps: Int, maxVarianceThresh: Double) {
    //todo fix this properly
    if(inferer.query.isEmpty) return
    runFor(3)
    do {
      step()
      //val variable = inferer.graph.variables.head
      //val value = variable.getRange.head
      //println((0 until numChains).map(getMarginals(_).marginal(variable,value)).mkString("; "))
    } while(maxVariance > maxVarianceThresh && steps < maxSteps)
  }

  def maxVariance: Double = {
    MarginalVarianceTester.maxVariance((0 to (numChains - 1)).map(getMarginals), inferer.query.toSeq)
  }

  def getSteps = steps
}