package edu.uulm.scbayes.inference.bp

import util.Random
import edu.uulm.scbayes.probabilities.{DiscreteMarginals, DiscreteVariable}
import edu.uulm.scbayes.factorgraph.messages.DiscreteMessage
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.factorgraph.{FactorGraph, DiscreteFactor}

/**
 * This class offers functions to calculate the messages of the belief propagation algorithm.
 * A message schedule has to implemented additionally.
 *
 *
 * Date: 10.05.11
 */

object BeliefPropagation {

  /**
   * Creates all messages by assuming identical distribution.
   */
  def randomMessageGenerator(rand: Random): DiscreteVariable => DiscreteMessage = {v =>
    DiscreteMessage.normalized(v.getRange.map(_ => rand.nextDouble()))
  }

  /**
   * Creates messages by assuming maximum entropy (all entries are equal).
   */
  def equalMessageGenerator: DiscreteVariable => DiscreteMessage = {v =>
    new DiscreteMessage(Array.fill[Double](v.getRange.length)(1d / v.getRange.length))
  }

  /**
   * Creates all messages originating at some given variables.
   */
  def createInitialVariableMessages[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                                      generator: DiscreteVariable => DiscreteMessage,
                                                                                      _variables: Set[VN] = null): Map[(VN, FN), DiscreteMessage] = {
    val variables = Option(_variables).getOrElse(graph.variables.toSet)
    val messages = for(v <- variables; f <- graph.factorsOf(v)) yield (v,f) -> generator(v)
    messages.toMap
  }

  def createInitialFactorMessages[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                                    generator: DiscreteVariable => DiscreteMessage,
                                                                                    _factors: Set[FN] = null): Map[(FN, VN), DiscreteMessage] = {
    val factors = Option(_factors).getOrElse(graph.factors)
    val messages = for (f <- factors; v <- graph.variablesOf(f)) yield (f, v) -> generator(v)
    messages.toMap
  }

