{-# LANGUAGE CPP, RankNTypes #-}
module Control.Monad.Morph (
MFunctor(..),
generalize,
MMonad(..),
MonadTrans(lift),
squash,
(>|>),
(<|<),
(=<|),
(|>=)
) where
import Control.Monad.Trans.Class (MonadTrans(lift))
import qualified Control.Monad.Trans.Error as E
import qualified Control.Monad.Trans.Identity as I
import qualified Control.Monad.Trans.List as L
import qualified Control.Monad.Trans.Maybe as M
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.RWS.Lazy as RWS
import qualified Control.Monad.Trans.RWS.Strict as RWS'
import qualified Control.Monad.Trans.State.Lazy as S
import qualified Control.Monad.Trans.State.Strict as S'
import qualified Control.Monad.Trans.Writer.Lazy as W'
import qualified Control.Monad.Trans.Writer.Strict as W
import Data.Monoid (Monoid, mappend)
import Data.Functor.Compose (Compose (Compose))
import Data.Functor.Identity (runIdentity)
import Data.Functor.Product (Product (Pair))
#if MIN_VERSION_transformers(0,3,0)
import Control.Applicative.Backwards (Backwards (Backwards))
import Control.Applicative.Lift (Lift (Pure, Other))
#endif
import Control.Exception (try, IOException)
import Control.Monad ((=<<), (>=>), (<=<), join)
import Data.Functor.Identity (Identity)
class MFunctor t where
hoist :: (Monad m) => (forall a . m a -> n a) -> t m b -> t n b
instance MFunctor (E.ErrorT e) where
hoist nat m = E.ErrorT (nat (E.runErrorT m))
instance MFunctor I.IdentityT where
hoist nat m = I.IdentityT (nat (I.runIdentityT m))
instance MFunctor L.ListT where
hoist nat m = L.ListT (nat (L.runListT m))
instance MFunctor M.MaybeT where
hoist nat m = M.MaybeT (nat (M.runMaybeT m))
instance MFunctor (R.ReaderT r) where
hoist nat m = R.ReaderT (\i -> nat (R.runReaderT m i))
instance MFunctor (RWS.RWST r w s) where
hoist nat m = RWS.RWST (\r s -> nat (RWS.runRWST m r s))
instance MFunctor (RWS'.RWST r w s) where
hoist nat m = RWS'.RWST (\r s -> nat (RWS'.runRWST m r s))
instance MFunctor (S.StateT s) where
hoist nat m = S.StateT (\s -> nat (S.runStateT m s))
instance MFunctor (S'.StateT s) where
hoist nat m = S'.StateT (\s -> nat (S'.runStateT m s))
instance MFunctor (W.WriterT w) where
hoist nat m = W.WriterT (nat (W.runWriterT m))
instance MFunctor (W'.WriterT w) where
hoist nat m = W'.WriterT (nat (W'.runWriterT m))
instance Functor f => MFunctor (Compose f) where
hoist nat (Compose f) = Compose (fmap nat f)
instance MFunctor (Product f) where
hoist nat (Pair f g) = Pair f (nat g)
#if MIN_VERSION_transformers(0,3,0)
instance MFunctor Backwards where
hoist nat (Backwards f) = Backwards (nat f)
instance MFunctor Lift where
hoist _ (Pure a) = Pure a
hoist nat (Other f) = Other (nat f)
#endif
generalize :: Monad m => Identity a -> m a
generalize = return . runIdentity
class (MFunctor t, MonadTrans t) => MMonad t where
embed :: (Monad n) => (forall a . m a -> t n a) -> t m b -> t n b
squash :: (Monad m, MMonad t) => t (t m) a -> t m a
squash = embed id
infixr 2 >|>, =<|
infixl 2 <|<, |>=
(>|>)
:: (Monad m3, MMonad t)
=> (forall a . m1 a -> t m2 a)
-> (forall b . m2 b -> t m3 b)
-> m1 c -> t m3 c
(f >|> g) m = embed g (f m)
(<|<)
:: (Monad m3, MMonad t)
=> (forall b . m2 b -> t m3 b)
-> (forall a . m1 a -> t m2 a)
-> m1 c -> t m3 c
(g <|< f) m = embed g (f m)
(=<|) :: (Monad n, MMonad t) => (forall a . m a -> t n a) -> t m b -> t n b
(=<|) = embed
(|>=) :: (Monad n, MMonad t) => t m b -> (forall a . m a -> t n a) -> t n b
t |>= f = embed f t
instance (E.Error e) => MMonad (E.ErrorT e) where
embed f m = E.ErrorT (do
x <- E.runErrorT (f (E.runErrorT m))
return (case x of
Left e -> Left e
Right (Left e) -> Left e
Right (Right a) -> Right a ) )
instance MMonad I.IdentityT where
embed f m = f (I.runIdentityT m)
instance MMonad L.ListT where
embed f m = L.ListT (do
x <- L.runListT (f (L.runListT m))
return (concat x))
instance MMonad M.MaybeT where
embed f m = M.MaybeT (do
x <- M.runMaybeT (f (M.runMaybeT m))
return (case x of
Nothing -> Nothing
Just Nothing -> Nothing
Just (Just a) -> Just a ) )
instance MMonad (R.ReaderT r) where
embed f m = R.ReaderT (\i -> R.runReaderT (f (R.runReaderT m i)) i)
instance (Monoid w) => MMonad (W.WriterT w) where
embed f m = W.WriterT (do
~((a, w1), w2) <- W.runWriterT (f (W.runWriterT m))
return (a, mappend w1 w2) )
instance (Monoid w) => MMonad (W'.WriterT w) where
embed f m = W'.WriterT (do
((a, w1), w2) <- W'.runWriterT (f (W'.runWriterT m))
return (a, mappend w1 w2) )