Source

pa_ovisitor / pa / tctools.ml

open Camlp4
open PreCast
open Pa_type_conv
open Ast

let _loc = Loc.ghost

(** { 6 Error } *)

let errorf loc fmt =
  Printf.ksprintf (fun s -> 
    failwith (Printf.sprintf "%s: %s" (Loc.to_string loc) s))
    fmt

(** { 6 Tools } *)

let from_to f t = 
  let rec loop i = 
    if i > t then [] else i :: loop (i+1) 
  in 
  loop f

let (--) = from_to

(** { 6 Idents and Paths } *)

(** [mk_idents name n] creates idents from name1 to namen *) 
let mk_idents : string -> int -> ident list = fun pref n ->
  List.map (fun i -> IdLid (_loc, pref ^ string_of_int i)) (1 -- n)

(** A.x => A.y where y = f x *)
let rec change_id f = function
  | IdAcc (loc, a, b) -> IdAcc (loc, a, change_id f b)
  | IdLid (loc, s) -> IdLid (loc, f s)
  | id -> id

(** Abcc.x => abc_dot_x *)
let rec label_of_path = function
  | IdAcc (_loc, a, b) -> label_of_path a ^ "_dot_" ^ label_of_path b
  | IdLid (_loc, s) -> s
  | IdUid (_loc, s) -> String.uncapitalize s
  | _id -> assert false

(** [expr_of_id id] and [patt_of_id id] create an expr or patt of [id] correspondingly *)
let expr_of_id : ident -> expr = fun id -> <:expr< $id:id$ >>
let patt_of_id : ident -> patt = fun id -> <:patt< $id:id$ >>

(* Note: a.B.c.D is accepted *)    
let convert_path ss =
  let create = function
    | "" -> assert false
    | x -> match x.[0] with
      | 'a'..'z' -> <:ident< $lid:x$ >> 
      | 'A'..'Z' -> <:ident< $uid:x$ >>
      | _ -> assert false
  in
  let rec concat = function
    | [x] -> x
    | x::xs -> <:ident< $x$ . $concat xs$ >>
    | [] -> assert false
  in
  concat (List.map create ss)

let name_of_ident = function
  | IdUid (_, name) | IdLid (_, name) -> name
  | _ -> assert false

let strip_locs_of_ident id =
  let rec f = function
    | IdAcc (_, id1, id2) -> IdAcc (_loc, f id1, f id2)
    | IdApp (_, id1, id2) -> IdApp (_loc, f id1, f id2)
    | IdLid (_, s) -> IdLid (_loc, s)
    | IdUid (_, s) -> IdUid (_loc, s)
    | IdAnt _ -> assert false
  in
  f id

let same_idents id1 id2 = strip_locs_of_ident id1 = strip_locs_of_ident id2

(** { 6 Tvars } *)

(** [patt_of_tvar tv] creates a pattern variable for a type variable [tv] *)
let patt_of_tvar : ctyp -> patt = function
  | <:ctyp<'$tv$>> -> <:patt< $lid:"__tv_" ^ tv$ >>
  | _ -> assert false

(** [expr_of_tvar tv] creates an expression variable for a type variable [tv] *)
let expr_of_tvar : ctyp -> expr = function
  | <:ctyp<'$tv$>> -> <:expr< $lid:"__tv_" ^ tv$ >>
  | _ -> assert false


(** { 6 Creators } *)

(** [create_patt_app const args] creates a pattern of variant constructor like 
    const (arg1,..,argn) *)
let create_patt_app : patt -> patt list -> patt = fun f patts ->
  List.fold_left (fun st p -> PaApp (_loc, st, p)) f patts

(** [create_expr_app const args] creates an expr of variant constructor like 
    const (arg1,..,argn) 
  This is a variant of Gen.create_expr_app *)
let create_expr_app : expr -> expr list -> expr = fun f exprs ->
  List.fold_left (fun st p -> ExApp (_loc, st, p)) f exprs


let create_top_let : bool -> binding list -> str_item = fun rec_ binds ->
  let binds = 
    let rec create_binds = function
      | [] -> BiNil _loc
      | x::xs -> BiAnd (_loc, x, create_binds xs)
    in
    create_binds binds 
  in
  if rec_ then <:str_item< let rec $binds$ >>
  else <:str_item< let $binds$ >>


