Commits

Patrick Bahr committed 7a01644

generalised equivalenceClass function to a general equivalence class description function

  • Participants
  • Parent commits 8bbb400

Comments (0)

Files changed (2)

File src/Data/UnionFind/Monad.hs

      runPartitionT
      ) where
 
-import Data.UnionFind.STT hiding (equate, equivalent, equivalenceClass)
+import Data.UnionFind.STT hiding (equate, equivalent, classDesc)
 import qualified Data.UnionFind.STT  as S
 
  
 
 
 
-newtype PartitionT s v m a = PartitionT {unPartitionT :: ReaderT (Partition s v) (STT s m) a}
-type PartitionM s v = PartitionT s v Identity
+newtype PartitionT s c v m a = PartitionT {unPartitionT :: ReaderT (Partition s c v) (STT s m) a}
+type PartitionM s c v = PartitionT s c v Identity
 
-instance (Monad m) => Monad (PartitionT s v m) where
+instance (Monad m) => Monad (PartitionT s c v m) where
     PartitionT m >>= f = PartitionT (m >>= (unPartitionT . f))
     return = PartitionT . return
 
-instance MonadTrans (PartitionT s v) where
+instance MonadTrans (PartitionT s c v) where
     lift = PartitionT . lift . lift
 
-instance (MonadReader r m) => MonadReader r (PartitionT s v m) where
+instance (MonadReader r m) => MonadReader r (PartitionT s c v m) where
     ask = PartitionT $ lift ask
     local f (PartitionT (ReaderT m)) = PartitionT $ ReaderT $ (\ r -> local f (m r))
 
-instance (Monoid w, MonadWriter w m) => MonadWriter w (PartitionT s v m) where
+instance (Monoid w, MonadWriter w m) => MonadWriter w (PartitionT s c v m) where
     tell w = PartitionT $ tell w
     listen (PartitionT m) = PartitionT $ listen m
     pass (PartitionT m) = PartitionT $ pass m
 
-instance (MonadState st m) => MonadState st (PartitionT s v m) where
+instance (MonadState st m) => MonadState st (PartitionT s c v m) where
     get = PartitionT get
     put s = PartitionT $ put s
 
-instance (MonadError e m) => MonadError e (PartitionT s v m) where
+instance (MonadError e m) => MonadError e (PartitionT s c v m) where
     throwError e = lift $ throwError e
     catchError (PartitionT m) f = PartitionT $ catchError m (unPartitionT . f)
     
 
-runPartitionT :: (Monad m) => (forall s. PartitionT s v m a) -> m a
-runPartitionT m = runST $ do
-  p <- emptyPartition
+runPartitionT :: (Monad m) => (v -> c) -> (c -> c -> c) -> (forall s. PartitionT s c v m a) -> m a
+runPartitionT mk com m = runST $ do
+  p <- emptyPartition mk com
   (`runReaderT` p) $ unPartitionT m
 
 
-class (Monad m, Ord v) => MonadPartition v m | m -> v where
+class (Monad m, Ord v) => MonadPartition c v m | m -> v, m -> c where
     equivalent :: v -> v -> m Bool
-    equivalenceClass :: v -> m [v]
+    classDesc :: v -> m c
     equate :: v -> v -> m ()
 
-instance (Monad m, Ord v) => MonadPartition v (PartitionT s v m) where
+instance (Monad m, Ord v) => MonadPartition c v (PartitionT s c v m) where
     equivalent x y = PartitionT $ do
       part <- ask
       lift $ S.equivalent part x y
 
-    equivalenceClass x = PartitionT $ do
+    classDesc x = PartitionT $ do
       part <- ask
-      lift $ S.equivalenceClass part x
+      lift $ S.classDesc part x
            
     equate x y = PartitionT $ do
       part <- ask
       lift $ S.equate part x y
 
