{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
--------------------------------------------------------------------------------
module Flux.Typed.CoreToExpr where

import Flux.Typed.Expr
import Flux.Typed.Var
import Flux.Typed.Literal
import Flux.Typed.Boxed
import Flux.Typed.Type

import System.IO.Unsafe


import Data.List(intercalate)
import Control.Applicative
import Data.Type.Equality
import GHC.TypeLits
import qualified Data.ByteString.UTF8 as UTF8

import qualified CoreSyn as GHC(
  Expr(..), CoreExpr, CoreBind, Bind(..), CoreBndr, CoreAlt, AltCon(..))
import qualified Literal as GHC(Literal(..))
import qualified Var as GHC(Var)
import Var hiding (Var, varName, varType)
import qualified Var as GHC(varName, varType)
import Name hiding (varName)
import DataCon
import DataCon as GHC(DataCon)
import Id hiding (Var)
import Type hiding (Type, Var)
import TyCon
import qualified Type as GHC(Type)
import FastString hiding (mkLitString)
import Outputable
import DynFlags
import Coercion(coercionKind)
import Pair
--------------------------------------------------------------------------------
varToString :: GHC.Var -> String
varToString v
  | isTypeVar v    = show (varUnique v)
  | isExportedId v = n
  | otherwise      = n ++ "_" ++ show (varUnique v)
  where
    n = (Name.occNameString . Name.occName . GHC.varName) v

dataConToString :: DataCon.DataCon -> String
dataConToString = Name.occNameString . Name.occName . DataCon.dataConName
--------------------------------------------------------------------------------
translateLiteral :: GHC.Literal -> Boxed Lit
translateLiteral (GHC.MachChar c)      = mkLitPrimChar c
translateLiteral (GHC.MachStr bs)      = mkLitPrimString (UTF8.toString bs)
translateLiteral (GHC.MachNullAddr)    = mkLitSymbol "NULL"
translateLiteral (GHC.MachInt i)       = mkLitPrimInt i
translateLiteral (GHC.MachInt64 i)     = mkLitPrimInt i
translateLiteral (GHC.MachWord i)      = mkLitPrimInt i
translateLiteral (GHC.MachWord64 i)    = mkLitPrimInt i
translateLiteral (GHC.MachFloat r)     = mkLitPrimFloat (fromRational r)
translateLiteral (GHC.MachDouble r)    = mkLitPrimDouble (fromRational r)
translateLiteral (GHC.MachLabel s _ _) = mkLitSymbol ("<" ++ unpackFS s ++ ">")
translateLiteral (GHC.LitInteger i _)  = mkLitInteger i
--------------------------------------------------------------------------------
-- | Translate GHC Core expressions into lambda terms represented by our data type.
translateCore :: DynFlags -> GHC.CoreExpr -> Either String (Boxed Expr)
translateCore dflags e = (\(Boxed e) -> Boxed (simplifyExpr e)) <$> translateCore' 0 e
  where
    translateVarType :: GHC.Var -> Boxed Type
    translateVarType v = translateType True (expandTypeSynonyms $ GHC.varType v)

    translateType :: Bool -> GHC.Type -> Boxed Type
    translateType forAlls t
      -- | unsafePerformIO (print $ showSDoc dflags (pprType $ tidyTopType t)) `seq` False = undefined
      | Just t'     <- coreView t            = translateType forAlls t'
      | Just (v,t') <- splitForAllTy_maybe t =
          if forAlls
          then mkTypeForAll (showSDoc dflags (ppr $ varUnique v)) (translateType forAlls t')
          else translateType forAlls t'

      | Just (c,[a,b]) <- splitTyConApp_maybe t,
        isTupleTyCon c                       = mkTypeTuple (translateType forAlls a) (translateType forAlls b)
      | Just (a,b)  <- splitFunTy_maybe t,
        isVoidTy a                           = translateType forAlls b
      | Just (a,b)  <- splitFunTy_maybe t    = mkTypeFun (translateType forAlls a) (translateType forAlls b)
      | Just (a,b)  <- splitAppTy_maybe t    = mkTypeApply (translateType forAlls a) (translateType forAlls b)
      | Just v      <- getTyVar_maybe t      = mkTypeVar (showSDoc dflags (ppr $ varUnique v))
      | otherwise                            = mkType (showSDoc dflags (pprType $ expandTypeSynonyms t))

    translateCore' :: Int -> GHC.CoreExpr -> Either String (Boxed Expr)
    translateCore' _ (GHC.Lit l) = Right $ mkLitE (translateLiteral l)
    translateCore' _ (GHC.Var v)
      | isDataConWorkId v = Right $ mkVarE $ mkCons (translateVarType v) (varToString v) Global (dataConSourceArity (idDataCon v))
      | otherwise = Right $ mkVarE $ mkVar (translateVarType v) (varToString v) (if isExportedId v then Global else Local)

    translateCore' i (GHC.App e e2)
      | isDictExpr e2, Right (Boxed (VarE v')) <- translateCore' i e,
        TypeFun _ b <- typeOf v' = Right (Boxed (VarE (changeVarType v' b)))
        where
          isDictExpr (GHC.Var v)             = isDictId v
          isDictExpr (GHC.App (GHC.Var v) _) = isDictId v
          isDictExpr (GHC.Type t)            = isDictLikeTy t
          isDictExpr _                       = False

    translateCore' i (GHC.App e (GHC.Type t)) = do
      Boxed e' <- translateCore' i e
      let t' = translateType True t
      case typeOf e' of
        TypeForAll _ _ -> e' @! t'
        _              -> Left "translateCore: cannot apply non-forall-type to type"

    translateCore' i (GHC.App e (GHC.Var v)) | isVoidTy (GHC.varType v) = translateCore' i e

    translateCore' i (GHC.App e arg) = do
      e'   <- translateCore' i e
      arg' <- translateCore' i arg
      mkUnifyAppE e' arg'

    translateCore' i (GHC.Lam v e)
      | isTypeVar v              = mkTLamE (varToString v) <$> translateCore' i e
      | isVoidTy (GHC.varType v) || isDictId v = translateCore' i e
      | otherwise                = mkLamE (mkVar (translateVarType v) (varToString v) Local) <$> translateCore' i e

    translateCore' i (GHC.Let b e) = mkLetRecE <$> translateBinds i b <*> translateCore' i e

    translateCore' i (GHC.Case e _ _ alts) = do
      e' <- translateCore' i e
      let alts' = if null alts then [] else tail alts ++ [head alts] -- put default case last
      cs <- mapM (translateAlt i) alts'
      mkUnifyCaseE e' cs

    translateCore' i (GHC.Cast e c) = do
      Boxed e' <- translateCore' i e
      changeSubType t1 t2 e'
      where
        p = coercionKind c
        t1 = translateType True (pFst p)
        t2 = translateType True (pSnd p)

    translateCore' i (GHC.Tick _ e)   = translateCore' i e
    translateCore' _ (GHC.Type _)     = error "translateCore': encountered 'Type'"
    translateCore' _ (GHC.Coercion _) = error "translateCore': encountered 'Coercion'"

    translateBinds :: Int -> GHC.CoreBind -> Either String (Boxed Binds)
    translateBinds i (GHC.NonRec v e) = translateBind i v e mkNoBind
    translateBinds _ (GHC.Rec []) = Right mkNoBind
    translateBinds i (GHC.Rec ((v,e):bs)) = do
      bs' <- translateBinds i (GHC.Rec bs)
      translateBind i v e bs'

    translateBind :: Int -> GHC.CoreBndr -> GHC.CoreExpr -> Boxed Binds -> Either String (Boxed Binds)
    translateBind i v e b = do
      e' <- translateCore' i e
      mkBinds (mkVar (boxedType e') (varToString v) Local) e' b

    translateAlt :: Int -> GHC.CoreAlt -> Either String (Boxed Pat, Boxed Expr)
    translateAlt i (altCon, vs, e) = do
      e' <- translateCore' (i+1) e
      case altCon of
        --GHC.DataAlt con
        --  | '#' <- last (dataConToString con),
        --    [v] <- vs -> Right (mkPatVar (mkVar (translateVarType v) (varToString v) Global), e')

        GHC.DataAlt con
          -- | null vs && (dataConToString con == "()") -> Right (mkPatUnit, e')
          | otherwise ->
              let c = mkCons (translateType False (dataConRepType con)) (dataConToString con) Global (dataConSourceArity con)
              in do p <- mkConsPat (mkPatVar c) vs; Right (p,e')

        GHC.LitAlt l -> Right (mkPatLit (translateLiteral l), e')
        GHC.DEFAULT  -> Right (mkPatVar wild, e')

      where
        mkConsPat :: Boxed Pat -> [GHC.Var] -> Either String (Boxed Pat)
        mkConsPat p [] = Right p
        mkConsPat p (v:vs) = do
          if isTypeVar v
            then mkConsPat p vs
            else do
              p' <- mkUnifyPatApp p (mkPatVar (mkVar (translateVarType v) (varToString v) Local))
              mkConsPat p' vs

        wild :: Boxed Var
        wild = let s = ("wild" ++ if i /= 0 then show i else "") in mkVar (mkTypeVar s) s Local
--------------------------------------------------------------------------------
simplifyExpr :: Expr a -> Expr a
simplifyExpr = transformComps . simplifyExprWith simple
  where
    -- simplify the boxing of literals:
    simple (AppE (VarE v) (LitE l))
      | varName v == "C#",
        TypeFun _ b <- typeOf v,
        LitPrimChar c <- l,
        let c' = LitChar c,
        Just Refl <- testEquality b (typeOf c') = LitE c'
      | varName v == "I#",
        TypeFun _ b <- typeOf v,
        LitPrimInt i <- l,
        let i' = LitInt i,
        Just Refl <- testEquality b (typeOf i') = LitE i'
      | varName v == "F#",
        TypeFun _ b <- typeOf v,
        LitPrimFloat f <- l,
        let f' = LitFloat f,
        Just Refl <- testEquality b (typeOf f') = LitE f'
      | varName v == "D#",
        TypeFun _ b <- typeOf v,
        LitPrimDouble d <- l,
        let d' = LitDouble d,
        Just Refl <- testEquality b (typeOf d') = LitE d'
      | varName v == "unpackCString#",
        TypeFun _ b <- typeOf v,
        LitPrimString s <- l,
        let s' = LitString s,
        Just Refl <- testEquality b (typeOf s') = LitE s'

    -- simplify ($) operator:
    simple (AppE (AppE (VarE op) e1) e2)
      | "$" <- varName op,
        TypeFun (TypeFun a1 b1) (TypeFun a2 b2) <- typeOf op,
        Just Refl <- testEquality a1 a2,
        Just Refl <- testEquality b1 b2
        = AppE e1 e2

    simple x = x

    -- transform compositions into normal function applications
    transformComps :: Expr a -> Expr a
    -- (f . g . h) x => f (g (h x))
    transformComps e@(AppE (AppE (AppE (VarE op) e1) e2) x)
      | "." <- varName op,
        TypeFun (TypeFun b c) (TypeFun (TypeFun a b') (TypeFun a' c')) <- typeOf op,
        Just Refl <- testEquality a a',
        Just Refl <- testEquality b b',
        Just Refl <- testEquality c c'
        = AppE e1 (mkComp e2 x)

    --  f . g . h    => \x -> f (g (h x))
    transformComps e@(AppE (AppE (VarE op) e1) e2)
      | "." <- varName op,
        TypeFun (TypeFun b c) (TypeFun (TypeFun a b') (TypeFun a' c')) <- typeOf op,
        Just Refl <- testEquality a a',
        Just Refl <- testEquality b b',
        Just Refl <- testEquality c c',
        let v = Var {varName = "%x%", varType = a, varScope = Local}
        = LamE v (AppE e1 (mkComp e2 (VarE v)))

    -- search the expression in depth for matching patterns
    transformComps (LamE v e)     = LamE v (transformComps e)
    transformComps (AppE e1 e2)   = AppE (transformComps e1) (transformComps e2)
    transformComps (CaseE e cs)   = CaseE (transformComps e) [(p,transformComps e') | (p,e') <- cs]
    transformComps (LetRecE bs e) = LetRecE (mapBindsExpr transformComps bs) (transformComps e)
    transformComps (TLamE t e)    = TLamE t (transformComps e)
    transformComps x              = x

    -- helper function
    mkComp :: (GenType a, GenType b) => Expr (a :-> b) -> Expr a -> Expr b
    mkComp (AppE (AppE (VarE op) e1) e2) x
      | "." <- varName op,
      TypeFun (TypeFun b c) (TypeFun (TypeFun a b') (TypeFun a' c')) <- typeOf op,
      Just Refl <- testEquality a a',
      Just Refl <- testEquality b b',
      Just Refl <- testEquality c c'
      = AppE e1 (mkComp e2 x)
    mkComp e x = AppE e x
--------------------------------------------------------------------------------
