package edu.uulm.scbayes.logic.cnf.sampling

import util.Random
import edu.uulm.scbayes.util._
import edu.uulm.scbayes.logic._
import edu.uulm.scbayes.logic.cnf._

/**
 * Implementation of SampleSAT algorithm described in [1].
 *
 * This implementation is extended to multivalued atoms.
 *
 * [1] Wei, W.; Erenrich, J. & Selman, B. Towards efficient sampling: Exploiting random walk strategies
 *
 * @tparam C type of clause
 * @tparam A type of atom
 *
 *
 * Date: 25.03.11
 */
class SampleSatState(val cnf: GroundCNF, val domain: ConstantDomain, random: Random)
  extends MutableSATWalker[GroundClause,sss.ChangeItem] {
  import sss._

  protected val truth = new HashMapAssignment[PredicateAtomBase, PredicateAtomBase#RangeType]
  protected val funs = new HashMapAssignment[FunctionalAtomBase, FunctionalAtomBase#RangeType]

  //initial assignment
  randomize(random)

  /**Satisfy a random unsatisfied clause by flipping one atom such that as many other clauses as possible become true. */

  def isClauseSatisfied(clause: GroundClause): Boolean = clause.evaluate(MapInterpretation(truth, funs))

  def isAtomSatisfied(atom: Atom): Boolean = baseAssignment.baseValue(atom.base).get == atom.target

  def satisfyingChanges(clause: GroundClause): Seq[ChangeItem] = {
    val unsatisfiedAtomsOfClause = clause.unsatisfiedAtoms(this.interpretation)
    unsatisfiedAtomsOfClause.flatMap {
      //combine the normal predicates with their negated value, must be Seq because of flatMap
      case PredicateAtom(base, value) => Seq(PredicateChange(base, !value))
      //true funcitonal atoms
      case fa@FunctionalAtom(base, value) if (this.interpretation(fa)) =>
        (domain.constants(base.name.targetSignature) - value).map(FunctionChange(base, _))
      //choose the target value
      case fa@FunctionalAtom(base, value) if (!this.interpretation(fa)) => Seq(FunctionChange(base, value))
    }
  }

  def clauses: Seq[GroundClause] = cnf.clauses

  def applyChange(change: ChangeItem) {
    change match {
      case PredicateChange(base, value) => truth(base) = value
      case FunctionChange(base, value) => funs(base) = value
    }
  }

  def createRandomFlip(random: Random): ChangeItem = cnf.atoms.pickRandom(random) match {
    case PredicateAtom(base, target) => PredicateChange(base, !target)
    case FunctionalAtom(base, target) => FunctionChange(base, (domain.constants(base.name.targetSignature) - target).pickRandom(random))
  }

  def score(change: ChangeItem): Int = {
    val undo: ChangeItem = change match {
      case PredicateChange(base, value) => PredicateChange(base, baseAssignment.baseValue(base).get)
      case FunctionChange(base, value) => FunctionChange(base, baseAssignment.baseValue(base).get)
    }

    val pre = cnf.clauses.count(_ => isSatisfied)
    applyChange(change)
    val post = cnf.clauses.count(_ => isSatisfied)
    applyChange(undo)

    post - pre
  }

  /**
   * Count the satisfied clauses this atom appears in.
   */
  def countSatisfied: Int = clauses.count(this.isClauseSatisfied)

  //implement traits
  def isSatisfied: Boolean = cnf.evaluate(interpretation)

  def interpretation: Interpretation = MapInterpretation(truth, funs)

  /**
    * Also creates new Maps that can be handed out after a sampling request.
   */
  def randomize(random: Random) {
    this.cnf.atoms.foreach {
      case p: PredicateAtom => truth.update(p.base, random.nextBoolean())
      case f: FunctionalAtom => funs.update(f.base, domain.constants(f.base.name.targetSignature).pickRandom(random))
    }
  }

  def truthAssignment: TruthAssignment = TruthAssignment(this.truth.toMap, this.funs.toMap)

  def baseAssignment: AtomBaseAssignment = new AtomBaseAssignment {
    def baseValue[A <: AtomBase](base: A): Option[A#RangeType] = base match {
      case pab: PredicateAtomBase => truth.get(pab).map(_.asInstanceOf[A#RangeType])
      case fab: FunctionalAtomBase => funs.get(fab).map(_.asInstanceOf[A#RangeType])
    }
  }

  def atomBases: Seq[AtomBase] = (truth.keys ++ funs.keys).toSeq

  def setBase[B <: AtomBase](ab: B, newVal: B#RangeType) {
    ab match {
      case t: PredicateAtomBase => truth(t) = newVal.asInstanceOf[Boolean]
      case t: FunctionalAtomBase => funs(t) = newVal.asInstanceOf[Constant]
    }
  }

  def atoms: Seq[Atom] = cnf.atoms
}


package sss {
  abstract class ChangeItem {
    type AtomType <: AtomBase

    def base: AtomType

    def value: AtomType#RangeType
  }

  case class PredicateChange(base: PredicateAtomBase, value: Boolean) extends ChangeItem {
    type AtomType = PredicateAtomBase
  }

  case class FunctionChange(base: FunctionalAtomBase, value: Constant) extends ChangeItem {
    type AtomType = FunctionalAtomBase
  }
}








