{-# LANGUAGE DeriveDataTypeable #-}
module Crypto.Types.PubKey.RSA
    ( PublicKey(..)
    , PrivateKey(..)
    , KeyPair(..)
    , private_size
    , private_n
    , toPublicKey
    , toPrivateKey
    ) where
import Data.Data
import Data.ASN1.Types
import Data.Bits (shiftL, shiftR, complement, testBit, (.&.))
import Data.Word (Word8)
data PublicKey = PublicKey
    { public_size :: Int      
    , public_n    :: Integer  
    , public_e    :: Integer  
    } deriving (Show,Read,Eq,Data,Typeable)
instance ASN1Object PublicKey where
    toASN1 pubKey = \xs -> Start Sequence
                         : IntVal (public_n pubKey)
                         : IntVal (public_e pubKey)
                         : End Sequence
                         : xs
    fromASN1 (Start Sequence:IntVal smodulus:IntVal pubexp:End Sequence:xs) =
        Right (PublicKey { public_size = calculate_modulus modulus 1
                         , public_n    = modulus
                         , public_e    = pubexp
                         }
              , xs)
        where calculate_modulus n i = if (2 ^ (i * 8)) > n then i else calculate_modulus n (i+1)
              
              
              modulus = toPositive smodulus
    fromASN1 _ =
        Left "fromASN1: RSA.PublicKey: unexpected format"
data PrivateKey = PrivateKey
    { private_pub  :: PublicKey 
    , private_d    :: Integer   
    , private_p    :: Integer   
    , private_q    :: Integer   
    , private_dP   :: Integer   
    , private_dQ   :: Integer   
    , private_qinv :: Integer   
    } deriving (Show,Read,Eq,Data,Typeable)
private_size = public_size . private_pub
private_n    = public_n . private_pub
private_e    = public_e . private_pub
instance ASN1Object PrivateKey where
    toASN1 privKey = \xs -> Start Sequence
                          : IntVal 0
                          : IntVal (public_n $ private_pub privKey)
                          : IntVal (public_e $ private_pub privKey)
                          : IntVal (private_d privKey)
                          : IntVal (private_p privKey)
                          : IntVal (private_q privKey)
                          : IntVal (private_dP privKey)
                          : IntVal (private_dQ privKey)
                          : IntVal (fromIntegral $ private_qinv privKey)
                          : End Sequence
                          : xs
    fromASN1 (Start Sequence
             : IntVal 0
             : IntVal n
             : IntVal e
             : IntVal d
             : IntVal p1
             : IntVal p2
             : IntVal pexp1
             : IntVal pexp2
             : IntVal pcoef
             : End Sequence
             : xs) = Right (privKey, xs)
        where calculate_modulus n i = if (2 ^ (i * 8)) > n then i else calculate_modulus n (i+1)
              privKey = PrivateKey
                        { private_pub  = PublicKey { public_size = calculate_modulus n 1
                                                   , public_n    = n
                                                   , public_e    = e
                                                   }
                        , private_d    = d
                        , private_p    = p1
                        , private_q    = p2
                        , private_dP   = pexp1
                        , private_dQ   = pexp2
                        , private_qinv = pcoef
                        }
    fromASN1 _ =
        Left "fromASN1: RSA.PrivateKey: unexpected format"
newtype KeyPair = KeyPair PrivateKey
    deriving (Show,Read,Eq,Data,Typeable)
instance ASN1Object KeyPair where
    toASN1 (KeyPair pkey) = toASN1 pkey
    fromASN1 = either Left (\(k,s) -> Right (KeyPair k, s)) . fromASN1
toPublicKey :: KeyPair -> PublicKey
toPublicKey (KeyPair priv) = private_pub priv
toPrivateKey :: KeyPair -> PrivateKey
toPrivateKey (KeyPair priv) = priv
toPositive :: Integer -> Integer
toPositive int
    | int < 0   = uintOfBytes $ bytesOfInt int
    | otherwise = int
  where uintOfBytes = foldl (\acc n -> (acc `shiftL` 8) + fromIntegral n) 0
        bytesOfInt :: Integer -> [Word8]
        bytesOfInt n = if testBit (head nints) 7 then nints else 0xff : nints
          where nints = reverse $ plusOne $ reverse $ map complement $ bytesOfUInt (abs n)
                plusOne []     = [1]
                plusOne (x:xs) = if x == 0xff then 0 : plusOne xs else (x+1) : xs
                bytesOfUInt x = reverse (list x)
                  where list i = if i <= 0xff then [fromIntegral i] else (fromIntegral i .&. 0xff) : list (i `shiftR` 8)