Source

iotransaction / IoTransaction.hs

{-# LANGUAGE DeriveDataTypeable #-}

{--|
This module provides facilities for building IO actions in such a way that, if one IO action in a sequence
throws an exception, the effects of previous actions will be undone.
--}
module IoTransaction(runUndoable, execUndoable, makeUndoable, doAction, addUndoer, rollback, UndoableIO) where

import qualified System.Posix.Directory as Dir
import qualified System.Posix.Files as Files
import qualified System.Posix.Types as PTypes
import qualified System.FilePath.Posix as PPath
import qualified System.IO as SysIo
import qualified Control.Exception as C
import Data.Typeable.Internal

data UndoableM m a = Do (m (a, m ()))

runUndoableM :: UndoableM m a -> m (a, m ())
runUndoableM (Do op) = op

execUndoableM :: Monad m => UndoableM m a -> m a
execUndoableM u = do (val, undo) <- runUndoableM u
                     return val

class Monad m => ExceptionalMonad m where
    throwM :: C.Exception e => e -> m a
    catchM :: C.Exception e => m a -> (e -> m a) -> m a

instance ExceptionalMonad IO where
    throwM = C.throwIO -- C.throwIO should be used when throwing inside the IO monad
    catchM = C.catch

instance ExceptionalMonad m => Monad (UndoableM m) where
    {-
    We combine two UIOs by making a new one that executes the
    first one and then tries to execute the second; if the first fails,
    the UIO fails; if the second fails, the UIO undoes the first and
    then fails.  The new UIO returns the value retuned by the second
    along with an undo IO that is a combination of the undo IOs of the
    first and second UIOs.
    -}
    -- (>>=) :: UIO a -> (a -> UIO b) -> UIO b
    Do op >>= f  = Do $ do (val, undo)   <- op
                           (val', undo') <- runUndoableM (f val) `catchM` (\e -> undo >> throwM (e :: C.SomeException))
                           return (val', (undo >> undo'))

    return val = Do $ return (val, return ())

makeUndoableM :: ExceptionalMonad m => m a -> m () -> UndoableM m a
makeUndoableM op undo = Do $ do result <- op
                                return (result, undo)

doActionM :: ExceptionalMonad m => m a -> UndoableM m a
doActionM action = makeUndoableM action (return ())

addUndoerM :: ExceptionalMonad m => m () -> UndoableM m ()
addUndoerM undo = makeUndoableM (return ()) undo

data ManualUndo = ManualUndo
    deriving (Show, Typeable)

instance C.Exception ManualUndo

rollbackM :: ExceptionalMonad m => UndoableM m ()
rollbackM = makeUndoableM (throwM ManualUndo) (return ()) -- yes, the handler in the ExceptionalMonad instance decl above will catch this exception

{-|
An "undoable action" is a wrapper for an IO action (the "doer") that combines it with another
IO action (the "undoer") that undoes the effects of the first one.

Undoable actions are monads, and when sequenced together they act like transactions involving IO operations.
As undoable actions are sequenced together, their doers are also sequenced together and their undoers
are combined into a stack.  When the doers are executed, if one of them throws an exception, the undoers
so far added to the stack are executed in reverse order, then the exception is rethrown, and
no other doers (or undoers) are executed.  If no exception is thrown, none of the undoers are executed.
-}
type UndoableIO a = UndoableM IO a

runUndoable :: UndoableIO a -> IO (a, IO ())
runUndoable op = runUndoableM op

execUndoable :: UndoableIO a -> IO a
execUndoable op = execUndoableM op

-- | Make an undoable action.
makeUndoable :: IO a -- ^ The "doer": the action to perform.
             -> IO () -- ^ The "undoer": an action that undoes the effect of the other one.
             -> UndoableIO a
makeUndoable doer undoer = makeUndoableM doer undoer

{--|
Make an undoable action without any undoer.

This undoable action will not add any undoer to the undoer stack.
--}
doAction :: IO a -- ^ The "doer": the action to perform.
         -> UndoableIO a
doAction doer = doActionM doer

-- | Add an undoer to the undoer stack.
addUndoer :: IO () -- ^ The "undoer": an action that will be added to the undoer stack.
          -> UndoableIO ()
addUndoer undoer = addUndoerM undoer

-- | Stop execution, run the actions on the undoer stack, and throw an exception.
rollback :: UndoableIO ()
rollback = rollbackM

{-
UndoableIO satisfies first two monad laws.  Don't know about third one.

Monad Laws:
1. return x >>= f    ==  f x
2. mv >>= return     ==  mv
3. (mv >>= f) >>= g  ==  mv >>= (\x -> (f x >>= g))

let gK = \(v, u) -> runUio (g v) `C.catch` (\e -> u >> ioError e)
let fK = \(v, u) -> runUio (f v) `C.catch` (\e -> u >> ioError e)
let J = \u -> \(val', undoIo') -> return (val', (u >> undoIo'))

// (Do io >>= f) >>= g  ==  Do io >>= (\x -> (f x >>= g))
\x -> (f x >>= g) == \x -> Do $ runUio (f x) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo

(Do io >>= f) >>= g == Do io' >>= g
    where io' = io >>= \(val, undoIo) -> fK (val, undoIo) >>= J undoIo
          
Do io' >>= g == Do $ (io >>= \(val, undoIo) -> fK (val, undoIo) >>= J undoIo) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo

Do io >>= (\x -> (f x >>= g)) == Do io >>= (\x -> Do $ runUio (f x) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo)
    == Do $ do (v, u) <- io
               (v', u') <- runUio (Do $ runUio (f v) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo) `C.catch` (\e -> undoIo >> ioError e)
               return (v', (u >> u'))
    == Do $ io >>= \(v, u) -> 
-}