package edu.uulm.scbayes.logic

import edu.uulm.scbayes.util._
/**
 * This file holds the code for operations on Formula objects like grounding, simplifying, converting to CNF.
 *
 * Date: 17.03.11
 */
class FormulaOps(formula: Formula) {

  /**
   * Ground existential and nested forall quantifiers.
   * @param constants Maps sorts to their domain objects.
   * @return A formula with all quantifiers grounded.
   */
  def groundifyQuantifiers(constants: PartialFunction[Sort, Set[Constant]]): Formula = {
    formula match {
      //the two interesting cases are the quantifiers
      case ForAll(boundVar, f) =>
        Conjunction(
          getAllConstantsForVariable(boundVar,constants)
            .map(c => f.substitute(boundVar, c))
            .toSet
        ).groundifyQuantifiers(constants)
      case Existential(boundVar, f) =>
        Disjunction(
          getAllConstantsForVariable(boundVar,constants)
            .map(c => f.substitute(boundVar, c))
            .toSet
        ).groundifyQuantifiers(constants)

      //simply traverse all other formulas
      case Conjunction(i) => Conjunction(i.map(_.groundifyQuantifiers(constants)))
      case Disjunction(i) => Disjunction(i.map(_.groundifyQuantifiers(constants)))
      case Negation(f) => Negation(f.groundifyQuantifiers(constants))

      //stop recursion at the predicates
      case p: Predicate => p
    }
  }

  /**
   * Find out the sort of a given term.
   *
   * @throws RuntimeException If the term does not appear inside a predicate in this formula.
   */
  def getTypeOf(t: Term): Sort = {
    val types: Set[Sort] = getTypesOf(t)

    types.toSeq match {
      case Seq(s) => s
      case Seq()  => throw new RuntimeException("term (%s) does not appear in formula (%s)".format(t,this.formula))
      case _ => throw new RuntimeException("term (%s) has more than one sort in formula (%s)".format(t,this.formula))
    }
  }

  /** @return a set of all the sorts in whose position the term t appears inside the formula. */
  def getTypesOf(t: Term): Set[Sort] = for(
      (term, sort) <- formula.extractTermsWithTypes.toSet if (term == t)
    ) yield sort

  /**
   * Returns an Iterator, that traverses the formula depth-first.
   *
   * @return A list containing every formula object that is rooted below this formula. This list includes the
   *  formula itself.
   */
  def extract: List[Formula] = formula :: (formula match {
        case Disjunction(i) => i.toList.flatMap(_.extract)
        case Conjunction(i) => i.toList.flatMap(_.extract)
        case Negation(f) => f.extract
        case Existential(v,f) => f.extract
        case ForAll(v,f) => f.extract
        case p: GAtom => Nil
      })

  def extractPredicates: Seq[GAtom] = formula match {
        case Disjunction(i) => i.toSeq.flatMap(_.extractPredicates)
        case Conjunction(i) => i.toSeq.flatMap(_.extractPredicates)
        case Negation(f) => f.extractPredicates
        case Existential(v,f) => f.extractPredicates
        case ForAll(v,f) => f.extractPredicates
        case p: GAtom => Seq(p)
      }

  def extractTerms: Seq[Term] = for(
    pred <- extractPredicates;
    term <- pred.parameters;
    t <- term.extractTermsT)
  yield t

  def extractTermsWithTypes: Seq[(Term, Sort)] = for(
    pred <- extractPredicates;
    (term, sort) <- pred.parameters.zip(pred.predicate.signature);
    t <- term.extractTermsT)
  yield (t, sort)

  /**
   * Tries to simplify the formula by accumulating conjunctions and disjunctions and removing double negations.
   * @return The simplified formula and the number of simplifications.
   */
  def simplify: Formula = {
    //this is recursive and not tail-recursive

    /**
     * @param f Formula to simplify.
     * @param notify Is evaluated on every simplification.
     */
    def recSimplify(_formula: Formula, notify: () => Unit): Formula = {
      //recursively traverse the formula tree and apply simplifications
      _formula match {
        case dj: Disjunction => {
          if(dj.children.exists(_.isInstanceOf[Disjunction])){
            notify()
            FormulaOps.flatten(dj)
          } else {
            Disjunction(dj.children.map(recSimplify(_,notify)))
          }
        }
        case cj: Conjunction => {
          if(cj.children.exists(_.isInstanceOf[Conjunction])){
            notify()
            FormulaOps.flatten(cj)
          } else {
            Conjunction(cj.children.map(recSimplify(_,notify)))
          }
        }
        case Negation(Negation(nf)) => {
          notify()
          recSimplify(nf, notify)
        }
        case Negation(nf) => Negation(recSimplify(nf, notify))
        case Existential(v, f) => Existential(v, recSimplify(f, notify))
        case ForAll(v, f) => ForAll(v, recSimplify(f, notify))
        case p: GAtom => p
      }
    }

    var change = false
    var result = formula

    do {
      change = false
      result = recSimplify(result, {
        () => change = true
      })
    } while (change)

    result
  }

