module Crypto.Util where
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr
incBS :: B.ByteString -> B.ByteString
incBS bs = B.concat (go bs (B.length bs - 1))
where
go bs i
| B.length bs == 0 = []
| unsafeIndex bs i == 0xFF = (go (B.init bs) (i-1)) ++ [B.singleton 0]
| otherwise = [B.init bs] ++ [B.singleton $ (unsafeIndex bs i) + 1]
{-# INLINE incBS #-}
i2bs :: Int -> Integer -> B.ByteString
i2bs l i = B.unfoldr (\l' -> if l' < 0 then Nothing else Just (fromIntegral (i `shiftR` l'), l' - 8)) (l-8)
{-# INLINE i2bs #-}
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized 0 = B.singleton 0
i2bs_unsized i = B.reverse $ B.unfoldr (\i' -> if i' <= 0 then Nothing else Just (fromIntegral i', (i' `shiftR` 8))) i
{-# INLINE i2bs_unsized #-}
throwLeft :: Exception e => Either e a -> a
throwLeft (Left e) = throw e
throwLeft (Right a) = a
for :: Tagged a b -> a -> b
for t _ = unTagged t
(.::.) :: Tagged a b -> a -> b
(.::.) = for
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq s1 s2 =
unsafePerformIO $
unsafeUseAsCStringLen s1 $ \(s1_ptr, s1_len) ->
unsafeUseAsCStringLen s2 $ \(s2_ptr, s2_len) ->
if s1_len /= s2_len
then return False
else (== 0) `fmap` c_constTimeEq s1_ptr s2_ptr (fromIntegral s1_len)
foreign import ccall unsafe
c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
bs2i :: B.ByteString -> Integer
bs2i bs = B.foldl' (\i b -> (i `shiftL` 8) + fromIntegral b) 0 bs
{-# INLINE bs2i #-}
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' a = B.pack . B.zipWith xor a
{-# INLINE zwp' #-}
zwp :: L.ByteString -> L.ByteString -> L.ByteString
zwp a b =
let as = L.toChunks a
bs = L.toChunks b
in L.fromChunks (go as bs)
where
go [] _ = []
go _ [] = []
go (a:as) (b:bs) =
let l = min (B.length a) (B.length b)
(a',ar) = B.splitAt l a
(b',br) = B.splitAt l b
as' = if B.length ar == 0 then as else ar : as
bs' = if B.length br == 0 then bs else br : bs
in (zwp' a' b') : go as' bs'
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect 0 _ = []
collect _ [] = []
collect i (b:bs)
| len < i = b : collect (i - len) bs
| len >= i = [B.take i b]
where
len = B.length b
{-# INLINE collect #-}