{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE UndecidableInstances #-}
#ifndef MIN_VERSION_mtl
#define MIN_VERSION_mtl(x,y,z) 1
#endif
module Control.Monad.Trans.Free.Church
(
FT(..)
, F, free, runF
, toFT, fromFT
, iterT
, iterTM
, hoistFT
, transFT
, cutoff
, improve
, fromF, toF
, retract
, iter
, iterM
, MonadFree(..)
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
import Control.Monad.Cont.Class
import Control.Monad.Free.Class
import Control.Monad.Trans.Free (FreeT(..), FreeF(..), Free)
import qualified Control.Monad.Trans.Free as FreeT
import Data.Foldable (Foldable)
import qualified Data.Foldable as F
import Data.Traversable (Traversable)
import qualified Data.Traversable as T
import Data.Monoid
import Data.Functor.Bind hiding (join)
import Data.Function
newtype FT f m a = FT {runFT :: forall r. (a -> m r) -> (f (m r) -> m r) -> m r}
instance (Functor f, Monad m, Eq (FreeT f m a)) => Eq (FT f m a) where
(==) = (==) `on` fromFT
instance (Functor f, Monad m, Ord (FreeT f m a)) => Ord (FT f m a) where
compare = compare `on` fromFT
instance Functor (FT f m) where
fmap f (FT k) = FT $ \a fr -> k (a . f) fr
instance Apply (FT f m) where
(<.>) = (<*>)
instance Applicative (FT f m) where
pure a = FT $ \k _ -> k a
FT fk <*> FT ak = FT $ \b fr -> ak (\d -> fk (\e -> b (e d)) fr) fr
instance Bind (FT f m) where
(>>-) = (>>=)
instance Monad (FT f m) where
return = pure
FT fk >>= f = FT $ \b fr -> fk (\d -> runFT (f d) b fr) fr
instance (Functor f) => MonadFree f (FT f m) where
wrap f = FT (\kp kf -> kf (fmap (\(FT m) -> m kp kf) f))
instance MonadTrans (FT f) where
lift m = FT (\a _ -> m >>= a)
instance Alternative m => Alternative (FT f m) where
empty = FT (\_ _ -> empty)
FT k1 <|> FT k2 = FT $ \a fr -> k1 a fr <|> k2 a fr
instance MonadPlus m => MonadPlus (FT f m) where
mzero = FT (\_ _ -> mzero)
mplus (FT k1) (FT k2) = FT $ \a fr -> k1 a fr `mplus` k2 a fr
instance (Foldable f, Foldable m, Monad m) => Foldable (FT f m) where
foldMap f (FT k) = F.fold $ k (return . f) (F.foldr (liftM2 mappend) (return mempty))
instance (Monad m, Traversable m, Traversable f) => Traversable (FT f m) where
traverse f (FT k) = fmap (join . lift) . T.sequenceA $ k traversePure traverseFree
where
traversePure = return . fmap return . f
traverseFree = return . fmap (wrap . fmap (join . lift)) . T.sequenceA . fmap T.sequenceA
instance (MonadIO m) => MonadIO (FT f m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
instance (Functor f, MonadError e m) => MonadError e (FT f m) where
throwError = lift . throwError
{-# INLINE throwError #-}
m `catchError` f = toFT $ fromFT m `catchError` (fromFT . f)
instance (MonadCont m) => MonadCont (FT f m) where
callCC f = join . lift $ callCC (\k -> return $ f (lift . k . return))
instance (Functor f, MonadReader r m) => MonadReader r (FT f m) where
ask = lift ask
{-# INLINE ask #-}
local f = hoistFT (local f)
{-# INLINE local #-}
instance (Functor f, MonadWriter w m) => MonadWriter w (FT f m) where
tell = lift . tell
{-# INLINE tell #-}
listen = toFT . listen . fromFT
pass = toFT . pass . fromFT
#if MIN_VERSION_mtl(2,1,1)
writer w = lift (writer w)
{-# INLINE writer #-}
#endif
instance (Functor f, MonadState s m) => MonadState s (FT f m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
#if MIN_VERSION_mtl(2,1,1)
state f = lift (state f)
{-# INLINE state #-}
#endif
toFT :: (Monad m, Functor f) => FreeT f m a -> FT f m a
toFT (FreeT f) = FT $ \ka kfr -> do
freef <- f
case freef of
Pure a -> ka a
Free fb -> kfr $ fmap (($ kfr) . ($ ka) . runFT . toFT) fb
fromFT :: (Monad m, Functor f) => FT f m a -> FreeT f m a
fromFT (FT k) = FreeT $ k (return . Pure) (runFreeT . wrap . fmap FreeT)
type F f = FT f Identity
runF :: Functor f => F f a -> (forall r. (a -> r) -> (f r -> r) -> r)
runF (FT m) = \kp kf -> runIdentity $ m (return . kp) (return . kf . fmap runIdentity)
free :: Functor f => (forall r. (a -> r) -> (f r -> r) -> r) -> F f a
free f = FT (\kp kf -> return $ f (runIdentity . kp) (runIdentity . kf . fmap return))
iterT :: (Functor f, Monad m) => (f (m a) -> m a) -> FT f m a -> m a
iterT phi (FT m) = m return phi
{-# INLINE iterT #-}
iterTM :: (Functor f, Monad m, MonadTrans t, Monad (t m)) => (f (t m a) -> t m a) -> FT f m a -> t m a
iterTM f (FT m) = join . lift $ m (return . return) (return . f . fmap (join .lift))
hoistFT :: (Monad m, Monad n, Functor f) => (forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT phi (FT m) = FT (\kp kf -> join . phi $ m (return . kp) (return . kf . fmap (join . phi)))
transFT :: (Monad m, Functor g) => (forall a. f a -> g a) -> FT f m b -> FT g m b
transFT phi (FT m) = FT (\kp kf -> m kp (kf . phi))
cutoff :: (Functor f, Monad m) => Integer -> FT f m a -> FT f m (Maybe a)
cutoff n = toFT . FreeT.cutoff n . fromFT
retract :: (Functor f, Monad f) => F f a -> f a
retract m = runF m return join
{-# INLINE retract #-}
iter :: Functor f => (f a -> a) -> F f a -> a
iter phi = runIdentity . iterT (Identity . phi . fmap runIdentity)
{-# INLINE iter #-}
iterM :: (Functor f, Monad m) => (f (m a) -> m a) -> F f a -> m a
iterM phi = iterT phi . hoistFT (return . runIdentity)
fromF :: (Functor f, MonadFree f m) => F f a -> m a
fromF m = runF m return wrap
{-# INLINE fromF #-}
toF :: (Functor f) => Free f a -> F f a
toF = toFT
{-# INLINE toF #-}
improve :: Functor f => (forall m. MonadFree f m => m a) -> Free f a
improve m = fromF m
{-# INLINE improve #-}