{-# LANGUAGE OverloadedStrings #-}
module Crypto.PubKey.RSA.OAEP
(
OAEPParams(..)
, defaultOAEPParams
, encryptWithSeed
, encrypt
, decrypt
, decryptSafer
) where
import Crypto.Random
import Crypto.Types.PubKey.RSA
import Crypto.PubKey.HashDescr
import Crypto.PubKey.MaskGenFunction
import Crypto.PubKey.RSA.Prim
import Crypto.PubKey.RSA.Types
import Crypto.PubKey.RSA (generateBlinder)
import Crypto.PubKey.Internal (and')
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Bits (xor)
data OAEPParams = OAEPParams
{ oaepHash :: HashFunction
, oaepMaskGenAlg :: MaskGenAlgorithm
, oaepLabel :: Maybe ByteString
}
defaultOAEPParams :: HashFunction -> OAEPParams
defaultOAEPParams hashF =
OAEPParams { oaepHash = hashF
, oaepMaskGenAlg = mgf1
, oaepLabel = Nothing
}
encryptWithSeed :: ByteString
-> OAEPParams
-> PublicKey
-> ByteString
-> Either Error ByteString
encryptWithSeed seed oaep pk msg
| k < 2*hashLen+2 = Left InvalidParameters
| B.length seed /= hashLen = Left InvalidParameters
| mLen > k - 2*hashLen-2 = Left MessageTooLong
| otherwise = Right $ ep pk em
where
k = public_size pk
mLen = B.length msg
hashF = oaepHash oaep
mgf = (oaepMaskGenAlg oaep) hashF
labelHash = hashF $ maybe B.empty id $ oaepLabel oaep
hashLen = B.length labelHash
ps = B.replicate (k - mLen - 2*hashLen - 2) 0
db = B.concat [labelHash, ps, B.singleton 0x1, msg]
dbmask = mgf seed (k - hashLen - 1)
maskedDB = B.pack $ B.zipWith xor db dbmask
seedMask = mgf maskedDB hashLen
maskedSeed = B.pack $ B.zipWith xor seed seedMask
em = B.concat [B.singleton 0x0,maskedSeed,maskedDB]
encrypt :: CPRG g
=> g
-> OAEPParams
-> PublicKey
-> ByteString
-> (Either Error ByteString, g)
encrypt g oaep pk msg = (encryptWithSeed seed oaep pk msg, g')
where hashF = oaepHash oaep
hashLen = B.length (hashF B.empty)
(seed, g') = cprgGenerate hashLen g
unpad :: OAEPParams
-> Int
-> ByteString
-> Either Error ByteString
unpad oaep k em
| paddingSuccess = Right msg
| otherwise = Left MessageNotRecognized
where
hashF = oaepHash oaep
mgf = (oaepMaskGenAlg oaep) hashF
labelHash = hashF $ maybe B.empty id $ oaepLabel oaep
hashLen = B.length labelHash
(pb, em0) = B.splitAt 1 em
(maskedSeed,maskedDB) = B.splitAt hashLen em0
seedMask = mgf maskedDB hashLen
seed = B.pack $ B.zipWith xor maskedSeed seedMask
dbmask = mgf seed (k - hashLen - 1)
db = B.pack $ B.zipWith xor maskedDB dbmask
(labelHash',db1) = B.splitAt hashLen db
(_,db2) = B.break (/= 0) db1
(ps1,msg) = B.splitAt 1 db2
paddingSuccess = and' [ labelHash' == labelHash
, ps1 == "\x01"
, pb == "\x00"
]
decrypt :: Maybe Blinder
-> OAEPParams
-> PrivateKey
-> ByteString
-> Either Error ByteString
decrypt blinder oaep pk cipher
| B.length cipher /= k = Left MessageSizeIncorrect
| k < 2*hashLen+2 = Left InvalidParameters
| otherwise = unpad oaep (private_size pk) $ dp blinder pk cipher
where
k = private_size pk
hashF = oaepHash oaep
hashLen = B.length (hashF B.empty)
decryptSafer :: CPRG g
=> g
-> OAEPParams
-> PrivateKey
-> ByteString
-> (Either Error ByteString, g)
decryptSafer rng oaep pk cipher = (decrypt (Just blinder) oaep pk cipher, rng')
where (blinder, rng') = generateBlinder rng (private_n pk)