package edu.uulm.scbayes.logic

import scalaz.Monoid

trait Interpretation extends PartialFunction[Atom, Boolean]
case class PFInterpretation(pf: PartialFunction[Atom, Boolean]) extends Interpretation {
  def isDefinedAt(x: Atom): Boolean = pf.isDefinedAt(x)
  override def apply(v1: Atom): Boolean = pf(v1)
}

trait AtomBaseAssignment extends Interpretation {
  def isDefinedAt(a: Atom): Boolean = baseValue(a.base).isDefined
  def apply(a: Atom): Boolean = baseValue(a.base).map(_ == a.target).getOrElse(false)
  def baseValue[A <: AtomBase](base: A): Option[A#RangeType]
}


case class MapInterpretation(truth: collection.Map[PredicateAtomBase, PredicateAtomBase#RangeType],
                             functionValue: collection.Map[FunctionalAtomBase, FunctionalAtomBase#RangeType])
  extends AtomBaseAssignment {
  def baseValue[A <: AtomBase](base: A): Option[A#RangeType] = {
    base match {
      case gnp: PredicateAtomBase => truth.lift.apply(gnp).map(_.asInstanceOf[A#RangeType])
      case fb: FunctionalAtomBase => functionValue.lift.apply(fb).map(_.asInstanceOf[A#RangeType])
    }
  }
}

case class TruthAssignment(truth: Map[PredicateAtomBase, PredicateAtomBase#RangeType],
                           functionValue: Map[FunctionalAtomBase, FunctionalAtomBase#RangeType])
  extends AtomBaseAssignment {

  /** Apply the closed world assumption to the given predicate. */
  def applyClosedWorld(predicate: AtomDefinition, signature: Signature): TruthAssignment = {
    val newFalseEntries: Set[(PredicateAtomBase, Boolean)] = for(
      atomBase <- signature.normalBase
      if atomBase.name == predicate && !truth.isDefinedAt(atomBase)
    ) yield (atomBase -> false)

    this.copy(truth = truth ++ newFalseEntries)
  }

  def applyClosedWorldToEverything: TruthAssignment = this.copy(truth = truth.withDefaultValue(false))

  def baseValue[A <: AtomBase](base: A): Option[A#RangeType] = {
    base match {
      case gnp: PredicateAtomBase => truth.get(gnp).map(_.asInstanceOf[A#RangeType])
      case fb: FunctionalAtomBase => functionValue.get(fb).map(_.asInstanceOf[A#RangeType])
    }
  }

  /**
    * @return A TruthAssignment that is constructed by using ta when this assignment is not defined.
    */
  def orElse(ta: TruthAssignment): TruthAssignment =
    TruthAssignment(truth ++ ta.truth, functionValue ++ ta.functionValue)

  def filterBase(p: AtomBase => Boolean): TruthAssignment =
    TruthAssignment(truth.filterKeys(p), functionValue.filterKeys(p))

  def partitionBase(p: AtomBase => Boolean): (TruthAssignment, TruthAssignment) = {
    val (tFirst, tSecond) = truth.partition(a => p(a._1))
    val (fFirst, fSecond) = functionValue.partition(a => p(a._1))
    (TruthAssignment(tFirst,fFirst), TruthAssignment(tSecond,fSecond))
  }

  def asEvidenceString(onlyTrueAtoms: Boolean = true): String = {
    val evidenceList =
      truth.collect{
        case (pab,t) if(!onlyTrueAtoms || t)=> (if(t) "%s" else "!%s").format(pab.toString)
      } ++
      functionValue.map{case (fab,target) => "%s = %s".format(fab.toString, target.toString)}
    evidenceList.toList.sorted.mkString("\n")
  }

  def asAtoms: Iterable[Atom] = {
    val preds = truth map {case (k,v) => PredicateAtom(k,v)}
    val funs = functionValue map {case (k,v) => FunctionalAtom(k,v)}
    preds ++ funs
  }
}

object TruthAssignment {
  def fromAtoms(atoms: Iterable[Atom]): TruthAssignment = {
    val truth = atoms.collect{
      case PredicateAtom(base, t) => (base -> t)
    }.toMap
    val funs = atoms.collect{
      case FunctionalAtom(base, t) => (base -> t)
    }.toMap
    TruthAssignment(truth,funs)
  }

  /** Combine a list of TruthAssignment objects into one object.
    *
    * @return None if there is a contradiction.
    */
  def flattenOptional(list: Seq[TruthAssignment]): Either[String,TruthAssignment] = {

    /** @return None if there are contradicting entries inside in. */
    def buildMap[A, B](in: TraversableOnce[(A, B)]): Either[String,Map[A, B]] = {
      import collection.mutable.HashMap

      val map = new HashMap[A, B]
      //if optimization becomes an issue here, maybe using HashMap.put instead of getOrElseUpdate might be faster
      for (e@(k, v) <- in) {
        if (map.getOrElseUpdate(k, v) != v) return Left("contradiction on %s".format(e.toString))
      }
      Right(map.toMap)
    }

    val (truth, function) = list.map(ta => (ta.truth, ta.functionValue)).unzip
    val (flatTruth, flatFunction) = (truth.flatten, function.flatten)
    for(
      truthMap <- buildMap(flatTruth).right;
      functionMap <- buildMap(flatFunction).right
    ) yield TruthAssignment(truthMap,functionMap)
  }

  /**
    * Same as flattenOptional, but throws instead of returning null.
    */
  def flatten(list: Seq[TruthAssignment]): TruthAssignment =
    flattenOptional(list).fold(s => throw new RuntimeException(s),identity)

  def empty: TruthAssignment = TruthAssignment(Map.empty, Map.empty)

  implicit val taAsMonoid = new Monoid[TruthAssignment]{
    def append(s1: TruthAssignment, s2: => TruthAssignment): TruthAssignment = flatten(List(s1,s2))
    val zero: TruthAssignment = TruthAssignment.empty
  }
}
