package edu.uulm.scbayes.mln.generator

import util.Random
import edu.uulm.scbayes.mln._
import edu.uulm.scbayes.logic._
import edu.uulm.scbayes.util._

class MLNGenerator(val rand: Random = new Random) {

  /** Pick num elements without putting back. */
  def pickUniqueFromSeq[A](elems: Seq[A], num: Int): Set[A] = {
    assert(num <= elems.size)
    Iterator.iterate(Set.empty[A])(set => set + elems.pickRandom(rand)).dropWhile(_.size < num).next()
  }

  def pickFromSeq[A](elems: Seq[A], num: Int): Seq[A] =
    Iterator.continually(elems.pickRandom(rand)).take(num).toSeq

  def genSignature(timeSteps: Int, numSorts: Int, numPredicates: Int): Signature = {
    val sorts = (1 to numSorts).map(n => new Sort("sort%d".format(n)))
    val timeSort = new Sort("time")

    val constants = sorts
      .map {
      sort => sort -> ((0 to rand.nextInt(20)).map(n => Constant(n.toString, sort))).toSet
    }
      .toMap

    val timeDomain = timeSort -> (0 to timeSteps).map(n => Constant("t%d".format(n), timeSort)).toSet

    val normalPredicates = (1 to numPredicates)
      .map(n => new NormalPredicateDefinition("np%d".format(n), pickUniqueFromSeq(sorts, rand.nextInt(3)).toIndexedSeq))

    val functionPredicate = (1 to numPredicates)
      .map(n => new FunctionalPredicateDefinition("nf%d".format(n), pickUniqueFromSeq(sorts, rand.nextInt(3) max 1).toIndexedSeq))

    val dynamicNormalPredicates = (1 to numPredicates)
      .map(n => new NormalPredicateDefinition("dp%d".format(n), timeSort +: pickUniqueFromSeq(sorts, rand.nextInt(2)).toIndexedSeq))

    val dynamicFunctionPredicates = (1 to numPredicates)
      .map(n => new FunctionalPredicateDefinition("df%d".format(n), timeSort +: pickUniqueFromSeq(sorts, rand.nextInt(2) max 1).toIndexedSeq))

    Signature(sorts.toSet + timeSort,
      (normalPredicates ++ functionPredicate ++ dynamicNormalPredicates ++ dynamicFunctionPredicates).toSet,
      constants + timeDomain
    )
  }

  def genStaticFormula(signature: Signature): Formula =
    buildCNF(genAtoms(signature,3), rand.nextInt(5) max 1, rand.nextInt(5) max 1)

  def weighFormula(f: Formula): MLNWeightedFormula =
      MLNWeightedFormula(f,(if(rand.nextDouble() < 0.3) None else Some(rand.nextDouble() * 5)))

  //
  //  /** Generates a "markovian" formula. Thus state t+1 somehow depends on state t. */
  //  def genDynamicFormula(signature: Signature, rand: Random = new Random): Seq[MLNWeightedFormula] = {
  //
  //  }

  def genAtoms(signature: Signature, num: Int, useTime: Boolean = false): Seq[Predicate] = {
    val usedPredicates: Seq[AtomDefinition] = pickFromSeq(signature.predicates.toSeq, num)

    val sorts: Seq[Sort] = (for (pred <- usedPredicates; sort <- pred.signature) yield sort).distinct

    val maxSortCounts: Map[Sort, Int] =
      sorts.map(s => s -> usedPredicates.map(pred => pred.signature.count(_ == s)).max)
        .toMap + ((Sort("time"), 1))

    //produce variables for a fraction of the occurrences of one sort
    val variables: Map[Sort, Seq[Variable]] = maxSortCounts.map {
      case (s, n) => s -> (1 to n).map(n => new Variable("v%s_%d".format(s.name, n)))
    }

    //fill variables into predicates
    usedPredicates.map(pd => Predicate(pd, pd.signature.map(variables(_).pickRandom(rand)).toList))
  }

  def buildCNF(preds: Seq[Predicate], numClauses: Int, clauseSize: Int): Formula = {
    def buildClause(): Disjunction =
      Disjunction(Iterator.continually(
        if(rand.nextBoolean())Negation(preds.pickRandom(rand)) else preds.pickRandom(rand)
      ).take(clauseSize).toSet)

    Conjunction(Iterator.continually(buildClause()).take(numClauses).toSet)
  }


}