  /**
   * Substitute all occurrences of a given variable by a term.
   * @param variable the variable to substitute.
   * @param term This term is inserted
   */
  def substitute(target: Term, replacement: Term): Formula = {
    formula match {
      //interesting case is for predicates
      case p: GAtom => {
        val mappedParameters = p.parameters.map(_.substituteVariable(target: Term, replacement: Term))
        Predicate(p.predicate, mappedParameters)
      }
      //simply traverse the rest
      case ForAll(v, f) => {
        //don't allow to substitute bound variables
        assert(v != target)
        ForAll(v, f.substitute(target, replacement))
      }
      case Existential(v, f) => {
        //don't allow to substitute bound variables
        assert(v != target)
        Existential(v, f.substitute(target, replacement))
      }
      case Conjunction(i) => Conjunction(i.map(_.substitute(target, replacement)))
      case Disjunction(i) => Disjunction(i.map(_.substitute(target, replacement)))
      case Negation(f) => Negation(f.substitute(target, replacement))
    }
  }

  /**
   * Returns a version of the formula with all negations moved towards their inner-most
   * position in front of the predicates.
   */
  def pushNegations: Formula = {
    formula match {
      case n@Negation(a: GAtom) => n
      case Negation(Disjunction(i)) => Conjunction(i.map(Negation(_))).pushNegations
      case Negation(Conjunction(i)) => Disjunction(i.map(Negation(_))).pushNegations
      case Negation(Negation(f)) => f.pushNegations
      case Negation(Existential(v,f)) => ForAll(v,Negation(f).pushNegations)
      case Negation(ForAll(v,f)) => Existential(v,Negation(f).pushNegations)
      case Disjunction(i) => Disjunction(i.map(_.pushNegations))
      case Conjunction(i) => Conjunction(i.map(_.pushNegations))
      case p: GAtom => p
      case Existential(v,f) => Existential(v,f.pushNegations)
      case ForAll(v,f) => ForAll(v,f.pushNegations)
    }
  }

  /**
   * Given an interpretation, that assigns each ground predicate a truth value, return the truth value of
   * this formula.
   */
  def evaluate(interpretation: Function[GroundPredicate,Boolean]): Boolean = formula match {
    case gp: GroundPredicate => interpretation(gp)
    case Negation(f) => !f.evaluate(interpretation)
    case Disjunction(fs) => fs.exists(_.evaluate(interpretation))
    case Conjunction(fs) => fs.forall(_.evaluate(interpretation))
    case _ => throw new RuntimeException("this formula can't be evaluated: " + formula)
  }

  /**
   * Returns the set of all free variables of the formula.
   */
  def freeVariables: Set[Variable] = formula match {
    case Negation(f) => f.freeVariables
    case Disjunction(fs) => fs.flatMap(_.freeVariables)
    case Conjunction(fs) => fs.flatMap(_.freeVariables)
    case Predicate(pd,params) => params.flatMap(_.getVariables).toSet
    case ForAll(v,f) => f.freeVariables - v
    case Existential(v,f) => f.freeVariables - v
  }

  /**
   * @param constants Maps a sort to all of its constants (its domain).
   * @return A list containing an instance of the variable for each grounding of its free variables.
   */
  def groundifyFreeVariables(constants: PartialFunction[Sort, Set[Constant]],
                             inequalities: Set[(Variable,Variable)] = Set()): List[Formula] = {
    import scalaz._
    import Scalaz._
    val freeVariables: IndexedSeq[Variable] = formula.freeVariables.toIndexedSeq
    val variableIndices = freeVariables.zipWithIndex.toMap
    val inequalityIndices: Set[(Int,Int)] = inequalities.map(constr => variableIndices <-: constr :-> variableIndices)

    val groundings = crossProduct(freeVariables.map(getAllConstantsForVariable(_,constants).toSeq))
      .filter(binding => inequalityIndices.forall(constr => binding(constr._1) != binding(constr._2))) //honour inequalities

    val ground_formulas = groundings.map{ grounding =>
      val substitutions = freeVariables.zip(grounding)
      substitutions.foldLeft(formula){case (f,(v,c)) => f.substitute(v,c)}
    }

    //still left to ground the function succ
    ground_formulas.map(_.groundAdd(constants)).toList.flatten
  }

  def getAllConstantsForVariable(v: Variable, constants: Sort => Set[Constant]): Set[Constant] = getTypesOf(v).map(constants).reduce(_ intersect _)

