Source

son_of_blub / part2 / son_of_blub.ml

Full commit
(*
  Obviously there were a couple of issues with inputting instructions for
  the interpreter to run.  First, you were forced to hand code the data
  structure, and second, you were forced to hand-code the variable bindings
  by indicating the frame and offset indices to find the value of the
  variable.  In this installment, we will automatically do the variable
  lookups and get a tiny bit closer to having a 'nice' input format.

  Before that, however, we will get to the point of all this -- running the
  LLVM JIT compiler.  In this installment the compiler will be the
  bare minimum, but the basics will be in place.  We will pass in a function
  to be compiled and get an llvalue back representing the code to be
  executed, and then we will call that function with some arguments and
  get a result.  This is the basic core that will slowly but surely build
  up into something impressive.

  <look at code here>

  Also add a simple let-statement that is like this:
  (let (x y) expr) ==> ((lambda (x) expr) y)

  ((lambda (a x y) (if a x y)) #t 34 37)

  (((lambda (y) (lambda (a x) (if a x y))) 37) #t 34) -- even C can't do that!
  (let (y 37) ((lambda (a x) (if a x y)) #t 34))
*)

module EE = Llvm_executionengine.ExecutionEngine
module GV = Llvm_executionengine.GenericValue

type variable = Variable of string

type ast =
    Ast_lit of sval
  | Ast_ref of variable * int * int  (* variable; frame #; offset # *)
  | Ast_cnd of ast * ast * ast  (* if part; then part; else part *)
  | Ast_app of ast * ast array
  | Ast_abs of lambda

and lambda = {
  lam_ast : ast;
  lam_params : variable array;
  lam_lltype : Llvm.lltype option;  (* this is a hack because we don't have a
				       type system *)
}

and sval =
    Sclosure of sclosure
  | Sllvm of Llvm.lltype * GV.t
  | Sunbound

and sclosure = {
  close_env : environment;
  close_lam : lambda;
  (* we could either have this as part of the lambda (compiled without a
     local environment) or as part of a closure (compiled in an environment) *)
  close_jitcode : (Llvm.lltype * Llvm.llvalue) Lazy.t  (* return value; code *)
}

and environment_frame = {
  env_frame_vars : variable array;  (* mostly for debugging *)
  env_frame_vals : sval array
}

and environment = environment_frame list

(* LLVM SECTION *)

exception Jit_failed
open Llvm


let llvm_val_of_int t x = Sllvm (t, GV.of_int t x) ;;
let make_bool value = llvm_val_of_int Llvm.i1_type (if value then 1 else 0) ;;
let make_int value = llvm_val_of_int Llvm.i64_type value ;;

let llvalue_of_gv t v =
  match t with
      i1_type -> const_int t (GV.as_int v)
    | i64_type -> const_of_int64 t (GV.as_int64 v) true (* WTF *)
    | _ -> Format.printf "FAIL\n%!"; raise Jit_failed
;;
 

let cur_module = Llvm.create_module "helloworld" ;;
let jit = EE.create (ModuleProvider.create cur_module) ;;



let compile_fn (env:environment) (lambda:lambda) = 
  let fn_type = match lambda.lam_lltype with
      Some t -> t | None -> raise Jit_failed
  in
  let rettype = return_type fn_type in
  let argtypes = param_types fn_type in

  let cur_fn = define_function "lambda" fn_type cur_module in
  let builder = builder_at_end (entry_block cur_fn) in

  let rec gen_llvm (builder:llbuilder) (ast:ast) : (llbuilder * llvalue) =
    match ast with
	Ast_lit (Sllvm (t, v)) -> 
	  let lit_val = llvalue_of_gv t v in
	  (builder, lit_val)
      | Ast_ref (var, 0, offset) ->
	  assert (ExtArray.Array.mem var lambda.lam_params);
	  (builder, param cur_fn offset)
      | Ast_ref (var, frame, offset) -> 
	  (* frame 0 is found in the fn params, not in the environment *)
	  let v = match (List.nth env (frame-1)).env_frame_vals.(offset) with
	      Sllvm (t, v) -> llvalue_of_gv t v 
	    | _ -> raise Jit_failed
	  in
	  (builder, v)
      | Ast_cnd (pred, cons, alt) ->
	  let builder, pred_val = gen_llvm builder pred in
	  let test = build_icmp Icmp.Ne pred_val (const_int i1_type 0) 
	    "test" builder in
	  
	  let cons_block = append_block "true_branch" cur_fn in
	  let alt_block = append_block "false_branch" cur_fn in
	  let join_block = append_block "join_branches" cur_fn in
	  ignore (build_cond_br test cons_block alt_block builder);
      
	  let cons_builder = builder_at_end cons_block in
	  let cons_builder, v1 = gen_llvm cons_builder cons in
	  ignore (build_br join_block cons_builder);
	
	  let alt_builder = builder_at_end alt_block in
	  let alt_builder, v2 = gen_llvm alt_builder alt in
	  ignore (build_br join_block alt_builder);

	  let join_builder = builder_at_end join_block in
	  let return = build_phi [(v1, cons_block); (v2, alt_block)] 
	    "phi" join_builder in
	  (join_builder, return)
    in

    let builder, retval = gen_llvm builder lambda.lam_ast in
    ignore (build_ret retval builder);
    (rettype, cur_fn)

  
;;


(* INTERPRETER SECTION *)

let rec eval env = function 
    Ast_lit lit -> lit
  | Ast_ref (_, frame, offset) -> (List.nth env frame).env_frame_vals.(offset)
  | Ast_cnd (pred, cons, alt) -> begin
      (* we will just treat it as an i1_type because in general we won't
	 be converting the Llvm values into svals *)
      (* also note that this is Scheme-style semantics, where everything
	 other than #f is considered #t *)
      match eval env pred with
	  Sllvm (t, v) when t = Llvm.i1_type && (GV.as_int v = 0) ->
	    eval env alt
	| _ -> eval env cons
    end
  | Ast_app (fn, args) -> apply (eval env fn) (Array.map (eval env) args)    
  | Ast_abs lambda ->
      Sclosure { close_env = env; 
		 close_lam = lambda;
		 close_jitcode = lazy (compile_fn env lambda) }
and apply fn args =
  match fn with
      Sclosure close -> begin
	try
	  let rettype, fn = Lazy.force close.close_jitcode in
	  let args = Array.map (fun (Sllvm (_, v)) -> v) args in
	  let result = EE.run_function fn args jit in
	  Sllvm (rettype, result)
	with Jit_failed ->
	  let frame = { env_frame_vars = close.close_lam.lam_params;
			env_frame_vals = args } in
	  let env' = frame :: close.close_env in
	  eval env' close.close_lam.lam_ast
      end
;;


(* TEST SECTION *)

let mkvar name = Variable name ;;
let lit gv = Ast_lit gv ;;
let lit_bool x = lit (make_bool x) ;;
let lit_int x = lit (make_int x) ;;
let mkref var frame offset = Ast_ref (mkvar var, frame, offset) ;;
let mkcnd pred cons alt = Ast_cnd (pred, cons, alt) ;;
let mkabs params body lltype = 
  Ast_abs { lam_params = Array.of_list (List.map mkvar params);
	    lam_ast = body;
	    lam_lltype = lltype } ;;
let mkapp fn args = Ast_app (fn, Array.of_list args) ;;


let () =
  let lambda = mkabs ["a"; "x"; "y"] 
    (mkcnd (mkref "a" 0 0) (mkref "x" 0 1) (mkref "y" 0 2))
    (Some (function_type i64_type [| i1_type; i64_type; i64_type |]))
  in
  let expr = mkapp lambda [lit_bool true; lit_int 34; lit_int 47] in
  let result = eval [] expr in

  match result with
      Sllvm (t, v) when t = Llvm.i64_type -> assert (GV.as_int v = 34)
    | _ -> assert false
;;

let () =
  let lambda1 = mkabs ["a"; "x"] 
    (mkcnd (mkref "a" 0 0) (mkref "x" 0 1) (mkref "y" 1 0))
    (Some (function_type i64_type [| i1_type; i64_type |]))
  in
  let lambda2 = mkabs ["y"] lambda1 None in (* can't type this _yet_ *)
  let expr = mkapp lambda2 [lit_int 47] in
  let expr = mkapp expr [lit_bool true; lit_int 34] in
  let result = eval [] expr in

  match result with
      Sllvm (t, v) when t = Llvm.i64_type -> assert (GV.as_int v = 34)
    | _ -> assert false
;;

let () = Llvm.dump_module cur_module ;;