package edu.uulm.scbayes.logic.cnf.sampling

import util.Random
import edu.uulm.scbayes.logic._
import cnf.GroundCNF

/**
 * Encapsules a optimized, mutable assignment to a CNF, also containing multi-valued variables.
 *
 * @param cnf The cnf formula to represent. Cannot be changed later.
 * @param domain A ConstantDomain is needed for the multi-valued variables.
 *
 *
 * Date: 30.03.11
 */
class ArrayCNFWalker(_cnf: GroundCNF, _domain: ConstantDomain) extends ArrayCNFAssignment(_cnf, _domain) with MutableSATWalker[Int,CNFWalkChange] {
  def satisfyingChanges(unsatClause: Int): Seq[CNFWalkChange] = {
    val unsatPosBin = this.clBinPos(unsatClause).map(CNFWalkChange.fromBinBase(_,true))
    val unsatNegBin = this.clBinNeg(unsatClause).map(CNFWalkChange.fromBinBase(_,false))

    val unsatPosMV = this.clMVPosBases(unsatClause).zip(this.clMVPosValues(unsatClause))
      .map{case (base,value) => CNFWalkChange.fromMVBase(base,value)}

    //here we need to generate a change for every other value
    //todo don't generate "other values" that also appear negated for this base in this clause
    val unsatNegMV = this.clMVNegBases(unsatClause).zip(this.clMVNegValues(unsatClause)).flatMap{
      case (base,value) => (0 until mvDomainSizes(base)).filterNot(_ == value).map(CNFWalkChange.fromMVBase(base,_))
    }

    val result = (unsatPosBin ++ unsatNegBin ++ unsatPosMV ++ unsatNegMV)
    assert(unsatPosMV ++ unsatNegMV forall (change => mvAssign(change.base) != change.value))
    result
  }

  def clauses: Seq[Int] = 0 until numClauses

  def applyChange(change: CNFWalkChange) {
    if(change.isBinary) {
      baBitset(change.base) = change.binaryValue
    } else {
      mvAssign(change.base) = change.value
    }
  }

  def createRandomFlip(random: Random): CNFWalkChange = {
    //todo maybe draw uniformly from the possible flips instead of drawing from the possible bases and then again from the flips?
    val makeBinaryFlip = random.nextInt(binaryAtomBases.size + mvAtomBases.size) < binaryAtomBases.size

    val result = if(makeBinaryFlip) {
      val base = random.nextInt(binaryAtomBases.size)
      CNFWalkChange.fromBinBase(base,!baBitset(base))
    } else {
      val base = random.nextInt(mvAtomBases.size)
      val newValue = random.nextInt(mvDomainSizes(base) - 1)
      val shiftedNewValue = newValue + (if (newValue >= mvAssign(base)) 1 else 0)
      CNFWalkChange.fromMVBase(base, shiftedNewValue)
    }

    result
  }

  def countSatisfiedClauses: Int = {
    var satisfiedClauses = 0
    var clause = 0
    while(clause < numClauses){
      if(isClauseSatisfied(clause)) satisfiedClauses += 1
      clause += 1
    }

    satisfiedClauses
  }

  def score(change: CNFWalkChange): Int = {
    val undo = if(change.isBinary)
      CNFWalkChange.fromBinBase(change.base,baBitset(change.base))
    else
      CNFWalkChange.fromMVBase(change.base,mvAssign(change.base))

    val scoreBefore = countSatisfiedClauses
    applyChange(change)

    val scoreAfter = countSatisfiedClauses
    applyChange(undo)

    val result = scoreAfter - scoreBefore
    result
  }
}


