{-# LANGUAGE TypeOperators,
             GADTs,
             RankNTypes,
             KindSignatures,
             StandaloneDeriving,
             MultiParamTypeClasses,
             ScopedTypeVariables,
             DataKinds #-}
--------------------------------------------------------------------------------
module Flux.Typed.Cat((~>)(..), Object(..), exprToCat) where
--------------------------------------------------------------------------------
import Flux.Typed.Literal
import Flux.Typed.Var
import Flux.Typed.Expr
import Flux.Typed.Pattern
import Flux.Typed.Type

import Data.Type.Equality
import Data.Proxy
import Control.Monad
import GHC.TypeLits
--------------------------------------------------------------------------------
data Object :: TypeKind -> * where
    OLit :: (GenType l)            => Lit l -> Object l
    OVar :: (GenType t)            => Var t -> Object t
    OArr :: (GenType a, GenType b) => (a ~> b) -> Object (a :-> b)

deriving instance Show (Object t)

data (~>) :: TypeKind -> TypeKind -> * where
    Id    :: (GenType a)            => a ~> a
    Const :: (GenType a, GenType b) => Object b -> a ~> b
    Arr   :: (GenType a, GenType b) => Var (a :-> b) -> (a ~> b)

    (:>>>) :: (GenType a, GenType b, GenType c)            => (a ~> b) -> (b ~> c) -> (a ~> c)
    (:&&&) :: (GenType a, GenType b, GenType c)            => (a ~> b) -> (a ~> c) -> (a ~> TyTuple b c)
    (:***) :: (GenType a, GenType b, GenType c, GenType d) => (a ~> c) -> (b ~> d) -> (TyTuple a b ~> TyTuple c d)

    Fst :: (GenType a, GenType b)   => TyTuple a b ~> a
    Snd :: (GenType a, GenType b)   => TyTuple a b ~> b
    ExPat :: (GenType a, GenType b) => Pat a -> Pat b -> (a ~> b)

    Case :: (GenType a, GenType b, GenType c) => [(Pat a, TyTuple c a ~> b)] -> (TyTuple c a ~> b)

    Loop    :: (GenType a, GenType b, GenType c) => (TyTuple a c ~> TyTuple b c) -> (a ~> b)
    Curry   :: (GenType a, GenType b, GenType c) => (TyTuple a b ~> c) -> (a ~> (b :-> c))
    Uncurry :: (GenType a, GenType b, GenType c) => (a ~> (b :-> c)) -> (TyTuple a b ~> c)
    Apply   :: (GenType a, GenType b)            => TyTuple (a :-> b) a ~> b

    ForAll :: (KnownSymbol v, GenType a, GenType b) => Proxy v -> (a ~> b) -> (a ~> TyForAll v b)

deriving instance Show (a ~> b)
--------------------------------------------------------------------------------
exprToCat :: Expr (a :-> b) -> (a ~> b)
exprToCat (LamE v e) = optimize (convert (PatVar v) e)

convert :: forall a b. (GenType a, GenType b) => Pat a -> Expr b -> (a ~> b)
convert _ (LitE l) = Const (OLit l)

convert p (LamE v e)   = Curry $ convert (p :* (PatVar v)) e
convert p e@(AppE f x) = convertApp p e

convert p e@(VarE v)
    | isVar v && v `varOccursIn` p = extract p
    | otherwise                    = Const (OVar v)
    where
        extract :: forall c. Pat c -> (c ~> b)
        extract (PatVar u)
            | Just Refl <- testEquality u v, u == v = Id

        extract (l :* r)
            | v `varOccursIn` r = Snd :>>> convert r e
            | v `varOccursIn` l = Fst :>>> convert l e

        extract p@(PatApp l r)
            | Just q <- findPat v r = ExPat p q
            | Just q <- findPat v l = ExPat p q

        findPat :: forall c d. Var c -> Pat d -> Maybe (Pat c)
        findPat v (PatApp l r) = findPat v r `mplus` findPat v l
        findPat v p@(PatVar u)
            | Just Refl <- testEquality u v,
              u == v = Just p
        findPat _ _ = Nothing

convert p (LetRecE ls expr) = Loop (convert p' expr :&&& convertBinds p' ls)
    where
        p' = p :* makeLetPattern ls

        makeLetPattern :: forall c. Binds c -> Pat c
        makeLetPattern NoBind            = PatUnit
        makeLetPattern (Binds x _ binds) = PatVar x :* makeLetPattern binds

        convertBinds :: forall c d. (GenType c, GenType d) => Pat c -> Binds d -> (c ~> d)
        convertBinds _ NoBind             = Const (OVar unit)
        convertBinds p2 (Binds _ e binds) = convert p2 e :&&& convertBinds p2 binds

convert p (CaseE e []) = error "convert: CaseE e []"
convert p (CaseE e cs) = (Id :&&& convert p e) :>>> Case [ (lhs, convert (p :* lhs) rhs) | (lhs, rhs) <- cs ]
convert p (TLamE v e)  = ForAll v (convert p e)
--------------------------------------------------------------------------------
convertApp :: (GenType a, GenType b) => Pat a -> Expr b -> (a ~> b)
convertApp p e = either id id (convertApp' p e)
  where
    convertApp' :: (GenType a, GenType b) => Pat a -> Expr b -> Either (a ~> b) (a ~> b)
    convertApp' p (AppE f@(VarE v) x)
      | isVar v && v `varOccursIn` p = Right $ (convert p f :&&& convert p x) :>>> Apply
      | otherwise                    = Left (convert p x :>>> Arr v)

    convertApp' p (AppE f x) = case convertApp' p f of
      Right f' -> Right ((f' :&&& convert p x) :>>> Apply)
      Left  f' -> Left ((Id :&&& convert p x) :>>> Uncurry f')

    convertApp' p e = Right (convert p e)


