{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
#ifndef MIN_VERSION_base
#define MIN_VERSION_base(x,y,z) 1
#endif
module Control.Monad.Free.Church
  ( F(..)
  , improve
  , fromF
  , iterM
  , toF
  , retract
  , hoistF
  , MonadFree(..)
  , liftF
  ) where
import Control.Applicative
import Control.Monad as Monad
import Control.Monad.Fix
import Control.Monad.Free hiding (retract, iterM)
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.Cont.Class
import Control.Monad.Trans.Class
import Control.Monad.State.Class
import Data.Foldable
import Data.Functor.Bind
import Prelude hiding (foldr)
newtype F f a = F { runF :: forall r. (a -> r) -> (f r -> r) -> r }
iterM :: (Monad m, Functor f) => (f (m a) -> m a) -> F f a -> m a
iterM phi xs = runF xs return phi
instance Functor (F f) where
  fmap f (F g) = F (\kp -> g (kp . f))
instance Apply (F f) where
  (<.>) = (<*>)
instance Applicative (F f) where
  pure a = F (\kp _ -> kp a)
  F f <*> F g = F (\kp kf -> f (\a -> g (kp . a) kf) kf)
instance Alternative f => Alternative (F f) where
  empty = F (\_ kf -> kf empty)
  F f <|> F g = F (\kp kf -> kf (pure (f kp kf) <|> pure (g kp kf)))
instance Bind (F f) where
  (>>-) = (>>=)
instance Monad (F f) where
  return a = F (\kp _ -> kp a)
  F m >>= f = F (\kp kf -> m (\a -> runF (f a) kp kf) kf)
instance MonadFix (F f) where
  mfix f = a where
    a = f (impure a)
    impure (F x) = x id (error "MonadFix (F f): wrap")
instance (Foldable f, Functor f) => Foldable (F f) where
    foldr f r xs = runF xs f (foldr (.) id) r
    {-# INLINE foldr #-}
#if MIN_VERSION_base(4,6,0)
    foldl' f z xs = runF xs (flip f) (foldr (!>>>) id) z
      where (!>>>) h g = \r -> g $! h r
    {-# INLINE foldl' #-}
#endif
instance MonadPlus f => MonadPlus (F f) where
  mzero = F (\_ kf -> kf mzero)
  F f `mplus` F g = F (\kp kf -> kf (return (f kp kf) `mplus` return (g kp kf)))
instance MonadTrans F where
  lift f = F (\kp kf -> kf (liftM kp f))
instance Functor f => MonadFree f (F f) where
  wrap f = F (\kp kf -> kf (fmap (\ (F m) -> m kp kf) f))
instance MonadState s m => MonadState s (F m) where
  get = lift get
  put = lift . put
instance MonadReader e m => MonadReader e (F m) where
  ask = lift ask
  local f = lift . local f . retract
instance MonadWriter w m => MonadWriter w (F m) where
  tell = lift . tell
  pass = lift . pass . retract
  listen = lift . listen . retract
instance MonadCont m => MonadCont (F m) where
  callCC f = lift $ callCC (retract . f . fmap lift)
retract :: Monad m => F m a -> m a
retract (F m) = m return Monad.join
{-# INLINE retract #-}
hoistF :: (forall x. f x -> g x) -> F f a -> F g a
hoistF t (F m) = F (\p f -> m p (f . t))
fromF :: MonadFree f m => F f a -> m a
fromF (F m) = m return wrap
{-# INLINE fromF #-}
toF :: Functor f => Free f a -> F f a
toF xs = F (\kp kf -> go kp kf xs) where
  go kp _  (Pure a) = kp a
  go kp kf (Free fma) = kf (fmap (go kp kf) fma)
improve :: Functor f => (forall m. MonadFree f m => m a) -> Free f a
improve m = fromF m
{-# INLINE improve #-}