  /** Convenience method to create random initial messages. */
  def createRandomMessages[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN,FN],
                                                                             random: Random): (Map[(VN, FN), DiscreteMessage],Map[(FN, VN), DiscreteMessage]) = {
    val generator = randomMessageGenerator(random)
    (createInitialVariableMessages(graph,generator),createInitialFactorMessages(graph,generator))
  }

  /**
   * Compute messages from factors to variables given a set of messages from variables to factors and
   * a list of factors to recompute.
   *
   * @param v2f Maps an edge from a variable node to a factor node onto a discrete message.
   * @param factors The factors to recompute.
   *
   * @return The recomputed messages originating form the given list of factors.
   */
  def computeFactorMessages[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                              v2f: Map[(VN, FN), DiscreteMessage],
                                                                              factors: Iterable[FN]): Map[(FN, VN), DiscreteMessage] = {

    /**
     * Compute a single message from a factor node to a variable node.
     *
     * @param v2f A map of messages from variables to factor nodes. Must be defined for the incoming messages
     *  of sourceFactor.
     */
    def computeF2V[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                     sourceFactor: FN,
                                                                     targetVariable: VN,
                                                                     v2f: Map[(VN, FN), DiscreteMessage]): DiscreteMessage = {
      /**
       * Calculate the factor for one given value of v.
       */
      def calculateMessageValue(x: Int): Double = {
        val summationVariables = sourceFactor.assignmentVariables

        //set the value of targetVariable fixed to x for all assignments
        val summationVariableRanges: Array[Array[Int]] = summationVariables
          .map(variable => if (variable == targetVariable) Array(x) else variable.getRange) (collection.breakOut)

        //map to the factors value multiplied by the variable messages; and sum over everything
        def assignmentToValue(assignment: IndexedSeq[Int]): Double = {
          val factorValue = math.exp(sourceFactor.logFactor(assignment))

          //go over the assignment (each Int is an assignment to a discrete variable)
          var i = 0
          var product = 1d
          while (i < assignment.size) {
            val vi = summationVariables(i)
            //exclude the message from v, by setting it to 1
            if (vi != targetVariable) product *= v2f((vi, sourceFactor))(assignment(i))
            i += 1
          }
          factorValue * product
        }

        //sum the value over all possible variable assignments
        crossProduct(summationVariableRanges).iterator.foldLeft(0: Double)(_ + assignmentToValue(_))
      }
      DiscreteMessage.normalized(targetVariable.getRange.map(x => calculateMessageValue(x)))
    }

    val entries = for (f <- factors; v <- graph.variablesOf(f)) yield ((f, v) -> computeF2V(graph, f, v, v2f))
    entries.toMap
  }

  /**
   * Compute messages from variable nodes to factor nodes given a set of messages from factors to variables
   * and a list of variables for which to recompute.
   *
   * @param f2v The old messages from factors to variables.
   *  Must be defined for all messages incoming on the nodes in variables.
   * @param variables The variables to recompute.
   *
   * @return The messages going out of the nodes given in variables.
   */
  def computeVariableMessages[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                                f2v: Map[(FN, VN), DiscreteMessage],
                                                                                variables: Iterable[VN]): Map[(VN, FN), DiscreteMessage] = {
    //the following code block is highly optimized and at the time of writing about 20% faster than the used code
    //there might be some inaccuracies originating from comparing Doubles (with 0d)

    //    val messages = for( v <- variables ) yield {
    //      //(0,-2) if there are two factors contributing zeros to the product
    //      //(x,>= 0) if f is the only factor contributing zero
    //      //(x,-1) x != 0, no zeros at all
    //
    //      val factors: IndexedSeq[FN] = graph.factorsOf(v).toIndexedSeq
    //
    //      val (cMsg,cZeros): (Array[Double],Array[Int]) = {
    //        val msg = Array.fill(v.domainSize)(1d)
    //        val zeros = Array.fill(v.domainSize)(-1)
    //
    //        var fIndex = 0
    //
    //        while(fIndex < factors.size){
    //          var valueIndex = 0
    //
    //          //process this message
    //          val fm = f2v((factors(fIndex),v))
    //
    //          while(valueIndex < v.domainSize){
    //            //if this value is already -2 there is nothing we can do
    //            if( zeros(valueIndex) != -2 ){
    //              val value = fm(valueIndex)
    //
    //              //we get a zero, check whether it's the first or second one
    //              if(value == 0) {
    //                //first zero
    //                if(zeros(valueIndex) == -1) {
    //                  zeros(valueIndex) = fIndex
    //                  //second zero
    //                } else {
    //                  zeros(valueIndex) = -2
    //                }
    //                //normal case
    //              } else {
    //                msg(valueIndex) *= value
    //              }
    //            }
    //            valueIndex = valueIndex + 1
    //          }
    //
    //          fIndex = fIndex + 1
    //        }
    //
    //        (msg,zeros)
    //      }
    //
    //      //second pass to calculate all messages originating from this variable
    //      factors.view.zipWithIndex.map{ case (f,fIndex) => (v,f) -> {
    //        val msgForF = cMsg.clone
    //        val msgFromF = f2v((f, v))
    //
    //        var valueIndex = 0
    //        while(valueIndex < msgForF.length){
    //          if(cZeros(valueIndex) == -1) {
    //            msgForF(valueIndex) /= msgFromF(valueIndex)
    //          } else if (cZeros(valueIndex) == fIndex ) {
    //            //do nothing, this factor contributes the only zero and we take the value as is
    //          } else if (cZeros(valueIndex) > -1) {
    //            //it's someone else
    //            msgForF(valueIndex) = 0
    //          } else {
    //            //zeros must be -2, two factors contributing zeros
    //            msgForF(valueIndex) = 0
    //          }
    //
    //          valueIndex += 1
    //        }
    //        new DiscreteMessage(msgForF)
    //      }}
    //    }
    //
    //    messages.flatten.toMap

    /**
     * Compute a single message from a variable node to a factor node.
     */
    def computeV2F[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                     v: VN,
                                                                     f: FN,
                                                                     f2v: Map[(FN, VN), DiscreteMessage]): DiscreteMessage = {

      val factors = graph.factorsOf(v) ensuring(_.contains(f), "can only compute messages for adjacent nodes")

      //collect all messages to this node, except the one from f
      val inputMessages = factors.view.toSeq.filter(_ != f).map(f => f2v((f, v)))(collection.breakOut)

      //val result: Array[Double] = (0 until v.domainSize).map(x => inputMessages.view.map(msg => msg(x)).product)(collection.breakOut)
      var i = 0
      val messageValues = new Array[Double](v.domainSize)
      while (i < v.domainSize) {
        var product = 1d
        var i2 = 0
        while (i2 < inputMessages.size) {
          product *= inputMessages(i2)(i);
          i2 = i2 + 1
        }
        messageValues(i) = product
        i = i + 1
      }

      DiscreteMessage.normalized(messageValues)
    }

    val entries = for (v <- variables; f <- graph.factorsOf(v)) yield ((v, f) -> computeV2F(graph, v, f, f2v))
    entries.toMap
  }

  /**Compute marginal probabilities from a set of factor -> variable messages. */
  def marginals[VN <: DiscreteVariable, FN <: DiscreteFactor[VN]](graph: FactorGraph[VN, FN],
                                                                  f2v: Map[(FN, VN), DiscreteMessage]): DiscreteMarginals[VN] =
    new DiscreteMarginals[VN] {
      def marginal(rv: VN, value: Int): Double = {
        require(value < rv.domainSize && value >= 0, "value out of domain")

        val inputMessages = graph.factorsOf(rv).toSeq.map(f => f2v((f, rv)))
        val normalization = rv.getRange.map(v => inputMessages.map(msg => msg(v)).product).sum
        inputMessages.map(msg => msg(value)).product / normalization
      }

      def canInfer(v: VN): Boolean = graph.variables.contains(v)
    }
}
