package edu.uulm.scbayes.mln.parsing

import edu.uulm.scbayes.logic._
import util.parsing.combinator.JavaTokenParsers
import edu.uulm.scbayes.mln.MLNWeightedFormula

/**
  * Combinator parser for first-order logic formulas.
  *
  *
  * Date: 3/11/11
  */

object FOLParser extends JavaTokenParsers {

  protected override val whiteSpace = """(\s|//.*|(?m)/\*(\*(?!/)|[^*])*\*/)+""".r

  val predname = "[a-zA-Z_]*".r
  val sortname = "[a-zA-Z_]*".r
  val constantLiteral = "[A-Z][a-zA-Z0-9_]*".r
  val variableLiteral = "\\??[a-z][a-zA-Z0-9_]*".r
  val integerLiteral = "-?[0-9]+".r
  val succFunctionLiteral = "succ".r
  val isALiteral = "isA".r

  implicit def withStripEither[T](p: Parser[Either[String,T]]) = new {
    def stripEither: Parser[T] = p >> {
      case Left(s) => failure(s)
      case Right(t) => success(t)
    }
  }

  def predicateDefinition(implicit sM: SignatureBuilder): Parser[AbstractPredicateDefinition] =
    (predname ~ ("(" ~> repsep(sortname, ",") ~ opt("!") <~ ")") ^^ {
      case (pname ~ (args ~ None)) => sM.definePredicate(pname, args, false)
      case (pname ~ (args ~ Some("!"))) => sM.definePredicate(pname, args, true)
    }).stripEither

  def isAStatement(implicit sM: SignatureBuilder): Parser[Unit] = sortname ~ isALiteral ~ sortname ^^ {
    case (subSort ~ _ ~ superSort) => sM.addIsA(subSort,superSort)
  }

  def domainDefinition(implicit sM: SignatureBuilder) = enumeratingDomainDefinition | rangeDomainDefinition

  /**
    * Parse a domain definition that assigns a set of constants as the domain of a sort.
    * Example:
    * Emotion = {Happy, Sad}
    */
  def enumeratingDomainDefinition(implicit sM: SignatureBuilder) =
    (sortname ^^ (sM.getOrCreateSort(_))) >> (sort => ("=" ~ "{") ~> rep1sep(constantDefinition(sort), ",") <~ "}" ^^ {
      case constants => sM.defineConstants(sort, constants)
    }
  )

  def rangeDomainDefinition(implicit sM: SignatureBuilder): Parser[Unit] =
    (sortname ^^ (sM.getOrCreateSort(_))) >> (sort => "=" ~> ("{" ~> integerLiteral ~ (("-" | ",...,") ~> integerLiteral <~ "}")) ^^ {
      case (lower ~ upper) => sM.defineConstantRange(sort.name, lower.toInt, upper.toInt)
    })

  def variable = variableLiteral ^^ (new Variable(_))
  def term(sort: Sort)(implicit sM: AtomBuilder): Parser[Term] = succ | fadd | variable | constantUsage(sort)

  def constantUsage(sort: Sort)(implicit sM: AtomBuilder): Parser[Constant] =
    (constantLiteral | integerLiteral) ^^ ((c: String) => sM.lookupOrCreateConstants((c, sort) :: Nil, Nil, Nil)._1.head)

  def constantDefinition(sort: Sort)(implicit sM: SignatureBuilder): Parser[Constant] =
    (constantLiteral | integerLiteral) ^^ {Constant(_,sort)}

  def succ(implicit sM: AtomBuilder): Parser[Succ] = ("succ" ~ "(" ~> variable <~ ")") ^^ (Succ(_))

  def fadd(implicit sM: AtomBuilder): Parser[FAdd] = variable ~ ( "+" ~> integerLiteral) ^^ {
    case t ~ intLit => FAdd(t,intLit.toInt)
  }

  def predicate(implicit sM: AtomBuilder): Parser[AbstractPredicateDefinition] = (predname ^^ (sM.getPredicateByName(_))).stripEither

  def withDelim[T](parsers: List[Parser[T]], delim: String): Parser[List[T]] = parsers match {
    case Nil => success(Nil)
    case head :: Nil => head ^^ (_ :: Nil)
    case head :: tail => (head <~ delim) ~ withDelim(tail, delim) ^^ {case h ~ t => h :: t}
  }

  def predicateUsage(implicit sM: AtomBuilder): Parser[Predicate] = (predicate >> ((predDef: AbstractPredicateDefinition) =>
    "(" ~> withDelim(predDef.signature.map(term(_)).toList, ",") <~ ")" ^^ (sM.createPredicateInstanceFixSorts(predDef.name, _))
    )).stripEither