-instance (MonadPartition v m, MonadTrans t, Monad (t m)) => MonadPartition v (t m) where
+instance (MonadPartition c v m, MonadTrans t, Monad (t m)) => MonadPartition c v (t m) where
     equivalent x y = lift $ equivalent x y
-    equivalenceClass = lift . equivalenceClass
+    classDesc = lift . classDesc
     equate x y = lift $ equate x y

File src/Data/UnionFind/STT.hs

   ( emptyPartition
   , equate
   , equivalent
-  , equivalenceClass
+  , classDesc
   , Partition
   )
 where
 import Data.Map (Map)
 import qualified Data.Map as Map
 
-newtype Entry s a = Entry (STRef s (EntryData s a))
+newtype Entry s c a = Entry (STRef s (EntryData s c a))
     deriving (Eq)
 
-data EntryData s a = EntryData {
-      entryParent :: Maybe (Entry s a),
-      entryClass :: [a],
+data EntryData s c a = Node {
+      entryParent :: Entry s c a,
+      entryValue :: a
+    }
+                     | Root {
+      entryDesc :: c,
       entryWeight :: Int,
       entryValue :: a
     }
 
-data Partition s a = Partition {
-      entries :: STRef s (Map a (Entry s a))
+data Partition s c a = Partition {
+      entries :: STRef s (Map a (Entry s c a)),
+      singleDesc :: a -> c,
+      combDesc :: c -> c -> c
       }
 
 modifySTRef :: (Monad m) => STRef s a -> (a -> a) -> STT s m ()
 modifySTRef r f = readSTRef r >>= (writeSTRef r . f)
 
 
-emptyPartition :: Monad m => STT s m (Partition s a)
-emptyPartition = liftM Partition $ newSTRef Map.empty
+emptyPartition :: Monad m => (a -> c) -> (c -> c -> c) -> STT s m (Partition s c a)
+emptyPartition mk com = do 
+  es <- newSTRef Map.empty
+  return Partition {entries = es, singleDesc = mk, combDesc = com}
 
 
 -- | /O(1)/. @repr point@ returns the representative point of
 -- representative of its class.
 --
 -- This method performs the path compresssion.
-representative' :: Monad m => Entry s a -> STT s m (Maybe (Entry s a))
+representative' :: Monad m => Entry s c a -> STT s m (Maybe (Entry s c a))
 representative' (Entry e) = do
   ed <- readSTRef e
-  case entryParent ed of
-    Nothing -> return Nothing
-    Just parent -> do
+  case ed of
+    Root {} -> return Nothing
+    Node { entryParent = parent} -> do
       mparent' <- representative' parent
       case mparent' of
         Nothing -> return $ Just parent
-        Just parent' -> writeSTRef e ed{entryParent = Just parent'} >> return (Just parent')
+        Just parent' -> writeSTRef e ed{entryParent = parent'} >> return (Just parent')
 
 
 -- | /O(1)/. @repr point@ returns the representative point of
 -- @point@'s equivalence class.
 --
 -- This method performs the path compresssion.
-representative :: Monad m => Entry s a -> STT s m (Entry s a)
+representative :: Monad m => Entry s c a -> STT s m (Entry s c a)
 representative entry = do
   mrepr <- representative' entry
   case mrepr of
     Just repr -> return repr
 
 
-getEntry' :: (Monad m, Ord a) => Partition s a -> a -> STT s m (Entry s a)
-getEntry' (Partition mref) val = do
+getEntry' :: (Monad m, Ord a) => Partition s c a -> a -> STT s m (Entry s c a)
+getEntry' Partition {entries = mref, singleDesc = mkDesc} val = do
   m <- readSTRef mref
   case Map.lookup val m of
     Nothing -> do
-      e <- newSTRef EntryData
-            { entryParent = Nothing,
-              entryClass = [val],
+      e <- newSTRef Root
+            { entryDesc = mkDesc val,
               entryWeight = 1,
               entryValue = val
             }
     Just entry -> return entry
 
 
-getEntry :: (Monad m, Ord a) => Partition s a -> a -> STT s m (Maybe (Entry s a))
-getEntry (Partition mref) val = do
+getEntry :: (Monad m, Ord a) => Partition s c a -> a -> STT s m (Maybe (Entry s c a))
+getEntry Partition { entries = mref} val = do
   m <- readSTRef mref
   case Map.lookup val m of
     Nothing -> return Nothing
     Just entry -> return $ Just entry
 
-equate :: (Monad m, Ord a) => Partition s a -> a -> a -> STT s m ()
+equate :: (Monad m, Ord a) => Partition s c a -> a -> a -> STT s m ()
 equate part x y = do
   ex <- getEntry' part x
   ey <- getEntry' part  y
-  equate' ex ey
+  equate' part ex ey
 
-equate' :: (Monad m, Ord a) => Entry s a -> Entry s a -> STT s m ()
-equate' x y = do
+equate' :: (Monad m, Ord a) => Partition s c a -> Entry s c a -> Entry s c a -> STT s m ()
+equate' Partition {combDesc = mkDesc} x y = do
   repx@(Entry rx) <- representative x
   repy@(Entry ry) <- representative y
   when (rx /= ry) $ do
-    dx@EntryData{entryWeight = wx, entryClass = chx} <- readSTRef rx
-    dy@EntryData{entryWeight = wy, entryClass = chy} <- readSTRef ry
+    dx@Root{entryWeight = wx, entryDesc = chx, entryValue = vx} <- readSTRef rx
+    dy@Root{entryWeight = wy, entryDesc = chy, entryValue = vy} <- readSTRef ry
     if  wx >= wy
       then do
-        writeSTRef ry dy{entryParent = Just repx}
-        writeSTRef rx dx{entryWeight = wx + wy, entryClass = chx ++ chy}
+        writeSTRef ry Node {entryParent = repx, entryValue = vy}
+        writeSTRef rx dx{entryWeight = wx + wy, entryDesc = mkDesc chx chy}
       else do
-       writeSTRef rx dx{entryParent = Just repy}
-       writeSTRef ry dy{entryWeight = wx + wy, entryClass = chx ++ chy}
+       writeSTRef rx Node {entryParent = repy, entryValue = vx}
+       writeSTRef ry dy{entryWeight = wx + wy, entryDesc = mkDesc chx chy}
 
-equivalenceClass :: (Monad m, Ord a) => Partition s a -> a -> STT s m [a]
-equivalenceClass p val = do
+classDesc :: (Monad m, Ord a) => Partition s c a -> a -> STT s m c
+classDesc p val = do
   mentry <- getEntry p val
   case mentry of
-    Nothing -> return [val]
-    Just entry -> equivalenceClass' entry
+    Nothing -> return $ singleDesc p val
+    Just entry -> classDesc' entry
 
-equivalenceClass' :: (Monad m) => Entry s a -> STT s m [a]
-equivalenceClass' entry = do
+classDesc' :: (Monad m) => Entry s c a -> STT s m c
+classDesc' entry = do
   Entry e <- representative entry
-  ed <- readSTRef e
-  return $ entryClass ed
+  liftM entryDesc $ readSTRef e
 
 -- | /O(1)/. Return @True@ if both points belong to the same
 -- | equivalence class.
-equivalent :: (Monad m, Ord a) => Partition s a -> a -> a -> STT s m Bool
+equivalent :: (Monad m, Ord a) => Partition s c a -> a -> a -> STT s m Bool
 equivalent p v1 v2 = do
   me1 <- getEntry p v1
   me2 <- getEntry p v2
     _ -> return False
     
 
-equivalent' :: (Monad m, Ord a) => Entry s a -> Entry s a -> STT s m Bool
+equivalent' :: (Monad m, Ord a) => Entry s c a -> Entry s c a -> STT s m Bool
 equivalent' e1 e2 = liftM2 (==) (representative e1) (representative e2)