{-# LANGUAGE GADTs,
             ExistentialQuantification,
             MultiParamTypeClasses,
             KindSignatures,
             TypeOperators,
             RankNTypes,
             DataKinds #-}
--------------------------------------------------------------------------------
module Flux.Typed.Expr(Expr(..), Binds(..), mapBindsExpr, simplifyExprWith) where
--------------------------------------------------------------------------------
import Flux.Typed.Literal
import Flux.Typed.Type
import Flux.Typed.Var
import Flux.Typed.Pattern

import Data.List(intersperse)
import Data.Type.Equality
import GHC.TypeLits
import Data.Proxy
--------------------------------------------------------------------------------
-- | A (statically typed) data type representing a typed lambda calculus.
-- Available constructs are:
-- variables, literals, lambda-expressions, function application,
-- case-of-expressions, (recursive) let-expressions, as well as
-- lambda-expressions with types as argument.
data Expr :: TypeKind -> * where
  VarE    :: (GenType t)                => Var t -> Expr t
  LitE    :: (GenType t)                => Lit t -> Expr t
  LamE    :: (GenType a, GenType b)     => Var a -> Expr b -> Expr (a :-> b)
  AppE    :: (GenType a, GenType b)     => Expr (a :-> b) -> Expr a -> Expr b
  CaseE   :: (GenType a, GenType b)     => Expr a -> [(Pat a, Expr b)] -> Expr b
  LetRecE :: (GenType a, GenType b)     => Binds a -> Expr b -> Expr b
  TLamE   :: (KnownSymbol v, GenType t) => Proxy (v :: Symbol) -> Expr t -> Expr (TyForAll v t)
--------------------------------------------------------------------------------
instance TestEquality Expr where testEquality e1 e2 = testEquality (typeOf e1) (typeOf e2)
--------------------------------------------------------------------------------
instance Eq (Expr t) where
  VarE v1       == VarE v2             = v1 == v2
  LitE l1       == LitE l2             = l1 == l2
  LamE v1 e1    == LamE v2 e2          = v1 == v2 && e1 == e2
  TLamE v1 e1   == TLamE v2 e2         = v1 == v2 && e1 == e2
  AppE e1 e2    == AppE e1' e2'
    | Just Refl <- testEquality e2 e2' = e1 == e1' && e2 == e2'
  CaseE e1 cs1  == CaseE e2 cs2
    | Just Refl <- testEquality e1 e2  = e1 == e2 && cs1 == cs2
  LetRecE b1 e1 == LetRecE b2 e2
    | Just Refl <- testEquality b1 b2  = b1 == b2 && e1 == e2
--------------------------------------------------------------------------------
instance Typed Expr where
  typeOf (VarE v)             = typeOf v
  typeOf (LitE l)             = typeOf l
  typeOf (LamE v e)           = TypeFun (typeOf v) (typeOf e)
  typeOf (AppE e _)
    | TypeFun _ b <- typeOf e = b
  typeOf (CaseE _ [])         = error "typeOf: empty list of cases for CaseE"
  typeOf (CaseE _ cs)         = typeOf (snd $ head cs)
  typeOf (LetRecE _ e)        = typeOf e
  typeOf (TLamE v e)          = TypeForAll v (typeOf e)
--------------------------------------------------------------------------------
instance Show (Expr t) where
  ------------------------------------------------------------------------------
  showsPrec _ (VarE v) = shows v
  ------------------------------------------------------------------------------
  showsPrec _ (LitE l) = shows l
  ------------------------------------------------------------------------------
  showsPrec i (AppE func arg) = showParen (i > 9) $
      showsPrec 9 func . showChar ' ' . showsPrec 10 arg
  ------------------------------------------------------------------------------
  showsPrec i (CaseE expr binds) = showParen (i > 0) $
    showString "case " . shows expr . showString " of { " . showCases binds . showString " }"
    where
      showCases :: [(Pat t1, Expr t2)] -> ShowS
      showCases = foldr (.) id . intersperse (showString "; ") . map showCase
        where
          showCase (p,e) = shows p . showString " -> " . shows e
  ------------------------------------------------------------------------------
  showsPrec i (LetRecE binds expr) = showParen (i > 0) $
    showString "let " . showLetBinds binds . showString " in { " . shows expr . showString " } "
    where
      showLetBinds :: Binds a -> ShowS
      showLetBinds NoBind = id
      showLetBinds (Binds v e b) = shows v . showString " = " . shows e . showString "; " . showLetBinds b
  ------------------------------------------------------------------------------
  showsPrec i (LamE v expr) = showParen (i > 0) $ showChar '\\' . shows v . showVars expr
    where
      showVars :: forall a. Expr a -> ShowS
      showVars (LamE v' e) = showChar ' ' . shows v' . showVars e
      showVars e           = showChar '.' . shows e
  ------------------------------------------------------------------------------
  showsPrec i (TLamE v expr) = showParen (i > 0) $ showString "/\\" . showString (symbolVal v) . showVars expr
    where
      showVars :: forall a. Expr a -> ShowS
      showVars (TLamE v' e) = showChar ' ' . showString (symbolVal v') . showVars e
      showVars e            = showChar '.' . shows e
--------------------------------------------------------------------------------
-- | Binds represent a collection of terms of the form 'variable = expression'.
-- When inductively constructing binds, the resulting type is a nested tuple,
-- e.g. @(a,(b,(c,(d,()))))@, representing a list containing the type of each single bind.
data Binds :: TypeKind -> * where
  NoBind :: Binds TyUnit
  Binds  :: (GenType a, GenType b) => Var a -> Expr a -> Binds b -> Binds (TyTuple a b)
--------------------------------------------------------------------------------
instance TestEquality Binds where
  testEquality NoBind NoBind = Just Refl
  testEquality (Binds v1 e1 b1) (Binds v2 e2 b2) = do
    Refl <- testEquality v1 v2
    Refl <- testEquality e1 e2
    Refl <- testEquality b1 b2
    return Refl
--------------------------------------------------------------------------------
instance Eq (Binds t) where
  NoBind == NoBind = True
  Binds v1 e1 b1 == Binds v2 e2 b2 = v1 == v2 && e1 == e2 && b1 == b2
--------------------------------------------------------------------------------
instance Typed Binds where
  typeOf NoBind         = TypeUnit
  typeOf (Binds v _ bs) = TypeTuple (typeOf v) (typeOf bs)
--------------------------------------------------------------------------------
mapBindsExpr :: (forall a.Expr a -> Expr a) -> Binds b -> Binds b
mapBindsExpr f NoBind         = NoBind
mapBindsExpr f (Binds v e bs) = Binds v (f e) (mapBindsExpr f bs)
--------------------------------------------------------------------------------
simplifyExprWith :: (forall a. Expr a -> Expr a) -> Expr a -> Expr a
simplifyExprWith f x = f $ case x of
  LamE v e -> LamE v (simplifyExprWith f e)
  AppE e1 e2 -> AppE (simplifyExprWith f e1) (simplifyExprWith f e2)
  CaseE e cs -> CaseE (simplifyExprWith f e) [(p, simplifyExprWith f rhs) | (p,rhs) <- cs]
  LetRecE bs e ->  LetRecE (simplifyBinds bs) (simplifyExprWith f e)
  TLamE t e -> TLamE t (simplifyExprWith f e)
  y -> y
  where
    simplifyBinds :: forall a. Binds a -> Binds a
    simplifyBinds NoBind = NoBind
    simplifyBinds (Binds v e bs) = Binds v (simplifyExprWith f e) (simplifyBinds bs)
--------------------------------------------------------------------------------