Source

blub / blub_llvm.ml

Full commit
open Llvm
open Format
open Blub_types
open Blub_typecheck
open Blub_common
open Blub_environment
open Llvm_executionengine

exception Jit_failed of string

let label =
  let counter = ref 0 in
  fun prefix ->
    incr counter;
    prefix ^ "." ^ (string_of_int !counter)
;;

(* get the LLVM jit-compiler up and running *)
let m = create_module "jithelloworld";;
let print_int = define_global "print_int" (const_stringz "%d") m;;
let print_double = define_global "print_double" (const_stringz "%g") m;;
let print_newline = define_global "print_int" (const_stringz "%d") m;;
let putchar = declare_function "putchar" (function_type i32_type [| i32_type |]) m;;
let llprintf = declare_function "printf" (var_arg_function_type i32_type [| pointer_type i8_type |]) m;;


let rec lltype_of_type = function
    Bool -> i1_type
  | Int -> i32_type
  | Float64 -> double_type
  | Symbol -> array_type i8_type 0 
  | Fun (params, rettype) ->
      let args = Array.map lltype_of_type params in
      pointer_type (function_type (lltype_of_type rettype) args)
  | Typevar (r, _) -> (
      match !r with
	  Some t -> lltype_of_type t
	| None -> 
	    raise (Jit_failed (sprintf "Unbound type variable"))
    )
  | x ->
      printf "Unhandled type %a\n%!" pp_type x;
      assert false
;;   

let rec function_sig = function
    Fun (params, rettype) ->
      let args = Array.map lltype_of_type params in
      function_type (lltype_of_type rettype) args
  | _ ->
      assert false (* can only get the function sig of a function *)
;;

(* Convert a GenericValue.t to an sval, given the Blub type *)
let rec construct_sval rettype value = 
  match rettype with
      Bool -> 
	let i = GenericValue.as_int value in
	if i=0 then Sfalse else Strue
    | Int ->
	let i = GenericValue.as_int value in
	Sint i
    | Float64 ->
	let x = GenericValue.as_float (lltype_of_type rettype) value in
	Sfloat x
    | Fun _ ->
	(* let i = GenericValue.as_pointer value in
	FIXME this is not supported again for the moment
	Sjitfn (t, i) *)
	assert false
    | Typevar (r, _) -> (
	match !r with
	    Some t -> construct_sval t value
	  | None -> 
	      printf "Unbound type variable\n%!";
	      assert false
      )
    | _ -> 
	raise (Jit_failed "Couldn't convert return value")
	  (* TODO no others supported yet *)
;;




(* Convert an sval to a GenericValue.t (used in the JIT) *)
let gval_of_sval = function
    Sfalse -> GenericValue.of_int (lltype_of_type Bool) 0
  | Sint x -> GenericValue.of_int (lltype_of_type Int) x
  | Sfloat x -> GenericValue.of_float (lltype_of_type Float64) x
  | _ -> assert false (* TODO no others supported yet *)
;;


let convert_string s =
  let foo = Array.create (String.length s) (undef i8_type) in
  for i = 0 to (String.length s)-1 do
    foo.(i) <- (const_int i8_type (int_of_char s.[i]))
  done;
  foo
;;


(* Convert an sval to an llvalue *)
let llvalue_of_sval type_ = 
  let lltype = lltype_of_type type_ in
  function
    Sfalse -> const_int lltype 0
  | Sint x -> const_int lltype x
  | Sfloat x -> const_float lltype x
  | Ssymbol s -> const_array i8_type (convert_string s)

(*
  | Sjitclosure Jitcode (code, t) -> 
      (* TODO unify t and type_ *)
      code
  | Sjitclosure Dynload (_, fn_name, t) -> begin
      (* TODO unify t and type_ *)
      match lookup_function fn_name m with
	  Some fn -> fn
	| None -> assert false
    end
*)

  | s ->
      let msg = 
	fprintf str_formatter "Converting %a to llvalue not supported yet" 
	  pp_sval s;
	flush_str_formatter ()
      in
      raise (Jit_failed msg)
;;

let get_type typeinfo (ast:ast) =
  simplify (PMap.find ast typeinfo)