  def groundAdd(constants: PartialFunction[Sort, Set[Constant]]): Option[Formula] = {

    //todo get rid of the exception throwing
    def replaceAdd(term: Term): Term = {
      term match {
        case Succ(t) => FAdd(t, 1)
        //grounded FAdds
        case FAdd(Constant(number, sort), diff) => {
          val successor = Constant((number.toInt + diff).toString, sort)
          if(constants(sort).contains(successor))
            successor
          else
            throw new NumberFormatException
        }
        //nested FAdds
        case FAdd(FAdd(t,diff2),diff1) => FAdd(t, diff1 + diff2)
        case FAdd(Succ(t),d) => FAdd(t, d + 1)
        case FAdd(t,d) => throw new RuntimeException("only numeric constants allowed in Fadd")
        case x => x
      }
    }

    try{
      Some(
        FormulaOps.traversal(formula){
          case Predicate(pd,terms) if terms.exists(x => x.isInstanceOf[FAdd] || x.isInstanceOf[Succ]) =>
            Predicate(pd, terms.map(replaceAdd))
        }
      )
    }
    catch {
      case nfe: NumberFormatException => None
    }
  }
}

object FormulaOps{

  /**
  * Traverse the formula tree and try to apply a given partial function.
  *
  * If the function matches, the transformed formula will get returned together with true.
  * If the second tuple entry is false, there was no match during the whole traversal.
  * The recursive traversal stops either at a predicate or at an application of pf. There are usually several
  * applications of pf per call.
  */
  private[this] def traversalOnce(formula: Formula)(pf: PartialFunction[Formula,Formula]): (Formula, Boolean) = {

    def recTraversal(fIn: Formula): (Formula, Boolean) = {
      if(pf.isDefinedAt(fIn))
        (pf(fIn), true)
      else
        fIn match {
          case p: GAtom => (p, false)
          case ForAll(v, f) => {val (fproc, flag) = recTraversal(f); (ForAll(v, fproc),flag)}
          case Existential(v, f) => {val (fproc, flag) = recTraversal(f); (Existential(v, fproc), flag)}
          case Conjunction(i) => {val proc = i.map(recTraversal); (Conjunction(proc.map(_._1)),proc.exists(_._2))}
          case Disjunction(i) => {val proc = i.map(recTraversal); (Disjunction(proc.map(_._1)),proc.exists(_._2))}
          case Negation(f) => {val (fproc, flag) = recTraversal(f); (Negation(fproc), flag)}
        }
    }

    recTraversal(formula)
  }

  def traversal(formula: Formula)(pf: PartialFunction[Formula,Formula]): Formula = {
    var unfinished = true
    var f = formula
    do {
      val (fn,un) = traversalOnce(f)(pf)
      unfinished = un
      f = fn
    } while(unfinished)

    f
  }

  /**
   * Applies distributive law to a conjunction with nested conjunctions
   */
  def distribute(cj: Disjunction): Formula = {

    val clauses = cj.children

    //separate the children of the disjunction into conjunctions and rest
    val conjs = clauses.collect{case d: Conjunction => d}
    val rest = clauses -- conjs

    //create every combination of the disjunction's children with each other
    val conjunction_contents: IndexedSeq[Seq[Formula]] = conjs.map{case Conjunction(i) => i.toIndexedSeq}(collection.breakOut)
    val multiplied_conjunctions = crossProduct(conjunction_contents).iterator

    //append the common elements to every "factor"
    val complete_clauses: Iterator[Disjunction] = multiplied_conjunctions.map(d => Disjunction(rest ++ d.toSet))

    Conjunction(complete_clauses.toSet)
  }

  /**
   * Extracts all disjunctions, that are direct children of the given disjunction and
   * adds the result to the top-level formula
   */
  def flatten(dj: Disjunction): Disjunction = {
    val nested_djs = dj.children.collect{case f:Disjunction => f}
    val rest = dj.children -- nested_djs
    Disjunction(rest ++ nested_djs.flatMap(_.children))
  }

  /**
   * Extracts all conjunctions, that are direct children of the given conjunction and
   * adds the result to the top-level formula
   */
  def flatten(dj: Conjunction): Conjunction = {
    val nested_djs = dj.children.collect{case f:Conjunction => f}
    val rest = dj.children -- nested_djs
    Conjunction(rest ++ nested_djs.flatMap(_.children))
  }

  def isTautology(dj: Disjunction): Boolean = {

    def isLiteral(f: Formula) = f match {
      case Negation(p: Predicate) => true
      case p: Predicate => true
      case _ => false
    }

    //this must be a disjunction of literals
    assert(dj.children.forall(isLiteral))

    val negated_formulas = dj.children.collect{case Negation(f) => f}
    //is a negated formula contained un-negated?
    negated_formulas.exists(f => dj.children.contains(f))
  }
}
