Source

Toy C#-ish compiler / CSharpCode.hs

Full commit
{-# LANGUAGE TemplateHaskell, FlexibleInstances #-}
module CSharpCode (
    codeAlgebra
) where

import Prelude hiding (LT, GT, EQ)
import qualified Data.Map as M
import qualified Data.List as L
import Data.Char
import Data.Maybe
import Data.Function
import Control.Lens
import Control.Monad.State
import Control.Applicative
import qualified SymbolTable as ST
import CSharpGram
import CSharpAlgebra
import SSM
import Envs

staticInit :: Int -> String -> Code
staticInit heapSize mainLabel =
    [ LDR SP
    -- Put the stack pointer somewhere safe.  That'll be our heap.
    -- We'll be messing around with this pointer quite a bit, so we put it in a
    -- register for the time being.
    , LDC 1
    , ADD
    , LDS 0
    , STR R7
    -- With the pointer stashed away, we can start initialising the heap.  Set
    -- the top to point to the second element.
    , LDC 2
    , ADD
    -- Set the heap size and second element
    , LDC $ heapSize - 2
    , LDS 0
    , LDR R7
    , SWP
    , LDR R7
    , ADD
    -- Correcting stack pointer to make sure distance is precise.
    , LDC 3
    , ADD
    , STR SP

    -- Well, here we are then.  The heap should be fine now.  From here, we can
    -- branch to the user's code.  We first set the mark pointer, though; again,
    -- this doesn't seem strictly necessary, but I don't want it anywhere near
    -- the bottom of the heap while the program is running.
    , LDRR MP SP

    , BSR mainLabel
    , HALT]

infixr 1 `thenReturn`
thenReturn :: MonadPlus m => Bool -> a -> m a
thenReturn p = if p then return else const mzero

-- Join with underscores.
-- Not the most efficient way.
uJoin :: [String] -> String
uJoin xs = foldl f (head xs) $ tail xs
    where f l r = l ++ "_" ++ r

mkVtableEntries :: M.Map String (a, Maybe (String, Int)) -> [String]
mkVtableEntries = map fmt . L.sortBy cf . M.elems . M.mapMaybeWithKey f
    where
        f k (_, x) = (\(n, i) -> (n, k, i)) <$> x
        cf = compare `on` (^._3)
        fmt :: (String, String, a) -> String
        fmt x = uJoin [x^._1, x^._2, "head"]

-- Seeing as this thing was becoming rather a mess, I decided to split the fold
-- out into multiple steps.  They are as follows:
--  * Build up the global namespaces and register all class members.
--  * Resolve all the expression types.
--  * Put together the code.
--
-- This split means that in the first and second stages all expressions we have
-- are unannotated, while in the third they are all annotated.

data MethodEnv = MethodEnv { _cName :: String
                           , _fName :: String
                           , _locals :: M.Map String Int
                           , _nextVar :: Int
                           , _nextLabel :: Int }
                           deriving Show

$(makeLenses ''MethodEnv)

type ProgramFT = Int -> Code
type ClassFT = Code
type MemberFT = String -> Code
type StatementFT = State MethodEnv Code
type ExpressionFT = State MethodEnv Code
type DeclFT = String

codeAlgebra :: ProgramEnv ->
               CSharpAlgebra ProgramFT
                             ClassFT
                             MemberFT
                             StatementFT
                             ExpressionFT
                             DeclFT

s3error x = error $ x ++ " in stage 3"

-- We don't strictly need the whole program environment (just the member
-- functions of all classes would do), but this is more convenient to pass
-- around.
codeAlgebra penv = ( fProgram
                   , (s3error "not simplified class", fClas)
                   , let e = s3error "not simplified member" in (e, e, e, e, fMembMethT)
                   , (fStatDecl,fStatExpr,fStatVoidExpr,fStatIf,fStatWhile,fStatReturn,fStatBlock,fStatDelete,fStatDeclInit)
                   , let e = s3error "untyped expression" in
                     (s3error "untyped expression",
                     (fExprIntLT,fExprBoolLT,fExprCharLT,fExprNullLT,fExprVarT,fExprOpT,fExprNewT,fExprMemberT,fExprSMemberT,fExprSCallT,fExprMCallT,fExprLCallT,fExprLambdaT)
                     )
                   , (s3error "unstripped decl", fDeclStripped)
                   )
    where
    fProgram :: [ClassFT] -> ProgramFT
    fProgram cs heapSize =
        let label = fromJust $ penv^.mainClass
            label' = label ++ "_main_head"
        in staticInit heapSize label' ++ concat cs

    fClas :: String -> [MemberFT] -> ClassFT
    fClas cn ms =
        -- GHC doesn't let me use ($cn) here.  Weird.
        let code = concatMap (\x -> x cn) ms
            cls = (penv^.classes) M.! cn
            vnms = mkVtableEntries $ cls^.mfuns
            vtable = (LABEL $ uJoin [cn, "vtable"]) : map BRA vnms
            -- We abuse the fact that all types are interchangable at the
            -- hardware level, so we don't actually have to lay out the static
            -- members in any particular order.
            statics = (LABEL $ uJoin [cn, "statics"]) : replicate (ST.size $ cls^.svars) NOP
        in vtable ++ statics ++ code

    fMembMethT :: String -> [DeclFT] -> StatementFT -> MemberFT
    fMembMethT fn ps s cn =
        let args = length ps
            body = mapM_ addLocal ps >> (nextVar += 2) >> s
            menv = MethodEnv cn fn M.empty (-args-1) 0
            code = evalState body menv
        in concat [[LABEL $ uJoin [cn, fn, "head"], LINK 0], code, [UNLINK, RET]]

    fStatDecl :: DeclFT -> StatementFT
    fStatDecl t = [] <$ addLocal t

    fStatDeclInit :: DeclFT -> ExpressionFT -> StatementFT
    fStatDeclInit t expr = do
        addLocal t
        nv <- use nextVar
        evalAppend expr [STL $ nv - 1]

    fStatExpr :: ExpressionFT -> StatementFT
    fStatExpr expr = evalAppend expr [pop]

    fStatVoidExpr :: ExpressionFT -> StatementFT
    fStatVoidExpr expr = expr

    fStatIf :: ExpressionFT -> StatementFT -> StatementFT -> StatementFT
    fStatIf expr s1 s2 = do
        s1' <- ignoreAllButLabels s1
        s2' <- ignoreAllButLabels s2
        l1 <- formatLabel
        l2 <- formatLabel
        evalAppend expr $ concat [[BRF l1], s1', [BRA l2, LABEL l1], s2', [LABEL l2]]

    fStatWhile :: ExpressionFT -> StatementFT -> StatementFT
    fStatWhile expr s1 = do
        s1' <- ignoreAllButLabels s1
        l_act <- formatLabel
        l_cond <- formatLabel
        cond <- expr
        return $ concat [[BRA l_cond, LABEL l_act], s1', [LABEL l_cond], cond, [BRT l_act]]

    fStatReturn :: ExpressionFT -> StatementFT
    fStatReturn expr =
        evalAppend expr [STR R3, UNLINK, RET]

    fStatBlock :: [StatementFT] -> StatementFT
    fStatBlock ss = do
        nv <- use nextVar
        code <- sequence ss
        nv' <- use nextVar
        let spc = nv' - nv
        return $ [AJS spc] ++ concat code ++ [AJS $ -spc]

    fStatDelete :: ExpressionFT -> StatementFT
    fStatDelete expr = evalAppend expr [BSR "__stdlib_deallocate_head"]

    fExprIntLT :: Int -> ExpressionFT
    fExprIntLT i = return [LDC i]

    fExprBoolLT :: Bool -> ExpressionFT
    fExprBoolLT True = return [LDC $ -1]
    fExprBoolLT False = return [LDC 0]

    fExprCharLT :: Char -> ExpressionFT
    fExprCharLT c = return [LDC $ ord c]

    fExprNullLT :: ExpressionFT
    fExprNullLT = return [LDC 0]

    fExprVarT :: String -> ValueOrAddress -> ExpressionFT
    fExprVarT x va = do
        vs <- use locals
        let f Value = LDL
            f Address = LDLA
        return [f va $ vs M.! x]

    -- For the sake of simplicity, we promote everything to int pretty much
    -- immediately.
    fExprOpT :: String -> ExpressionFT -> ExpressionFT -> ExpressionFT
    fExprOpT "=" lhs rhs =
        (\lhs' rhs' -> rhs' ++ [LDS 0] ++ lhs' ++ [STA 0]) <$> lhs <*> rhs
    fExprOpT "&&" lhs rhs = do
        lhs' <- lhs
        rhs' <- rhs
        ifF <- formatLabel
        els <- formatLabel
        return $ lhs' ++ [BRF ifF] ++ rhs' ++ [BRA els, LABEL ifF, LDC 0, LABEL els]
    fExprOpT "||" lhs rhs = do
        lhs' <- lhs
        rhs' <- rhs
        ifT <- formatLabel
        els <- formatLabel
        return $ lhs' ++ [BRT ifT] ++ rhs' ++ [BRA els, LABEL ifT, LDC $ -1, LABEL els]
    fExprOpT op lhs rhs =
        (\lhs' rhs' -> lhs' ++ rhs' ++ [opCodes M.! op]) <$> lhs <*> rhs

    fExprNewT :: Int -> String -> [ExpressionFT] -> ExpressionFT
    fExprNewT sz cn _ =
        return [LDC sz, BSR "__stdlib_allocate_head", Ldc $ uJoin [cn, "vtable"], LDS $ -1, STA 0]

    fExprMemberT :: ExpressionFT -> Int -> ValueOrAddress -> ExpressionFT
    fExprMemberT expr off va = do
        e <- expr
        let lda = va == Value `thenReturn` LDA 0
        return $ e ++ [LDC off, ADD] ++ lda

    fExprSMemberT :: String -> Int -> ValueOrAddress -> ExpressionFT
    fExprSMemberT cn off va =
        let lbl = uJoin [cn, "statics"]
            lda = va == Value `thenReturn` LDA 0
        in return $ [Ldc lbl, LDC off, ADD] ++ lda

    fExprSCallT :: String -> [ExpressionFT] -> Bool -> ExpressionFT
    fExprSCallT "print" xs _ = do
        args <- sequence xs
        let args' = concat . reverse $ args
            traps = replicate (length xs) $ TRAP 0
        return $ args' ++ traps
    fExprSCallT fn args hv = do
        args' <- concat <$> sequence args
        let label = uJoin [fn, "head"]
            argc = length args
        return $ args' ++ [BSR label, AJS $ -argc] ++ (hv `thenReturn` LDR R3)

    fExprMCallT :: ExpressionFT -> Int -> [ExpressionFT] -> Bool -> ExpressionFT
    fExprMCallT expr off args hv = do
        e <- expr
        args' <- sequence args
        let argc = length args
            code = e ++ concat args'
            vCall = [LDS $ -argc, LDA 0, LDC $ 2*off, ADD, JSR, AJS $ -argc-1]
        return $ code ++ vCall ++ (hv `thenReturn` LDR R3)

    fExprLCallT :: ExpressionFT -> [ExpressionFT] -> Bool -> ExpressionFT
    fExprLCallT expr args hv = do
        e <- expr
        args' <- sequence args
        let argc = length args
        return $ e ++ concat args' ++ [LDS $ -argc, LDA 0, JSR, AJS $ -argc-1] ++ (hv `thenReturn` LDR R3)

    fExprLambdaT :: [String] -> [String] -> StatementFT -> ExpressionFT
    fExprLambdaT cs ps s = do
        env <- get
        let size = length cs
            argc = length ps
            body = mapM_ addLocal ps >> (nextVar += 2) >> mapM_ addLocal cs >> s
            menv = MethodEnv (env^.cName) (env^.fName) M.empty (-argc-1) (env^.nextLabel)
            (code, menv') = runState body menv
            lup x = error ("cannot find captured variable " ++ x) `fromMaybe` M.lookup x (env^.locals)
            caps = concat [[LDC 1, ADD, LDL $ lup x, LDS $ -1, STA 0] | x <- cs]
            ucaps = [LDL $ -argc-2] ++ concat [[LDC 1, ADD, LDS 0, LDA 0, STL x] | x <- take size [1..]] ++ [AJS $ -1]
        nextLabel .= menv' ^. nextLabel
        l_lambda <- formatLabel
        l_cont <- formatLabel
        return $ concat [[BRA l_cont, LABEL l_lambda, LINK $ length cs], ucaps, code, [UNLINK, RET],
            [LABEL l_cont, LDC $ size+1, BSR "__stdlib_allocate_head", Ldc l_lambda, LDS $ -1, STA 0],
            caps, [LDC size, SUB]]

    fDeclStripped :: String -> DeclFT
    fDeclStripped = id

    formatLabel :: State MethodEnv String
    formatLabel = do
        env <- get
        nextLabel += 1
        return $ uJoin [env^.cName, env^.fName, show $ env^.nextLabel]

    ignoreAllButLabels :: StatementFT -> StatementFT
    ignoreAllButLabels s = do
        env <- get
        s' <- s
        newLbl <- use nextLabel
        put env
        nextLabel .= newLbl
        return s'

    addLocal :: String -> State MethodEnv ()
    addLocal x = do
        nv <- use nextVar
        locals %= M.insert x nv
        nextVar += 1

    evalAppend e xs = (++xs) <$> e

opCodes :: M.Map String Instr
opCodes
 = M.fromList
     [ ( "+" , ADD )
     , ( "-" , SUB )
     , ( "*" , MUL )
     , ( "/" , DIV )
     , ( "%" , MOD )
     , ( "<=", LE  )
     , ( ">=", GE  )
     , ( "<" , LT  )
     , ( ">" , GT  )
     , ( "==", EQ  )
     , ( "!=", NE  )
     , ( "&&", AND )
     , ( "||", OR  )
     , ( "^" , XOR )
     ]