{-# LANGUAGE DeriveDataTypeable, OverloadedStrings #-}
module Network.TLS.Handshake.Common
( handshakeFailed
, errorToAlert
, unexpected
, newSession
, handshakeTerminate
, sendChangeCipherAndFinish
, recvChangeCipherAndFinish
, RecvState(..)
, runRecvState
, recvPacketHandshake
, onRecvStateHandshake
) where
import Control.Concurrent.MVar
import Network.TLS.Parameters
import Network.TLS.Context.Internal
import Network.TLS.Session
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.State hiding (getNegotiatedProtocol)
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.State
import Network.TLS.Record.State
import Network.TLS.Measurement
import Network.TLS.Types
import Network.TLS.Cipher
import Network.TLS.Util
import Data.ByteString.Char8 ()
import Control.Monad.State
import Control.Exception (throwIO)
handshakeFailed :: TLSError -> IO ()
handshakeFailed err = throwIO $ HandshakeFailed err
errorToAlert :: TLSError -> Packet
errorToAlert (Error_Protocol (_, _, ad)) = Alert [(AlertLevel_Fatal, ad)]
errorToAlert _ = Alert [(AlertLevel_Fatal, InternalError)]
unexpected :: String -> Maybe [Char] -> IO a
unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected)
newSession :: Context -> IO Session
newSession ctx
| supportedSession $ ctxSupported ctx = getStateRNG ctx 32 >>= return . Session . Just
| otherwise = return $ Session Nothing
handshakeTerminate :: Context -> IO ()
handshakeTerminate ctx = do
session <- usingState_ ctx getSession
case session of
Session (Just sessionId) -> do
sessionData <- getSessionData ctx
liftIO $ sessionEstablish (sharedSessionManager $ ctxShared ctx) sessionId (fromJust "session-data" sessionData)
_ -> return ()
liftIO $ modifyMVar_ (ctxHandshake ctx) (return . const Nothing)
updateMeasure ctx resetBytesCounters
setEstablished ctx True
return ()
sendChangeCipherAndFinish :: IO ()
-> Context
-> Role
-> IO ()
sendChangeCipherAndFinish betweenCall ctx role = do
sendPacket ctx ChangeCipherSpec
betweenCall
liftIO $ contextFlush ctx
cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role
sendPacket ctx (Handshake [Finished cf])
liftIO $ contextFlush ctx
recvChangeCipherAndFinish :: Context -> IO ()
recvChangeCipherAndFinish ctx = runRecvState ctx (RecvStateNext expectChangeCipher)
where expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished")
data RecvState m =
RecvStateNext (Packet -> m (RecvState m))
| RecvStateHandshake (Handshake -> m (RecvState m))
| RecvStateDone
recvPacketHandshake :: Context -> IO [Handshake]
recvPacketHandshake ctx = do
pkts <- recvPacket ctx
case pkts of
Right (Handshake l) -> return l
Right x -> fail ("unexpected type received. expecting handshake and got: " ++ show x)
Left err -> throwCore err
onRecvStateHandshake :: Context -> RecvState IO -> [Handshake] -> IO (RecvState IO)
onRecvStateHandshake _ recvState [] = return recvState
onRecvStateHandshake ctx (RecvStateHandshake f) (x:xs) = do
nstate <- f x
processHandshake ctx x
onRecvStateHandshake ctx nstate xs
onRecvStateHandshake _ _ _ = unexpected "spurious handshake" Nothing
runRecvState :: Context -> RecvState IO -> IO ()
runRecvState _ (RecvStateDone) = return ()
runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx
runRecvState ctx iniState = recvPacketHandshake ctx >>= onRecvStateHandshake ctx iniState >>= runRecvState ctx
getSessionData :: Context -> IO (Maybe SessionData)
getSessionData ctx = do
ver <- usingState_ ctx getVersion
mms <- usingHState ctx (gets hstMasterSecret)
tx <- liftIO $ readMVar (ctxTxState ctx)
case mms of
Nothing -> return Nothing
Just ms -> return $ Just $ SessionData
{ sessionVersion = ver
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher tx
, sessionSecret = ms
}