Source

ocaml-lib / sumonad / sumonad.ml


module Stream =
struct
  type 'a t =
    | Nil
    | Single of 'a
    | Cons of 'a * 'a t
    | Lazy of (unit -> 'a t)

  let rec is_empty : 'a t -> bool = function
    | Nil -> true
    | Single _ -> false
    | Cons _ -> false
    | Lazy f -> is_empty (f ())

  let rec concat : 'a t * 'a t -> 'a t = function
    | Nil, str2 -> str2
    | str1, Nil -> str1
    | Single x1, str2 -> Cons (x1,str2)
    | Cons (x1,rest1), str2 -> Cons (x1, concat (rest1, str2))
    | Lazy f1, str2 -> Lazy (fun () -> concat (f1 (), str2))

  let rec read : 'a t -> ('a * 'a t) option = function
    | Nil -> None
    | Single x -> Some (x, Nil)
    | Cons (x,str) -> Some (x, str)
    | Lazy f -> read (f ())

  let rec fold (f : 'a -> 'b -> 'a) (init : 'a) (str : 'b t) : 'a =
    match read str with
      | None -> init
      | Some (x,rest) -> fold f (f init x) rest

  let of_list l =
    List.fold_left
      (fun str x -> Cons (x,str))
      Nil l

  let of_hashtbl ht =
    Hashtbl.fold
      (fun k v str -> Cons ((k,v), str))
      ht Nil
end

type ('a,'e) either = Result of 'a | Error of 'e

type ('a,'e,'s) t = 's -> (('a,'e) either * 's) Stream.t

let return (x : 'a) : ('a,'e,'s) t =
  fun s -> Stream.Single (Result x, s)

let rec bind (m : ('a,'e,'s) t) (k : 'a -> ('b,'e,'s) t) : ('b,'e,'s) t = 
  fun s -> bind_aux (m s) k
and bind_aux ms k =
  match ms with
    | Stream.Nil -> Stream.Nil
    | Stream.Single (x,s') -> bind_either x s' k
    | Stream.Cons ((x,s'),str) -> Stream.concat (bind_either x s' k, Stream.Lazy (fun () -> bind_aux str k))
    | Stream.Lazy fstr -> bind_aux (fstr ()) k
and bind_either x s' k =
  match x with
    | Result v -> k v s'
    | Error e -> Stream.Single (Error e, s')

let fail : ('a,'e,'s) t =
  fun s -> Stream.Nil

let error msg : ('a,'e,'s) t =
  fun s -> Stream.Single (Error msg, s)

let mplus (m1 : ('a,'e,'s) t) (m2 : ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s -> Stream.concat (m1 s, Stream.Lazy (fun () -> m2 s))

let cut (m1 : ('a,'e,'s) t) (k2 : 'a -> ('b,'e,'s) t) (m3 : ('b,'e,'s) t) : ('b,'e,'s) t =
  fun s ->
    let str1 = m1 s in
    if Stream.is_empty str1
    then m3 s
    else bind_aux str1 k2

let guard (cond : bool) : (unit,'e,'s) t =
  fun s ->
    if cond
    then Stream.Single (Result (), s)
    else Stream.Nil

let ifthenelse (cond : bool) (m1 : ('a,'e,'s) t) (m2 : ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s ->
    if cond
    then m1 s
    else m2 s

let succeeds (m : ('a,'e,'s) t) : (unit,'e,'s) t =
  fun s ->
    let str = m s in
    if Stream.is_empty str
    then Stream.Nil
    else Stream.Single (Result (), s)

let fails (m : ('a,'e,'s) t) : (unit,'e,'s) t =
  fun s ->
    let str = m s in
    if Stream.is_empty str
    then Stream.Single (Result (), s)
    else Stream.Nil

let aggreg (f : 'c -> 'b -> 'c) (init : 'c) (m : ('a * 'b,'e,'s) t) : ('a * 'c,'e,'s) t =
  fun s ->
    let ht = Hashtbl.create 13 in
    let str = m s in
    let _ =
      Stream.fold
	(fun _ -> function
	  | (Result (k,v),_) ->
	    let res0 = try Hashtbl.find ht k with Not_found -> init in
	    Hashtbl.replace ht k (f res0 v)
	  | _ -> ())
	() str in
    Hashtbl.fold
      (fun k v str -> Stream.Cons ((Result (k,v),s),str))
      ht Stream.Nil

let get_state : ('s,'e,'s) t =
  fun s -> Stream.Single (Result s, s)

let set_state (s2 : 's) : (unit,'e,'s) t =
  fun s -> Stream.Single (Result (), s2)

let update (modif : 's -> ('s,'e) either) : (unit,'e,'s) t =
  fun s ->
    match modif s with
      | Result s' -> Stream.Single (Result (), s')
      | Error e -> Stream.Single (Error e, s)

let view (access : 's -> ('a,'e) either) : ('a,'e,'s) t =
  fun s -> Stream.Single (access s, s)

let effect (action : 's -> unit) : (unit,'e,'s) t =
  fun s -> action s; Stream.Single (Result (), s)


let rec kiter (f : 'a -> unit) (k : int) (m : ('a,'e,'s) t) (s : 's) : unit =
  kiter_aux f k (m s)
and kiter_aux f k str =
  if k = 0
  then ()
  else
    match Stream.read str with
      | None -> ()
      | Some ((x,s), rest) ->
	match x with
	  | Result v -> f v; kiter_aux f (k-1) rest
	  | Error e -> kiter_aux f k rest

let rec klist (k : int) (m : ('a,'e,'s) t) (s : 's) : 'a list =
  klist_aux k (m s)
and klist_aux k str =
  if k = 0
  then []
  else
    match Stream.read str with
      | None -> []
      | Some ((x,s), rest) ->
	match x with
	  | Result v -> v :: klist_aux (k-1) rest
	  | Error e -> klist_aux k rest

let rec fold (f : 'acc -> 'a -> 'acc) (acc : 'acc) (m : ('a,'e,'s) t) (s : 's) : unit =
  iter_aux f acc (m s)
and iter_aux f acc str =
  match Stream.read str with
    | None -> acc
    | Some ((x,s), rest) ->
      match x with
	| Result v -> let acc' = f acc v in iter_aux f acc' rest
	| Error e -> acc

let rec exists (f : 'a -> bool) (m : ('a,'e,'s) t) (s : 's) : bool =
  exists_aux f (m s)
and exists_aux f str =
  match Stream.read str with
    | None -> false
    | Some ((x,s), rest) ->
      match x with
	| Result v -> f v || exists_aux f rest
	| Error e -> exists_aux f rest

let rec forall (f : 'a -> bool) (m : ('a,'e,'s) t) (s : 's) : bool =
  forall_aux f (m s)
and forall_aux f str =
  match Stream.read str with
    | None -> true
    | Some ((x,s), rest) -> 
      match x with
	| Result v -> f v && forall_aux f rest
	| Error e -> forall_aux f rest


module Int =
struct
  let rec range a b =
    if a <= b
    then mplus (return a) (range (a+1) b)
    else fail
end

module List =
struct
  let rec choose : 'a list -> ('a,'e,'s) t = function
    | [] -> fail
    | x::l -> mplus (return x) (choose l)

  let rec map (f : 'a -> ('b,'e,'s) t) : 'a list -> ('b list,'e,'s) t = function
    | [] -> return []
    | x::lx -> bind (f x) (fun y -> bind (map f lx) (fun ly -> return (y::ly)))
end