package edu.uulm.scbayes.inference.junctiontree

import org.specs2._
import specification.Fragments
import scalaz._
import Scalaz._

class JunctionTreeTest extends Specification {
  def is: Fragments =
  "test up/down propagation using scanr and scand" ! ((_: Int) must_== 35).foreach(upDownSum(tree1).flatten) ^
  "test up/down propagation using upwardPass and downwardPass" ! ((_: Int) must_== 35).foreach(upDownSum2(tree1).flatten)

  val tree1 = Tree.node(3,Stream(Tree.leaf(7), Tree.node(4,Stream(Tree.leaf(12),Tree.leaf(9)))))
  def upDownSum(tree: Tree[Int]): Tree[Int] = {
    //in upProp, each node holds also its upward message as second
    val upProp = tree.scanr[(Int,Int)]{case (my, children) => (my,my + children.map(_.rootLabel._2).sum)}

    //next tree has as content (ownValue,fromParent,fromChildren)
    val downProp = JunctionTree.scand[(Int,Int),Int,Int](upProp)(0){ case (incoming, t) =>
      //return new nod value and down going messages that match up with the children
      val result = incoming + t.rootLabel._2
      val childMessages = for(
        child <- t.subForest;
        otherChildren = t.subForest.filterNot(_ == child); //that's basically not correct
        otherSum = otherChildren.map(_.rootLabel._2).sum
      ) yield otherSum + incoming + t.rootLabel._1
      (result, childMessages)
    }
    downProp
  }

  /** Now using upwardPass and downwardPass methods to achieve the same. */
  def upDownSum2(tree: Tree[Int]): Tree[Int] = {
    val upProp = JunctionTree.upwardPass[Int,Int](tree){case (n, incoming) => n + incoming.sum}
    val downProp = JunctionTree.downwardPass[Int,Int,Int](upProp)(0){case (n, upwardMsgs, downwardMsg) =>
      JunctionTree.mapOthers(upwardMsgs){otherUpMsgs => n + downwardMsg + otherUpMsgs.sum}
    }
    downProp.map{case (n,u,d) => u + d}
  }
}