package edu.uulm.scbayes.inference.sampling

import collection.mutable.HashMap
import util.Random
import edu.uulm.scbayes.factorgraph._
import edu.uulm.scbayes.probabilities._
import edu.uulm.scbayes.inference.SteppingGraphInferer

/**
 * Common functionality for MCMC algorithms.
 *
 *
 * Date: 3/21/11
 */

trait MutableMCMCStepper extends SteppingGraphInferer {

  type SubRepr[V <: DiscreteVariable, F <: DiscreteFactor[V]]
  final type TMCMC[V] = (HashMap[(V,Int),Int],Int)
  final type TState[V <: DiscreteVariable, F <: DiscreteFactor[V]] = (DiscreteFactorGraph[V,F], Random, TMCMC[V], SubRepr[V,F])


  final def createInitialState[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                              query: Set[V],
                                                                              random: Random): TState[V,F] = {
    //create the "empty" hashmap for soting the counts
    val counts = new HashMap[(V,Int),Int]

    for(v <- graph.variables; value <- v.getRange) {
      counts.put((v, value), 0)
    }

    val initialSubstate = createInitialStateMCMC(graph, query, random)
    (graph, random, updateCounts(graph, (counts, 0), initialSubstate), initialSubstate)
  }

  final def marginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: TState[V,F]): DiscreteMarginals[V] = {
    val (graph, _, mcmcState, _) = state
    new DiscreteMarginals[V] {
      def marginal(rv: V, value: Int): Double = mcmcState._1((rv,value)).toDouble / mcmcState._2

      def canInfer(v: V): Boolean = graph.variables.contains(v)
    }
  }

  final def advanceState[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: TState[V,F]): TState[V,F] = {
    val (graph, random, mcmcState, subState) = state

    val newSubstate: SubRepr[V,F] = stepStateMCMC(subState, graph, random)

    //counts is returned mutated
    (graph, random, updateCounts(graph, mcmcState, newSubstate), newSubstate)
  }

  private def updateCounts[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V,F],
                                                                          mcmcState: TMCMC[V],
                                                                          subState: SubRepr[V,F]): TMCMC[V] = {
    val (counts, steps) = mcmcState
    //update variables
    def updateCount(v: V) {
      val key = (v, valueOfVariable(v, subState))
      counts.put(key, counts(key) + 1)
    }
    graph.variables.foreach(updateCount)

    (counts, steps + 1)
  }

  protected def createInitialStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V,F],
                                                                                      query: Set[V],
                                                                                      random: Random): SubRepr[V,F]
  protected def valueOfVariable[V <: DiscreteVariable, F <: DiscreteFactor[V]](v: V, state: SubRepr[V,F]): Int
  protected def stepStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: SubRepr[V,F],
                                                                             graph: DiscreteFactorGraph[V,F],
                                                                             random :Random): SubRepr[V,F]
}

