package edu.uulm.scbayes.logic.cnf.sampling

import util.Random
import edu.uulm.scbayes.logic._
import cnf.{GroundClause, GroundCNF}
import collection.mutable.BitSet
import collection.immutable.IndexedSeq

/**
 * Encapsules a optimized, mutable assignment to a CNF, also containing multi-valued variables.
 *
 * @param cnf The cnf formula to represent. Cannot be changed later.
 * @param domain A ConstantDomain is needed for the multi-valued variables.
 *
 *
 * Date: 30.03.11
 */
class ArrayCNFAssignment(val cnf: GroundCNF, val domain: ConstantDomain) extends MutableCNFAssignment {
  /** Indices for PredicateAtoms. */
  protected val binaryAtomBases: Array[PredicateAtomBase] = cnf.atomBases.collect{case ab: PredicateAtomBase => ab}.distinct.toArray
  protected val baIdx: Map[PredicateAtomBase,Int] = binaryAtomBases.zipWithIndex.toMap

  /** Indices for multi-valued atoms and their range values.*/
  protected val mvAtomBases: Array[FunctionalAtomBase] = cnf.atomBases.collect{case ab: FunctionalAtomBase => ab}.distinct.toArray
  protected val mvIdx: Map[FunctionalAtomBase, Int] = mvAtomBases.zipWithIndex.toMap
  /** First index: atom, second index: Constants of this atom. */
  protected val mvDomains: Array[Array[Constant]] = mvAtomBases.map(v => domain.constants(v.name.targetSignature).toArray).toArray
  protected val mvDomainSizes: Array[Int] = mvDomains.map(_.size).toArray

  //represent the cnf in the following arrays
  /*
  first index: clause number; second index: (positive or negative) atom number

  For satisfaction, it must hold:
  for every i:
    for every j: clBinPos(i)(j) evaluates to true
    and for every j: clBinPos(i)(j) evaluates to false
   */

  /** Take a ground clause and extract all binary atoms, that must evaluate to extractPositive.
   * Return them as their index. */
  protected def extractBinAtoms(extractPositive: Boolean)(cl: GroundClause): Array[Int] = {
    cl.plainAtoms.collect {
      case ba: PredicateAtom if (!extractPositive ^ ba.target) => ba.base
    } ++
      cl.negatedAtoms.collect {
        case ba: PredicateAtom if (!extractPositive ^ !ba.target) => ba.base
      } map (baIdx)
  }

  /**Given a sequence of Atoms, extract a sequence with a tuple t for each FunctionalPrediacate, st first
   * is index of function base and second is index of target value.
   *
   * WARNING: This function is not optimized to the end.Do not use for heavy duty calculations.
   */
  protected def extractMVAtoms(atoms: Array[Atom]): Array[(Int,Int)] = atoms.collect{
      case FunctionalAtom(base, value) => (mvIdx(base), mvDomains(mvIdx(base)).indexWhere(_ == value))
    }

  val numClauses = cnf.clauses.size

  protected val clBinPos: Array[Array[Int]] = cnf.clauses.map(extractBinAtoms(true)).toArray
  protected val clBinNeg: Array[Array[Int]] = cnf.clauses.map(extractBinAtoms(false)).toArray
  /*
  Base: first index: clause number; secondIndex: (positive or negative) multi-valued atomBase number
  Value: first index: clause number; secondIndex: multi-valued atom value (to what the base must (not) evaluate)

  For satisfaction, it must hold:
  for every i:
    for every j: clMVPosBase(i)(j) evaluates to clMVPosValue(i)(j)
    and for every j: clMVNegBase(i)(j) does NOT evaluate to clMVNegValue(i)(j)
   */
  protected val (clMVPosBases, clMVPosValues): (Array[Array[Int]], Array[Array[Int]]) = {
      val tupleArray = cnf
        .clauses
        .map(cl => extractMVAtoms(cl.plainAtoms))
        .unzip(_.unzip)
      (tupleArray._1.map(_.toArray).toArray, tupleArray._2.map(_.toArray).toArray.toArray)
    }
  protected val (clMVNegBases, clMVNegValues): (Array[Array[Int]], Array[Array[Int]]) = {
      val tupleArray = cnf
        .clauses
        .map(cl => extractMVAtoms(cl.negatedAtoms))
        .unzip(_.unzip)
      (tupleArray._1.map(_.toArray).toArray, tupleArray._2.map(_.toArray).toArray.toArray)
    }


  def emptyClauses: IndexedSeq[Int] = {
    0 until numClauses filter {
      cl => clBinPos(cl).isEmpty && clBinNeg(cl).isEmpty && clMVNegBases(cl).isEmpty && clMVPosBases(cl).isEmpty
    }
  }
  assert(emptyClauses.isEmpty, "there is an empty clause!" + cnf)


