{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
module Network.Wai.Parse
( parseHttpAccept
, parseRequestBody
, RequestBodyType (..)
, getRequestBodyType
, sinkRequestBody
, BackEnd
, lbsBackEnd
, tempFileBackEnd
, tempFileBackEndOpts
, Param
, File
, FileInfo (..)
, parseContentType
#if TEST
, Bound (..)
, findBound
, sinkTillBound
, killCR
, killCRLF
, takeLine
#endif
) where
import qualified Data.ByteString.Search as Search
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Char8 as S8
import Data.Word (Word8)
import Data.Maybe (fromMaybe)
import Data.List (sortBy)
import Data.Function (on, fix)
import System.Directory (removeFile, getTemporaryDirectory)
import System.IO (hClose, openBinaryTempFile)
import Network.Wai
import qualified Network.HTTP.Types as H
import Control.Monad (when, unless)
import Control.Monad.Trans.Resource (allocate, release, register, InternalState, runInternalState)
import Data.IORef
breakDiscard :: Word8 -> S.ByteString -> (S.ByteString, S.ByteString)
breakDiscard w s =
let (x, y) = S.break (== w) s
in (x, S.drop 1 y)
parseHttpAccept :: S.ByteString -> [S.ByteString]
parseHttpAccept = map fst
. sortBy (rcompare `on` snd)
. map (addSpecificity . grabQ)
. S.split 44
where
rcompare :: (Double,Int) -> (Double,Int) -> Ordering
rcompare = flip compare
addSpecificity (s, q) =
let semicolons = S.count 0x3B s
stars = S.count 0x2A s
in (s, (q, semicolons - stars))
grabQ s =
let (s', q) = S.breakSubstring ";q=" (S.filter (/=0x20) s)
q' = S.takeWhile (/=0x3B) (S.drop 3 q)
in (s', readQ q')
readQ s = case reads $ S8.unpack s of
(x, _):_ -> x
_ -> 1.0
lbsBackEnd :: Monad m => ignored1 -> ignored2 -> m S.ByteString -> m L.ByteString
lbsBackEnd _ _ popper =
loop id
where
loop front = do
bs <- popper
if S.null bs
then return $ L.fromChunks $ front []
else loop $ front . (bs:)
tempFileBackEnd :: InternalState -> ignored1 -> ignored2 -> IO S.ByteString -> IO FilePath
tempFileBackEnd = tempFileBackEndOpts getTemporaryDirectory "webenc.buf"
tempFileBackEndOpts :: IO FilePath
-> String
-> InternalState
-> ignored1
-> ignored2
-> IO S.ByteString
-> IO FilePath
tempFileBackEndOpts getTmpDir pattern internalState _ _ popper = do
(key, (fp, h)) <- flip runInternalState internalState $ allocate (do
tempDir <- getTmpDir
openBinaryTempFile tempDir pattern) (\(_, h) -> hClose h)
_ <- runInternalState (register $ removeFile fp) internalState
fix $ \loop -> do
bs <- popper
unless (S.null bs) $ do
S.hPut h bs
loop
release key
return fp
data FileInfo c = FileInfo
{ fileName :: S.ByteString
, fileContentType :: S.ByteString
, fileContent :: c
}
deriving (Eq, Show)
type Param = (S.ByteString, S.ByteString)
type File y = (S.ByteString, FileInfo y)
type BackEnd a = S.ByteString
-> FileInfo ()
-> IO S.ByteString
-> IO a
data RequestBodyType = UrlEncoded | Multipart S.ByteString
getRequestBodyType :: Request -> Maybe RequestBodyType
getRequestBodyType req = do
ctype' <- lookup "Content-Type" $ requestHeaders req
let (ctype, attrs) = parseContentType ctype'
case ctype of
"application/x-www-form-urlencoded" -> return UrlEncoded
"multipart/form-data" | Just bound <- lookup "boundary" attrs -> return $ Multipart bound
_ -> Nothing
parseContentType :: S.ByteString -> (S.ByteString, [(S.ByteString, S.ByteString)])
parseContentType a = do
let (ctype, b) = S.break (== semicolon) a
attrs = goAttrs id $ S.drop 1 b
in (ctype, attrs)
where
semicolon = 59
equals = 61
space = 32
goAttrs front bs
| S.null bs = front []
| otherwise =
let (x, rest) = S.break (== semicolon) bs
in goAttrs (front . (goAttr x:)) $ S.drop 1 rest
goAttr bs =
let (k, v') = S.break (== equals) bs
v = S.drop 1 v'
in (strip k, strip v)
strip = S.dropWhile (== space) . fst . S.breakEnd (/= space)
parseRequestBody :: BackEnd y
-> Request
-> IO ([Param], [File y])
parseRequestBody s r =
case getRequestBodyType r of
Nothing -> return ([], [])
Just rbt -> sinkRequestBody s rbt (requestBody r)
sinkRequestBody :: BackEnd y
-> RequestBodyType
-> IO S.ByteString
-> IO ([Param], [File y])
sinkRequestBody s r body = do
ref <- newIORef (id, id)
let add x = atomicModifyIORef ref $ \(y, z) ->
case x of
Left y' -> ((y . (y':), z), ())
Right z' -> ((y, z . (z':)), ())
conduitRequestBody s r body add
(x, y) <- readIORef ref
return (x [], y [])
conduitRequestBody :: BackEnd y
-> RequestBodyType
-> IO S.ByteString
-> (Either Param (File y) -> IO ())
-> IO ()
conduitRequestBody _ UrlEncoded rbody add = do
let loop front = do
bs <- rbody
if S.null bs
then return $ S.concat $ front []
else loop $ front . (bs:)
bs <- loop id
mapM_ (add . Left) $ H.parseSimpleQuery bs
conduitRequestBody backend (Multipart bound) rbody add =
parsePieces backend (S8.pack "--" `S.append` bound) rbody add
takeLine :: Source -> IO (Maybe S.ByteString)
takeLine src =
go id
where
go front = do
bs <- readSource src
if S.null bs
then close front
else push front bs
close front = leftover src (front S.empty) >> return Nothing
push front bs = do
let (x, y) = S.break (== 10) $ front bs
in if S.null y
then go $ S.append x
else do
when (S.length y > 1) $ leftover src $ S.drop 1 y
return $ Just $ killCR x
takeLines :: Source -> IO [S.ByteString]
takeLines src = do
res <- takeLine src
case res of
Nothing -> return []
Just l
| S.null l -> return []
| otherwise -> do
ls <- takeLines src
return $ l : ls
data Source = Source (IO S.ByteString) (IORef S.ByteString)
mkSource :: IO S.ByteString -> IO Source
mkSource f = do
ref <- newIORef S.empty
return $ Source f ref
readSource :: Source -> IO S.ByteString
readSource (Source f ref) = do
bs <- atomicModifyIORef ref $ \bs -> (S.empty, bs)
if S.null bs
then f
else return bs
leftover :: Source -> S.ByteString -> IO ()
leftover (Source _ ref) bs = writeIORef ref bs
parsePieces :: BackEnd y
-> S.ByteString
-> IO S.ByteString
-> (Either Param (File y) -> IO ())
-> IO ()
parsePieces sink bound rbody add =
mkSource rbody >>= loop
where
loop src = do
_boundLine <- takeLine src
res' <- takeLines src
unless (null res') $ do
let ls' = map parsePair res'
let x = do
cd <- lookup contDisp ls'
let ct = lookup contType ls'
let attrs = parseAttrs cd
name <- lookup "name" attrs
return (ct, name, lookup "filename" attrs)
case x of
Just (mct, name, Just filename) -> do
let ct = fromMaybe "application/octet-stream" mct
fi0 = FileInfo filename ct ()
(wasFound, y) <- sinkTillBound' bound name fi0 sink src
add $ Right (name, fi0 { fileContent = y })
when wasFound (loop src)
Just (_ct, name, Nothing) -> do
let seed = id
let iter front bs = return $ front . (:) bs
(wasFound, front) <- sinkTillBound bound iter seed src
let bs = S.concat $ front []
let x' = (name, bs)
add $ Left x'
when wasFound (loop src)
_ -> do
let seed = ()
iter () _ = return ()
(wasFound, ()) <- sinkTillBound bound iter seed src
when wasFound (loop src)
where
contDisp = S8.pack "Content-Disposition"
contType = S8.pack "Content-Type"
parsePair s =
let (x, y) = breakDiscard 58 s
in (x, S.dropWhile (== 32) y)
data Bound = FoundBound S.ByteString S.ByteString
| NoBound
| PartialBound
deriving (Eq, Show)
findBound :: S.ByteString -> S.ByteString -> Bound
findBound b bs = handleBreak $ Search.breakOn b bs
where
handleBreak (h, t)
| S.null t = go [lowBound..S.length bs - 1]
| otherwise = FoundBound h $ S.drop (S.length b) t
lowBound = max 0 $ S.length bs - S.length b
go [] = NoBound
go (i:is)
| mismatch [0..S.length b - 1] [i..S.length bs - 1] = go is
| otherwise =
let endI = i + S.length b
in if endI > S.length bs
then PartialBound
else FoundBound (S.take i bs) (S.drop endI bs)
mismatch [] _ = False
mismatch _ [] = False
mismatch (x:xs) (y:ys)
| S.index b x == S.index bs y = mismatch xs ys
| otherwise = True
sinkTillBound' :: S.ByteString
-> S.ByteString
-> FileInfo ()
-> BackEnd y
-> Source
-> IO (Bool, y)
sinkTillBound' bound name fi sink src = do
(next, final) <- wrapTillBound bound src
y <- sink name fi next
b <- final
return (b, y)
data WTB = WTBWorking (S.ByteString -> S.ByteString)
| WTBDone Bool
wrapTillBound :: S.ByteString
-> Source
-> IO (IO S.ByteString, IO Bool)
wrapTillBound bound src = do
ref <- newIORef $ WTBWorking id
return (go ref, final ref)
where
final ref = do
x <- readIORef ref
case x of
WTBWorking _ -> error "wrapTillBound did not finish"
WTBDone y -> return y
go ref = do
state <- readIORef ref
case state of
WTBDone _ -> return S.empty
WTBWorking front -> do
bs <- readSource src
if S.null bs
then do
writeIORef ref $ WTBDone False
return $ front bs
else push $ front bs
where
push bs =
case findBound bound bs of
FoundBound before after -> do
let before' = killCRLF before
leftover src after
writeIORef ref $ WTBDone True
return before'
NoBound -> do
let (toEmit, front') =
if not (S8.null bs) && S8.last bs `elem` "\r\n"
then let (x, y) = S.splitAt (S.length bs - 2) bs
in (x, S.append y)
else (bs, id)
writeIORef ref $ WTBWorking front'
if S.null toEmit
then go ref
else return toEmit
PartialBound -> do
writeIORef ref $ WTBWorking $ S.append bs
go ref
sinkTillBound :: S.ByteString
-> (x -> S.ByteString -> IO x)
-> x
-> Source
-> IO (Bool, x)
sinkTillBound bound iter seed0 src = do
(next, final) <- wrapTillBound bound src
let loop seed = do
bs <- next
if S.null bs
then return seed
else iter seed bs >>= loop
seed <- loop seed0
b <- final
return (b, seed)
parseAttrs :: S.ByteString -> [(S.ByteString, S.ByteString)]
parseAttrs = map go . S.split 59
where
tw = S.dropWhile (== 32)
dq s = if S.length s > 2 && S.head s == 34 && S.last s == 34
then S.tail $ S.init s
else s
go s =
let (x, y) = breakDiscard 61 s
in (tw x, dq $ tw y)
killCRLF :: S.ByteString -> S.ByteString
killCRLF bs
| S.null bs || S.last bs /= 10 = bs
| otherwise = killCR $ S.init bs
killCR :: S.ByteString -> S.ByteString
killCR bs
| S.null bs || S.last bs /= 13 = bs
| otherwise = S.init bs