{-# LANGUAGE GADTs #-}
----------------------------------------------------------------------------------
module Flux.Typed.ReduceGraph(reduceGraph) where

import Flux.Typed.Boxed
import Flux.Typed.Graph
import Flux.Typed.Pattern
import Flux.Typed.Type
import Flux.Util

import Data.Type.Equality
import Control.Monad(forM_, mplus)
import Control.Monad.State
----------------------------------------------------------------------------------
-- | compute all sub-patterns, direct and indirect (recursive) ones
allSubPatterns :: GenType a => Pat a -> [Boxed Pat]
allSubPatterns p = case p of
  PatApp q1 q2   -> Boxed p : (allSubPatterns q1 ++ allSubPatterns q2)
  q1 :* q2       -> Boxed p : (allSubPatterns q1 ++ allSubPatterns q2)
  PatForAll _ p' -> Boxed p : allSubPatterns p'
  _              -> [Boxed p]
----------------------------------------------------------------------------------
reduceGraph :: Graph -> Graph
reduceGraph graph = finalSweep $ constructGraphFrom graph (constructSubgraphs graph)
  where
    finalSweep :: Graph -> Graph
    finalSweep g = constructGraphFrom g sweep
      where
        sweep :: GraphCon GraphState ()
        sweep = do
          forM_ (allNodes g) $ \(Node {nodeId = n}) ->
            case (getEdgesTo g n, getEdgesFrom g n) of
              ([],[]) -> removeNode n
              _       -> return ()

    constructSubgraphs :: Graph -> GraphCon GraphState ()
    constructSubgraphs g = do
      let essentialNodes = filter (not . isNonEssential . nodeType) (nodes g)
      setNodes essentialNodes
      setEdges []
      forM_ (map nodeId essentialNodes) $ \n ->
        forM_ (getInputPortsFor n) $ \p@(Boxed dst) ->
          case getInputFor p of
            Just (Boxed src) | Just Refl <- testEquality src dst -> mkEdge_ src dst
            _                                                    -> return ()
      forM_ (subgraphs g) $ \subgraph -> withinSubGraph (graphId subgraph) $
          constructSubgraphs subgraph

    isNonEssential :: NodeType -> Bool
    isNonEssential NodeId          = True
    isNonEssential (NodePattern _) = True
    isNonEssential _               = False

    getInputPortsFor :: UniqueId -> [Boxed Port]
    getInputPortsFor n
      | Just n' <- getNode n, nodeType n' `elem` [NodeStart, NodeStartInv, NodeEndInv] = []
      | otherwise = concatMap (\Edge {edgePorts = Ports (_,o)} -> [Boxed o]) (edgesTo n)

    getOutputPortsFor :: UniqueId -> [Boxed Port]
    getOutputPortsFor n = concatMap (\Edge {edgePorts = Ports (i,_)} -> [Boxed i]) (edgesFrom n)

    getInputFor :: Boxed Port -> Maybe (Boxed Port)
    getInputFor (Boxed p@(Port {portData = pat})) = evalState (getInputFor' p) []
      where
        getInputFor' :: GenType a => Port a -> State [UniqueId] (Maybe (Boxed Port))
        getInputFor' p@(Port {portId = n}) = do
          ns <- get
          if n `elem` ns
            then return Nothing
            else modify (n:) >> foldM f Nothing (predPorts p)
          where
            f x (Boxed src@(Port {portData = pat'}))
              | Just n' <- getNode n,
                NodeLiteral _ <- nodeType n'         = return Nothing

              | Just n' <- getNode n,
                NodeVar _ <- nodeType n'             = return Nothing

              | Just n' <- getNode (portId src),
                isNonEssential (nodeType n')         = do y <- getInputFor' src; return (y `mplus` x)

              | Just Refl <- testEquality pat pat',
                pat == pat'                          = return (Just (Boxed src))

              | Boxed pat `elem` allSubPatterns pat' = return (Just (Boxed (src {portData = pat})))

--              | Just n' <- getNode (portId src),
--                NodeCase _ <- nodeType n',
--                Boxed pat `elem` allSubPatterns pat' = return (Just (Boxed (src {portData = pat})))

              | otherwise                            = return x

    getNode   = getNodeFor graph
    edgesTo   = getEdgesTo graph
    edgesFrom = getEdgesFrom graph

    -- | Return predecessor ports for a given port. First, ports that are directly
    -- connected to this port. Second, ports that are connected to the same node,
    -- but not this specific port.wikti
    predPorts :: Port a -> [Boxed Port]
    predPorts p@(Port {portId = n}) =
      concat [ case testEquality p dst of Just Refl | p == dst -> [Boxed src]; _ -> []
             | Edge {edgePorts = Ports (src,dst)} <- edgesTo n ]
          ++ [ Boxed src | Edge {edgePorts = Ports (src,_)} <- edgesTo n ]

-- Issues:
-- - Check for loops when searching backwards
----------------------------------------------------------------------------------