--------------------------------------------------------------------------------
optimize :: (a ~> b) -> (a ~> b)
optimize = optimizeWith opt

opt :: (a ~> b) -> (a ~> b)
opt (Fst :&&& Snd)                   = Id
opt (Id :>>> x)                      = x
opt (x :>>> Id)                      = x
opt (Curry (Uncurry x))              = x
opt (Uncurry (Curry x))              = x
opt (ExPat p1 p2)
  | Just Refl <- testEquality p1 p2,
    p1 == p2                         = Id
opt (Const (OVar f) :>>> Arr v)
  | varName v == "uncurry",
    TypeFun (TypeFun a1 (TypeFun b1 c1)) (TypeFun (TypeTuple a2 b2) c2) <- typeOf v,
    TypeFun a3 (TypeFun b3 c3) <- typeOf f,
    Just Refl <- testEquality a1 a2,
    Just Refl <- testEquality a2 a3,
    Just Refl <- testEquality b1 b2,
    Just Refl <- testEquality b2 b3,
    Just Refl <- testEquality c1 c2,
    Just Refl <- testEquality c2 c3  = Const (OArr (Uncurry (Arr f)))

opt x                                = x

optimizeWith :: (forall c d. (c ~> d) -> (c ~> d)) -> (a ~> b) -> (a ~> b)
optimizeWith f x = f $ case x of
  (Const (OArr y)) -> Const (OArr (optimizeWith f y))
  (y :>>> z)       -> optimizeWith f y :>>> optimizeWith f z
  (y :&&& z)       -> optimizeWith f y :&&& optimizeWith f z
  (y :*** z)       -> optimizeWith f y :*** optimizeWith f z
  (Loop y)         -> Loop (optimizeWith f y)
  (Curry y)        -> Curry (optimizeWith f y)
  (Uncurry y)      -> Uncurry (optimizeWith f y)
  Case cs          -> Case [ (p, optimizeWith f x)  | (p,x) <- cs ]
  y                -> y
--------------------------------------------------------------------------------
