{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
--------------------------------------------------------------------------------
module Flux.Typed.Graph where
--------------------------------------------------------------------------------
import Flux.Typed.Boxed
import Flux.Typed.Var
import Flux.Typed.Type
import Flux.Util

import Data.Map(Map)
import qualified Data.Map as Map
import Data.Type.Equality
import Control.Monad.State
--------------------------------------------------------------------------------
type UniqueId = Int
--------------------------------------------------------------------------------
data GraphType  =
    DGraph   -- ^ graph is a top-level directed graph
  | SubGraph -- ^ graph is a sub-graph
  deriving (Eq,Show)
--------------------------------------------------------------------------------
-- | representation of a graph
data Graph = Graph {
    graphType      :: GraphType,
    graphId        :: UniqueId,           -- ^ the graph's unique identifier
    graphEdges     :: Map UniqueId Edge,
    graphNodes     :: Map UniqueId Node,
    graphSubGraphs :: Map UniqueId Graph  -- ^ a graph can have many subgraphs
} deriving (Eq,Show)
--------------------------------------------------------------------------------
data NodeType :: * where
    NodeId        :: NodeType
    NodeArr       :: Var (a :-> b) -> NodeType
    NodeStart     :: NodeType
    NodeEnd       :: NodeType
    NodeApp       :: NodeType
    NodePattern   :: GenType t => Pat t -> NodeType
    NodeLiteral   :: Lit t -> NodeType
    NodeVar       :: Var t -> NodeType
    NodeCase      :: GenType t => [Pat t] -> NodeType
    NodeJoin      :: Int -> NodeType
    NodeStartInv  :: NodeType
    NodeEndInv    :: NodeType
--------------------------------------------------------------------------------
deriving instance Show NodeType
--------------------------------------------------------------------------------
instance Eq NodeType where
  NodeId              == NodeId              = True
  NodeStartInv        == NodeStartInv        = True
  NodeEndInv          == NodeEndInv          = True
  NodeArr v1          == NodeArr v2
    | Just Refl <- testEquality v1 v2        = v1 == v2
  NodeStart           == NodeStart           = True
  NodeEnd             == NodeEnd             = True
  NodeApp             == NodeApp             = True
  NodePattern p1      == NodePattern p2
    | Just Refl <- testEquality p1 p2        = p1 == p2
  NodeLiteral l1      == NodeLiteral l2
    | Just Refl <- testEquality l1 l2        = l1 == l2
  NodeVar v1          == NodeVar v2
    | Just Refl <- testEquality v1 v2        = v1 == v2
  NodeCase ps1@(p1:_) == NodeCase ps2@(p2:_)
    | Just Refl <- testEquality p1 p2        = ps1 == ps2
  NodeJoin i          == NodeJoin j          = i == j
  _                   == _                   = False
--------------------------------------------------------------------------------
data Node = Node {
    nodeId      :: UniqueId,  -- ^ the node's unique identifier
    nodeType    :: NodeType,  -- ^ data associated with the node
    nodeIn      :: Maybe Int, -- ^ number of input fields
    nodeOut     :: Maybe Int  -- ^ number of output fields
} deriving (Eq,Show)
--------------------------------------------------------------------------------
-- | an edge simply connects two ports
data Edge = Edge {
    edgeId    :: UniqueId,
    edgePorts :: Ports
} deriving (Eq,Show)

edgeSrc :: Edge -> Boxed Port
edgeSrc e | Ports (src,_) <- edgePorts e = Boxed src

edgeDst :: Edge -> Boxed Port
edgeDst e | Ports (_,dst) <- edgePorts e = Boxed dst
--------------------------------------------------------------------------------
data Ports where Ports :: GenType a => (Port a, Port a) -> Ports
deriving instance Show Ports

instance Eq Ports where
  Ports (src1,dst1) == Ports (src2,dst2)
    | Just Refl <- testEquality src1 src2,
      Just Refl <- testEquality dst1 dst2 = src1 == src2 && dst1 == dst2
    | otherwise                           = False
--------------------------------------------------------------------------------
-- | a port exactly specifies the source or the target of an edge
data Port a = Port {
    portId      :: UniqueId,       -- ^ the id of the associated node
    portData    :: Pat a,          -- ^ additional data stored at this port
    portField   :: Maybe String,   -- ^ name refering to a field of the node
    portCluster :: Maybe UniqueId  -- ^ the port can additionally refer to a
                                   -- cluster where the edge will be clipped
} deriving (Eq,Show)

instance TestEquality Port where
  testEquality p1 p2 = do
    Refl <- testEquality (portData p1) (portData p2)
    return Refl

instance Typed Port where
  typeOf p = typeOf (portData p)

instance Show (Boxed Port) where
  show (Boxed p) = show p
--------------------------------------------------------------------------------
-- | state used in the construction of a graph, supplying an infinite number of
-- unique identifiers and names
data GraphState = GraphState {
    getGraphFromState :: Graph,     -- ^ current graph on which modifications are applied to
    idCounter         :: UniqueId,  -- ^ supply of identifiers
    varPool           :: [VarName]  -- ^ supply of variable names
}
--------------------------------------------------------------------------------
nodes :: Graph -> [Node]
nodes = Map.elems . graphNodes

-- | all nodes, including those from all subgraphs
allNodes :: Graph -> [Node]
allNodes g = nodes g ++ concatMap allNodes (subgraphs g)

getNodeFor :: Graph -> UniqueId -> Maybe Node
getNodeFor g i =
    let gs = allSubgraphs g
    in (Map.lookup i (Map.unions (map graphNodes (g:gs))))

edges :: Graph -> [Edge]
edges = Map.elems . graphEdges

-- | all edges, including those from all subgraphs
allEdges :: Graph -> [Edge]
allEdges g = edges g ++ concatMap allEdges (subgraphs g)

subgraphs :: Graph -> [Graph]
subgraphs g = Map.elems (graphSubGraphs g)

allSubgraphs :: Graph -> [Graph]
allSubgraphs g = subgraphs g ++ concatMap allSubgraphs (subgraphs g)
--------------------------------------------------------------------------------
-- Convenience Functions
--------------------------------------------------------------------------------
-- | Given a graph, construct a function that returns all outgoing edges from a
-- given node. This includes all edges from all subgraphs.
getEdgesFrom :: Graph -> UniqueId -> [Edge]
getEdgesFrom = foldl (\f e -> let nid = boxed portId . edgeSrc $ e
                              in modifyFunction f nid (e:f nid))
                     (const []) . allEdges

-- | Given a graph, construct a function that returns all incoming edges to a
-- given node. This includes all edges from all subgraphs.
getEdgesTo :: Graph -> UniqueId -> [Edge]
getEdgesTo = foldl (\f e -> let nid = boxed portId . edgeDst $ e
                            in modifyFunction f nid (e:f nid))
                   (const []) . allEdges

-- | map two functions over all node and port data
mapGraph :: (NodeType -> NodeType) -> (Boxed Pat -> Boxed Pat) -> Graph -> Graph
mapGraph nf pf g = g {
    graphEdges     = fmap mapEdge (graphEdges g),
    graphNodes     = fmap mapNode (graphNodes g),
    graphSubGraphs = fmap (mapGraph nf pf) (graphSubGraphs g)
  }
  where
    mapPort :: GenType a => Port a -> Port a
    mapPort p | Boxed pat <- pf (Boxed (portData p)),
                Just Refl <- testEquality (portData p) pat = p {portData = pat}
              | Boxed pat <- pf (Boxed (portData p)) =
                  error $ "could not match types:\n" ++ show (typeOf (portData p)) ++
                                                "\n" ++ show (typeOf pat)

    mapEdge :: Edge -> Edge
    mapEdge e | Ports (src,dst) <- edgePorts e =
        e {edgePorts = Ports (mapPort src, mapPort dst)}

    mapNode n = n {nodeType = nf (nodeType n)}

allIds :: Graph -> [UniqueId]
allIds g = map nodeId (allNodes g)
        ++ map edgeId (allEdges g)
        ++ map graphId (allSubgraphs g)
--------------------------------------------------------------------------------
-- Smart Constructors
--------------------------------------------------------------------------------
-- | returns an empty graph with a specified type and id
emptyGraph :: GraphType -> UniqueId ->  Graph
emptyGraph t i = Graph {
    graphType      = t,
    graphId        = i,
    graphEdges     = Map.fromList [],
    graphNodes     = Map.fromList [],
    graphSubGraphs = Map.fromList []
}

-- | initial graph state with an empty graph
initialGraphState :: GraphState
initialGraphState = GraphState {
    getGraphFromState = emptyGraph DGraph 0,
    idCounter         = 1,
    varPool           = [ if i < 0 then [v] else v : show i
                        | i <- [-1,0..] :: [Integer], v <- ['a'..'z'] ]
}

-- | create a simple port without field or cluster data
port :: UniqueId -> Pat t -> Port t
port i p = Port {
    portId      = i,
    portData    = p,
    portField   = Nothing,
    portCluster = Nothing
}
--------------------------------------------------------------------------------
-- Procedural Graph Construction
--------------------------------------------------------------------------------
type GraphCon s = State s

class GraphConState s where
  -- | an initial state must be provided
  initialState :: s
  -- | to access (get/set) the actual graph-state, getter ...
  graphStateGetter :: s -> GraphState
  -- | ... and setter must be provided
  graphStateSetter :: GraphState -> s -> s

instance GraphConState GraphState where
  initialState = initialGraphState
  graphStateGetter = id
  graphStateSetter = const

-- | evaluate a graph construction to the resulting graph
constructGraph :: GraphConState s => GraphCon s a -> Graph
constructGraph c = evalState (c >> getGraph) initialState

-- | evaluate a graph construction, starting from a given graph, hence this is for
-- modifying existing graphs.
--
-- /Note: It is not possible to supply unique variable names in this setting./
-- Calling 'newVarName' will therefore result in an error.
constructGraphFrom :: GraphConState s => Graph -> GraphCon s a -> Graph
constructGraphFrom g c = evalState (c >> getGraph) (graphStateSetter ini initialState)
  where
    ini = GraphState {
      getGraphFromState = g,
      idCounter         = maximum (allIds g) + 1,
      varPool           = error "'constructGraphFrom' can not guarantee unique variable names"
    }

-- | return the current graph state
getGraphState :: GraphConState s => GraphCon s GraphState
getGraphState = graphStateGetter <$> get

-- | return the current graph
getGraph :: GraphConState s => GraphCon s Graph
getGraph = (getGraphFromState . graphStateGetter) <$> get

-- | modify the current graph state
modifyGraphState :: GraphConState s => (GraphState -> GraphState) -> GraphCon s ()
modifyGraphState f = do
  s <- getGraphState
  modify (graphStateSetter (f s))

-- | modify the graph within the current state
modifyGraph :: GraphConState s => (Graph -> Graph) -> GraphCon s ()
modifyGraph f = do
  s <- getGraphState
  let g  = getGraphFromState s
  modify (graphStateSetter (s {getGraphFromState = f g}))
--------------------------------------------------------------------------------
-- IDs and Variables
--------------------------------------------------------------------------------
-- | return the id of the current graph
getGraphId :: GraphConState s => GraphCon s UniqueId
getGraphId = graphId <$> getGraph

-- | return a new unique id
newId :: GraphConState s => GraphCon s UniqueId
newId = do
  s <- getGraphState
  let i = idCounter s
  modify (graphStateSetter (s {idCounter = i + 1}))
  return i

-- | return a new unique variable name
newVarName :: GraphConState s => GraphCon s VarName
newVarName = do
  s <- getGraphState
  let v:vs = varPool s
  modify (graphStateSetter (s {varPool = vs}))
  return v
--------------------------------------------------------------------------------
-- Nodes
--------------------------------------------------------------------------------
-- | create a node, modifying the graph-state
mkNode :: GraphConState s => NodeType -> GraphCon s UniqueId
mkNode n = do
  i   <- newId
  gid <- getGraphId
  let node = Node i n Nothing Nothing
  modifyGraph (\g -> g {graphNodes = Map.insert i node (graphNodes g)})
  return i

-- | lookup and return a node for a given id
getNode :: GraphConState s => UniqueId -> GraphCon s (Maybe Node)
getNode i = do
  g <- getGraph
  return (getNodeFor g i)

-- | substitute all nodes in the graph with the given list of nodes
setNodes :: GraphConState s => [Node] -> GraphCon s ()
setNodes ns = modifyGraph (\g -> g {graphNodes = Map.fromList [(nodeId n, n) | n <- ns]})

-- | modify a node given by an id with a function
modifyNode :: GraphConState s => (Node -> Node) -> UniqueId -> GraphCon s ()
modifyNode f i = modifyGraph modifyNode'
  where
    modifyNode' g =
      case Map.lookup i (graphNodes g) of
        Nothing -> g {graphSubGraphs = Map.map modifyNode' (graphSubGraphs g)}
        Just n  -> g {graphNodes = Map.insert i (f n) (graphNodes g)}

-- | delete a node from the graph
removeNode :: GraphConState s => UniqueId -> GraphCon s ()
removeNode i = modifyGraph removeNode'
  where
    removeNode' g' = g' {
      graphNodes     = Map.delete i (graphNodes g'),
      graphSubGraphs = Map.map removeNode' (graphSubGraphs g')
    }
--------------------------------------------------------------------------------
-- Edges
--------------------------------------------------------------------------------
-- | Create an edge, modifying the graph-state.
-- The default edge direction is 'Down'.
mkEdge :: (GraphConState s, GenType a) => Port a -> Port a -> GraphCon s UniqueId
mkEdge src dst = do
  i <- newId
  let edge = Edge i (Ports (src,dst))
  modifyGraph (\g -> g {graphEdges = Map.insert i edge (graphEdges g)})
  return i

mkEdge_ :: (GraphConState s, GenType a) => Port a -> Port a -> GraphCon s ()
mkEdge_ src dst = mkEdge src dst >> return ()

-- | substitute all edges in the graph with the given list of edges
setEdges :: GraphConState s => [Edge] -> GraphCon s ()
setEdges es = modifyGraph (\g -> g {graphEdges = Map.fromList [(edgeId e, e) | e <- es]})

-- | delete an edge from the graph
removeEdge :: GraphConState s => UniqueId -> GraphCon s ()
removeEdge i = modifyGraph removeEdge'
  where
    removeEdge' g' = g' {
      graphEdges     = Map.delete i (graphEdges g'),
      graphSubGraphs = Map.map removeEdge' (graphSubGraphs g')
    }
