Source

blub / blub_typecheck.ml

Full commit
open Blub_types
open Blub_common
open Format

exception Tcheck_failed of string

module TypeVarSet = Set.Make (struct
				type t = typevar
				let compare (_, i) (_, j) =
				  if i<j then -1 else if i>j then 1 else 0
			      end)

type typeschema = type_ * TypeVarSet.t

type env = (variable, typeschema) PMap.t

let terror s =
  eprintf "WARNING: %s\n%!" s;
  ()
;;

let rec simplify_tv tv = 
  let (r, id) = tv in
  match !r with
      Some t -> t
    | None -> Typevar tv
;;

let rec simplify = function
    Typevar (r, _) as tv -> (
      match !r with
	  Some t -> simplify t
	| None -> tv)
  | x -> x
;;



let rec unify (t1:type_) (t2:type_) =
  match t1, t2 with
      Typevar (r, _), _ -> unify_var t2 r
    | _, Typevar (r, _) -> unify_var t1 r
    | Fun (tparams, tret), Fun (tparams', tret') ->
	if (Array.length tparams) != (Array.length tparams') then
	  terror "Functions can't unify because param lengths differ"
	else begin
	  Array.iteri (fun i p -> unify p tparams'.(i)) tparams;
	  unify tret tret'
	end
    | VarArray t1, VarArray t2 -> unify t1 t2
    | t1, t2 when t1==t2 -> ()
    | _, _ -> 
	printf "Unification between %a and %a failed" pp_type t1 pp_type t2;
	raise (Tcheck_failed "")

and unify_var (t1:type_) (r:type_ option ref) =
  match t1, !r with
      Typevar (r1, _), None ->
	if r1 != r then r := Some t1 else ()
    | _, None -> r := Some t1
    | _, Some t ->
	unify t1 t
;;

(* all unsolved type variables in t *)
let rec unsolved (t:type_) : TypeVarSet.t =
  match t with
      Int | Bool | Float64 | Symbol -> TypeVarSet.empty
    | Fun (tparams, tret) -> 
	let u = Array.fold_left 
	  (fun acc t -> TypeVarSet.union acc (unsolved t)) 
	  TypeVarSet.empty tparams
	in
	TypeVarSet.union u (unsolved tret)
    | VarArray t -> unsolved t 
    | Typevar (r, nm) -> (
	match !r with
	    None -> TypeVarSet.add (r, nm) TypeVarSet.empty 
	  | Some t -> unsolved t)
;;

(* all unsolved type variables mentioned in the type environment *)
let env_unsolved (r:env) : TypeVarSet.t =
  PMap.fold (fun (t, _) l0 ->
	       TypeVarSet.union l0 (unsolved t)) r TypeVarSet.empty
;;

(* Not used by the typechecker itself, but for later; does this AST
   contain any unsolved type variables anywhere in it? *)
let rec all_unsolved (tmap:(ast, type_) PMap.t) ast : TypeVarSet.t =
  let type_ = PMap.find ast tmap in
  match ast with
      A_lit _ | A_ref _ -> unsolved type_
    | A_cnd (e1, e2, e3) ->
	TypeVarSet.union (TypeVarSet.union 
			    (all_unsolved tmap e1) (all_unsolved tmap e2))
	  (all_unsolved tmap e3)
    | A_seq exprs ->
	Array.fold_right (fun expr acc ->
			    TypeVarSet.union (all_unsolved tmap expr) acc)
	  exprs TypeVarSet.empty
    | A_abs (_, _, body) ->
	TypeVarSet.union (unsolved type_) (all_unsolved tmap body)
    | A_app (fn, args) ->
	let tmp = Array.fold_right (fun expr acc ->
				      TypeVarSet.union (all_unsolved tmap expr)
					acc)
	  args (all_unsolved tmap fn)
	in
	TypeVarSet.union (unsolved type_) tmp

;;

(* given an environment r, produce a schema for t that identifies all
   the new type variables in t, which can be arbitrarily substituted for *)
let schema (t:type_) (r:env) : typeschema =
  let uv = unsolved t in
  let ev = env_unsolved r in
  let uv' = TypeVarSet.diff uv ev in
  (t, uv')
;;

(* a type just like t except that every type variable in tvs has been
   consistently replaced by a fresh type variable. *)
let instantiate ((t, tvs):typeschema) : type_ =
  let tm : (typevar*type_) list = TypeVarSet.fold
    (fun (tv:typevar) tm -> (tv, typevar ()) :: tm) 
    tvs []
  in
  let rec inst_var (tv:typevar) (tm:((typevar*type_) list)) : type_ =
    match tm with
	[] -> Typevar tv
      | (tv1, tv2) :: tm' ->
	  let (_, id1) = tv1 in
	  let (_, id) = tv in
	  printf "Comparing typevar %d against %d\n%!" id1 id;
	  if ((simplify_tv tv1) = (simplify_tv tv)) then 
	    tv2 
	  else 
	    inst_var tv tm'
  in
  let rec inst (t:type_) =
    match t with
	Int | Bool | Float64 -> t
      | Fun (tparams, tret) -> 
	  Fun (Array.map inst tparams, inst tret)
      | VarArray t -> VarArray (inst t)
      | Typevar tv -> inst_var tv tm
  in
  inst t
;;

let rec type_of_sval = function
    Sfalse | Strue -> Bool
  | Sint _ -> Int
  | Sfloat _ -> Float64
  | Ssymbol _ -> Symbol
  | Svector v -> 
      if v = [||] then
	VarArray (typevar ())
      else
	VarArray 
	  (Array.fold_left 
	     (fun a e -> 
		let t = type_of_sval e in
		if a != t then
		  terror (sprintf "Creating vector with mismatched element types");
		t)
	     (type_of_sval v.(0))
	     v)
  | Sunbound -> typevar ()
  | Spfn { prim_type = Some t } -> t
  | x -> 
      printf "Unsupported sval %a\n%!" pp_sval x;
      typevar ()

type tmap = (ast, type_) PMap.t

let tmap_union map1 map2 =
  PMap.foldi PMap.add map1 map2
;;


let tcheck globals env (e:ast) : tmap =
  let rec tcheck (r:env) (e:ast) : type_ * tmap =
    (* printf "tcheck: %a\n%!" pp_ast e; *)
    let type_, tmap = match e with
	A_lit l -> (type_of_sval l, PMap.empty)
      | A_ref var -> begin
	  try
	    let t = instantiate (PMap.find var r) in
	    (t, PMap.empty)
	  with Not_found ->
	    printf "Variable %a not found in environment\n%!" pp_var var;
	    raise (Tcheck_failed "")
	end

      | A_cnd (e1, e2, e3) ->
	  let t2, tmap2 = tcheck r e2 in
	  let t3, tmap3 = tcheck r e3 in
	  let t1, tmap1 = tcheck r e1 in
	  let tmap = tmap_union tmap1 (tmap_union tmap2 tmap3) in

	  unify Bool t1;
	  unify t2 t3;
	  (t2, tmap)
	    
      | A_abs (params, false, body) ->
	  let tparams = Array.map (fun _ -> typevar ()) params in
	  let ts = Array.map (fun t -> (t, TypeVarSet.empty)) tparams in
	  let var_ts = Array.mapi (fun i ts -> (params.(i), ts)) ts in
	  let env = Array.fold_left (fun env (var, ts) ->
				       PMap.add var ts env) r var_ts in
	  let tret, tmap = tcheck env body in
	  let function_type = Fun (tparams, tret) in
	  (function_type, tmap)
	    
      | A_app (fn, args) -> begin
	  let arg_typeinfo = Array.map (tcheck r) args in
	  let fn_type, tmap = tcheck r fn in

	  let () = match fn_type with
	      Fun (tparams, tret) -> 
		assert ((Array.length tparams) = (Array.length args));
		Array.iteri 
		  (fun i (arg_type, _) -> 
		     unify tparams.(i) arg_type) 
		  arg_typeinfo

	    | t ->
		printf "App on unexpected type: %a" pp_type t;
		raise (Tcheck_failed "")
	  in

	  let tmap = Array.fold_right (fun (_, tmap) acc -> 
					 tmap_union tmap acc)
	    arg_typeinfo PMap.empty
	  in
	  fn_type, tmap
	end

      | A_seq exprs ->
	  let expr_typeinfo = Array.map (tcheck r) exprs in
	  let tmap = Array.fold_right
	    (fun (_, tmap) acc -> tmap_union tmap acc)
	    expr_typeinfo PMap.empty
	  in
	  let last_expr_t, _ = expr_typeinfo.(Array.length exprs - 1) in
	  last_expr_t, tmap

	    (* very specialized support for a single let-binding of a
	       function with one parameter *)
      | A_let ( [| (f, (A_abs ([| x |], false, e1) as t)) |],
		e2, LT_letrec) ->
          (* FunDecl part *)
          let t1 = typevar () in
          let t2 = typevar () in
          let tf = Fun ( [| t1 |], t2 ) in
          let r' = PMap.add f (tf, TypeVarSet.empty) r in
          let r'' = PMap.add x (t1, TypeVarSet.empty) r' in
          let te, tmap = tcheck r'' e1 in
          unify te t2;
          let r = PMap.add f (schema tf r) r in
          
          (* Let part *)
          let t_e2, tmap' = tcheck r e2 in
          let tmap = tmap_union tmap tmap' in
	  t_e2, tmap


      | A_set (var, expr) ->
	  let t, tmap = tcheck r expr in
	  t, tmap

      | A_let _ | A_set _ | A_callcc _ | A_abs _ ->
	  printf "sorry, Shawn is lazy and hasn't got this stuff working again yet!!";
	  raise (Tcheck_failed "")

    in
    type_, PMap.add e type_ tmap
  in

  (* Add the types of everything in the environment *)

  let env = List.fold_left
    (fun env frame ->
       let frame = Array.mapi 
	 (fun idx var -> (var, frame.vals.(idx))) frame.vars in
       Array.fold_right
	 (fun (var, value) env -> 
	    let ts = schema (type_of_sval value) PMap.empty in
	    PMap.add var ts env) frame env)
    PMap.empty env
  in

  let globals = Array.mapi (fun idx var ->
			      (var, globals.vals.(idx))) globals.vars in

  let env = Array.fold_right
    (fun (var, value) env -> 
	 let ts = schema (type_of_sval value) PMap.empty in
	 PMap.add var ts env)
    globals env
  in

  let _, tmap = tcheck env e in
  tmap