  protected val baAssign = new Array[Long](binaryAtomBases.size/64 + 1)
  protected val mvAssign = new Array[Int](mvAtomBases.size)
  /** For easy access into baAssign. */
  protected val baBitset = new BitSet(baAssign)

  /**
   * This returns an Interpretation, that reflects the current state of this assignment. It is mutable and will change
   * when this object is mutated.
   *
   * Be careful what you do with this object, since the isDefined method always returns true!
   */
  val interpretation: Interpretation = new Interpretation {
    def apply(v1: Atom): Boolean = v1 match {
      case pa: PredicateAtom => baBitset(baIdx(pa.base)) == v1.target
      case fa: FunctionalAtom => {
        fa.target == valueOfFunctionBase(fa.base)
      }
    }
    def isDefinedAt(x: Atom): Boolean = true
  }

  //assign a new randomization
  def randomize(random: Random) {
    //randomize the binary atoms
    var i = 0
    while(i < baAssign.size) {
      baAssign(i) = random.nextLong()
      i+= 1
    }
    //randomize the multi-valued atoms
    i = 0
    while(i < mvAtomBases.size) {
      mvAssign(i) = random.nextInt(mvDomainSizes(i))
      i += 1
    }
  }

  def valueOfFunctionBase(fa: FunctionalAtomBase): Constant = {
    val varIdx = mvIdx(fa)
    val constantIdx = mvAssign(varIdx)

    mvDomains(varIdx)(constantIdx)
  }

  def isClauseSatisfied(clauseIdx: Int): Boolean = {
    //check positive binary atoms

    //each of the following four functions returns true if it can find a satisfied atom
    //thus if any of the four functions returns true, the clause is satisfied
    def binPos: Boolean = {
      val binBase = clBinPos(clauseIdx)
      var i = 0
      while (i < binBase.size) {
        if (baBitset(binBase(i))) return true
        i += 1
      }
      return false
    }
    //check negative binary atoms
    def binNeg: Boolean = {
      val binBase = clBinNeg(clauseIdx)
      var i = 0
      while (i < binBase.size) {
        if (!baBitset(binBase(i))) return true
        i += 1
      }
      return false
    }
    //check positive multivalued clauses
    def mvPos: Boolean = {
      val bases = clMVPosBases(clauseIdx)
      val values = clMVPosValues(clauseIdx)
      var i = 0
      while (i < bases.size) {
        if (mvAssign(bases(i)) == values(i)) return true
        i += 1
      }
      return false
    }
    //check negative multivalued clauses
    def mvNeg: Boolean = {
      val bases = clMVNegBases(clauseIdx)
      val values = clMVNegValues(clauseIdx)
      var i = 0
      while (i < bases.size) {
        if (mvAssign(bases(i)) != values(i)) return true
        i += 1
      }
      return false
    }

    val clauseSatisfied = mvNeg || binPos || binNeg || mvPos

    clauseSatisfied
  }

  def isSatisfied: Boolean = {
    var clauseIdx = 0
    while(clauseIdx < numClauses) {
      if(!isClauseSatisfied(clauseIdx)) return false
      clauseIdx += 1
    }
    return true
  }

  override def set(ta: AtomBaseAssignment) {
    for(
      ba <- binaryAtomBases;
      value <- ta.baseValue(ba);
      index = baIdx(ba)
    ) {
      baBitset(index) = value
    }

    for(
      va <- mvAtomBases;
      value <- ta.baseValue(va);
      index = mvIdx(va)
    ) {
      mvAssign(index) = mvDomains(index).indexOf(value)
    }
  }

  def baseAssignment: AtomBaseAssignment = new AtomBaseAssignment {
    def baseValue[A <: AtomBase](base: A): Option[A#RangeType] = util.control.Exception.allCatch[A#RangeType].opt{
      base match {
        case pab: PredicateAtomBase => baBitset(baIdx(pab)).asInstanceOf[A#RangeType]
        case fab: FunctionalAtomBase => valueOfFunctionBase(fab).asInstanceOf[A#RangeType]
      }
    }
  }

  def atomBases: Seq[AtomBase] = binaryAtomBases ++ mvAtomBases

  def setBase[B <: AtomBase](ab: B, newVal: B#RangeType){
    ab match {
      case pab: PredicateAtomBase => baBitset(baIdx(pab)) = newVal.asInstanceOf[Boolean]
      case fab: FunctionalAtomBase => mvAssign(mvIdx(fab)) = mvDomains(mvIdx(fab)).indexOf(newVal.asInstanceOf[Constant])
    }
  }

  /** This method makes a copy of the current assignment. If you only need an interpretation use `interpretation`. */
  def truthAssignment = {
    new TruthAssignment(
      (0 until binaryAtomBases.size).map(a => binaryAtomBases(a) -> baBitset(a)).toMap,
      (0 until mvAtomBases.size).map(fb => mvAtomBases(fb) -> mvDomains(fb)(mvAssign(fb))).toMap
    )
  }
}

