Source

ocaml-lib / diet.ml

Full commit
(*
   Copyright Š 2001, Olivier Andrieu.
   Code licensed under GNU Library Public License (LGPL).

   Implements Discrete Interval Encoding Tree (diet), a datastrucure 
   for sets of integers.

   Reference:
     Martin Erwig.
     Diets for Fat Sets.
     Journal of Functional Programming, Vol. 8, No. 6, 627-632, 1998
     http://www.cs.orst.edu/~erwig/papers/Diet_JFP98.pdf
     http://www.cs.orst.edu/~erwig/diet/
*)

type t = | Empty | Node of int * int * t * t

let empty = Empty

let rec mem t v = 
  match t with
    | Empty -> false
    | Node (x, _, l, _) when v < x ->
	mem l v
    | Node (_, y, _, r) when y < v ->
	mem r v
    | _ -> true

let rec split_max = function
  | Node (x, y, l, Empty) ->
      (x, y, l)
  | Node (x, y, l, r) ->
      let (u, v, r') = split_max r in
	(u, v, Node (x, y, l, r'))

let join_left = function
  | Node (_, _, Empty, _) as t -> t
  | Node (x, y, l, r) as t ->
      let (x', y', l') = split_max l in
	if succ y' = x
	then Node (x', y, l', r)
	else t

let rec split_min = function
  | Node (x, y, Empty, r) ->
      (x, y, r)
  | Node (x, y, l, r) ->
      let (u, v, l') = split_min l in
	(u, v, Node (x, y, l', r))

let join_right = function
  | Node (_, _, _, Empty) as t -> t
  | Node (x, y, l, r) as t ->
      let (x', y', r') = split_min r in
	if succ y = x'
	then Node (x, y', l, r')
	else t   

let rec add t v =
  match t with
    | Empty -> 
	Node (v, v, Empty, Empty)
    | Node (x, y, l, r) when v < x ->
	if succ v = x 
	then join_left (Node (v, y, l, r))
	else Node (x, y, add l v, r)
    | Node (x, y, l, r) when y < v ->
	if succ y = v
	then join_right (Node (x, v, l, r))
	else Node (x, y, l, add r v)
    | t -> t

let merge = function
  | l, Empty -> l
  | Empty, r -> r
  | (l, r) ->
      let (u, v, l') = split_max l in
	Node (u, v, l', r)

let rec remove t v = 
  match t with
    | Empty -> Empty
    | Node (x, y, l, r) when v < x ->
	Node (x, y, remove l v, r)
    | Node (x, y, l, r) when y < v ->
	Node (x, y, l, remove r v)
    | Node (x, y, l, r) when v=x ->
	if x=y
	then merge (l, r)
	else Node (succ x, y, l, r)
    | Node (x, y, l, r) when v=y ->
	Node (x, pred y, l, r)
    | Node (x, y, l, r) ->
	Node (x, pred v, l, Node (succ v, y, Empty, r))
	  
let rec iter f = function
  | Empty -> ()
  | Node (x, y, l, r) ->
      iter f l ;
      for i=x to y do f i done ;
      iter f r

let rec fold f i = function
  | Empty -> i
  | Node (x, y, l, r) ->
      let il = fold f i l in
      let tmp = ref il in
	for j=x to y do
	  tmp := f !tmp j 
	done ;
	fold f !tmp r



(* ================================================== *)
(*    FUNCTORIAL IMPLEMENTATION                       *)
(* ================================================== *)

module type ORD =
  sig
    type t
    val compare : t -> t -> int
    val pred : t -> t
    val succ : t -> t
  end

module type DIET = 
  sig
    type elt
    type t
    val empty  : t
    val mem    : t -> elt -> bool
    val add    : t -> elt -> t
    val remove : t -> elt -> t
    val iter   : (elt -> unit) -> t -> unit
    val fold   : ('a -> elt -> 'a) -> 'a -> t -> 'a
  end

module Make =
  functor (O : ORD) ->
  struct
    type elt = O.t
    type t = | Empty | Node of elt * elt * t * t

    let empty = Empty

    let rec mem t v = 
      match t with
      | Empty -> false
      | Node (x, _, l, _) when O.compare v x < 0 ->
	  mem l v
      | Node (_, y, _, r) when O.compare v y > 0 ->
	  mem r v
      | _ -> true

    let rec split_max = function
      | Node (x, y, l, Empty) ->
	  (x, y, l)
      | Node (x, y, l, r) ->
	  let (u, v, r') = split_max r in
	  (u, v, Node (x, y, l, r'))

    let join_left = function
      | Node (_, _, Empty, _) as t -> t
      | Node (x, y, l, r) as t ->
	  let (x', y', l') = split_max l in
	  if O.succ y' = x
	  then Node (x', y, l', r)
	  else t

    let rec split_min = function
      | Node (x, y, Empty, r) ->
	  (x, y, r)
      | Node (x, y, l, r) ->
	  let (u, v, l') = split_min l in
	  (u, v, Node (x, y, l', r))

    let join_right = function
      | Node (_, _, _, Empty) as t -> t
      | Node (x, y, l, r) as t ->
	  let (x', y', r') = split_min r in
	  if O.succ y = x'
	  then Node (x, y', l, r')
	  else t   

    let rec add t v =
      match t with
      | Empty -> 
	  Node (v, v, Empty, Empty)
      | Node (x, y, l, r) when O.compare v x < 0 ->
	  if O.succ v = x 
	  then join_left (Node (v, y, l, r))
	  else Node (x, y, add l v, r)
      | Node (x, y, l, r) when O.compare v y > 0 ->
	  if O.succ y = v
	  then join_right (Node (x, v, l, r))
	  else Node (x, y, l, add r v)
      | t -> t

    let merge = function
      | (l, Empty) -> l
      | (Empty, r) -> r
      | (l, r) ->
	  let (u, v, l') = split_max l in
	  Node (u, v, l', r)

    let rec remove t v = 
      match t with
      | Empty -> Empty
      | Node (x, y, l, r) when O.compare v x < 0 ->
	  Node (x, y, remove l v, r)
      | Node (x, y, l, r) when O.compare v y > 0 ->
	  Node (x, y, l, remove r v)
      | Node (x, y, l, r) when O.compare v x = 0 ->
	  if x=y
	  then merge (l, r)
	  else Node (O.succ x, y, l, r)
      | Node (x, y, l, r) when O.compare v y = 0 ->
	  Node (x, O.pred y, l, r)
      | Node (x, y, l, r) ->
	  Node (x, O.pred v, l, Node (O.succ v, y, Empty, r))
	  
    let rec iter f = function
      | Empty -> ()
      | Node (x, y, l, r) ->
	  iter f l ;
	  let i = ref x in
	  while O.compare !i y <> 0
	  do 
	    f !i ; 
	    i := O.succ !i
	  done ;
	  iter f r

    let rec fold f i = function
      | Empty -> i
      | Node (x, y, l, r) ->
	  let il = fold f i l in
	  let tmp = ref il in
	  let j = ref x in
	  while O.compare !j y <> 0
	  do 
	    tmp := f !tmp !j ;
	    j := O.succ !j
	  done ;
	  fold f !tmp r
  end