{-# LANGUAGE TypeOperators,
             ScopedTypeVariables,
             DataKinds,
             GADTs #-}
----------------------------------------------------------------------------------
module Flux.Typed.FluxGraph where

import Flux.Typed.Boxed
import Flux.Typed.Cat
import Flux.Typed.Graph
import Flux.Typed.Type
import Flux.Typed.Pattern
import Flux.Typed.Var
import Flux.Typed.Unification
import Flux.Typed.ReduceGraph

import Control.Monad
import Data.Proxy
import Data.Type.Equality
----------------------------------------------------------------------------------
type FluxGraphCon = GraphCon GraphState

withStart, withEnd, withStartEnd ::
  (GenType a, GenType b) => FluxGraphCon (Port a, Port b)
                         -> FluxGraphCon (Port a, Port b)
withStart graphCon = do
      start                 <- mkNode NodeStart
      a                     <- newPatVar
      (graph_in, graph_out) <- graphCon
      mkEdge_ (port start a) graph_in
      return (port start a, graph_out)

withEnd graphCon = do
      end                   <- mkNode NodeEnd
      b                     <- newPatVar
      (graph_in, graph_out) <- graphCon
      mkEdge_ graph_out (port end b)
      return (graph_in, port end b)

withStartEnd = withStart . withEnd

----------------------------------------------------------------------------------
packSubGraph :: FluxGraphCon (Port a, Port b) -> FluxGraphCon (Port a, Port b)
packSubGraph graphCon = do
  (f_in, f_out) <- graphCon
  n1 <- mkNode NodeStartInv
  n2 <- mkNode NodeEndInv
  return (f_in {portId = n1}, f_out {portId = n2})
----------------------------------------------------------------------------------
catToGraph :: forall a b. (GenType a, GenType b) => (a ~> b) -> Graph
catToGraph = reduceGraph . unifyPatterns . constructGraph . withStartEnd . mkGraph

newPatVar :: (GenType a) => FluxGraphCon (Pat a)
newPatVar = PatVar <$> newVar

newVar :: (GenType a) => FluxGraphCon (Var a)
newVar = (\n -> Var genType n Global) <$> newVarName
----------------------------------------------------------------------------------
optUncurryArrows = True
optApply         = True
----------------------------------------------------------------------------------
mkGraph :: (GenType a, GenType b) => (a ~> b) -> FluxGraphCon (Port a, Port b)
--------------------------------
mkGraph f | optUncurryArrows, Just graphCon <- unCurry f = graphCon
          | optApply,         Just graphCon <- apply f   = graphCon
  where

    --((f :&&& g) :>>> Apply) :: a ~> b
    --(f :&&& g) :: a ~> (c :-> b, c)
    --f :: a ~> (c :-> b)
    --g :: a ~> c

    {- Try to resolve cascades of apply nodes. They have the form:
       ((((... ((f :&&& gn) :>>> Apply) ...) :&&& g2) :>>> Apply) :&&& g1) :>>> Apply
       Create one apply node as the reification of the output of 'f', then connect
       the outputs from g1, ..., gn to the n input fields of the apply node. -}

    apply :: forall a b. (a ~> b) -> Maybe (FluxGraphCon (Port a, Port b))
    apply ((f :&&& g) :>>> Apply) = Just $ do
      a :: Pat a <- newPatVar
      c <- newPatVar
      appNode <- mkNode NodeApp
      patNode <- mkNode (NodePattern a)
      (f_in, f_out, i) <- apply' appNode f
      (g_in, g_out) <- mkGraph g
      mkEdge_ (port patNode a) f_in
      mkEdge_ (port patNode a) g_in
      mkEdge_ g_out ((port appNode c) {portField = Just ("in" ++ show i)})
      modifyNode (\node -> node {nodeIn = Just i}) appNode
      return (port patNode a, f_out)

    apply _ = Nothing

    -- ((f :&&& g) :>>> Apply) :: a ~> (c :-> b)
    -- f :&&& g :: a ~> (d :-> (c :-> b),d)
    -- f :: a ~> (d :-> (c :-> b))
    -- g :: a ~> d
    -- apply' appNode f :: (Port a, Port (c :-> b))

    apply' :: forall a b c. (GenType a, GenType b, GenType c) =>
              UniqueId -> (a ~> (c :-> b)) -> FluxGraphCon (Port a, Port b, Int)
    apply' appNode ((f :&&& g) :>>> Apply) = do
      a :: Pat a <- newPatVar
      b :: Pat b <- newPatVar
      c <- newPatVar
      patNode <- mkNode (NodePattern a)
      (f_in, f_out, i) <- apply' appNode f
      (g_in, g_out) <- mkGraph g
      mkEdge_ (port patNode a) f_in
      mkEdge_ (port patNode a) g_in
      mkEdge_ g_out ((port appNode c) {portField = Just ("in" ++ show i)})
      return (port patNode a, f_out {portData = b}, i+1)

    apply' appNode f = do
      b <- newPatVar
      (f_in, f_out) <- mkGraph f
      mkEdge_ f_out ((port appNode (portData f_out)) {portField = Just "w"})
      return (f_in, port appNode b, 1)

     {- Try to resolve uncurrying, i.e. categories of the form:
        (f1 :>>> (Uncurry (f2 :>>> (Uncurry (...  (fn :>>> (Uncurry (Arr (Ovar v))) ... )))))
        Create one node, labeled with 'v', translate f1, ..., fn and connect their outputs
        to the n input fields of the node 'v'
        The return type is Maybe in the case that we are not successful. -}

    unCurry :: forall a b. (a ~> b) -> Maybe (FluxGraphCon (Port a, Port b))
    unCurry (Uncurry (Arr v)) = return $ do
      a <- newPatVar
      b <- newPatVar
      c <- newPatVar
      n <- mkNode (NodePattern (a :* b))
      nodeArr <- mkNode (NodeArr v)
      modifyNode (\node -> node {nodeIn = Just 2}) nodeArr
      mkEdge_ (port n a) ((port nodeArr a) {portField = Just "in1"})
      mkEdge_ (port n b) ((port nodeArr b) {portField = Just "in2"})
      return (port n (a :* b), port nodeArr c)


    -- special case; can't it be merged with the others?
    unCurry (f :>>> (Uncurry (Arr v))) = return $ do
        nodeArr <- mkNode (NodeArr v)
        modifyNode (\node -> node {nodeIn = Just 2}) nodeArr
        (f_in, f_out) <- mkGraph f
        nodePat <- mkNode (NodePattern (portData f_out))
        mkEdge_ f_out (port nodePat (portData f_out))
        c :* d <- return (portData f_out)
        b :: Pat b <- newPatVar
        mkEdge_ (port nodePat c) ((port nodeArr c) {portField = Just "in1"})
        mkEdge_ (port nodePat d) ((port nodeArr d) {portField = Just "in2"})
        return (f_in, port nodeArr b)

    unCurry (f :>>> (Uncurry g)) = do
      graphCon <- unCurry' g
      return $ do
        (f_in, f_out)  <- mkGraph f
        (inPort, node, i, outPat) <- graphCon
        a :* b <- return $ portData f_out
        mkEdge_ (f_out {portData = a}) inPort
        mkEdge_ (f_out {portData = b}) ((port node b) {portField = Just ("in" ++ show i)})
        modifyNode (\node -> node {nodeIn = Just i}) node
        return (f_in, port node outPat)

    unCurry _ = Nothing

    {- If successful, unCurry' returns a tuple with the following constituents:
       (inPort, node, i, outPat), where "inPort" is the input port of a graph that still needs its
       input, i.e. the context/environment, or the graph will stay isolated. "node" is the unique
       id of the node that receives multiple inputs through multiple input fields.
       "i" is the current input field number that must be connected to the node. "outPat" is the
       form of the pattern for the output port. As recursion ascends, "i" will be incremented, and
       the appropriate input will be linked to its respective port of "node"; "outPat" will be
       reduced from a complex function pattern to one single output. -}

    unCurry' :: forall a b c. (a ~> (b :-> c)) -> Maybe (FluxGraphCon (Port a, UniqueId, Int, Pat c))
    -- We reached an arrow. That means we can create a node and start to ascend from recursion.
    unCurry' (f :>>> Arr v)
      -- f has type a ~> d, whereas v has type d :-> (b :-> c)
      | TypeFun (_ :: Type d) (TypeFun _ _) <- typeOf v = Just $ do
          c :: Pat c <- newPatVar
          d :: Pat d <- newPatVar
          node <- mkNode (NodeArr v)
          (f_in, f_out) <- mkGraph f
          mkEdge_ f_out ((port node d) {portField = Just ("in1")})
          return (f_in, node, 2, c)

    unCurry' (f :>>> (Uncurry g)) = do
      graphCon <- unCurry' g
      return $ do
        (f_in, f_out)  <- mkGraph f
        (inPort, node, i, outPat) <- graphCon
        a :* b <- return $ portData f_out
        mkEdge_ (f_out {portData = a}) inPort
        mkEdge_ (f_out {portData = b}) ((port node b) {portField = Just ("in" ++ show i)})
        case typeOf outPat of
            TypeFun _ (_ :: Type c) -> do
              c  :: Pat c <- newPatVar
              return (f_in, node, i+1, c)

    unCurry' _ = Nothing
--------------------------------
mkGraph Id = do
  a <- newPatVar
  n <- mkNode NodeId
  return (port n a, port n a)
--------------------------------
mkGraph (ForAll v f) = do
  (f_in, f_out) <- mkGraph f
  return (f_in, f_out {portData = PatForAll v (portData f_out)})
--------------------------------
mkGraph (f :>>> g) = do
  (f_in, f_out) <- mkGraph f
  (g_in, g_out) <- mkGraph g
  mkEdge_ f_out g_in
  return (f_in, g_out)
--------------------------------
mkGraph Fst = do
  a <- newPatVar
  b <- newPatVar
  n <- mkNode (NodePattern (a :* b))
  return (port n (a :* b), port n a)
--------------------------------
mkGraph Snd = do
  a <- newPatVar
  b <- newPatVar
  n <- mkNode (NodePattern (a :* b))
  return (port n (a :* b), port n b)
--------------------------------
mkGraph (Arr f) = do
  a <- newPatVar
  b <- newPatVar
  n <- mkNode (NodeArr f)
  return (port n a, port n b)
--------------------------------
mkGraph (Const c) = do
  a <- newPatVar
  case c of
    OLit l -> do
      n <- mkNode (NodeLiteral l)
      return (port n a, port n (PatLit l))

    OVar v -> do
      b <- newPatVar
      n <- mkNode (NodeVar v)
      return (port n a, port n b)

    OArr f -> do
      subgraphId    <- mkSubgraph
      g             <- newPatVar
      (f_in, f_out) <- withinSubGraph subgraphId (packSubGraph $ withStartEnd (mkGraph f))
      let inport  = f_in  {portCluster = Just subgraphId, portData = a}
          outport = f_out {portCluster = Just subgraphId, portData = g}
      return (inport, outport)
--------------------------------
mkGraph (f :&&& g) = do
  (f_in, f_out) <- mkGraph f
  (g_in, g_out) <- mkGraph g
  a <- newPatVar
  let b  = portData f_out
      c  = portData g_out
  fanNode     <- mkNode (NodePattern a)
  combineNode <- mkNode (NodePattern (b :* c))
  let fanPort = port fanNode a
  mkEdge_ fanPort f_in
  mkEdge_ fanPort g_in
  mkEdge_ f_out (port combineNode b)
  mkEdge_ g_out (port combineNode c)
  return (fanPort, port combineNode (b :* c))
--------------------------------
mkGraph (f :*** g) = do
  (f_in, f_out) <- mkGraph f
  (g_in, g_out) <- mkGraph g
  a     <- newPatVar
  b     <- newPatVar
  let c  = portData f_out
      d  = portData g_out
  splitNode   <- mkNode (NodePattern (a :* b))
  combineNode <- mkNode (NodePattern (c :* d))
  mkEdge_ (port splitNode a) f_in
  mkEdge_ (port splitNode b) g_in
  mkEdge_ f_out (port combineNode c)
  mkEdge_ g_out (port combineNode d)
  return (port splitNode (a :* b), port combineNode (c :* d))
--------------------------------
mkGraph Apply = do
  a  <- newPatVar
  b  <- newPatVar
  c  <- newPatVar
  patNode <- mkNode (NodePattern (a :* b))
  appNode <- mkNode NodeApp
  mkEdge_ (port patNode a) ((port appNode a) {portField = Just "w"})
  mkEdge_ (port patNode b) (port appNode b)
  return (port patNode (a :* b), port appNode c)
--------------------------------
mkGraph (Curry f) = do
  a <- newPatVar
  b <- newPatVar
  conNode      <- mkNode NodeId
  subgraphId   <- mkSubgraph
  (_, sub_out) <- withinSubGraph subgraphId $ packSubGraph $ withStartEnd $ do
      patNode <- mkNode (NodePattern (a :* b))
      (f_in, f_out) <- mkGraph f
      mkEdge_ (port patNode (a :* b)) f_in
      mkEdge_ (port conNode a) (port patNode a)
      return (port patNode b, f_out)
  g <- newPatVar
  return (port conNode a, sub_out {portCluster = Just subgraphId, portData = g})
  --------------------------------
mkGraph (Uncurry f) = do
  a <- newPatVar
  b <- newPatVar
  c <- newPatVar
  (f_in, f_out) <- mkGraph f
  patNode       <- mkNode (NodePattern (a :* b))
  appNode       <- mkNode NodeApp
  mkEdge_ (port patNode a) f_in
  mkEdge_ f_out ((port appNode (portData f_out)) {portField = Just "w"})
  mkEdge_ (port patNode b) (port appNode b)
  return (port patNode (a :* b), port appNode c)
--------------------------------
-- proc (c,a) -> do
--   let [(p1,f1),...,(pn,fn)] = cs
--   case a of
--      p1 -> f1 -< (c,a)
--      ...
--      pn -> fn -< (c,a)
--------------------------------
mkGraph (Case cs) = do
  context   <- newPatVar
  scrutinee <- newPatVar
  let inPat = context :* scrutinee

  splitNode <- mkNode (NodePattern inPat)
  -- Because the case-patterns are originally taken from the lambda-expression /
  -- Haskell-code itself, we have to make all variables in the patterns unique.
  --newPats <- forM cs $ \(cPat, _) -> traverseVars (\v -> if isVar v then newVar else return v) cPat
  let newPats = map fst cs

  caseNode <- mkNode (NodeCase newPats)
  ports    <- mapM (mkGraph . snd) cs

  -- if there is only one possible pattern-match, we can directly unify our input with it
  case newPats of
    [pat] -> mkEdge_ (port splitNode scrutinee) ((port caseNode pat)       {portField = Just "w"})
    _     -> mkEdge_ (port splitNode scrutinee) ((port caseNode scrutinee) {portField = Just "w"})

  outPorts <- sequence [ do
      let field = Just ("<case" ++ show k ++ ">")
          casePort = (port caseNode cPat) {portField = field}
      helpNode <- mkNode (NodePattern (context :* cPat))
      mkEdge_ (port splitNode context) (port helpNode context)
      mkEdge_ casePort (port helpNode cPat)
      mkEdge_ (port helpNode (context :* cPat)) f_in
      return (casePort, f_out)
    | (cPat, ((f_in, f_out), k)) <- zip newPats (zip ports [1 .. length cs]) ]

  -- if there is only one case, we don't need to join possible results
  case outPorts of
    [(_,outPort)] -> return (port splitNode inPat, outPort)
    _             -> do
        outPortPattern <- newPatVar
        joinNode <- mkNode (NodeJoin (length outPorts))
        forM_ (zip outPorts ([1..] :: [Int])) $ \((casePort, outPort), k) -> do
            let field = Just ("<join" ++ show k ++ ">")
                a = portData outPort
            mkEdge_ outPort  ((port joinNode a) {portField = field})
            ctrl1 :: Pat TyCtrl <- newPatVar
            ctrl2 :: Pat TyCtrl <- newPatVar
            mkEdge_ (casePort {portData = ctrl1}) ((port joinNode ctrl2) {portField = field})
        return (port splitNode inPat, port joinNode outPortPattern)
--------------------------------
mkGraph (Loop f) = do
  a <- newPatVar
  b <- newPatVar
  c <- newPatVar
  let ac = a :* c
      bc = b :* c
  patIn         <- mkNode (NodePattern ac)
  patOut        <- mkNode (NodePattern bc)
  (f_in, f_out) <- mkGraph f
  mkEdge_ (port patIn ac) f_in
  mkEdge_ f_out (port patOut bc)
  mkEdge_ (port patOut c) (port patIn c)
  return (port patIn a, port patOut b)
--------------------------------
mkGraph (ExPat p p') = do
  n <- mkNode (NodePattern p)
  return (port n p, port n p')
--------------------------------
----------------------------------------------------------------------------------
-- | unify all patterns within the graph, including all edges and all nodes that
 --somehow contain patterns
unifyPatterns :: Graph -> Graph
unifyPatterns graph =
  case unifylist srcPatterns dstPatterns of
    [subst] -> applySubstToGraph subst graph
    _       -> error "could not unify graph"
  where
    es :: [Edge]
    es = allEdges graph
    srcPatterns, dstPatterns :: [Boxed Pat]
    srcPatterns = map (\e -> boxed (Boxed . portData) (edgeSrc e)) es
    dstPatterns = map (\e -> boxed (Boxed . portData) (edgeDst e)) es

applySubstToGraph :: Subst VarName (Boxed Pat) -> Graph -> Graph
applySubstToGraph s = mapGraph substNode (s@@)
  where
    substNode :: NodeType -> NodeType
    substNode (NodePattern p) | Boxed p' <- s @@ (Boxed p) = NodePattern p'
    substNode (NodeCase (ps :: [Pat t])) =
      NodeCase ((map (\pat ->
          case s@@(Boxed pat) of
              Boxed pat' | Just Refl <- testEquality pat' pat -> pat'
                         | otherwise -> pat) ps) :: [Pat t])
    substNode n = n
--------------------------------------------------------------------------------
