package edu.uulm.scbayes.inference.junctiontree

import edu.uulm.scbayes._
import factorgraph.{DiscreteFactor, DiscreteFactorGraph}
import inference.NInferer
import probabilities.{JointMarginals, DiscreteMarginals, DiscreteVariable}
import scalaz._
import edu.uulm.scbayes.util.CrossProductIndexer

object JunctionTreeInference extends NInferer {

  def infer[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V, F]): DiscreteMarginals[V] = inferMarginals(graph)

  def addLog(a: Double, b: Double): Double =
    if(a.isNegInfinity)
      b
    else if(b.isNegInfinity)
      a
    else
      a + math.log1p(math.exp(b - a))

  def computeMarginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](tree: JunctionTree[V,F]): DiscreteMarginals[V] = {
    def sumOut[TF[V <: DiscreteVariable] <: DiscreteFactor[V]](args: IndexedSeq[V], sumArgs: Iterable[V], factors: Iterable[TF[V]]): JTMessage[V] = {
      val sumSeq = sumArgs.toIndexedSeq
      val argIndexer = new CrossProductIndexer(args.map(_.domainSize))
      val sumIndexer = new CrossProductIndexer(sumSeq.map(_.domainSize))

      val completeSequence = (args ++ sumSeq).toIndexedSeq

      val messageValues: Array[Double] = argIndexer.iterator.map{ msgAssignment =>
        //now sum out the other values
        sumIndexer.iterator.map{ sumOutAssignment =>
          val assignment = completeSequence.view.zip(msgAssignment ++ sumOutAssignment).toMap
          //multiply the factors <=> sum their log values
          (factors.map(f => f.logFactorFromInterpretation(assignment))).sum
        }.reduce(addLog(_,_))
      }.toArray
      JTMessage(args, messageValues, true).normalized
    }

    //compute the variables that are shared by a node and its parent and store them inside second tuple entry
    val withIntersections = JunctionTree.scand[Set[V],Set[V],(Set[V],Set[V])](tree.variables)(Set.empty[V]){
        case (parentSet,mySubTree) =>
          val myVariables = mySubTree.rootLabel
          val intersectionWithParent = myVariables intersect parentSet
          ((myVariables, intersectionWithParent),Seq.fill(mySubTree.subForest.size)(myVariables))
      }

    val (withUpwardMessages, cpuUpwardPass) = util.benchmarkCPUTime(
      JunctionTree.upwardPass[(Set[V],Set[V]),JTMessage[V]](withIntersections){
        case ((clique,upwardIntersection), incomingUpwardMessages) =>
          val messageVariables = upwardIntersection.toIndexedSeq
          val sumOutVariables: Seq[V] = (clique -- messageVariables).toIndexedSeq
          val factors = tree.factors(clique) ++ incomingUpwardMessages

          sumOut(messageVariables, sumOutVariables, factors)
      }
    )

    val (withDownwardMessages, cpuDownwardPass) = util.benchmarkCPUTime(
      JunctionTree.downwardPass[(Set[V],Set[V]),JTMessage[V],JTMessage[V]](withUpwardMessages)(null){
        case ((clique, _),incomingUpwardMessages,incomingDownwardMessage) =>
          JunctionTree.mapOthers2[JTMessage[V],JTMessage[V]](incomingUpwardMessages){
            case (receiver, otherUpwardMessages) =>
              //extract the downward interface from the message
              val messageVariables = receiver.variableOrder
              val sumOutVariables: Seq[V] = (clique -- messageVariables).toSeq
              val factors = tree.factors(clique) ++
                otherUpwardMessages ++
                (if(incomingDownwardMessage != null) Seq(incomingDownwardMessage) else Seq())

              sumOut(messageVariables, sumOutVariables, factors)
          }
      }
    )

    def myTreeFind[A](t: Tree[A])(p: A => Boolean): Option[Tree[A]] = {
      if(p(t.rootLabel))
        Some(t)
      else
        t.subForest.map(subTree => myTreeFind(subTree)(p)).flatten.headOption
    }

    def calcMarginal(v: V): JTMessage[V] = {
      //find (the smallest) clique, that contains v
      val cliquesWithV = withDownwardMessages.flatten.filter(_._1._1.contains(v))
      val node@((clique, _), _, downMessage) = cliquesWithV.minBy(_._1._1.size)
      //findChild DOES NOT FIND THE ROOT LABEL!
      val subTree = myTreeFind(withDownwardMessages)(_ == node).get

      //sum everything out except v, using the factors inside the clique, the up messages from its children and the down message if not null (for the root)
      val factors = tree.factors(clique) ++ subTree.subForest.map(_.rootLabel._2) ++ Seq(Option(downMessage)).flatten
      sumOut(IndexedSeq(v), clique - v, factors)
    }

    new DiscreteMarginals[V] {
      def canInfer(v: V): Boolean = tree.variables.flatten.exists(_.contains(v))
      def marginal(rv: V, value: Int): Double = {
        val message = calcMarginal(rv)
        message(IndexedSeq(value))
      }
    }
  }

  def inferMarginals[V <: DiscreteVariable, F <: DiscreteFactor[V]](factorGraph: DiscreteFactorGraph[V,F], observe:(Seq[JunctionTree[V, F]],Long) => Unit = null): DiscreteMarginals[V] = {
    val (junctionTrees: Seq[JunctionTree[V, F]], cpuTreeDecomposition: Long) = util.benchmarkCPUTime(JunctionTree.fromFactorGraph(factorGraph).toSeq)
    if(observe != null) observe(junctionTrees,cpuTreeDecomposition)
    JointMarginals(junctionTrees.map(computeMarginals):_*)
  }
}

