# equivalence / src / Data / UnionFind / STT.hs

 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166``` ```-- | An implementation of Tarjan's UNION-FIND algorithm. (Robert E -- Tarjan. \"Efficiency of a Good But Not Linear Set Union Algorithm\", JACM -- 22(2), 1975) -- -- The algorithm implements three operations efficiently (all amortised -- @O(1)@): -- -- 1. Check whether two elements are in the same equivalence class. -- -- 2. Create a union of two equivalence classes. -- -- 3. Look up the descriptor of the equivalence class. -- -- The implementation is based on mutable references. Each -- equivalence class has exactly one member that serves as its -- representative element. Every element either is the representative -- element of its equivalence class or points to another element in -- the same equivalence class. Equivalence testing thus consists of -- following the pointers to the representative elements and then -- comparing these for identity. -- -- The algorithm performs lazy path compression. That is, whenever we -- walk along a path greater than length 1 we automatically update the -- pointers along the path to directly point to the representative -- element. Consequently future lookups will be have a path length of -- at most 1. -- module Data.UnionFind.STT ( emptyPartition , equate , equivalent , equivalenceClass , Partition ) where import Control.Monad.ST.Trans import Control.Monad import Data.Map (Map) import qualified Data.Map as Map newtype Entry s a = Entry (STRef s (EntryData s a)) deriving (Eq) data EntryData s a = EntryData { entryParent :: Maybe (Entry s a), entryClass :: [a], entryWeight :: Int, entryValue :: a } data Partition s a = Partition { entries :: STRef s (Map a (Entry s a)) } modifySTRef :: (Monad m) => STRef s a -> (a -> a) -> STT s m () modifySTRef r f = readSTRef r >>= (writeSTRef r . f) emptyPartition :: Monad m => STT s m (Partition s a) emptyPartition = liftM Partition \$ newSTRef Map.empty -- | /O(1)/. @repr point@ returns the representative point of -- @point@'s equivalence class or @Nothing\$ if it itself is the -- representative of its class. -- -- This method performs the path compresssion. representative' :: Monad m => Entry s a -> STT s m (Maybe (Entry s a)) representative' (Entry e) = do ed <- readSTRef e case entryParent ed of Nothing -> return Nothing Just parent -> do mparent' <- representative' parent case mparent' of Nothing -> return \$ Just parent Just parent' -> writeSTRef e ed{entryParent = Just parent'} >> return (Just parent') -- | /O(1)/. @repr point@ returns the representative point of -- @point@'s equivalence class. -- -- This method performs the path compresssion. representative :: Monad m => Entry s a -> STT s m (Entry s a) representative entry = do mrepr <- representative' entry case mrepr of Nothing -> return entry Just repr -> return repr getEntry' :: (Monad m, Ord a) => Partition s a -> a -> STT s m (Entry s a) getEntry' (Partition mref) val = do m <- readSTRef mref case Map.lookup val m of Nothing -> do e <- newSTRef EntryData { entryParent = Nothing, entryClass = [val], entryWeight = 1, entryValue = val } let entry = Entry e writeSTRef mref (Map.insert val entry m) return entry Just entry -> return entry getEntry :: (Monad m, Ord a) => Partition s a -> a -> STT s m (Maybe (Entry s a)) getEntry (Partition mref) val = do m <- readSTRef mref case Map.lookup val m of Nothing -> return Nothing Just entry -> return \$ Just entry equate :: (Monad m, Ord a) => Partition s a -> a -> a -> STT s m () equate part x y = do ex <- getEntry' part x ey <- getEntry' part y equate' ex ey equate' :: (Monad m, Ord a) => Entry s a -> Entry s a -> STT s m () equate' x y = do repx@(Entry rx) <- representative x repy@(Entry ry) <- representative y when (rx /= ry) \$ do dx@EntryData{entryWeight = wx, entryClass = chx} <- readSTRef rx dy@EntryData{entryWeight = wy, entryClass = chy} <- readSTRef ry if wx >= wy then do writeSTRef ry dy{entryParent = Just repx} writeSTRef rx dx{entryWeight = wx + wy, entryClass = chx ++ chy} else do writeSTRef rx dx{entryParent = Just repy} writeSTRef ry dy{entryWeight = wx + wy, entryClass = chx ++ chy} equivalenceClass :: (Monad m, Ord a) => Partition s a -> a -> STT s m [a] equivalenceClass p val = do mentry <- getEntry p val case mentry of Nothing -> return [val] Just entry -> equivalenceClass' entry equivalenceClass' :: (Monad m) => Entry s a -> STT s m [a] equivalenceClass' entry = do Entry e <- representative entry ed <- readSTRef e return \$ entryClass ed -- | /O(1)/. Return @True@ if both points belong to the same -- | equivalence class. equivalent :: (Monad m, Ord a) => Partition s a -> a -> a -> STT s m Bool equivalent p v1 v2 = do me1 <- getEntry p v1 me2 <- getEntry p v2 case (me1,me2) of (Just e1, Just e2) -> equivalent' e1 e2 (Nothing, Nothing) -> return \$ v1 == v2 _ -> return False equivalent' :: (Monad m, Ord a) => Entry s a -> Entry s a -> STT s m Bool equivalent' e1 e2 = liftM2 (==) (representative e1) (representative e2) ```