package edu.uulm.scbayes.inference.bp

import edu.uulm.scbayes.factorgraph.{DiscreteFactor, DiscreteFactorGraph}
import util.Random
import edu.uulm.scbayes.probabilities.{DiscreteMarginals, DiscreteVariable}
import edu.uulm.scbayes.inference.{SteppingGraphInferer, ParallelGraphInferer}

/**
 * Acts like a ParallelInferer using BeliefPropagation. But before each step of BP, the messages between the chains
 * are mixed.
 */

class MixingBP(numChains: Int, val mixingRatio: Double) extends SteppingGraphInferer {
  val parInferer = new ParallelGraphInferer(numChains, FloodingBeliefPropagationStepper)

  /**
   * selfProbability is the probability to retain the own message; chainProbability is the probability for each foreign chain
   * to get mixed in.
   */
  val (selfProbability, chainProbability) = (1 - mixingRatio, mixingRatio / numChains)

  def findDraw[A](draw: Double, domain: IndexedSeq[(Double, A)]): A = {
    var idx = 0
    var acc = draw
    while(domain(idx)._1 < acc){
      acc = acc - domain(idx)._1
      idx = idx + 1
    }

    domain(idx)._2
  }

  /**
   * Example:
   * input (1,Map(1 -> A)),(9, Map(1 -> B))
   * output will be Map(1 -> A) in 1 out of 10 cases and Map(1 -> C) otherwise.
   *
   * @return A map that is created by drawing the value of each key at random from the given maps with
   *  a probability given by the first tuple entry for each map.
   */
  def mixMaps[A,B](inMaps: IndexedSeq[(Double,Map[A,B])], rand: Random): Map[A,B] = {
    val normalization = inMaps.map(_._1).sum
    inMaps.head._2.keys.map{key: A =>
      val pickedMap = findDraw(normalization * rand.nextDouble(), inMaps)
      key -> pickedMap(key)
    }.toMap
  }

  /** receives a sequence of states from a BP inferer and returns the mixed states. */
  protected def mixChains[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: Seq[FloodingBeliefPropagationStepper.TState[V,F]],
                                                                         random: Random): Seq[FloodingBeliefPropagationStepper.TState[V,F]] = {
    val messages = state.toIndexedSeq

    //build the mixed messages
    val mixedMessages = (0 until numChains) //each number will map to the modified messages for that chain
      .map{n => mixMaps(
      //compute the input for mix maps
      messages.zipWithIndex     //add an index to each chain's messages
        //"the chain that gets stuff mixed in" is identified by 'n'
        .map{case ((graph,msgs),idx) => ((if(n == idx) selfProbability else chainProbability),msgs)}, random)}

    messages.zip(mixedMessages).map{case ((graph, oldMsgs), mixedMsgs) => (graph, mixedMsgs)}
  }

  def marginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: MixingBP#TState[V, F]): DiscreteMarginals[V] =
    parInferer.marginals(state._2)

  def advanceState[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: MixingBP#TState[V, F]): MixingBP#TState[V, F] = {
    val (random, parState) = state
    (random, parInferer.advanceState(parInferer.packSubstates(mixChains(parInferer.extractSubstates(parState), random))))
  }


  def createInitialState[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F],
                                                                        query: Set[V],
                                                                        random: Random): MixingBP#TState[V, F] =
    (random, parInferer.createInitialState(graph, query, new Random(random.nextLong())))

  def chainMarginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](state: MixingBP#TState[V, F]): Seq[DiscreteMarginals[V]] =
    parInferer.extractSubstates(state._2).map(FloodingBeliefPropagationStepper.marginals)

  type TState[V <: DiscreteVariable, F <: DiscreteFactor[V]] = (Random, parInferer.TState[V,F])
}

