package edu.uulm.scbayes

import factorgraph.{DiscreteFactorGraph, DiscreteFactor}
import probabilities.statistics.MarginalVarianceTester
import probabilities.{DiscreteMarginals, DiscreteVariable}
import scala.util.Random

/**
 * Convenience stuff for inference.
 *
 * Date: 21.11.11
 */
package object inference {
  implicit def stepperToRichStepper(st: SteppingGraphInferer) = new {
    def forSteps(n: Int, random: Random = new Random()): NInferer = new NInferer {
      def infer[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F]): DiscreteMarginals[V] = {
        val init = st.createInitialState(graph, graph.variables, random)
        val it= Iterator.iterate(init)(st.advanceState(_))
        st.marginals(it.drop(n).next())
      }
    }

    def maxVarianceConvergence(numChains: Int, maxVariance: Double, maxSteps: Int, rand: Random): NInferer = new NInferer {
      def infer[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F]): DiscreteMarginals[V] = {
        val parInference = new ParallelGraphInferer(numChains, st)
        val inferenceSteps = Iterator.iterate(parInference.createInitialState(graph, graph.variables.toSet, rand))(parInference.advanceState)
        val inferenceIterator = inferenceSteps
          .take(maxSteps)
          .dropWhile(state => MarginalVarianceTester.maxVariance(parInference.chainMarginals(state), graph.variables) > maxVariance)
        if(inferenceIterator.hasNext)
          parInference.marginals(inferenceIterator.next())
        else
          throw new RuntimeException("inference did not converge after %d steps".format(maxSteps))
      }
    }
  }
}