case class JTMessage[V <: DiscreteVariable](variableOrder: IndexedSeq[V], values: Seq[Double], logValues: Boolean = false) extends DiscreteFactor[V] {
  val indexer = new CrossProductIndexer(variableOrder.map(_.domainSize))

  require(values.size == indexer.size, "message expected %d values but received a Seq with %d values".format(indexer.size,values.size))

  def assignmentVariables: IndexedSeq[V] = variableOrder

  def logFactor(assign: IndexedSeq[Int]): Double = if(logValues) values(indexer.seq2Index(assign)) else math.log(apply(assign))

  override def apply(v1: IndexedSeq[Int]): Double = if(!logValues) values(indexer.seq2Index(v1)) else math.exp(logFactor(v1))

  override def toString(): String = "JTMessage(Variables(%s),Values(%s),%s)".format(
    variableOrder.mkString(","),
    values.map("%.2e".format(_)).mkString(","),
    if(logValues) "log Values" else "normal Values")

  def normalized: JTMessage[V] = {
    if(!logValues){
      val z = values.sum
      copy(values = values.map(d => d / z))
    } else {
      //first shift the values for the maximum to be 0
      val shift = values.max
      val shifted = values.map(_ - shift)
      val z = shifted.view.map(math.exp).sum
      copy(values = shifted.map(_ - math.log(z)))
    }
  }
}

case class JunctionTree[V <: DiscreteVariable,F <: DiscreteFactor[V]](variables: Tree[Set[V]], factors: Set[V] => Seq[F]){
  def treeWidth: Int = variables.flatten.map(_.size).max - 1
}

object JunctionTree {
  /** downward propagation in trees. */
  def scand[A,B,C](tree: Tree[A])(init: B)(f: (B,Tree[A]) => (C,Seq[B])): Tree[C] = {
    val (newVal, childPropagations) = f(init,tree)
    Tree.node(newVal,tree.subForest.zip(childPropagations).map{case (child,childProp) => scand(child)(childProp)(f)})
  }

  /**
   * @tparam A Node type of tree.
   * @tparam B Type of upward messages.
   */
  def upwardPass[A,B](tree: Tree[A])(computeUpwardMessage: (A, Seq[B]) => B): Tree[(A,B)] =
    tree.scanr[(A,B)]{case (a, subForests) => (a,computeUpwardMessage(a,subForests.map(_.rootLabel._2)))}

  /**
   * @param computeDownwardMessages Receives the node to do the computation (1st arg), the upward messages of its
   *   children (2nd arg) and the downward message from its parent (3rd arg).
   *   Must return the downward messages for its children in the order given by the second argument.
   * @param init The downward message to pass to the root node.
   *
   * @tparam A Node type of tree.
   * @tparam B Type of upward messages.
   * @tparam C Type of downward messages.
   *
   * @return The tree where each node additionally contains its outgoing upward message
   *  and its incoming downward message.
   */
  def downwardPass[A,B,C](tree: Tree[(A,B)])(init: C)(computeDownwardMessages: (A,Seq[B],C) => Seq[C]): Tree[(A,B,C)] =
    scand[(A,B),C,(A,B,C)](tree)(init){ case (downwardMessage, treeAB) =>
      val downMessages = computeDownwardMessages(treeAB.rootLabel._1,treeAB.subForest.map(_.rootLabel._2),downwardMessage)
      val ownValue = (treeAB.rootLabel._1,treeAB.rootLabel._2,downwardMessage)
      (ownValue,downMessages)
    }

  /** Given a sequence xs of As, call a function for each element x of xs, taking xs - x as argument.
    * Return the results as a sequence. */
  def mapOthers[A,B](xs: Seq[A])(f: Seq[A] => B): Seq[B] = {
    for(
      (x,idx) <- xs.zipWithIndex;
      others = xs.take(idx) ++ xs.drop(idx + 1)
    ) yield f(others)
  }

  /** Return xs without the element at position idx. */
  def others[B, A](xs: scala.Seq[A], idx: Int): Seq[A] =  xs.take(idx) ++ xs.drop(idx + 1)

  /** Also gives the current element as argument to f. */
  def mapOthers2[A,B](xs: Seq[A])(f: (A,Seq[A]) => B): Seq[B] = {
    for(
      (x,idx) <- xs.zipWithIndex;
      otherx = others(xs, idx)
    ) yield f(x,otherx)
  }

  def fromFactorGraph[V <: DiscreteVariable, F <: DiscreteFactor[V]](graph: DiscreteFactorGraph[V,F]): Set[JunctionTree[V,F]] = {
    import edu.uulm.scbayes.util.TreeWidth

    val var2Idx: Map[V, Int] = graph.variables.zipWithIndex.toMap
    val idx2Var: Map[Int, V] = var2Idx.map(_.swap)
    val decomposition: Seq[Tree[(Set[Int], Seq[F])]] =
      TreeWidth.minDegreeJTs(graph.factors.map(df => (df.variables.map(var2Idx).toSet,df)).toIndexedSeq)

    decomposition.map{tree =>
      val factorMap: Map[Set[V], Seq[F]] = tree.map(sf => sf._1.map(idx2Var) -> sf._2).flatten.toMap
      val cliqueTree: Tree[Set[V]] = tree.map(n => n._1.map(idx2Var))
      JunctionTree[V,F](cliqueTree, factorMap)
    }.toSet
  }
}