Toy C#-ish compiler / CSharpTypeCheck.hs

{-# LANGUAGE TemplateHaskell, FlexibleInstances #-}
module CSharpTypeCheck (
    typeAlgebra
) where

import qualified Data.Map as M
import qualified Data.List as L
import Data.Maybe
import Control.Lens
import Control.Monad.State
import Control.Applicative
import qualified SymbolTable as ST
import CSharpLex
import CSharpGram
import CSharpAlgebra
import Envs

data MethodEnv = MethodEnv { _cName :: String
                           , _fName :: String
                           , _fType :: Type
                           , _locals :: M.Map String Type }
                           deriving Show

$(makeLenses ''MethodEnv)

type ProgramFT = Program
type ClassFT = Class
type MemberFT = String -> Maybe Member
type StatementFT = State MethodEnv Stat
type ExpressionFT = ValueOrAddress -> State MethodEnv (Expr, Type)
type DeclFT = Decl

s2error x = error $ x ++ " in stage 2"
typeAlgebra :: ProgramEnv ->
               CSharpAlgebra ProgramFT
                             ClassFT
                             MemberFT
                             StatementFT
                             ExpressionFT
                             DeclFT

typeAlgebra penv = ( Program
                   , (s2error "not simplified class", fClas)
                   , (fMembDecl,fMembStaticDecl,fMembMeth,fMembStaticMeth, s2error "simplified statement")
                   , (fStatDecl,fStatExpr,s2error "void expr statement",fStatIf,fStatWhile,fStatReturn,fStatBlock,fStatDelete,fStatDeclInit)
                   , ((fExprCon,fExprVar,fExprOp,fExprCall,fExprNew,fExprMember,fExprSMember,fExprSCall,fExprMCall,fExprLCall,fExprLambda),
                      s2error "typed expression")
                   , (Decl, s2error "simplified declaration")
                   )
    where
    fClas :: String -> [MemberFT] -> ClassFT
    -- Yay for ugly hacks!
    fClas cn ms = ClassS cn [ x' | x <- ms, Just x' <- [x cn]]

    fMembDecl :: DeclFT -> MemberFT
    fMembDecl _ _ = Nothing

    fMembStaticDecl :: DeclFT -> MemberFT
    fMembStaticDecl _ _ = Nothing

    fMembMeth :: Type -> Token -> [DeclFT] -> StatementFT -> MemberFT
    fMembMeth rt nm ps s cn =
        let thisArg = Decl (objT cn) (Id "this")
        in fMembStaticMeth rt nm (thisArg:ps) s cn

    fFuncCommon :: Type -> String -> String -> [DeclFT] -> M.Map String Type -> StatementFT -> (String -> [String] -> Stat -> a) -> a
    fFuncCommon rt cn fn ps m s a =
        let args = length ps
            body = mapM_ addLocal ps >> s
            dtot (Decl t (Id x)) = (t, x)
            (ats, ans) = unzip . map dtot $ ps
            t = TypeFunc rt ats
            menv = MethodEnv cn fn t m
            (stat, menv') = runState body menv
        -- We need to evaluate the environment to ensure that the types are
        -- correct; otherwise, several checks can be ignored as they do nothing
        -- but raise an error.
        in seq menv' $ a fn ans stat

    fMembStaticMeth :: Type -> Token -> [DeclFT] -> StatementFT -> MemberFT
    fMembStaticMeth rt (Id fn) ps s cn =
        Just $ fFuncCommon rt cn fn ps M.empty s MemberMT

    fStatDecl :: DeclFT -> StatementFT
    fStatDecl (Decl TypeVar (Id s)) =
        error $ "variable " ++ s ++ " declared var but not initialised"
    fStatDecl d@(Decl t (Id s)) =
        StatDecl (DeclT s) <$ addLocal d

    fStatDeclInit :: DeclFT -> ExpressionFT -> StatementFT
    fStatDeclInit (Decl TypeVar (Id s)) expr = do
        (e, t) <- expr Value
        StatDeclInit (DeclT s) e <$ addLocal (Decl t (Id s))
    fStatDeclInit d@(Decl dt (Id s)) expr = do
        (e, t) <- expr Value
        match dt t
        StatDeclInit (DeclT s) e <$ addLocal d

    fStatExpr :: ExpressionFT -> StatementFT
    fStatExpr expr = f <$> expr Value
        where f (e, t) = let s | t /= TypeVoid = StatExpr
                               | otherwise = StatVoidExpr
                         in s e

    fStatIf :: ExpressionFT -> StatementFT -> StatementFT -> StatementFT
    fStatIf expr s1 s2 = do
        (e, t) <- expr Value
        match (primT "bool") t
        StatIf e <$> ignore s1 <*> ignore s2

    fStatWhile :: ExpressionFT -> StatementFT -> StatementFT
    fStatWhile expr s1 = do
        (e, t) <- expr Value
        match (primT "bool") t
        StatWhile e <$> ignore s1

    fStatReturn :: ExpressionFT -> StatementFT
    fStatReturn expr = do
        (e, t) <- expr Value
        TypeFunc rt _ <- use fType
        match rt t
        return $ StatReturn e

    fStatBlock :: [StatementFT] -> StatementFT
    fStatBlock ss = StatBlock <$> sequence ss

    fStatDelete :: ExpressionFT -> StatementFT
    fStatDelete expr = do
        (e, t) <- expr Value
        match t $ primT "null"
        return $ StatDelete e

    fExprCon :: Token -> ExpressionFT
    fExprCon _ Address = addrError "literal"
    fExprCon c Value = return $ case c of
        ConstInt n -> (ExprIntLT n, primT "int")
        ConstBool b -> (ExprBoolLT b, primT "bool")
        ConstChar c -> (ExprCharLT c, primT "char")
        KeyNull     -> (ExprNullLT, primT "null")

    fExprVar :: Token -> ExpressionFT
    fExprVar (Id x) va = do
        ls <- use locals
        let t = error ("unknown local " ++ x) `fromMaybe` M.lookup x ls
        return (ExprVarT x va, t)

    -- For the sake of simplicity, we promote everything to int pretty much
    -- immediately.
    fExprOp :: Token -> ExpressionFT -> ExpressionFT -> ExpressionFT
    fExprOp _ _ _ Address = addrError "binary operator expression"
    fExprOp (Operator "=") lhs rhs Value = do
        (lhs', lhsT) <- lhs Address
        (rhs', rhsT) <- rhs Value
        match lhsT rhsT
        return (ExprOperT "=" lhs' rhs', lhsT)
    fExprOp (Operator op) lhs rhs Value = do
        (lhs', lhsT) <- lhs Value
        (rhs', rhsT) <- rhs Value
        return (ExprOperT op lhs' rhs', opTypes op lhsT rhsT)

    fExprCall :: Token -> [ExpressionFT] -> ExpressionFT
    fExprCall _ _ Address = addrError "static method call"
    fExprCall (Id "print") xs Value = do
        (aes, ats) <- unzip <$> mapM ($Value) xs
        assert (TypeVoid `notElem` ats) $ "all arguments of print must not be void, got " ++ show ats
        return (ExprSCallT "print" aes False, TypeVoid)
    fExprCall fn xs va = do
        cn <- use cName
        fExprSCall (TypeObj (Id cn)) fn xs va

    fExprNew :: Type -> [ExpressionFT] -> ExpressionFT
    fExprNew _ _ Address = addrError "new expression"
    fExprNew t@(TypeObj (Id cn)) xs Value = do
        assert (null xs) "parametrised new not implemented yet"
        let clt = error ("Error in new expression: no such class " ++ cn ++ ".")
                    `fromMaybe` M.lookup cn (penv^.classes)
            cls = ST.size $ clt^.mvars
        return (ExprNewT cls cn [], t)

    fExprMember :: ExpressionFT -> Token -> ExpressionFT
    fExprMember expr (Id mv) va = do
        (e, TypeObj (Id cn)) <- expr Value
        let ncls = penv^.classes
            -- Seeing as we have an instance of this class, apparently the class
            -- exists.
            mvs = ncls M.! cn ^.mvars
            (tp, off) = error ("member " ++ cn ++ "::" ++ mv ++ " not found") `fromMaybe` ST.lookup mv mvs
        return (ExprMemberT e off va, tp)

    fExprSMember :: Type -> Token -> ExpressionFT
    fExprSMember (TypeObj (Id cn)) (Id s) va =
        let cls = error ("class " ++ cn ++ " not found") `fromMaybe` M.lookup cn (penv^.classes)
            (tp, off) = error ("static member " ++ cn ++ "::" ++ s ++ " not found") `fromMaybe` ST.lookup s (cls^.svars)
        in return (ExprSMemberT cn off va, tp)
    fExprSMember _ _ _ = error "static member of non-class type accessed"

    fExprSCall :: Type -> Token -> [ExpressionFT] -> ExpressionFT
    fExprSCall (TypeObj (Id cn)) (Id fn) xs Value = do
        (aes, ats) <- unzip <$> mapM ($Value) xs
        let cls = error ("no such class " ++ cn) `fromMaybe` M.lookup cn (penv^.classes)
            (TypeFunc rt ats', mfData) = funcError cn fn "not found" `fromMaybe` M.lookup fn (cls^.mfuns)
        return () `fromMaybe` (funcError cn fn "is not static" <$ mfData)
        matchl ats' ats
        return (ExprSCallT (cn ++ "_" ++ fn) aes (rt /= TypeVoid), rt)
    fExprSCall _ _ _ Address = addrError "static member call"
    fExprSCall _ _ _ _ = error "static member function of non-class type called"

    fExprMCall :: ExpressionFT -> Token -> [ExpressionFT] -> ExpressionFT
    fExprMCall expr (Id fn) args Value = do
        (e, t@(TypeObj (Id cn))) <- expr Value
        (aes, ats) <- unzip <$> mapM ($Value) args
        let ncls = penv^.classes
            mfs = ncls M.! cn ^.mfuns
            (TypeFunc rt ats', fd) = funcError cn fn "not found" `fromMaybe` M.lookup fn mfs
            (_, ofs) = funcError cn fn "is static" `fromMaybe` fd
        matchl ats' (t:ats)
        return (ExprMCallT e ofs aes (rt /= TypeVoid), rt)
    fExprMCall _ _ _ Address = addrError "non-static member call"

    fExprLCall :: ExpressionFT -> [ExpressionFT] -> ExpressionFT
    fExprLCall _ _ Address = addrError "lambda call"
    fExprLCall expr args Value = do
        (e, TypeFunc rt ats') <- expr Value
        (aes, ats) <- unzip <$> mapM ($Value) args
        matchl ats' ats
        return (ExprLCallT e aes $ rt /= TypeVoid, rt)

    fExprLambda :: [Token] -> [DeclFT] -> Type -> StatementFT -> ExpressionFT
    fExprLambda _ _ _ _ Address = addrError "lambda expression"
    fExprLambda cs ps rt s Value = do
        cn <- use cName
        fn <- use fName
        ls <- use locals
        let t = TypeFunc rt $ map (\(Decl x _) -> x) ps
            f c (Id x) = M.insert x (error ("undefined local " ++ x) `fromMaybe` M.lookup x ls) c
            m = foldl f M.empty cs
            cs' = map (\(Id x) -> x) cs
        return (fFuncCommon rt cn fn ps m s (const $ ExprLambdaT cs'), t)

    convertsTo (TypePrim (StdType "null")) (TypeObj _) = True
    convertsTo (TypeObj (Id x)) (TypeObj (Id y)) = y `elem` (penv^.classes) M.! x ^.parents
    convertsTo (TypeFunc xrt xs) (TypeFunc yrt ys) = xrt `convertsTo` yrt && ys `convertsTol` xs
    convertsTo xs ys = xs == ys

    assert p e | p = return ()
               | otherwise = error e

    assertValidType (TypeFunc rt ps) = assertValidType rt >> mapM_ assertValidType ps
    assertValidType (TypeObj (Id x)) = assert (M.member x $ penv^.classes) $ "No such class " ++ x
    assertValidType _ = return ()

    match x y = assert (y `convertsTo` x) $ "type mismatch: " ++ show x ++ " -/-> " ++ show y
    matchl xs ys = assert (ys `convertsTol` xs) $ "type mismatch in arguments: " ++ show xs ++ " -/-> " ++ show ys
    convertsTol xs ys = length xs == length ys && and (zipWith convertsTo xs ys)

    funcError cn mf msg = error $ "function " ++ cn ++ "::" ++ mf ++ " " ++ msg
    addrError msg = error $ "trying to take address of " ++ msg

    ignore :: StatementFT -> StatementFT
    ignore s = do
        env <- get
        result <- s
        put env
        return result

    addLocal :: DeclFT -> State MethodEnv ()
    addLocal (Decl t (Id x)) =
        assertValidType t >> locals %= M.insert x t

    primT = TypePrim . StdType
    objT = TypeObj . Id

    opTypes "==" l r = if l `convertsTo` r || r `convertsTo` l then primT "bool" else error (show l ++ " <-/-> " ++ show r)
    opTypes "!=" l r = if l `convertsTo` r || r `convertsTo` l then primT "bool" else error (show l ++ " <-/-> " ++ show r)
    opTypes o l r = error (o ++ " does not accept " ++ show l ++ " and " ++ show r)
                    `fromMaybe` M.lookup (o, l, r) m
        where f (op, lt, rt, res) = ((op, primT lt, primT rt), primT res)
              m = M.fromList $ map f xs
              xs = [ ("&&", "bool", "bool", "bool")
                   , ("||", "bool", "bool", "bool")
                   , ("+", "int", "int", "int")
                   , ("-", "int", "int", "int")
                   , ("*", "int", "int", "int")
                   , ("/", "int", "int", "int")
                   , ("%", "int", "int", "int")
                   , ("<=", "int", "int", "bool")
                   , ("<", "int", "int", "bool")
                   , (">=", "int", "int", "bool")
                   , (">", "int", "int", "bool")
                   , ("==", "int", "int", "bool")
                   , ("!=", "int", "int", "bool")
                   ]
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.