Source

ocaml-lib / monad / logicMonad.ml


(* streams as implementation of MonadPlus *)

module Stream =
struct
  type 'a t =
    | Nil
    | Single of 'a
    | Cons of 'a * 'a t
    | Concat of 'a t * 'a t
    | Lazy of 'a t Lazy.t

  let make_concat = function
      | Nil, str2 -> str2
      | str1, Nil -> str1
      | Single x, str2 -> Cons (x,str2)
      | str1, str2 -> Concat (str1,str2)

  let return (x : 'a) : 'a t = Single x

  let rec bind (str : 'a t) (k : 'a -> 'b t) : 'b t =
    match str with
      | Nil -> Nil
      | Single x -> k x
      | Cons (x,str1) ->
	let k_x = k x in
	let k_str1 = bind str1 k in
	make_concat (k_x, k_str1)
      | Concat (str1,str2) ->
	let k_str1 = bind str1 k in
	let k_str2 = bind str2 k in
	make_concat (k_str1, k_str2)
      | Lazy lstr -> Lazy (lazy (bind (Lazy.force lstr) k))

  let mzero = Nil

  let mplus str1 str2 = make_concat (str1,str2)

  let rec split : 'a t -> ('a * 'a t) option = function
    | Nil -> None
    | Single x -> Some (x, Nil)
    | Cons (x,str1) -> Some (x, str1)
    | Concat (str1,str2) ->
      ( match split str1 with
	| None -> split str2
	| Some (x1,str11) -> Some (x1, make_concat (str11,str2)) )
    | Lazy lstr -> split (Lazy.force lstr)

  let rec interleave str1 str2 =
    match split str1 with
      | None -> str2
      | Some (x,str11) -> Cons (x, Lazy (lazy (interleave str2 str11)))

  let rec fold (f : 'acc -> 'a -> 'acc) (init : 'acc) (str : 'a t) : 'acc =
    match split str with
      | None -> init
      | Some (x,str1) -> fold f (f init x) str1
end

(* types for MonadPlus(MonadState(MonadError(Id))) *)

type ('a,'e) either = Result of 'a | Error of 'e

type ('a,'e,'s) t = 's -> (('a,'e) either * 's) Stream.t

let lift (k : 'a -> ('b,'e,'s) t) : ('a,'e) either * 's -> (('b,'e) either * 's) Stream.t = function
  | Result x, s1 -> k x s1
  | Error e, s1 -> Stream.return (Error e, s1)

(* implementation for Monad *)

let return (x : 'a) : ('a,'e,'s) t =
  fun s -> Stream.return (Result x, s)

let rec bind (m : ('a,'e,'s) t) (k : 'a -> ('b,'e,'s) t) : ('b,'e,'s) t =
  fun s -> Stream.bind (m s) (lift k)

(* implementation for MonadPlus *)

let mzero : ('a,'e,'s) t =
  fun s -> Stream.mzero

let mplus (m1 : ('a,'e,'s) t) (m2 : ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s -> Stream.mplus (m1 s) (m2 s)

let mguard (cond : bool) : (unit,'e,'s) t =
  fun s ->
    if cond
    then Stream.return (Result (), s)
    else Stream.mzero

let rec msplit str =
  match Stream.split str with
    | None -> None
    | Some (xes,str1) ->
      match xes with
	| Result x, s1 -> Some (x,s1,str1)
	| Error _, _ -> msplit str1

let ifte (m1 : ('a,'e,'s) t) (k1 : 'a -> ('b,'e,'s) t) (m2 : ('b,'e,'s) t) : ('b,'e,'s) t =
  fun s ->
    match msplit (m1 s) with
      | None -> m2 s
      | Some (x,s1,str1) -> Stream.mplus (k1 x s1) (Stream.bind str1 (lift k1))

let once (m : ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s ->
    match msplit (m s) with
      | None -> Stream.mzero
      | Some (x,s1,str1) -> Stream.return (Result x,s1)

let mplus_fair (m1 : ('a,'e,'s) t) (m2 : ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s -> Stream.interleave (m1 s) (m2 s)

let bind_fair (m : ('a,'e,'s) t) (k : 'a -> ('b,'e,'s) t) : ('b,'e,'s) t =
  fun s ->
    match msplit (m s) with
      | None -> Stream.mzero
      | Some (x,s1,str1) -> Stream.interleave (k x s1) (Stream.bind str1 (lift k))

let rec fold ?(limit : int option) (f : 'acc -> 'a -> 'acc) (init : 'acc) (m : ('a,'e,'s) t) : ('acc,'e,'s) t =
  fun s ->
    let n = match limit with Some n when n>=0 -> n | _ -> -1 in
    let res = fold_aux f init (m s) n in
    Stream.return (Result res, s)
and fold_aux f acc str n =
  if n = 0
  then acc
  else
    match msplit str with
      | None -> acc
      | Some (x,s1,str1) ->
	let n1 = if n>0 then n-1 else -1 in
	fold_aux f (f acc x) str1 n1

(* implementation of MonadError *)

let raise_ (e : 'e) : ('a,'e,'s) t =
  fun s -> Stream.return (Error e, s)

let rec catch (m : ('a,'e,'s) t) (k : 'e -> ('a,'e,'s) t) : ('a,'e,'s) t =
  fun s ->
    Stream.bind (m s) (fun xes ->
      match xes with
	| Result x, s1 -> Stream.return xes
	| Error e, s1 -> k e s1)

(* implementation of MonadState *)

let get : ('s,'e,'s) t =
  fun s -> Stream.return (Result s, s)

let put (s2 : 's) : (unit,'e,'s) t =
  fun s -> Stream.return (Result (), s2)

let modify (f : 's -> 's) : (unit,'e,'s) t =
  fun s -> Stream.return (Result (), f s)

let local (s1 : 's1) (m : ('a,'e,'s1) t) : ('a,'e,'s) t =
  fun s -> Stream.bind (m s1) (fun (xe,_) -> Stream.return (xe,s))

(* composed monadic operations *)

let yield e = mplus (raise_ e) (return ())

let succeeds m = ifte (once m) (fun _ -> return ()) mzero
let fails m = ifte (once m) (fun _ -> mzero) (return ())

let exists m pred = succeeds (bind m (fun x -> mguard (pred x)))
let for_all m pred = fails (bind m (fun x -> mguard (not (pred x))))

let bagof ?limit m = fold ?limit (fun acc x -> x::acc) [] m

(* running monads *)

let run (m : ('a,'e,'s) t) (s : 's) : 'a =
  match msplit (m s) with
    | None -> raise Not_found
    | Some (x,s1,str1) -> x

let test (m : ('a,'e,'s) t) (s : 's) : bool =
  match msplit (m s) with
    | None -> false
    | Some _ -> true

(* primitive monads for syntax extension *)

module PA_atom =
struct
  let return x = return x
  let fail = mzero
  let guard c = mguard c
  let yield x = yield x

  let local s = fun m -> local s m
  let once = fun m -> once m
  let succeeds = fun m -> succeeds m
  let fails = fun m -> fails m
  let bagof ?limit = fun m -> bagof ?limit m

  let fold ?limit f init = fun m-> fold ?limit f init m
  let iter ?limit f = fun m -> fold ?limit (fun _ x -> f x) () m
  let exists pred = fun m -> exists m pred
  let for_all pred = fun m -> for_all m pred
end

(* useful monads *)

module Int =
struct
  let rec range (a  : int) (b : int) : (int,'e,'s) t =
    bind (mguard (a <= b)) (fun _ -> mplus (return a) (range (a+1) b))
end

module List =
struct
  let rec choose : 'a list -> ('a,'e,'s) t = function
    | [] -> mzero
    | x::l -> mplus (return x) (choose l)

  let rec map (f : 'a -> ('b,'e,'s) t) : 'a list -> ('b list,'e,'s) t = function
    | [] -> return []
    | x::lx -> bind (f x) (fun y -> bind (map f lx) (fun ly -> return (y::ly)))
end