package edu.uulm.scbayes.inference.exact

import collection.mutable.HashMap
import edu.uulm.scbayes.factorgraph._
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.probabilities.{DiscreteVariable, DiscreteMarginals}
import edu.uulm.scbayes.inference.NInferer

/**
 * Infer the exact marginal probabilities of a FactorGraph by summing over all possible interpretations.
 *
 * Date: 29.03.11
 */

class CountingInferer[V <: DiscreteVariable, F <: DiscreteFactor[V]](val graph: DiscreteFactorGraph[V, F])
  extends DiscreteMarginals[V] {

  val (weights, total_weight) = {

    val _weights = new HashMap[(V, Int), Double]
    var _total_weight: Double = 0

    val variables = graph.variables.toSeq

    //init weights
    for (
      v <- variables;
      a <- v.getRange
    ) {
      _weights.put((v, a), 0)
    }

    val cpi = new CrossProductIndexer(variables.map(_.domainSize))

    cpi.iterator.foreach {
      interpretation =>
        //create an interpretation function from the list generated by CombinationIterator
        val int_fun: Map[DiscreteVariable, Int] = variables
          .zip(interpretation)
          .toMap

        //sum the log values of the factors for the current interpretation
        val int_eval = math.exp(
          graph
            .factors
            .toSeq
            .map(_.logFactorFromInterpretation(int_fun))
            .sum
        )

        //update weights for all predicates
        variables.foreach {
          v =>
            val key = (v, int_fun(v))
            _weights.put(key, _weights(key) + int_eval)
        }

        //update normalization
        _total_weight = _total_weight + int_eval
    }

    //return results
    (_weights, _total_weight)
  }

  override def marginal(rv: V, value: Int): Double = weights((rv, value)) / total_weight

  def canInfer(v: V): Boolean = graph.variables.contains(v)
}

object CountingInferer extends NInferer{
  override def infer[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F]): DiscreteMarginals[V] =
    new CountingInferer(graph)
}
