{-# LANGUAGE RankNTypes, TypeFamilies, FlexibleContexts, ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Pipes.Safe
(
SafeT
, runSafeT
, runSafeP
, ReleaseKey
, MonadSafe(..)
, onException
, finally
, bracket
, bracket_
, bracketOnError
, module Control.Monad.Catch
, module Control.Exception
) where
import Control.Applicative (Applicative(pure, (<*>)))
import Control.Exception(Exception(..), SomeException(..))
import qualified Control.Monad.Catch as C
import Control.Monad.Catch
( MonadCatch(..)
, MonadThrow(..)
, MonadMask(..)
, mask_
, uninterruptibleMask_
, catchAll
, catchIOError
, catchJust
, catchIf
, Handler(..)
, catches
, handle
, handleAll
, handleIOError
, handleJust
, handleIf
, try
, tryJust
, Exception(..)
, SomeException
)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Trans.Class (MonadTrans(lift))
import qualified Control.Monad.Catch.Pure as E
import qualified Control.Monad.Trans.Identity as I
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.RWS.Lazy as RWS
import qualified Control.Monad.Trans.RWS.Strict as RWS'
import qualified Control.Monad.Trans.State.Lazy as S
import qualified Control.Monad.Trans.State.Strict as S'
import qualified Control.Monad.Trans.Writer.Lazy as W
import qualified Control.Monad.Trans.Writer.Strict as W'
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.Map as M
import Data.Monoid (Monoid)
import Pipes (Proxy, Effect, Effect', runEffect)
import Pipes.Internal (unsafeHoist, Proxy(..))
import Pipes.Lift (liftCatchError)
data Restore m = Unmasked | Masked (forall x . m x -> m x)
liftMask
:: forall m a' a b' b r . (MonadIO m, MonadCatch m)
=> (forall s . ((forall x . m x -> m x) -> m s) -> m s)
-> ((forall x . Proxy a' a b' b m x -> Proxy a' a b' b m x)
-> Proxy a' a b' b m r)
-> Proxy a' a b' b m r
liftMask maskFunction k = do
ioref <- liftIO (newIORef Unmasked)
let unmask
:: forall y . (Monad m)
=> Proxy a' a b' b m y -> Proxy a' a b' b m y
unmask p = do
mRestore <- liftIO (readIORef ioref)
case mRestore of
Unmasked -> p
Masked restore -> do
r <- unsafeHoist restore p
lift $ restore $ return ()
return r
loop p = case p of
Request a' fa -> Request a' (loop . fa )
Respond b fb' -> Respond b (loop . fb')
M m0 -> M $ maskFunction $ \restore -> do
liftIO $ writeIORef ioref (Masked restore)
let loop' m = do
p' <- m
case p' of
M m' -> loop' m'
_ -> return p'
p' <- loop' m0
liftIO $ writeIORef ioref Unmasked
return (loop p')
Pure r -> Pure r
loop (k unmask)
instance (MonadThrow m) => MonadThrow (Proxy a' a b' b m) where
throwM = lift . throwM
instance (MonadCatch m) => MonadCatch (Proxy a' a b' b m) where
catch = liftCatchError C.catch
instance (MonadMask m, MonadIO m) => MonadMask (Proxy a' a b' b m) where
mask = liftMask mask
uninterruptibleMask = liftMask uninterruptibleMask
data Finalizers m = Finalizers
{ _nextKey :: !Integer
, _finalizers :: !(M.Map Integer (m ()))
}
newtype SafeT m r = SafeT { unSafeT :: R.ReaderT (IORef (Finalizers m)) m r }
instance (Monad m) => Functor (SafeT m) where
fmap f m = SafeT (do
r <- unSafeT m
return (f r) )
instance (Monad m) => Applicative (SafeT m) where
pure r = SafeT (return r)
mf <*> mx = SafeT (do
f <- unSafeT mf
x <- unSafeT mx
return (f x) )
instance (Monad m) => Monad (SafeT m) where
return r = SafeT (return r)
m >>= f = SafeT (do
r <- unSafeT m
unSafeT (f r) )
instance (MonadIO m) => MonadIO (SafeT m) where
liftIO m = SafeT (liftIO m)
instance MonadThrow m => MonadThrow (SafeT m) where
throwM e = SafeT (throwM e)
instance (MonadCatch m) => MonadCatch (SafeT m) where
m `catch` f = SafeT (unSafeT m `C.catch` \r -> unSafeT (f r))
instance (MonadMask m) => MonadMask (SafeT m) where
mask k = SafeT (mask (\restore ->
unSafeT (k (\ma -> SafeT (restore (unSafeT ma)))) ))
uninterruptibleMask k = SafeT (uninterruptibleMask (\restore ->
unSafeT (k (\ma -> SafeT (restore (unSafeT ma)))) ))
instance MonadTrans SafeT where
lift m = SafeT (lift m)
runSafeT :: (MonadMask m, MonadIO m) => SafeT m r -> m r
runSafeT m = C.bracket
(liftIO $ newIORef $! Finalizers 0 M.empty)
(\ioref -> do
Finalizers _ fs <- liftIO (readIORef ioref)
mapM snd (M.toDescList fs) )
(R.runReaderT (unSafeT m))
runSafeP :: (MonadMask m, MonadIO m) => Effect (SafeT m) r -> Effect' m r
runSafeP = lift . runSafeT . runEffect
newtype ReleaseKey = ReleaseKey { unlock :: Integer }
class (MonadCatch m, MonadMask m, MonadIO m, MonadIO (Base m)) => MonadSafe m where
type Base (m :: * -> *) :: * -> *
liftBase :: Base m r -> m r
register :: Base m () -> m ReleaseKey
release :: ReleaseKey -> m ()
instance (MonadIO m, MonadCatch m, MonadMask m) => MonadSafe (SafeT m) where
type Base (SafeT m) = m
liftBase = lift
register io = do
ioref <- SafeT R.ask
liftIO $ do
Finalizers n fs <- readIORef ioref
writeIORef ioref $! Finalizers (n + 1) (M.insert n io fs)
return (ReleaseKey n)
release key = do
ioref <- SafeT R.ask
liftIO $ do
Finalizers n fs <- readIORef ioref
writeIORef ioref $! Finalizers n (M.delete (unlock key) fs)
instance (MonadSafe m) => MonadSafe (Proxy a' a b' b m) where
type Base (Proxy a' a b' b m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (I.IdentityT m) where
type Base (I.IdentityT m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (E.CatchT m) where
type Base (E.CatchT m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (R.ReaderT i m) where
type Base (R.ReaderT i m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (S.StateT s m) where
type Base (S.StateT s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m) => MonadSafe (S'.StateT s m) where
type Base (S'.StateT s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W.WriterT w m) where
type Base (W.WriterT w m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W'.WriterT w m) where
type Base (W'.WriterT w m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS.RWST i w s m) where
type Base (RWS.RWST i w s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS'.RWST i w s m) where
type Base (RWS'.RWST i w s m) = Base m
liftBase = lift . liftBase
register = lift . register
release = lift . release
onException :: (MonadSafe m) => m a -> Base m b -> m a
m1 `onException` io = do
key <- register (io >> return ())
r <- m1
release key
return r
finally :: (MonadSafe m) => m a -> Base m b -> m a
m1 `finally` after = bracket_ (return ()) after m1
bracket :: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracket before after action = mask $ \restore -> do
h <- liftBase before
r <- restore (action h) `onException` after h
_ <- liftBase (after h)
return r
bracket_ :: (MonadSafe m) => Base m a -> Base m b -> m c -> m c
bracket_ before after action = bracket before (\_ -> after) (\_ -> action)
bracketOnError
:: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracketOnError before after action = mask $ \restore -> do
h <- liftBase before
restore (action h) `onException` after h