package edu.uulm.scbayes.inference

import util.Random
import edu.uulm.scbayes.probabilities.{MeanMarginals, DiscreteMarginals, DiscreteVariable}
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.factorgraph.{DiscreteFactor, DiscreteFactorGraph}

/**
 * Wraps a couple of SteppingGraphInferers and runs them in parallel.
 *
 * Date: 15.06.11
 */

class ParallelGraphInferer[TInferer <: SteppingGraphInferer]
(val numChains: Int,
val inferer: TInferer)
  extends SteppingGraphInferer{

  type TState[V <: DiscreteVariable, F <: DiscreteFactor[V]] = Seq[inferer.TState[V,F]]

  def extractSubstates[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: this.TState[V,F]): Seq[inferer.TState[V,F]] = state
  def packSubstates[V <: DiscreteVariable, F <: DiscreteFactor[V]](subStates: Seq[inferer.TState[V,F]]): this.TState[V,F] = subStates


  def marginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: this.TState[V,F]): DiscreteMarginals[V] =
    new MeanMarginals[V](state.map(inferer.marginals))

  def advanceState[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: this.TState[V,F]): this.TState[V,F] =
    state.par.map(inferer.advanceState).seq

  def createInitialState[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                        query: Set[V],
                                                                        random: Random): this.TState[V,F] =
    random.split(numChains).map(inferer.createInitialState(graph, query, _))

  def chainMarginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: TState[V, F]): Seq[DiscreteMarginals[V]] =
    extractSubstates(state).map(inferer.marginals)
}