package edu.uulm.scbayes.inference.sampling.sat

import org.specs2._
import edu.uulm.scbayes.inference.TestUtils._
import specification.Fragments
import util.Random
import edu.uulm.scbayes.logic.cnf.sampling.{SATSampler, RejectionSAT, SampleSAT}
import edu.uulm.scbayes.util._

class SATSamplingTest extends Specification {
  def is: Fragments =
  {
    val (cnf,sig) = parsePropCNF(
      "(a v b) ^ (b v ma3=0) ^ (c v d v !a v !mc8=4) ^ (!b v c v e v mb5=2) ^ (c v !d v !ma3=2) ^ (a v d v !mc8=0) ^ (e v a v ma3=1) ^ (!e v f) ^ (!f v c v !b v ma3=0) ^ (e v d)" +
        " ^ (mb5=4 v !ma3=0) ^ (!mb5=2 v c) ^ (mc8=2 v !mb5=1)"
    )
//    val (cnf,sig) = parsePropCNF(
//      "(!ma5=0 v !mb5=0) ^ (!ma5=1 v !mb5=1)"
//    )

    def countDistinct[A](xs: Iterable[A]): Map[A,Int] =
      xs.foldLeft(Map.empty[A,Int].withDefaultValue(0)){case (counts,x) => counts.updated(x, counts(x) + 1)}

    def printEval(desc: String, sampler: SATSampler, verbose: Boolean = false) {
      val rnd: Random = new Random(41)

      val samples: Int = 10000
      val (counts,time) = benchmarkWallTime(
        countDistinct(
          Iterable.fill(samples)(sampler.samplePartial(cnf, sig, rnd).get)
        )
      )

      val result = counts.values
      println("%s: %d solutions; %.0f sd; samples/ms %.2f".format(desc, result.size, result.variance, samples/(time * 10e-9)))
      if(verbose) {
        println(counts.mkString("\n"))
      }
    }

    printEval("SampleSAT, ign=0", new SampleSAT(0.5,0.1,1000,0))
    printEval("SampleSAT, p=0.2,t=0.3", new SampleSAT(0.2,0.3,1000))
    printEval("SampleSAT, ign=5", new SampleSAT(0.5,0.1,1000,5))
    printEval("SampleSAT, ign=10", new SampleSAT(0.5,0.1,1000,10))

    printEval("Reject slow", new RejectionSAT(100000,false))
    printEval("Reject fast", new RejectionSAT(100000,true))

    true
  }
}