(** { 6 Concatenations } *)

let rec gen_concat_items : 'a -> ('a -> 'a -> 'a) -> 'a list -> 'a = 
  fun nil cons xs ->
    match xs with
    | [] -> nil
    | [x] -> x
    | x::xs -> cons x (gen_concat_items nil cons xs)

let concat_class_str_items = 
  gen_concat_items (CrNil _loc) (fun x y -> CrSem (_loc, x, y))

let concat_let_bindings =
  gen_concat_items (BiNil _loc) (fun x y -> BiAnd (_loc, x, y))

let concat_str_items =
  gen_concat_items (StNil _loc) (fun x y -> StSem (_loc, x, y))

let concat_sig_items =
  gen_concat_items (SgNil _loc) (fun x y -> SgSem (_loc, x, y))



(** { 6 Strippers } *)

(** [strip_flags cty] removes mutable and private flags *)
let rec strip_field_flags = function
  | TyMut (_, cty) | TyPrv (_, cty) -> strip_field_flags cty
  | cty -> cty

(** forget the ident locations *)
let rec strip_ident_loc : ident -> ident = function
  | IdAcc(_, id1, id2) -> IdAcc(_loc, strip_ident_loc id1, strip_ident_loc id2)
  | IdApp(_, id1, id2) -> IdApp(_loc, strip_ident_loc id1, strip_ident_loc id2)
  | IdLid(_, n) -> IdLid(_loc, n)
  | IdUid(_, n) -> IdUid(_loc, n)
  | IdAnt(_, n) -> IdAnt(_loc, n)

(** { 6 Deconstruction } *)

let rec split_by_comma = function
  | <:expr< $e1$, $e2$ >> -> split_by_comma e1 @ split_by_comma e2
  | <:expr< >> -> []
  | <:expr< $e$ >> -> [e]

