module Crypto.Padding
(
padPKCS5
, padBlockSize
, putPaddedPKCS5
, unpadPKCS5safe
, unpadPKCS5
, padESP, unpadESP
, padESPBlockSize
, putPadESPBlockSize, putPadESP
) where
import Data.Serialize.Put
import Crypto.Classes
import Crypto.Types
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
padPKCS5 :: ByteLength -> B.ByteString -> B.ByteString
padPKCS5 len bs = runPut $ putPaddedPKCS5 len bs
putPaddedPKCS5 :: ByteLength -> B.ByteString -> Put
putPaddedPKCS5 0 bs = putByteString bs >> putWord8 1
putPaddedPKCS5 len bs = putByteString bs >> putByteString pad
where
pad = B.replicate padLen padValue
r = len - (B.length bs `rem` len)
padLen = if r == 0 then len else r
padValue = fromIntegral padLen
padBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padBlockSize k = runPut . putPaddedBlockSize k
putPaddedBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPaddedBlockSize k bs = putPaddedPKCS5 (blockSizeBytes `for` k) bs
unpadPKCS5safe :: B.ByteString -> Maybe B.ByteString
unpadPKCS5safe bs
| bsLen > 0 && B.all (== padLen) pad && B.length pad == pLen = Just msg
| otherwise = Nothing
where
bsLen = B.length bs
padLen = B.last bs
pLen = fromIntegral padLen
(msg,pad) = B.splitAt (bsLen - pLen) bs
unpadPKCS5 :: B.ByteString -> B.ByteString
unpadPKCS5 bs = if bsLen == 0 then bs else msg
where
bsLen = B.length bs
padLen = B.last bs
pLen = fromIntegral padLen
(msg,_) = B.splitAt (bsLen - pLen) bs
padESP :: Int -> B.ByteString -> B.ByteString
padESP i bs = runPut (putPadESP i bs)
padESPBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padESPBlockSize k bs = runPut (putPadESPBlockSize k bs)
putPadESPBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPadESPBlockSize k bs = putPadESP (blockSizeBytes `for` k) bs
putPadESP :: Int -> B.ByteString -> Put
putPadESP 0 bs = putByteString bs >> putWord8 0
putPadESP l bs = do
putByteString bs
putByteString pad
putWord8 pLen
where
pad = B.take padLen espPad
padLen = l - ((B.length bs + 1) `rem` l)
pLen = fromIntegral padLen
espPad = B.pack [1..255]
unpadESP :: B.ByteString -> Maybe B.ByteString
unpadESP bs =
if bsLen == 0 || not (constTimeEq (B.take pLen pad) (B.take pLen espPad))
then Nothing
else Just msg
where
bsLen = B.length bs
padLen = B.last bs
pLen = fromIntegral padLen
(msg,pad) = B.splitAt (bsLen - (pLen + 1)) bs