  def conjunction(implicit sM: AtomBuilder): Parser[Formula] = rep1sep(atom, "^") ^^ {
    case single :: Nil => single
    case formulas => Conjunction(formulas.toSet)
  }

  def disjunction(implicit sM: AtomBuilder) = rep1sep(conjunction, "v") ^^ {
    case single :: Nil => single
    case formulas => Disjunction(formulas.toSet)
  }

  def atom(implicit sM: AtomBuilder): Parser[Formula] = opt("!") ~ (predicateUsage | ("(" ~> formula <~ ")")) ^^ {
    case (Some(exclamation) ~ f) => Negation(f)
    case (None ~ f) => f
  }

  def implication(implicit sM: AtomBuilder) = disjunction ~ "=>" ~ disjunction ^^ {
    case d1 ~ arrow ~ d2 => Disjunction(Set(Negation(d1), d2))
  }

  def equivalence(implicit sM: AtomBuilder) = disjunction ~ "<=>" ~ disjunction ^^ {
    case d1 ~ arrow ~ d2 => Disjunction(Set(
      Conjunction(Set(d1, d2)),
      Conjunction(Set(Negation(d1), Negation(d2)))
    ))
    //    case d1 ~ arrow ~ d2 => Conjunction(Set(
    //      Disjunction(Set(Negation(d1),d1)),
    //      Disjunction(Set(d1,Negation(d2)))
    //    ))
  }

  def quantForall(implicit sM: AtomBuilder) = "FORALL" ~> variableLiteral ~ formula ^^ {
    case (v ~ f) => ForAll(Variable(v), f)
  }

  def quantExists(implicit sM: AtomBuilder) = "EXISTS" ~> variableLiteral ~ formula ^^ {
    case (v ~ f) => Existential(Variable(v), f)
  }

  def formula(implicit sM: AtomBuilder): Parser[Formula] = (quantExists
    | quantForall
    | implication
    | equivalence
    | disjunction
    )

  def variableInequality: Parser[(Variable,Variable)] = variable ~ "!=" ~ variable ^^ {case v1 ~ _ ~ v2 => (v1,v2)}
  def variableInequalities: Parser[Set[(Variable,Variable)]] = "[" ~> repsep(variableInequality,",")<~ "]" ^^ (_.toSet)

  def weightedFormula(implicit sM: AtomBuilder): Parser[MLNWeightedFormula] =
    opt(floatingPointNumber | "-?Infinity".r) ~ formula ~ opt(".") ~ (opt(variableInequalities) ^^ (_.getOrElse(Set.empty[(Variable,Variable)]))) >> {
      case Some(w) ~ f ~ Some(fs) ~ _=> failure("formula \"%s\"can't have weight and be rigid".format(f))
      case None ~ f ~ None ~ _=> failure("formula \"%s\" must be either rigid or have a weight".format(f))
      case x => success(x)
    } ^^ {
      case Some(w) ~ f ~ None ~ constr => MLNWeightedFormula(f, Some(w.toDouble),constr)
      case None ~ f ~ Some(fs) ~ constr => MLNWeightedFormula(f, None, constr)
    }

  def mlnFileParser(implicit sM: SignatureBuilder): Parser[(Signature, List[MLNWeightedFormula])] =
    rep(predicateDefinition) ~ rep(isAStatement) ~ rep(domainDefinition) ~> rep(weightedFormula(sM.getSignature)) ^^ ((sM.getSignature,_))

  /*                       stuff for evidence files below                               */

  def evidenceAtom(implicit sig: AtomBuilder): Parser[(Atom,Seq[Constant])] = predname ~ ("(" ~> repsep(integerLiteral | constantLiteral, ",") <~ ")") ^^ {
    case pn ~ args => sig.buildAtom(pn, args)
  }

  def evidence(implicit sig: AtomBuilder): Parser[(TruthAssignment,Seq[Constant])] = opt("!") ~ evidenceAtom ^^ {
    case Some("!") ~ (((a: PredicateAtom),newConstants)) => (TruthAssignment(Map(a.base -> false), Map()),newConstants)
    case Some("!") ~ (((a: FunctionalAtom),_)) => throw new RuntimeException("negative functional evidence not implemented")
    case None ~ (((a: PredicateAtom),newConstants)) => (TruthAssignment(Map(a.base -> true), Map()),newConstants)
    case None ~ (((a: FunctionalAtom),newConstants)) => (TruthAssignment(Map(), Map(a.base -> a.target)),newConstants)
  }

  def multipleEvidence(implicit sig: AtomBuilder): Parser[(TruthAssignment,Seq[Constant])] = rep(evidence) ^^ {
    case ev => {
      val (tas,constants) = ev.unzip
      (TruthAssignment.flatten(tas),constants.flatten.distinct)
    }
  }
}