;;


(* NOTE env probably needs to be passed through gen_llvm *)
let jit_compile globals env ast typeinfo =
  let current_module = m in
  let type_ = get_type typeinfo ast in

  let rettype = match type_ with
      Fun (_, tret) -> tret
    | _ -> assert false
  in

  let current_function = define_function "lambda" (function_sig type_) 
    current_module 
  in
  (* set_function_call_conv CallConv.fast current_function; *)
  let builder = builder_at_end (entry_block current_function) in

  (* probably need to do some stuff for the function header *)
  let A_abs (params, false, body) = ast in
  
  let rec gen_llvm (builder:llbuilder) (ast:ast) : (llvalue * llbuilder) = 
    let type_ = get_type typeinfo ast in
    printf "compiling %a => type %a\n%!" pp_ast ast pp_type type_;

    (* NOTE shouldn't really do this every time through the loop *)
    Llvm_bitwriter.write_bitcode_file current_module "bitcode.bc";

    match ast with
	A_lit l -> (llvalue_of_sval type_ l, builder)

      | A_ref var when ExtArray.Array.mem var params ->
	  let idx = ExtArray.Array.findi ((==) var) params in
	  printf "GOT TO THIS POINT %d\n%!" idx;
	  let llvalue = param current_function idx in
	  llvalue, builder

      | A_ref var -> 
	  (* In this case, we really hope the variable is not unbound right
	     now, because that would be a problem... *)
	  let sval = match Env.find globals env var with
	      Gbinding idx -> globals.vals.(idx)
	    | Lbinding (f, idx) -> (List.nth env f).vals.(idx)
	  in
	  (llvalue_of_sval type_ sval, builder)

      | A_cnd (pred, cons, alt) ->
	  let fn = current_function in
	  let pred_val, builder = gen_llvm builder pred in
	  let test = build_icmp Icmp.Ne pred_val (const_int i1_type 0) 
	    (label "test") builder in
      
	  let cons_block = append_block (label "true_branch") fn in
	  let alt_block = append_block (label "false_branch") fn in
	  let join_block = append_block (label "join_branches") fn in
	  ignore (build_cond_br test cons_block alt_block builder);
      
	  let cons_builder = builder_at_end cons_block in
	  let v1, cons_builder = gen_llvm cons_builder cons in
	  ignore (build_br join_block cons_builder);
	
	  let alt_builder = builder_at_end alt_block in
	  let v2, alt_builder = gen_llvm alt_builder alt in
	  ignore (build_br join_block alt_builder);

	  let join_builder = builder_at_end join_block in
	  let return = build_phi [(v1, cons_block); (v2, alt_block)] 
	    (label "phi") join_builder in
	  (return, join_builder)

      | A_seq exprs ->
	  let tmp = undef i32_type in
	  let result = 
	    Array.fold_left (fun (last, builder) expr -> 
			       gen_llvm builder expr) 
	      (tmp, builder) exprs 
	  in
	  result

      | A_app (A_ref var, args) -> begin
	  let fn = match Env.find globals env var with
	      Gbinding idx -> globals.vals.(idx)
	    | Lbinding (f, idx) -> (List.nth env f).vals.(idx)
	  in
	  let args, builder = Array.fold_right 
	    (fun arg (result, builder) ->
	       let r, builder = gen_llvm builder arg in
	       (r :: result, builder))
	    args ([], builder)
	  in
	  let llvalue = match fn with
	      Spfn { prim_llvm = Some llvm_gen_fn } ->
		llvm_gen_fn builder args
	      
	  in
	  llvalue, builder
	  

	end
	  
  in
  let llvalue, builder = gen_llvm builder body in
  ignore (build_ret llvalue builder);
  Llvm_bitwriter.write_bitcode_file m "bitcode.bc";
  printf "FINISHED LLVM COMPILING\n%!";
  Jitcode (current_function, rettype)
;;

(*
  | A_app (lam, args) -> (
      (* gen_llvm the fn, getting an llvalue; cast to a function and 
	 invoke it *)
      printf "BUILDING APP\n%!";
      let lam, builder = recur builder lam in
      let args, builder = 
	Array.fold_left (fun (acc, builder) arg -> 
			   let arg, builder = recur builder arg in
			   (arg :: acc, builder))
	  ([], builder) args 
      in
      printf "BUILT ARGS\n%!";
      let args = Array.of_list args in
      let result = build_call lam args "" builder in
      printf "BUILT CALL\n%!";
      set_instruction_call_conv CallConv.fast result;
      (* set_instruction_tail_call result; *)
      (result, builder))

  | A_abs _ ->
      (* queue up the lambda to be gen_llvmd *)
      let name = (label "lambda") in
      (gen_llvm_lambda name state ast, builder)

  | A_seq exprs ->
      let tmp = undef i32_type in
      let result = 
	Array.fold_left (fun (last, builder) expr -> 
			   recur builder expr) 
	  (tmp, builder) exprs 
      in
      result

and gen_llvm_lambda name state (ast:ast) =
  match ast with
    Abs (params, body) -> begin	
      (* HACK I have to do this manually because everywhere else
	 I want to have functions wrapped by a pointer *)
      let type_ = get_type state.cd_typeinfo ast in
      let fn = define_function name (function_sig type_) state.cd_module in
      set_function_call_conv CallConv.fast fn;
      
      let builder = builder_at_end (entry_block fn) in
      let frame = 
	{ frame_vars = params;
	  frame_vals = Array.create (Array.length params) Sunbound } 
      in
      let env' = { state.cd_env with 
		     env_locals = frame :: state.cd_env.env_locals } in
      let lambda = { fun_llvm_fn = fn } in
      let state = { state with cd_env = env' } in
      let return, builder = gen_llvm state lambda builder body in
      ignore (build_ret return builder);
      fn
    end

  | _ -> 
      raise (Jit_failed (sprintf "expected lambda, got %s" (string_of_ast ast)))
*)



(* Set up the JIT. *)
let jit = ExecutionEngine.create (ModuleProvider.create m)
let _ = 
  
  ExecutionEngine.run_static_ctors jit


let execute_function closure (args:sval array) =
  (* Execute the function and convert the result *)
  let fn, ret_type = match closure with
      Jitcode (fn, ret_type) -> fn, ret_type
    | Dynload (_, fn_name, ret_type) ->
	let Some f = ExecutionEngine.find_function fn_name jit in
	f, ret_type
  in

  let args = Array.map gval_of_sval args in
  let result = ExecutionEngine.run_function fn args jit in
  let result = construct_sval ret_type result in
  printf "LLVM COMPLETED EXECUTION: %a\n%!" pp_sval result;
  result
	

   (* Tear down the JIT.
   ExecutionEngine.run_static_dtors jit;
   ExecutionEngine.dispose jit *)
;;

(* turn this code into an LLVM function; execute it if it is a letrec 
let jit_compile (env:environment) (ast:ast) =
  let typeinfo = 
    try
      Blub_typecheck.tcheck env ast 
    with Tcheck_failed msg -> raise (Jit_failed msg)
  in
  let state = { cd_typeinfo = typeinfo;
		cd_env = env;
		cd_module = m } in
  let type_ = try
    get_type typeinfo ast
  with Not_found -> raise (Jit_failed "No type information for fn")
  in
  match ast with
      Abs _ ->
	printf "starting to compile\n%!";
	let llvalue = gen_llvm_lambda (label "lambda") state ast in
	Sjitfn (type_, value_name llvalue)

    | Letrec ( [| v, Abs (params, body) |], expr) ->
	printf "compiling named function\n%!";
	let rettype = get_type typeinfo expr in
	let wrapper = Abs ([||], ast) in
	let typeinfo = (wrapper, Fun ([||], rettype)) :: typeinfo in
	let state = { state with 
			cd_typeinfo = typeinfo } in
	let llvalue = gen_llvm_lambda (label "blah") state wrapper in
	let result = ExecutionEngine.run_function llvalue [||] jit in
	construct_sval rettype result

    | t ->
	raise (Jit_failed 
		 (sprintf "Can only jit-compile functions; got %s" 
		    (string_of_ast t)))

;;

*)