package edu.uulm.scbayes.tools

import edu.uulm.scbayes.mln.factorgraph.{MLNFactorGraph, TruthAssignmentMarginals, LogicNode}
import edu.uulm.scbayes.factorgraph.DiscreteFactor
import edu.uulm.scbayes.probabilities.{DiscreteMarginals, DiscreteVariable}
import edu.uulm.scbayes.logic.{TemporalSignature, Constant, TruthAssignment}
import scala.util.Random
import edu.uulm.scbayes.mln.parsing.ParseHelpers
import java.io.{PrintStream, FileOutputStream, OutputStream, File}
import edu.uulm.scbayes.inference.sampling.{MutableMCSatInferer, MutableGibbsStepper}
import edu.uulm.scbayes.inference._
import edu.uulm.scbayes.inference.bp.FloodingBeliefPropagationStepper
import edu.uulm.scbayes.mln.MarkovLogicNetwork
import edu.uulm.scbayes.inference.incremental.NSlide2S
import edu.uulm.scbayes.inference.exact.CountingInferer

/**
 * Executable for probabilistic inference in MLNs.
 * Shall be similar to alchemy's *infer* executable.
 *
 * `infer `
 */

object inferMLN {
  type DV = DiscreteVariable
  type DF = DiscreteFactor[DV]
  sealed trait Inference {
    def inferer(config: Config): NInferer
  }

  case object Gibbs extends Inference {
    override def inferer(c: Config) = MutableGibbsStepper.forSteps(c.mcmcSteps,c.randomSeed)
  }
  case object BP extends Inference {
    override def inferer(config: Config): NInferer = FloodingBeliefPropagationStepper.forSteps(config.bpSteps,config.randomSeed)
  }
  case object MCSat extends Inference {
    override def inferer(c: Config): NInferer = new MutableMCSatInferer().forSteps(c.mcmcSteps,c.randomSeed)
  }
  case object Counting extends Inference {
    override def inferer(c: Config): NInferer = CountingInferer
  }

  val algorithms = Map("COUNT" -> Counting, "BP" -> BP,"MCSAT" -> MCSat, "GIBBS" -> Gibbs)
  def parseAlgorithm(description: String): Inference = algorithms(description)

  case class Config(
    mlnFile: File = new File("."),
    closedWorldPredicates: Seq[String] = Seq(),
    randomSeed: Random = new Random(0),
    evidenceFiles: Seq[File] = Seq(),
    algorithm: Inference = Gibbs,
    mcmcSteps: Int = 1000,
    bpSteps: Int = 50,
    useSliceInference: Boolean = false,
    printGraph: Boolean = false,
    outputUAI: Boolean = false,
    noInference: Boolean = false,
    output: PrintStream = new PrintStream(System.out)
  ){
    def inferer = algorithm.inferer(this)
  }

  val parser = new scopt.OptionParser[Config]("infer"){
    head("infer","1.0")
    arg[File]("<mln-file>")
      .action{case (f,c) => c.copy(mlnFile=f)}
    opt[String]('a',"algorithm")
      .action{case (a,c) => c.copy(algorithm=parseAlgorithm(a))}
      .text("select algorithm used for inference; one of " + algorithms.keys.mkString(", "))
    opt[File]('e',"evidence")
      .action{case (e,c) => c.copy(evidenceFiles = e +: c.evidenceFiles)}
      .text("file containing evidence; can be given multiple times")
    opt[File]('o', "outfile")
      .action{case (o,c) => c.copy(output = new PrintStream(new FileOutputStream(o)))}
      .text("set output file; default is stdout")
    opt[Int]("seed") action {case (x,c) => c.copy(randomSeed=new Random(x))} text "used to seed the PRNG"
    opt[String]("cw")
      .action{case (p,c) => c.copy(closedWorldPredicates = p +: c.closedWorldPredicates)}
      .text("apply closed-world assumption to this predicate; can be given multiple times")
    opt[Int]("mcmcSteps")
      .action{case (s,c) => c.copy(mcmcSteps = s)}
      .text("how many samples to take for MCMC inference")
    opt[Unit]("slice-dynamic")
      .action{case (_,c) => c.copy(useSliceInference=true)}
      .text("use slice-wise inference for temporal models")
    //alternative actions to inference go here
    opt[Unit]("print-graph")
      .action{case (_,c) => c.copy(printGraph = true, noInference = true)}
      .text("just output graphviz description of ground factor graph and quit")
    opt[Unit]("to-uai")
      .action{case(_,c) => c.copy(outputUAI = true, noInference = true)}
      .text("just output ground factor graph in uai format and quit")
    help("help") text "prints this usage text"
  }

  /**
   * Run a given inference algorithm using the slicing approach.
   * @see Helpers.normalInference
   *
   * @param timeSteps Run the inference over this many time steps.
   */
  def NslidingInference(dmln: MarkovLogicNetwork,
                        evidence: TruthAssignment,
                        timeSteps: Int,
                        inferer: NInferer): DiscreteMarginals[LogicNode] = {
    val incrementalInferer = new NSlide2S(inferer)
    val incrementalInferences = Iterator.iterate((incrementalInferer.createInitialResult(dmln.timeShift(0,1),evidence),1)){
      case (state, time) =>
        println("step: " + time)
        (incrementalInferer.createIncrementalResult(state, Set(Constant((time + 1).toString, TemporalSignature.timeSort)), TruthAssignment.empty), time + 1)
    }.drop(timeSteps)
    incrementalInferer.marginals(incrementalInferences.next()._1)
  }

  def main(args: Array[String]) {

    parser.parse(args,Config()) foreach { config =>
      //create rng either from given seed of from system clock
      val mln = ParseHelpers.loadMLNFile(config.mlnFile.toString).makePositive

      val signature = mln.signature
      val evidencePre = TruthAssignment.flatten(config.evidenceFiles.map(_.toString).map(ParseHelpers.loadEvidenceFile(_,mln.signature)))
      val cwPredicates = config.closedWorldPredicates.map(pString => signature.predicates.find(_.name == pString).get)

      if(!cwPredicates.isEmpty)
        println("applying closed-world assumption to %s".format(cwPredicates.mkString(", ")))

      val evidence = cwPredicates.foldLeft(evidencePre){case (ev,pd) => ev.applyClosedWorld(pd,signature)}
      assert(cwPredicates.forall(!_.isDynamic), "closed world assumption is only allowed for static predicates")

      val graph: MLNFactorGraph = mln.toGraph(evidence)

      //other actions
      if(config.printGraph){
        config.output.println(graph.toDOT)
      }
      if(config.outputUAI){
        val uAI: (String, Map[Int, LogicNode]) = graph.toUAI
        config.output.println(uAI._1)
        println(uAI._2.mkString("\n"))
      }
      if(config.noInference){
        System.exit(0)
      }

      val inferenceMarginals = config.useSliceInference match {
        case true  => NslidingInference(mln,evidence,10,config.inferer)
        case false => config.inferer.infer(graph)
      }

      val marginals = inferenceMarginals.orElse(new TruthAssignmentMarginals(evidence))

      val probStrings = for{
        atom <- mln.signature.herbrandBase
        variable = LogicNode(atom, mln.signature)
        value <- variable.getRange
        prob = marginals.marginal(variable, value)
      } yield "P(%s = %s) = %f".format(variable, variable.int2Domain(value), prob)

      config.output.println(probStrings.mkString("\n"))
    }
  }
}