{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Connection
(
Connection
, connectionID
, ConnectionParams(..)
, TLSSettings(..)
, ProxySettings(..)
, SockSettings
, LineTooLong(..)
, initConnectionContext
, ConnectionContext
, connectFromHandle
, connectTo
, connectionClose
, connectionGet
, connectionGetChunk
, connectionGetChunk'
, connectionGetLine
, connectionPut
, connectionSetSecure
, connectionIsSecure
) where
import Control.Applicative
import Control.Concurrent.MVar
import Control.Monad (join)
import qualified Control.Exception as E
import qualified System.IO.Error as E (mkIOError, eofErrorType)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLS
import System.X509 (getSystemCertificateStore)
import Network.Socks5
import qualified Network as N
import Data.Default.Class
import Data.Data
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as L
import qualified Crypto.Random.AESCtr as RNG
import System.Environment
import System.IO
import qualified Data.Map as M
import Network.Connection.Types
type Manager = MVar (M.Map TLS.SessionID TLS.SessionData)
data LineTooLong = LineTooLong deriving (Show,Typeable)
instance E.Exception LineTooLong
connectionSessionManager :: Manager -> TLS.SessionManager
connectionSessionManager mvar = TLS.SessionManager
{ TLS.sessionResume = \sessionID -> withMVar mvar (return . M.lookup sessionID)
, TLS.sessionEstablish = \sessionID sessionData ->
modifyMVar_ mvar (return . M.insert sessionID sessionData)
, TLS.sessionInvalidate = \sessionID -> modifyMVar_ mvar (return . M.delete sessionID)
}
initConnectionContext :: IO ConnectionContext
initConnectionContext = ConnectionContext <$> getSystemCertificateStore
makeTLSParams :: ConnectionContext -> ConnectionID -> TLSSettings -> TLS.ClientParams
makeTLSParams cg cid ts@(TLSSettingsSimple {}) =
(TLS.defaultParamsClient (fst cid) portString)
{ TLS.clientSupported = def { TLS.supportedCiphers = TLS.ciphersuite_all }
, TLS.clientShared = def
{ TLS.sharedCAStore = globalCertificateStore cg
, TLS.sharedValidationCache = validationCache
}
}
where validationCache
| settingDisableCertificateValidation ts =
TLS.ValidationCache (\_ _ _ -> return TLS.ValidationCachePass)
(\_ _ _ -> return ())
| otherwise = def
portString = BC.pack $ show $ snd cid
makeTLSParams _ cid (TLSSettings p) =
p { TLS.clientServerIdentification = (fst cid, portString) }
where portString = BC.pack $ show $ snd cid
withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend f conn = readMVar (connectionBackend conn) >>= f
connectionNew :: ConnectionID -> ConnectionBackend -> IO Connection
connectionNew cid backend =
Connection <$> newMVar backend
<*> newMVar (Just B.empty)
<*> pure cid
connectFromHandle :: ConnectionContext
-> Handle
-> ConnectionParams
-> IO Connection
connectFromHandle cg h p = withSecurity (connectionUseSecure p)
where withSecurity Nothing = connectionNew cid $ ConnectionStream h
withSecurity (Just tlsSettings) = tlsEstablish h (makeTLSParams cg cid tlsSettings) >>= connectionNew cid . ConnectionTLS
cid = (connectionHostname p, connectionPort p)
connectTo :: ConnectionContext
-> ConnectionParams
-> IO Connection
connectTo cg cParams = do
conFct <- getConFct (connectionUseSocks cParams)
h <- conFct (connectionHostname cParams) (N.PortNumber $ connectionPort cParams)
connectFromHandle cg h cParams
where
getConFct Nothing = return N.connectTo
getConFct (Just (OtherProxy h p)) = return $ \_ _ -> N.connectTo h (N.PortNumber p)
getConFct (Just (SockSettingsSimple h p)) = return $ socksConnectTo h (N.PortNumber p)
getConFct (Just (SockSettingsEnvironment v)) = do
let name = maybe "SOCKS_SERVER" id v
evar <- E.try (getEnv name)
case evar of
Left (_ :: E.IOException) -> return N.connectTo
Right var ->
case parseSocks var of
Nothing -> return N.connectTo
Just (sHost, sPort) -> return $ socksConnectTo sHost (N.PortNumber $ fromIntegral (sPort :: Int))
parseSocks s =
case break (== ':') s of
(sHost, "") -> Just (sHost, 1080)
(sHost, ':':portS) ->
case reads portS of
[(sPort,"")] -> Just (sHost, sPort)
_ -> Nothing
_ -> Nothing
connectionPut :: Connection -> ByteString -> IO ()
connectionPut connection content = withBackend doWrite connection
where doWrite (ConnectionStream h) = B.hPut h content >> hFlush h
doWrite (ConnectionTLS ctx) = TLS.sendData ctx $ L.fromChunks [content]
connectionGet :: Connection -> Int -> IO ByteString
connectionGet conn size
| size < 0 = fail "Network.Connection.connectionGet: size < 0"
| size == 0 = return B.empty
| otherwise = connectionGetChunkBase "connectionGet" conn $ B.splitAt size
connectionGetChunk :: Connection -> IO ByteString
connectionGetChunk conn =
connectionGetChunkBase "connectionGetChunk" conn $ \s -> (s, B.empty)
connectionGetChunk' :: Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunk' = connectionGetChunkBase "connectionGetChunk'"
connectionGetChunkBase :: String -> Connection -> (ByteString -> (a, ByteString)) -> IO a
connectionGetChunkBase loc conn f =
modifyMVar (connectionBuffer conn) $ \m ->
case m of
Nothing -> throwEOF conn loc
Just buf
| B.null buf -> do
chunk <- withBackend getMoreData conn
if B.null chunk
then closeBuf chunk
else updateBuf chunk
| otherwise ->
updateBuf buf
where
getMoreData (ConnectionTLS tlsctx) = TLS.recvData tlsctx
getMoreData (ConnectionStream h) = B.hGetSome h (16 * 1024)
updateBuf buf = case f buf of (a, !buf') -> return (Just buf', a)
closeBuf buf = case f buf of (a, _buf') -> return (Nothing, a)
connectionGetLine :: Int
-> Connection
-> IO ByteString
connectionGetLine limit conn = more (throwEOF conn loc) 0 id
where
loc = "connectionGetLine"
lineTooLong = E.throwIO LineTooLong
more eofK !currentSz !dl =
getChunk (\s -> let len = B.length s
in if currentSz + len > limit
then lineTooLong
else more eofK (currentSz + len) (dl . (s:)))
(\s -> done (dl . (s:)))
(done dl)
done :: ([ByteString] -> [ByteString]) -> IO ByteString
done dl = return $! B.concat $ dl []
getChunk :: (ByteString -> IO r)
-> (ByteString -> IO r)
-> IO r
-> IO r
getChunk moreK doneK eofK =
join $ connectionGetChunkBase loc conn $ \s ->
if B.null s
then (eofK, B.empty)
else case B.breakByte 10 s of
(a, b)
| B.null b -> (moreK a, B.empty)
| otherwise -> (doneK a, B.tail b)
throwEOF :: Connection -> String -> IO a
throwEOF conn loc =
E.throwIO $ E.mkIOError E.eofErrorType loc' Nothing (Just path)
where
loc' = "Network.Connection." ++ loc
path = let (host, port) = connectionID conn
in host ++ ":" ++ show port
connectionClose :: Connection -> IO ()
connectionClose = withBackend backendClose
where backendClose (ConnectionTLS ctx) = TLS.bye ctx >> TLS.contextClose ctx
backendClose (ConnectionStream h) = hClose h
connectionSetSecure :: ConnectionContext
-> Connection
-> TLSSettings
-> IO ()
connectionSetSecure cg connection params =
modifyMVar_ (connectionBuffer connection) $ \b ->
modifyMVar (connectionBackend connection) $ \backend ->
case backend of
(ConnectionStream h) -> do ctx <- tlsEstablish h (makeTLSParams cg (connectionID connection) params)
return (ConnectionTLS ctx, Just B.empty)
(ConnectionTLS _) -> return (backend, b)
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure conn = withBackend isSecure conn
where isSecure (ConnectionStream _) = return False
isSecure (ConnectionTLS _) = return True
tlsEstablish :: Handle -> TLS.ClientParams -> IO TLS.Context
tlsEstablish handle tlsParams = do
rng <- RNG.makeSystem
ctx <- TLS.contextNew handle tlsParams rng
TLS.handshake ctx
return ctx