{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE CPP #-}
module Network.Socks5.Command
( establish
, Connect(..)
, Command(..)
, connectIPV4
, connectIPV6
, connectDomainName
, rpc
, rpc_
, sendSerialized
, waitSerialized
) where
import Control.Applicative
import Control.Exception
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import Data.Serialize
import Network.Socket (Socket, PortNumber, HostAddress, HostAddress6)
import Network.Socket.ByteString
import Network.Socks5.Types
import Network.Socks5.Wire
establish :: Socket -> [SocksMethod] -> IO SocksMethod
establish socket methods = do
sendAll socket (encode $ SocksHello methods)
getSocksHelloResponseMethod <$> runGetDone get (recv socket 4096)
newtype Connect = Connect SocksAddress deriving (Show,Eq,Ord)
class Command a where
toRequest :: a -> SocksRequest
fromRequest :: SocksRequest -> Maybe a
instance Command SocksRequest where
toRequest = id
fromRequest = Just
instance Command Connect where
toRequest (Connect (SocksAddress ha port)) = SocksRequest
{ requestCommand = SocksCommandConnect
, requestDstAddr = ha
, requestDstPort = fromIntegral port
}
fromRequest req
| requestCommand req /= SocksCommandConnect = Nothing
| otherwise = Just $ Connect $ SocksAddress (requestDstAddr req) (requestDstPort req)
connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
connectIPV4 socket hostaddr port = onReply <$> rpc_ socket (Connect $ SocksAddress (SocksAddrIPV4 hostaddr) port)
where onReply (SocksAddrIPV4 h, p) = (h, p)
onReply _ = error "ipv4 requested, got something different"
connectIPV6 :: Socket -> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
connectIPV6 socket hostaddr6 port = onReply <$> rpc_ socket (Connect $ SocksAddress (SocksAddrIPV6 hostaddr6) port)
where onReply (SocksAddrIPV6 h, p) = (h, p)
onReply _ = error "ipv6 requested, got something different"
connectDomainName :: Socket -> String -> PortNumber -> IO (SocksHostAddress, PortNumber)
connectDomainName socket fqdn port = rpc_ socket $ Connect $ SocksAddress (SocksAddrDomainName $ BC.pack fqdn) port
sendSerialized :: Serialize a => Socket -> a -> IO ()
sendSerialized sock a = sendAll sock $ encode a
waitSerialized :: Serialize a => Socket -> IO a
waitSerialized sock = runGetDone get (getMore sock)
rpc :: Command a => Socket -> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc socket req = do
sendSerialized socket (toRequest req)
onReply <$> runGetDone get (getMore socket)
where onReply res@(responseReply -> reply) =
case reply of
SocksReplySuccess -> Right (responseBindAddr res, fromIntegral $ responseBindPort res)
SocksReplyError e -> Left e
rpc_ :: Command a => Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ socket req = rpc socket req >>= either throwIO return
runGetDone :: Serialize a => Get a -> IO ByteString -> IO a
runGetDone getter ioget = ioget >>= return . runGetPartial getter >>= r where
#if MIN_VERSION_cereal(0,4,0)
r (Fail s _) = error s
#else
r (Fail s) = error s
#endif
r (Partial cont) = ioget >>= r . cont
r (Done a b)
| not $ B.null b = error "got too many bytes while receiving data"
| otherwise = return a
getMore :: Socket -> IO ByteString
getMore socket = recv socket 4096