package edu.uulm.scbayes.probabilities.statistics

import edu.uulm.scbayes.probabilities.{DiscreteVariable, DiscreteMarginals}

/**
 * Calculates the KL-Divergence between two distributions of the same DiscreteVariable.
 *
 *
 * Date: 30.05.11
 */

object KLDivergence {
  /** Calculates the Kullbach-Leibler divergence.
   *
   * @param v The random variable over whose distribution to compute KL.
   * @param p Typically the "true" distribution.
   * @param q A distribution that might differ from the "true" distribution.
   *
   * @return D_KL(p||q).
   *
   */
  def klDivergence[V <: DiscreteVariable](v: V,
                                          p: DiscreteMarginals[V],
                                          q: DiscreteMarginals[V]): Double = {
    require(p.canInfer(v) || q.canInfer(v), "both distributions must contain the given variable v")

    //\sum_i P(i) \log{P(i) / Q(i)}
    //where i assumes the possible values of v
    val result = v.getRange.map{i =>
      {
        val pp: Double = p.marginal(v, i)
        if (pp == 0) return 0
        val pq: Double = q.marginal(v, i)
        pp * math.log(pp / pq)
      }
    }.sum
    if (result.isNaN) {
      println("got it: " + v)
      v.getRange.foreach{i =>
        println("%s : %f <-> %f = %f".format(i, p.marginal(v,i), q.marginal(v,i), p.marginal(v,i) * math.log(p.marginal(v,i)/q.marginal(v,i))))
      }
    }
    result
  }

  /** Symetric version of the KL divergence by adding both directions together. */
  def klDivergenceSymetric[V <: DiscreteVariable](v: V,
                                          p: DiscreteMarginals[V],
                                          q: DiscreteMarginals[V]): Double =
    klDivergence(v,p,q) + klDivergence(v,q,p)

  def mutlivariateKLDistance[V <: DiscreteVariable](variables: Set[V],
                                                    p: DiscreteMarginals[V],
                                                    q: DiscreteMarginals[V]): Double =
    variables.map(klDivergence(_,p,q)).sum

  def mutlivariateSymetricKLDistance[V <: DiscreteVariable](variables: Set[V],
                                                            p: DiscreteMarginals[V],
                                                            q: DiscreteMarginals[V]): Double =
    variables.map(klDivergenceSymetric(_,p,q)).sum
}