{-# LANGUAGE CPP #-}
module Network.Simple.TCP (
connect
, serve
, listen
, accept
, acceptFork
, recv
, send
, bindSock
, connectSock
, closeSock
, NS.withSocketsDo
, HostPreference(..)
, NS.HostName
, NS.ServiceName
, NS.Socket
, NS.SockAddr
) where
import Control.Concurrent (ThreadId, forkIO)
import qualified Control.Exception as E
import qualified Control.Monad.Catch as C
import Control.Monad
import Control.Monad.IO.Class (MonadIO(liftIO))
import qualified Data.ByteString as BS
import Data.List (partition)
import qualified Network.Socket as NS
import Network.Simple.Internal
import qualified Network.Socket.ByteString as NSB
connect
:: (MonadIO m, C.MonadMask m)
=> NS.HostName
-> NS.ServiceName
-> ((NS.Socket, NS.SockAddr) -> m r)
-> m r
connect host port = C.bracket (connectSock host port)
(silentCloseSock . fst)
serve
:: MonadIO m
=> HostPreference
-> NS.ServiceName
-> ((NS.Socket, NS.SockAddr) -> IO ())
-> m ()
serve hp port k = liftIO $ do
listen hp port $ \(lsock,_) -> do
forever $ acceptFork lsock k
listen
:: (MonadIO m, C.MonadMask m)
=> HostPreference
-> NS.ServiceName
-> ((NS.Socket, NS.SockAddr) -> m r)
-> m r
listen hp port = C.bracket listen' (silentCloseSock . fst)
where
listen' = do x@(bsock,_) <- bindSock hp port
liftIO . NS.listen bsock $ max 2048 NS.maxListenQueue
return x
accept
:: (MonadIO m, C.MonadMask m)
=> NS.Socket
-> ((NS.Socket, NS.SockAddr) -> m r)
-> m r
accept lsock k = do
conn@(csock,_) <- liftIO (NS.accept lsock)
C.finally (k conn) (silentCloseSock csock)
acceptFork
:: MonadIO m
=> NS.Socket
-> ((NS.Socket, NS.SockAddr) -> IO ())
-> m ThreadId
acceptFork lsock k = liftIO $ do
conn@(csock,_) <- NS.accept lsock
forkFinally (k conn)
(\ea -> do silentCloseSock csock
either E.throwIO return ea)
connectSock :: MonadIO m
=> NS.HostName -> NS.ServiceName -> m (NS.Socket, NS.SockAddr)
connectSock host port = liftIO $ do
(addr:_) <- NS.getAddrInfo (Just hints) (Just host) (Just port)
E.bracketOnError (newSocket addr) closeSock $ \sock -> do
let sockAddr = NS.addrAddress addr
NS.connect sock sockAddr
return (sock, sockAddr)
where
hints = NS.defaultHints { NS.addrFlags = [NS.AI_ADDRCONFIG]
, NS.addrSocketType = NS.Stream }
bindSock :: MonadIO m
=> HostPreference -> NS.ServiceName -> m (NS.Socket, NS.SockAddr)
bindSock hp port = liftIO $ do
addrs <- NS.getAddrInfo (Just hints) (hpHostName hp) (Just port)
let addrs' = case hp of
HostIPv4 -> prioritize isIPv4addr addrs
HostIPv6 -> prioritize isIPv6addr addrs
_ -> addrs
tryAddrs addrs'
where
hints = NS.defaultHints { NS.addrFlags = [NS.AI_PASSIVE]
, NS.addrSocketType = NS.Stream }
tryAddrs [] = error "bindSock: no addresses available"
tryAddrs [x] = useAddr x
tryAddrs (x:xs) = E.catch (useAddr x)
(\e -> let _ = e :: IOError in tryAddrs xs)
useAddr addr = E.bracketOnError (newSocket addr) closeSock $ \sock -> do
let sockAddr = NS.addrAddress addr
NS.setSocketOption sock NS.NoDelay 1
NS.setSocketOption sock NS.ReuseAddr 1
NS.bindSocket sock sockAddr
return (sock, sockAddr)
closeSock :: MonadIO m => NS.Socket -> m ()
closeSock = liftIO .
#if MIN_VERSION_network(2,4,0)
NS.close
#else
NS.sClose
#endif
{-# INLINE closeSock #-}
recv :: MonadIO m => NS.Socket -> Int -> m (Maybe BS.ByteString)
recv sock nbytes = do
bs <- liftIO (NSB.recv sock nbytes)
if BS.null bs
then return Nothing
else return (Just bs)
send :: MonadIO m => NS.Socket -> BS.ByteString -> m ()
send sock = \bs -> liftIO (NSB.sendAll sock bs)
newSocket :: NS.AddrInfo -> IO NS.Socket
newSocket addr = NS.socket (NS.addrFamily addr)
(NS.addrSocketType addr)
(NS.addrProtocol addr)
isIPv4addr, isIPv6addr :: NS.AddrInfo -> Bool
isIPv4addr x = NS.addrFamily x == NS.AF_INET
isIPv6addr x = NS.addrFamily x == NS.AF_INET6
prioritize :: (a -> Bool) -> [a] -> [a]
prioritize p = uncurry (++) . partition p
forkFinally :: IO a -> (Either E.SomeException a -> IO ()) -> IO ThreadId
forkFinally action and_then =
E.mask $ \restore ->
forkIO $ E.try (restore action) >>= and_then
silentCloseSock :: MonadIO m => NS.Socket -> m ()
silentCloseSock sock = liftIO $ do
E.catch (closeSock sock)
(\e -> let _ = e :: IOError in return ())