{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ViewPatterns #-}
module Crypto.Cipher.Types.BlockIO
( BlockCipherIO(..)
, PtrDest
, PtrSource
, PtrIV
, BufferLength
, onBlock
) where
import Control.Applicative
import Data.Word
import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B (fromForeignPtr, memcpy)
import Data.Byteable
import Data.Bits (xor, Bits)
import Foreign.Storable (poke, peek, Storable)
import Crypto.Cipher.Types.Block
import Foreign.Ptr
import Foreign.ForeignPtr (newForeignPtr_)
type PtrDest = Ptr Word8
type PtrSource = Ptr Word8
type PtrIV = Ptr Word8
type BufferLength = Word32
class BlockCipher cipher => BlockCipherIO cipher where
ecbEncryptMutable :: cipher -> PtrDest -> PtrSource -> BufferLength -> IO ()
ecbDecryptMutable :: cipher -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcEncryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcEncryptMutable = cbcEncryptGeneric
cbcDecryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcDecryptMutable = cbcDecryptGeneric
cbcEncryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcEncryptGeneric cipher = loopBS cipher encrypt
where encrypt bs iv d s = do
mutableXor d iv s bs
ecbEncryptMutable cipher d d (fromIntegral bs)
return s
cbcDecryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcDecryptGeneric cipher = loopBS cipher decrypt
where decrypt bs iv d s = do
ecbEncryptMutable cipher d s (fromIntegral bs)
mutableXor d iv d bs
return d
onBlock :: BlockCipherIO cipher
=> cipher
-> (ByteString -> ByteString)
-> PtrDest
-> PtrSource
-> BufferLength
-> IO ()
onBlock cipher f dst src len = loopBS cipher wrap nullPtr dst src len
where wrap bs fakeIv d s = do
fSrc <- newForeignPtr_ s
let res = f (B.fromForeignPtr fSrc 0 bs)
withBytePtr res $ \r -> B.memcpy d r (fromIntegral bs)
return fakeIv
loopBS :: BlockCipherIO cipher
=> cipher
-> (Int -> PtrIV -> PtrDest -> PtrSource -> IO PtrIV)
-> PtrIV -> PtrDest -> PtrSource -> BufferLength
-> IO ()
loopBS cipher f iv dst src len = loop iv dst src len
where bs = blockSize cipher
loop _ _ _ 0 = return ()
loop i d s n = do
newIV <- f bs i d s
loop newIV (d `plusPtr` bs) (s `plusPtr` bs) (n - fromIntegral bs)
mutableXor :: PtrDest -> PtrSource -> PtrIV -> Int -> IO ()
mutableXor (to64 -> dst) (to64 -> src) (to64 -> iv) 16 = do
peeksAndPoke dst src iv
peeksAndPoke (dst `plusPtr` 8) (src `plusPtr` 8) ((iv `plusPtr` 8) :: Ptr Word64)
mutableXor (to64 -> dst) (to64 -> src) (to64 -> iv) 8 = do
peeksAndPoke dst src iv
mutableXor dst src iv len = loop dst src iv len
where loop _ _ _ 0 = return ()
loop d s i n = peeksAndPoke d s i >> loop (d `plusPtr` 1) (s `plusPtr` 1) (i `plusPtr` 1) (n-1)
to64 :: Ptr Word8 -> Ptr Word64
to64 = castPtr
peeksAndPoke :: (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke dst a b = (xor <$> peek a <*> peek b) >>= poke dst