(*
Copyright 1992-1996 Stephen Adams.
This software may be used freely provided that:
1. This copyright notice is attached to any copy, derived work,
or work including all or part of this software.
2. Any derived work must contain a prominent notice stating that
it has been altered from the original.
*)
(* Address: Electronics & Computer Science
University of Southampton
Southampton SO9 5NH
Great Britian
E-mail: sra@ecs.soton.ac.uk
Comments:
1. The implementation is based on Binary search trees of Bounded
Balance, similar to Nievergelt & Reingold, SIAM J. Computing
2(1), March 1973. The main advantage of these trees is that
they keep the size of the tree in the node, giving a constant
time size operation.
2. The bounded balance criterion is simpler than N&R's alpha.
Simply, one subtree must not have more than `weight' times as
many elements as the opposite subtree. Rebalancing is
guaranteed to reinstate the criterion for weight>2.23, but
the occasional incorrect behaviour for weight=2 is not
detrimental to performance.
3. There are two implementations of union. The default,
hedge_union, is much more complex and usually 20% faster. I
am not sure that the performance increase warrants the
complexity (and time it took to write), but I am leaving it
in for the competition. It is derived from the original
union by replacing the split_lt(gt) operations with a lazy
version. The `obvious' version is called old_union.
*)
structure B (*: INTSET*) =
struct
local
type T = int
val lt : T*T->bool = op <
(* weight is a parameter to the rebalancing process. *)
val weight:int = 3
datatype Set = E | T of T * int * Set * Set
fun size E = 0
| size (T(_,n,_,_)) = n
(*fun N(v,l,r) = T(v,1+size(l)+size(r),l,r)*)
fun N(v,E, E) = T(v,1,E,E)
| N(v,E, r as T(_,n,_,_)) = T(v,n+1,E,r)
| N(v,l as T(_,n,_,_),E) = T(v,n+1,l,E)
| N(v,l as T(_,n,_,_),r as T(_,m,_,_)) = T(v,n+m+1,l,r)
fun single_L (a,x,T(b,_,y,z)) = N(b,N(a,x,y),z)
| single_L _ = raise Match
fun single_R (b,T(a,_,x,y),z) = N(a,x,N(b,y,z))
| single_R _ = raise Match
fun double_L (a,w,T(c,_,T(b,_,x,y),z)) = N(b,N(a,w,x),N(c,y,z))
| double_L _ = raise Match
fun double_R (c,T(a,_,w,T(b,_,x,y)),z) = N(b,N(a,w,x),N(c,y,z))
| double_R _ = raise Match
fun T' (v,E,E) = T(v,1,E,E)
| T' (v,E,r as T(_,_,E,E)) = T(v,2,E,r)
| T' (v,l as T(_,_,E,E),E) = T(v,2,l,E)
| T' (p as (_,E,T(_,_,T(_,_,_,_),E))) = double_L p
| T' (p as (_,T(_,_,E,T(_,_,_,_)),E)) = double_R p
(* these cases almost never happen with small weight*)
| T' (p as (_,E,T(_,_,T(_,ln,_,_),T(_,rn,_,_)))) =
if lnrn then single_R p else double_R p
| T' (p as (_,E,T(_,_,E,_))) = single_L p
| T' (p as (_,T(_,_,_,E),E)) = single_R p
| T' (p as (v,l as T(lv,ln,ll,lr),r as T(rv,rn,rl,rr))) =
if rn>=weight*ln then (*right is too big*)
let val rln = size rl
val rrn = size rr
in
if rln < rrn then single_L p else double_L p
end
else if ln>=weight*rn then (*left is too big*)
let val lln = size ll
val lrn = size lr
in
if lrn < lln then single_R p else double_R p
end
else
T(v,ln+rn+1,l,r)
fun add (E,x) = T(x,1,E,E)
| add (set as T(v,_,l,r),x) =
if lt(x,v) then T'(v,add(l,x),r)
else if lt(v,x) then T'(v,l,add(r,x))
else set
fun concat3 (E,v,r) = add(r,v)
| concat3 (l,v,E) = add(l,v)
| concat3 (l as T(v1,n1,l1,r1), v, r as T(v2,n2,l2,r2)) =
if weight*n1 < n2 then T'(v2,concat3(l,v,l2),r2)
else if weight*n2 < n1 then T'(v1,l1,concat3(r1,v,r))
else N(v,l,r)
fun split_lt (E,x) = E
| split_lt (t as T(v,_,l,r),x) =
if lt(x,v) then split_lt(l,x)
else if lt(v,x) then concat3(l,v,split_lt(r,x))
else l
fun split_gt (E,x) = E
| split_gt (t as T(v,_,l,r),x) =
if lt(v,x) then split_gt(r,x)
else if lt(x,v) then concat3(split_gt(l,x),v,r)
else r
fun min (T(v,_,E,_)) = v
| min (T(v,_,l,_)) = min l
| min _ = raise Match
and delete' (E,r) = r
| delete' (l,E) = l
| delete' (l,r) = let val min_elt = min r in
T'(min_elt,l,delmin r)
end
and delmin (T(_,_,E,r)) = r
| delmin (T(v,_,l,r)) = T'(v,delmin l,r)
| delmin _ = raise Match
fun concat (E, s2) = s2
| concat (s1, E) = s1
| concat (t1 as T(v1,n1,l1,r1), t2 as T(v2,n2,l2,r2)) =
if weight*n1 < n2 then T'(v2,concat(t1,l2),r2)
else if weight*n2 < n1 then T'(v1,l1,concat(r1,t2))
else T'(min t2,t1, delmin t2)
fun fold(f,base,set) =
let fun fold'(base,E) = base
| fold'(base,T(v,_,l,r)) = fold'(f(v,fold'(base,r)),l)
in
fold'(base,set)
end
in
val empty = E
fun singleton x = T(x,1,E,E)
local
fun trim (lo,hi,E) = E
| trim (lo,hi,s as T(v,_,l,r)) =
if lt(lo,v) then
if lt(v,hi) then s
else trim(lo,hi,l)
else trim(lo,hi,r)
fun uni_bd (s,E,lo,hi) = s
| uni_bd (E,T(v,_,l,r),lo,hi) =
concat3(split_gt(l,lo),v,split_lt(r,hi))
| uni_bd (T(v,_,l1,r1), s2 as T(v2,_,l2,r2),lo,hi) =
concat3(uni_bd(l1,trim(lo,v,s2),lo,v),
v,
uni_bd(r1,trim(v,hi,s2),v,hi))
(* inv: lo < v < hi *)
(*all the other versions of uni and trim are
specializations of the above two functions with
lo=-infinity and/or hi=+infinity *)
fun trim_lo (_ ,E) = E
| trim_lo (lo,s as T(v,_,_,r)) =
if lt(lo,v) then s else trim_lo(lo,r)
fun trim_hi (_ ,E) = E
| trim_hi (hi,s as T(v,_,l,_)) =
if lt(v,hi) then s else trim_hi(hi,l)
fun uni_hi (s,E,hi) = s
| uni_hi (E,T(v,_,l,r),hi) =
concat3(l,v,split_lt(r,hi))
| uni_hi (T(v,_,l1,r1), s2 as T(v2,_,l2,r2),hi) =
concat3(uni_hi(l1,trim_hi(v,s2),v),
v,
uni_bd(r1,trim(v,hi,s2),v,hi))
fun uni_lo (s,E,lo) = s
| uni_lo (E,T(v,_,l,r),lo) =
concat3(split_gt(l,lo),v,r)
| uni_lo (T(v,_,l1,r1), s2 as T(v2,_,l2,r2),lo) =
concat3(uni_bd(l1,trim(lo,v,s2),lo,v),
v,
uni_lo(r1,trim_lo(v,s2),v))
fun uni (s,E) = s
| uni (E,s as T(v,_,l,r)) = s
| uni (T(v,_,l1,r1), s2 as T(v2,_,l2,r2)) =
concat3(uni_hi(l1,trim_hi(v,s2),v),
v,
uni_lo(r1,trim_lo(v,s2),v))
in
val hedge_union = uni
end
fun old_union (E,s2) = s2
| old_union (s1,E) = s1
| old_union (s1 as T(v,_,l,r),s2) =
let val l2 = split_lt(s2,v)
val r2 = split_gt(s2,v)
in
concat3(old_union(l,l2),v,old_union(r,r2))
end
(* The old_union version is about 20% slower than
hedge_union in most cases *)
val union = hedge_union
(*val union = old_union*)
val add = add
fun difference (E,s) = E
| difference (s,E) = s
| difference (s, T(v,_,l,r)) =
let val l2 = split_lt(s,v)
val r2 = split_gt(s,v)
in
concat(difference(l2,l),difference(r2,r))
end
fun member (x,set) =
let fun mem E = false
| mem (T(v,_,l,r)) =
if lt(x,v) then mem l else if lt(v,x) then mem r else true
in mem set end
(*fun intersection (a,b) = difference(a,difference(a,b))*)
fun intersection (E,_) = E
| intersection (_,E) = E
| intersection (s, T(v,_,l,r)) =
let val l2 = split_lt(s,v)
val r2 = split_gt(s,v)
in
if member(v,s) then
concat3(intersection(l2,l),v,intersection(r2,r))
else
concat(intersection(l2,l),intersection(r2,r))
end
fun members set = fold(op::,[],set)
fun cardinality E = 0
| cardinality (T(_,n,_,_)) = n
fun delete (E,x) = E
| delete (set as T(v,_,l,r),x) =
if lt(x,v) then T'(v,delete(l,x),r)
else if lt(v,x) then T'(v,l,delete(r,x))
else delete'(l,r)
fun fromList l = List.fold (fn(x,y)=>add(y,x)) l E
type intset = Set
end
end
structure IntSet : INTSET =B;