Source

blub / blub_typecheck.ml

Full commit
open Blub_types
open Blub_common
open Format

exception Tcheck_failed of string

module TypeVarSet = Set.Make (struct
				type t = typevar
				let compare (_, i) (_, j) =
				  if i<j then -1 else if i>j then 1 else 0
			      end)

type typeschema = type_ * TypeVarSet.t

type env = (variable, typeschema) PMap.t

let terror s =
  eprintf "WARNING: %s\n%!" s;
  ()
;;

let rec simplify tv = 
  let (r, id) = tv in
  match !r with
      Some t -> t
    | None -> Typevar tv
;;


let rec unify (t1:type_) (t2:type_) =
  match t1, t2 with
      Typevar (r, _), _ -> unify_var t2 r
    | _, Typevar (r, _) -> unify_var t1 r
    | Fun (tparams, tret), Fun (tparams', tret') ->
	if (Array.length tparams) != (Array.length tparams') then
	  terror "Functions can't unify because param lengths differ"
	else begin
	  Array.iteri (fun i p -> unify p tparams'.(i)) tparams;
	  unify tret tret'
	end
    | VarArray t1, VarArray t2 -> unify t1 t2
    | t1, t2 when t1==t2 -> ()
    | _, _ -> 
	printf "Unification between %a and %a failed" pp_type t1 pp_type t2;
	raise (Tcheck_failed "")

and unify_var (t1:type_) (r:type_ option ref) =
  match t1, !r with
      Typevar (r1, _), None ->
	if r1 != r then r := Some t1 else ()
    | _, None -> r := Some t1
    | _, Some t ->
	unify t1 t
;;

(* all unsolved type variables in t *)
let rec unsolved (t:type_) : TypeVarSet.t =
  match t with
      Int | Bool | Float64 | Symbol -> TypeVarSet.empty
    | Fun (tparams, tret) -> 
	let u = Array.fold_left 
	  (fun acc t -> TypeVarSet.union acc (unsolved t)) 
	  TypeVarSet.empty tparams
	in
	TypeVarSet.union u (unsolved tret)
    | VarArray t -> unsolved t 
    | Typevar (r, nm) -> (
	match !r with
	    None -> TypeVarSet.add (r, nm) TypeVarSet.empty 
	  | Some t -> unsolved t)
;;

(* all unsolved type variables mentioned in the type environment *)
let env_unsolved (r:env) : TypeVarSet.t =
  PMap.fold (fun (t, _) l0 ->
	       TypeVarSet.union l0 (unsolved t)) r TypeVarSet.empty
;;

(* Not used by the typechecker itself, but for later; does this AST
   contain any unsolved type variables anywhere in it? *)
let rec all_unsolved (tmap:(ast, type_) PMap.t) ast : TypeVarSet.t =
  let type_ = PMap.find ast tmap in
  match ast with
      A_lit _ | A_ref _ -> unsolved type_
    | A_cnd (e1, e2, e3) ->
	TypeVarSet.union (TypeVarSet.union 
			    (all_unsolved tmap e1) (all_unsolved tmap e2))
	  (all_unsolved tmap e3)
    | A_seq exprs ->
	Array.fold_right (fun expr acc ->
			    TypeVarSet.union (all_unsolved tmap expr) acc)
	  exprs TypeVarSet.empty
    | A_abs (_, _, body) ->
	TypeVarSet.union (unsolved type_) (all_unsolved tmap body)

;;

(* given an environment r, produce a schema for t that identifies all
   the new type variables in t, which can be arbitrarily substituted for *)
let schema (t:type_) (r:env) : typeschema =
  let uv = unsolved t in
  let ev = env_unsolved r in
  let uv' = TypeVarSet.diff uv ev in
  (t, uv')
;;

(* a type just like t except that every type variable in tvs has been
   consistently replaced by a fresh type variable. *)
let instantiate ((t, tvs):typeschema) : type_ =
  let tm : (typevar*type_) list = TypeVarSet.fold
    (fun (tv:typevar) tm -> printf "MAKE TYPEVAR\n%!"; (tv, typevar ()) :: tm) 
    tvs []
  in
  let rec inst_var (tv:typevar) (tm:((typevar*type_) list)) : type_ =
    match tm with
	[] -> Typevar tv
      | (tv1, tv2) :: tm' ->
	  let (_, id1) = tv1 in
	  let (_, id) = tv in
	  printf "Comparing typevar %d against %d\n%!" id1 id;
	  if ((simplify tv1) = (simplify tv)) then tv2 else inst_var tv tm'
  in
  let rec inst (t:type_) =
    match t with
	Int | Bool | Float64 -> t
      | Fun (tparams, tret) -> 
	  Fun (Array.map inst tparams, inst tret)
      | VarArray t -> VarArray (inst t)
      | Typevar tv -> inst_var tv tm
  in
  inst t
;;

let rec type_of_sval = function
    Sfalse | Strue -> Bool
  | Sint _ -> Int
  | Sfloat _ -> Float64
  | Ssymbol _ -> Symbol
  | Svector v -> 
      if v = [||] then
	VarArray (typevar ())
      else
	VarArray 
	  (Array.fold_left 
	     (fun a e -> 
		let t = type_of_sval e in
		if a != t then
		  terror (sprintf "Creating vector with mismatched element types");
		t)
	     (type_of_sval v.(0))
	     v)
  | Sunbound -> typevar ()
  | x -> 
      printf "Unsupported sval %a\n%!" pp_sval x;
      typevar ()

let tmap_union map1 map2 =
  PMap.foldi PMap.add map1 map2
;;

let tcheck globals env (e:ast) : (ast, type_) PMap.t =
  let rec tcheck (r:env) (e:ast) : (ast, type_) PMap.t =
    (* printf "tcheck: %a\n%!" pp_ast e; *)
    match e with
	A_lit l -> PMap.add e (type_of_sval l) PMap.empty
      | A_ref var -> begin
	  try
	    let t = instantiate (PMap.find var r) in
	    PMap.add e t PMap.empty
	  with Not_found ->
	    printf "Variable %a not found in environment\n%!" pp_var var;
	    raise (Tcheck_failed "")
	end

      | A_cnd (e1, e2, e3) ->
	  let t2 = tcheck r e2 in
	  let t3 = tcheck r e3 in
	  let t1 = tcheck r e1 in
	  let tmap = tmap_union t1 t2 in
	  let tmap = tmap_union tmap t3 in

	  unify Bool (PMap.find e1 tmap);
	  unify (PMap.find e2 tmap) (PMap.find e3 tmap);
	  PMap.add e (PMap.find e2 tmap) tmap
	    
      | A_abs (params, false, body) ->
	  let tparams = Array.map (fun _ -> typevar ()) params in
	  let ts = Array.map (fun t -> (t, TypeVarSet.empty)) tparams in
	  let var_ts = Array.mapi (fun i ts -> (params.(i), ts)) ts in
	  let env = Array.fold_left (fun env (var, ts) ->
				       PMap.add var ts env) r var_ts in
	  let tmap = tcheck env body in
	  let function_type = Fun (tparams, PMap.find body tmap) in
	  PMap.add e function_type tmap
	    
      | A_app (fn, args) -> begin
	  let tmap = Array.fold_right 
	    (fun arg acc -> tmap_union (tcheck r arg) acc)
	    args PMap.empty
	  in
	  let tmap = tmap_union (tcheck r fn) tmap in
	  let fn_type = PMap.find fn tmap in

	  match fn_type with
	      Fun (tparams, tret) -> 
		assert ((Array.length tparams) = (Array.length args));
		Array.iteri 
		  (fun i arg -> 
		     let arg_type = PMap.find arg tmap in
		     unify tparams.(i) arg_type) 
		  args;
		PMap.add e tret tmap

	    | t ->
		printf "App on unexpected type: %a" pp_type t;
		raise (Tcheck_failed "")
	end

      | A_seq exprs ->
	  let tmap = Array.fold_right 
	    (fun expr tmap ->
	       tmap_union (tcheck r expr) tmap)
	    exprs PMap.empty
	  in
	  let last_expr = exprs.(Array.length exprs - 1) in
	  PMap.add e (PMap.find last_expr tmap) tmap
(*
      | A_let ( [| (var, (A_abs (params, false, body) as t)) |], e2, LT_letrec) ->
	  let tparams = Array.map (fun _ -> typevar ()) params in
	  let tret = typevar () in
	  let tf = Fun (tparams, tret) in

	  Hashtbl.add typedict t tf;

	  let var_ts = Array.mapi 
	    (fun i tp -> (params.(i), (tp, TypeVarSet.empty))) tparams in
	  
	  let r' = Array.fold_left 
	    (fun env (var, ts) -> PMap.add var ts env) r var_ts
	  in
	  let r'' = PMap.add var (tf, TypeVarSet.empty) r' in
	  let tbody = tcheck r'' body in
	  unify tbody tret;

	  let env = PMap.add var (schema tf r) r in
	  tcheck env e2
*)

      | A_let _ | A_set _ | A_callcc _ | A_abs _ ->
	  printf "sorry, Shawn is lazy and hasn't got this stuff working again yet!!";
	  raise (Tcheck_failed "")
(*
      | Let ((var, _), (Abs (params, body) as f), e2) -> (
	  let tparams = Array.map (fun _ -> typevar ()) params in
	  let tret = typevar () in
	  let tf = Fun (tparams, tret) in

	  Hashtbl.add typedict f tf;

	  let var_ts = Array.mapi 
	    (fun i tp -> (params.(i), (tp, TypeVarSet.empty))) tparams in
	  
	  let r' = Array.fold_left 
	    (fun env (var, ts) -> PMap.add var ts env) r var_ts
	  in
	  let r'' = PMap.add var (tf, TypeVarSet.empty) r' in
	  let tbody = tcheck r'' body in
	  unify tbody tret;

	  printf "HA HA HA   %s   %!" (string_of_type tf);
	  (* NOTE the original does this to the original env r, not r'' *)
	  let env = PMap.add var (schema tf r) r in
	  tcheck env e2
	)

      | Let ((var, _), e1, e2) ->
	  tcheck (vardeclcheck var e1 r) e2
*)


  and vardeclcheck var e r =
    let tmap = tcheck r e in
    let typ = PMap.find e tmap in
    PMap.add var (schema typ r) r
  in

  (* Add the types of everything in the environment *)

  let env = List.fold_left
    (fun env frame ->
       let frame = Array.mapi 
	 (fun idx var -> (var, frame.vals.(idx))) frame.vars in
       Array.fold_right
	 (fun (var, value) env -> 
	    let ts = schema (type_of_sval value) PMap.empty in
	    PMap.add var ts env) frame env)
    PMap.empty env
  in

  let globals = Array.mapi (fun idx var ->
			      (var, globals.vals.(idx))) globals.vars in

  let env = Array.fold_right
    (fun (var, value) env -> 
	 let ts = schema (type_of_sval value) PMap.empty in
	 PMap.add var ts env)
    globals env
  in

  tcheck env e