package edu.uulm.scbayes.inference.sampling

import util.Random
import edu.uulm.scbayes.factorgraph._
import edu.uulm.scbayes.probabilities._
import edu.uulm.scbayes.util._

/**
 * A Gibbs sampler for factor graphs.
 *
 *
 * Date: 3/21/11
 */


object MutableGibbsStepper extends MutableMCMCStepper {
  protected def stepStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: MutableGibbsStepper.SubRepr[V, F],
                                                                             graph: DiscreteFactorGraph[V,F],
                                                                             random: Random): MutableGibbsStepper.SubRepr[V, F] = {
    /**
     * This method may only be called for non-evidence variables.
     *
     * @return The new value of a variable we attempt to flip.
     */
    def flipVariable(variable: V, assignment: HashMapAssignment[V,Int], random: Random): Int = {

      val range = variable.getRange

      val unnormalized_probabilities = range.map {
        a =>
          assignment.temporaryUpdate(Map(variable -> a))
          val factors: Seq[F] = graph.factorsOf(variable).toSeq
          val factorValues = factors.map(_.logFactorFromInterpretation(assignment))
          val result = math.exp(factorValues.sum)
          assignment.revert()
          result
      }

      val normalization = unnormalized_probabilities.sum

      val draw = random.nextDouble() * normalization

      def findDraw[A](draw: Double, domain: IndexedSeq[(Double, A)]): A = {
        var idx = 0
        var acc = draw
        while(domain(idx)._1 < acc){
          acc = acc - domain(idx)._1
          idx = idx + 1
        }

        domain(idx)._2
      }

      findDraw(draw, unnormalized_probabilities.zip(range))
    }

    //flip all variables
    for (v <- graph.variables) {
      val flip = flipVariable(v, state, random)
      state.update(v, flip)
    }

    //return the mutated state
    state
  }

  protected def valueOfVariable[V <: DiscreteVariable, F <: DiscreteFactor[V]](v: V, state: MutableGibbsStepper.SubRepr[V, F]): Int = state(v)

  protected def createInitialStateMCMC[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                                      query: Set[V],
                                                                                      random: Random): MutableGibbsStepper.SubRepr[V, F] = {
    val assignment = new HashMapAssignment[V,Int]()

    for(v <- graph.variables) {
        //generate random interpretation
        assignment.update(v, v.getRange.pickRandom(random))
    }

    assignment
  }

  type SubRepr[V <: DiscreteVariable, F <: DiscreteFactor[V]] = HashMapAssignment[V,Int]
}