package edu.uulm.scbayes.inference

import edu.uulm.scbayes.factorgraph.DiscreteFactorGraph
import util.Random
import collection.mutable.HashMap
import edu.uulm.scbayes.probabilities.DiscreteVariable

/**
 * Base class for MCMC algorithms that must count how often a variable assumes a certain value.
 *
 *
 * Date: 26.05.11
 */
abstract class MCMCInferenceAlgorithm[V <: DiscreteVariable](graph: DiscreteFactorGraph[V,_],
                                                             query: Set[V])
extends InferenceAlgorithm[V](graph,query) {

  type SubRepr
  final type Repr = (HashMap[(V,Int),Int],Int,SubRepr)

  final def computeMarginals(variable: V, value: Int, state: Repr): Double =
    state._1((variable,value)).toDouble / state._2

  final def computeStep(oldState: Repr, random: Random): Repr = {
    val (counts, steps, subrepr) = oldState

    def updateCount(v: V, state: SubRepr) {
      val key = (v, valueOfVariable(v, state))
      counts.put(key, counts(key) + 1)
    }

    val newSubState = stepStateMCMC(subrepr, random)

    //update variables
    graph.variables.foreach(updateCount(_,newSubState))

    //counts is returned mutated
    (counts, steps + 1, newSubState)
  }

  final def createInitialState(random: Random): Repr = {
    val counts = new HashMap[(V,Int),Int]

    for(v <- graph.variables; value <- v.getRange) {
      counts.put((v, value), 0)
    }

    (counts, 0, createInitialStateMCMC(random))
  }

  protected def createInitialStateMCMC(random: Random): SubRepr
  protected def valueOfVariable(v: V, state: SubRepr): Int
  protected def stepStateMCMC(state: SubRepr, random :Random): SubRepr

}

