package edu.uulm.scbayes.logic.cnf.sampling

import util.Random
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.logic._
import cnf.{GroundClause, GroundCNF}

/**
 * Samples uniformly from the satisfying interpretations of a SAT problem by rejection sampling.
 *
 * @param random Random source to use.
 * @param maxSamples Give up after hitting that many unsatisfying interpretations.
 *
 *
 * Date: 30.03.11
 */

class RejectionSAT(val maxSamples: Int = 1000000,
                   val fastAlgorithm: Boolean = true,
                   val reorderClauses: Boolean = true) extends SATSampler {

  protected def samplePartialIntern(cnf: GroundCNF, signature: ConstantDomain, rnd: Random): Option[TruthAssignment] = {
    val queryCNF = if(reorderClauses) reorderCNF(cnf, signature) else cnf

    val result = if(fastAlgorithm)
      samplePartialInternOptimized(queryCNF, signature, rnd)
    else
      samplePartialInternCanonical(queryCNF, signature, rnd)

    result
  }

  /** Reorder the clauses of the cnf, so that clauses with highest probability of being unsatisfied come first. */
  def reorderCNF(cnf: GroundCNF, signature: ConstantDomain): GroundCNF = {
    def unsatisfactionProbability(clause: GroundClause): Double = {
      //map atoms to their range cardinality
      val cardinalities: Array[Int] = clause.atoms.map(a => a match {
        case a: PredicateAtom => 2
        case f: FunctionalAtom => signature.constants(f.base.name.targetSignature).size
      })
      //map cardinalities c to the probability of the respective atom to not being satisfied (1 - 1/c)
      val probUnsatisfied = cardinalities.map(c => 1 - 1/(c.toDouble))
      probUnsatisfied.product
    }

    new GroundCNF(cnf.clauses.sortBy(unsatisfactionProbability))
  }

  def samplePartialInternCanonical(cnf: GroundCNF, signature: ConstantDomain, random: Random): Option[TruthAssignment] = {

    val relevantAtomBases: IndexedSeq[AtomBase] = cnf.atomBases

    def generateRandomAssignment: TruthAssignment = {
      val truth = relevantAtomBases
        .collect{case p: PredicateAtomBase => p -> random.nextBoolean()}
        .toMap
      val funs = relevantAtomBases
        .collect{case f: FunctionalAtomBase => f -> signature.constants(f.name.targetSignature).pickRandom(random)}
        .toMap
      TruthAssignment(truth, funs)
    }

    val interpretations = Iterator.continually(generateRandomAssignment).take(maxSamples)

    interpretations.find(cnf.evaluate)
  }

  def samplePartialInternOptimized(cnf: GroundCNF, signature: ConstantDomain, random: Random): Option[TruthAssignment] = {

    val interpretation: MutableCNFAssignment = new ArrayCNFAssignment(cnf, signature)

    var tries = 0
    var solutionFound = false
    while(!solutionFound && tries < maxSamples) {
      interpretation.randomize(random)
      solutionFound = interpretation.isSatisfied
      solutionFound = cnf.evaluate(interpretation.interpretation)
      tries += 1
    }

    if(solutionFound)
      Some(interpretation.truthAssignment)
    else
      None
  }
}