let rec deconstr_tydef tp =
  let rec strip_private = function
    | TyPrv (_, ctyp) -> strip_private ctyp
    | ctyp -> ctyp
  in
  match strip_private tp with
  | TyNil loc -> `Nil loc 
  | TyMan (loc, ctyp, ctyp') -> `Mani (loc, ctyp, deconstr_tydef ctyp') 
  | TyRec (loc ,ctyp) ->
      let fields = List.map (function
        | TyCol (loc, TyId(_, lab_id), ctyp) -> loc, lab_id, ctyp
        | _ -> assert false) (list_of_ctyp ctyp [] )
      in
      `Record (loc, fields) 
  | TySum (loc, ctyp) ->
      let cases = List.map (function
        | <:ctyp@loc< $id:id$ : $ctyp$ -> $_ctyp'$ >> -> loc, id, list_of_ctyp ctyp []
        | <:ctyp@loc< $id:id$ of $ctyp$ >> -> loc, id, list_of_ctyp ctyp []
        | <:ctyp@loc< $id:id$ >> -> loc, id, []
        | _ -> assert false) (list_of_ctyp ctyp []) in
      `Sum (loc, cases) 
  | TyVrnEq (loc, ctyp)
  | TyVrnSup (loc, ctyp)
  | TyVrnInf (loc, ctyp) ->
      let cases = List.map (function
        | <:ctyp@loc< `$idstr$ of $ctyp$ >> -> loc, idstr, list_of_ctyp ctyp []
        | <:ctyp@loc< `$idstr$ >> -> loc, idstr, []
        | _ -> assert false) (list_of_ctyp ctyp [])
      in
      `Variant (loc, cases)
  | ctyp -> `Alias (loc_of_ctyp ctyp, ctyp)

let deconstr_variant_type = function
  | TyVrnEq (loc, ctyp) -> (* [ = t ] *)
      let cases = List.map (function
        | <:ctyp@loc< `$idstr$ of $ctyp$ >> -> loc, idstr, list_of_ctyp ctyp []
        | <:ctyp@loc< `$idstr$ >> -> loc, idstr, []
        | _ -> assert false) (list_of_ctyp ctyp [])
      in
      loc, cases
  | _ -> assert false

let deconstr_object_type = function
  | TyObj (ty_loc, ctyp, flag) ->
      let fields = List.map (function
        | TyCol (_loc, TyId(loc, id), ctyp) -> loc, id, ctyp
        | _ -> assert false) (list_of_ctyp ctyp []) 
      in
      ty_loc, fields, flag
  | _ -> assert false

let type_definitions_are_recursive rec_ tds = 
  rec_ &&
    List.exists (function
      | TyDcl (_, name, _, _, _) -> Gen.type_is_recursive name tds
      | _ -> assert false) 
    (list_of_ctyp tds [])



(** { 6 Type construction } *)

(** (p1, p2, .., pn) name *)
let create_param_type : ctyp list -> string -> ctyp = fun params name ->
  List.fold_left (fun st x -> TyApp (_loc, st, x)) <:ctyp< $lid: name$ >> params 

(** p1 ... pn . ty *)
let create_for_all : ctyp list -> ctyp -> ctyp = fun params ty -> match params with
  | [] -> ty
  | x::xs -> TyPol (_loc, List.fold_right (fun x st -> TyApp (_loc, st, x)) xs x, ty)

class untyvar = object
  inherit map as super 
  method! ctyp ty = 
    let ty = super#ctyp ty in
    match ty with
    | TyQuo (loc, s) -> TyId (loc, <:ident< $lid: "tyvar__" ^ s $>> )
    | _ -> ty
end

(** ['a; ..; 'z] ty => type a .. z . ty[a/'a .. z/'z] *)
let create_type_quant : ctyp list -> ctyp -> ctyp = fun params ty -> 
  let untyvar = new untyvar in
  match List.map untyvar#ctyp params with
  | [] -> ty
  | x::xs -> 
      TyTypePol (_loc, List.fold_right (fun x st -> TyApp (_loc, st, x)) xs x, untyvar#ctyp ty)

let create_object_type : bool -> (string * ctyp) list -> ctyp = fun poly fields ->
  let fields = 
    List.map (fun (lab, ctyp) -> 
      let id = <:ident< $lid:lab$ >> in
      TyCol(_loc, TyId(_loc, id), ctyp)) fields
  in
  TyObj (_loc, tySem_of_list fields, if poly then RvRowVar else RvNil )

(** e1, e2, ..., en => (e1,...,en) 
    Do not use this for variant creations, since variants are curried in P4.
*)
let create_tuple : expr list -> expr = function
  | [] -> assert false
  | [e] -> e
  | exprs -> ExTup (_loc, exCom_of_list exprs)

let create_patt_tuple : patt list -> patt = function
  | [] -> assert false
  | [p] -> p
  | patts -> PaTup (_loc, paCom_of_list patts)

(** e1, e2, ..., en => [e1;...,en] *)
let rec create_list : expr list -> expr = function
  | [] -> <:expr< [] >>
  | e::es -> <:expr< $e$ ::$ create_list es $ >>

(** e1, e2, ..., en => [e1;...,en] *)
let rec create_patt_list : patt list -> patt = function
  | [] -> <:patt< [] >>
  | e::es -> <:patt< $e$ ::$ create_patt_list es $ >>

(** l1,e1, ... ln,en => { l1:e1; ...; ln:en } *)
let create_record : (ident * expr) list -> expr = fun label_exprs -> 
  ExRec (_loc, 
         rbSem_of_list (List.map (fun (l,e) -> RbEq(_loc, l,e)) label_exprs), 
         ExNil _loc)

(** l1,e1, ... ln,en => object method l1 = e1; ...; method ln = en end *)
let create_object : (string * expr) list -> expr = fun label_exprs -> 
  ExObj (_loc, 
         PaNil _loc,
         crSem_of_list (List.map (fun (l,e) -> <:class_str_item< method $l$ = $e$ >>) label_exprs))


(** { 6 Misc } *)

(** These are helper functions which require easier program construction
    in Original syntax *)

let make_class_eq ?(loc=_loc) ?(virt=false) vars ident clexpr =
    CeEq (loc, 
          CeCon (loc, (if virt then ViVirtual else ViNil), ident, tyCom_of_list vars),
          clexpr)

let make_class class_defs = StCls (_loc, ceAnd_of_list class_defs)