--------------------------------------------------------------------------------
-- Subgraphs
--------------------------------------------------------------------------------
-- | create a sub-graph, modifying the graph-state
mkSubgraph :: GraphConState s => GraphCon s UniqueId
mkSubgraph = do
  i <- newId
  let newSubGraph = emptyGraph SubGraph i
  modifyGraph (\g -> g {graphSubGraphs = Map.insert i newSubGraph (graphSubGraphs g)})
  return i

mkSubgraph_ :: GraphConState s => GraphCon s ()
mkSubgraph_ = mkSubgraph >> return ()

-- | perform an action within a sub-graph that is specified by its id.
-- After calling this function, the original graph will be restored and be unaltered
-- except for the specified sub-graph.
withinSubGraph :: GraphConState s => UniqueId -> GraphCon s a -> GraphCon s a
withinSubGraph gid action = do
  sgmap <- graphSubGraphs <$> getGraph
  case Map.lookup gid sgmap of
    Nothing -> action
    Just subGraph -> do
      graph <- getGraph
      modifyGraphState (\s -> s {getGraphFromState = subGraph})
      x <- action
      subGraph' <- getGraph
      let graph' = graph {graphSubGraphs = Map.insert gid subGraph' (graphSubGraphs graph)}
      modifyGraphState (\s -> s {getGraphFromState = graph'})
      return x
--------------------------------------------------------------------------------
