package edu.uulm.scbayes.logic.cnf.sampling

import edu.uulm.scbayes.logic._
import cnf.{GroundClause, CNFOps, GroundCNF}
import util.Random
import edu.uulm.scbayes.util._
import annotation.tailrec

/**
 * Exposes the functionality to sample a satisfying interpretation from a SAT problem.
 *
 *
 * Date: 30.03.11 */
trait SATSampler {

  /**
   * Draw a sample from the set of satisfying interpretations. The result is only defined for the atoms
   * that are present in cnf.
   *
   * @param clauses The set of clauses.
   * @return Some function mapping atoms to their truth value, or None if a satisfying interpretation wasn't found.
   */
  protected def samplePartialIntern(cnf: GroundCNF, signature: ConstantDomain, rnd: Random): Option[TruthAssignment]

  /* Try to split the CNF problem into independent sub problems. */
  def splitCNF(cnf: GroundCNF): Seq[GroundCNF] = {
    val atomsOfClauses: Map[GroundClause, Set[Atom]] =
      cnf.clauses.map(cl => cl -> cl.atoms.toSet).toMap
    val clausesOfAtoms: Map[Atom, Set[GroundClause]] =
      cnf.atoms.map(atom => atom -> cnf.clauses.filter(atomsOfClauses(_).contains(atom)).toSet).toMap

    def findSubCluster( allClauses: Set[GroundClause]): (Set[GroundClause], Set[GroundClause]) =  {
      val iterator = Iterator.iterate(
        (Set(allClauses.head), Set[Atom](), Set(allClauses.head), Set[Atom]())
      ){case (diffClauses, diffAtoms, clauses, atoms) =>
        val newClauses: Set[GroundClause] = diffAtoms.flatMap(clausesOfAtoms) -- clauses
        val newAtoms: Set[Atom] = diffClauses.flatMap(atomsOfClauses) -- atoms
        (newClauses, newAtoms, newClauses ++ clauses, newAtoms ++ atoms)
      }

      val dependentClauses: Set[GroundClause] = iterator.dropWhile{case (c,a,_,_) => !(c.isEmpty && a.isEmpty)}.next()._3
      val remainingClauses = allClauses -- dependentClauses
      (dependentClauses, remainingClauses)
    }

    @tailrec
    def split(clauses: Set[GroundClause], acc: List[Set[GroundClause]]): List[Set[GroundClause]] = {
      if(clauses.isEmpty) acc
      else {
        val (cluster, rest) = findSubCluster(clauses)
        split(rest, cluster :: acc)
      }
    }

    split(cnf.clauses.toSet, Nil).map(clauses => GroundCNF.fromClauses(clauses.toSeq))
  }

  protected def samplePartialSplit(cnf: GroundCNF, signature: ConstantDomain, rnd: Random): Option[TruthAssignment] = {
    val subProblems: Seq[GroundCNF] = splitCNF(cnf)
    val subSolutions: Seq[Option[TruthAssignment]] = subProblems.map(samplePartialIntern(_, signature, rnd))

    val numProblems = subProblems.size
    subSolutions.foldLeft(Some(TruthAssignment.empty): Option[TruthAssignment]){
      case (Some(ta1),Some(ta2)) => Some(ta1.orElse(ta2))
      case _ => None
    }
  }

  def samplePartial( cnf: GroundCNF,
                     signature: ConstantDomain,
                     rnd: Random,
                     unitPropagation: Boolean = true ): Option[TruthAssignment] = {
    val (newCNF, newTA) = if(unitPropagation) CNFOps.propagate(cnf) else (cnf, TruthAssignment.empty)
    samplePartialIntern(newCNF, signature, rnd).map(newTA.orElse(_))
  }

  /**
   * Same as samplePartialIntern, but the result is defined over all given atoms.
   */
  def sampleComplete(cnf: GroundCNF, signature: ConstantDomain, atoms: Set[AtomBase], rnd: Random, unitPropagation: Boolean = true): Option[TruthAssignment] =
    samplePartial(cnf, signature, rnd, unitPropagation).map(SATSampler.completeRandomized(_, signature, atoms, rnd))
}

object SATSampler {
  /**
   * Completes a partial TruthAssignment randomly.
   *
   * @param partial The TruthAssignment to complete.
   * @param domains Used to get the existing constants.
   * @param rnd Random generator to produce the random assignment.
   *
   * @return A TruthAssignment that is defined for the
   *          whole base of signature and is equal to partial where it is defined. */
  def completeRandomized(partial: TruthAssignment, domains: ConstantDomain, atoms: Set[AtomBase], rnd: Random): TruthAssignment = {
    val remainingTruth = atoms.collect{case a: PredicateAtomBase if (!partial.truth.isDefinedAt(a)) => a}
    val remainingFuns = atoms.collect{case a: FunctionalAtomBase if (!partial.functionValue.isDefinedAt(a)) => a}
    val complementaryTA = TruthAssignment(
      remainingTruth.map(pab => pab -> rnd.nextBoolean()).toMap,
      remainingFuns.map(fab => fab -> domains.constants(fab.name.targetSignature).pickRandom(rnd)).toMap
    )

    partial.orElse(complementaryTA)
  }
}