package edu.uulm.scbayes.probabilities

/**
 * This trait is implemented by classes, that are able to compute marginal distributions over discrete variables.
 *
 *
 * Date: 29.03.11
 */

trait DiscreteMarginals[V <: DiscreteVariable] {
  def canInfer(v: V): Boolean
  def marginal(rv: V, value: Int): Double

  /**
   * Add some equally distributed random variables to this marginal object.
   * Can be used if a factor graph does not infer marginals of nodes that are unconnected to any factor.
   */
  def padEqual(vs: V => Boolean): DiscreteMarginals[V] = new DiscreteMarginals[V] {
    def canInfer(v: V): Boolean = DiscreteMarginals.this.canInfer(v) || vs(v)

    def marginal(rv: V, value: Int): Double = if( DiscreteMarginals.this.canInfer(rv))
      DiscreteMarginals.this.marginal(rv,value)
    else
      1d / rv.domainSize
  }

  def orElse(other: DiscreteMarginals[V]) = new JointMarginals(this, other)

  def filter(pred: V => Boolean) = new DiscreteMarginals[V] {
    def canInfer(v: V): Boolean = pred(v) && DiscreteMarginals.this.canInfer(v)
    def marginal(rv: V, value: Int): Double = DiscreteMarginals.this.marginal(rv, value)
  }

  def uaiResultString(variables: Seq[V]): String = {
    val varsWithDist = variables.map(variable => variable -> variable.getRange.map(value => marginal(variable,value))).toMap
    val distStrings = variables.map(variable => variable.domainSize + " " + varsWithDist.mkString(" "))
    "MAR\n" + variables.size + " " + distStrings.mkString(" ")
  }
}

object DiscreteMarginals {
  /** @return A DiscreteMarginals that is defined no where. */
  def empty[V <: DiscreteVariable]: DiscreteMarginals[V] = new DiscreteMarginals[V] {
    def marginal(rv: V, value: Int): Double = throw new IllegalArgumentException("I have nothing to show")
    def canInfer(v: V): Boolean = false
  }

  implicit def discreteInferer2RichOne[V <: DiscreteVariable](inf: DiscreteMarginals[V]) = new {
    def toPrettyStrings(variables: Iterable[V]) = for(
      rv <- variables;
      a <- rv.getRange
    ) yield "P(%s = %s) = %f".format(rv,a,inf.marginal(rv,a))

    def printMarginals(variables: Iterable[V]) {
      println("Marginals of %s:\n\t%s".format(inf,toPrettyStrings(variables.toList.sortBy(_.toString)).mkString("\n\t")))
    }
  }
}