{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE NumericUnderscores #-}

-- |
-- Module: Crypto.KDF.PBKDF
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure PBKDF2 (password-based key derivation
-- function) implementation, as specified by
-- [RFC2898](https://datatracker.ietf.org/doc/html/rfc2898).

module Crypto.KDF.PBKDF (
    -- * HMAC synonym
    HMAC

    -- * PBKDF2
  , derive
  )where

import Control.Monad (guard)
import Data.Bits ((.>>.), (.&.))
import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Builder.Extra as BE
import Data.Word (Word32, Word64)

-- NB following synonym really only exists to make haddocks more
--    readable

-- | A HMAC function, taking a key as the first argument and the input
--   value as the second, producing a MAC digest.
--
--   (RFC2898 specifically requires a "pseudorandom function" of two
--   arguments, but in practice this will usually be a HMAC function.)
--
--   >>> import qualified Crypto.Hash.SHA256 as SHA256
--   >>> :t SHA256.hmac
--   SHA256.hmac :: BS.ByteString -> BS.ByteString -> BS.ByteString
--   >>> SHA256.hmac "my HMAC key" "my HMAC input"
--   <256-bit MAC>
type HMAC = BS.ByteString -> BS.ByteString -> BS.ByteString

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- serialize a 32-bit word, MSB first
ser32 :: Word32 -> BS.ByteString
ser32 :: Word32 -> ByteString
ser32 Word32
w =
  let !mask :: Word8
mask = Word8
0b00000000_00000000_00000000_11111111
      !w0 :: Word8
w0 = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
.>>. Int
24) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
mask
      !w1 :: Word8
w1 = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
.>>. Int
16) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
mask
      !w2 :: Word8
w2 = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
.>>. Int
08) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
mask
      !w3 :: Word8
w3 = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi Word32
w Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
mask
  in  Word8 -> ByteString -> ByteString
BS.cons Word8
w0 (Word8 -> ByteString -> ByteString
BS.cons Word8
w1 (Word8 -> ByteString -> ByteString
BS.cons Word8
w2 (Word8 -> ByteString
BS.singleton Word8
w3)))
{-# INLINE ser32 #-}

-- bytewise xor on bytestrings
xor :: BS.ByteString -> BS.ByteString -> BS.ByteString
xor :: ByteString -> ByteString -> ByteString
xor = (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString
BS.packZipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B.xor
{-# INLINE xor #-}

-- | Derive a key from a secret via the PBKDF2 key derivation function.
--
--   >>> :set -XOverloadedStrings
--   >>> import qualified Crypto.Hash.SHA256 as SHA256
--   >>> import qualified Data.ByteString as BS
--   >>> import qualified Data.ByteString.Base16 as B16
--   >>> let Just key = derive SHA256.hmac "passwd" "salt" 1 64
--   >>> BS.take 16 (B16.encode key)
--   "55ac046e56e3089f"
derive
  :: HMAC          -- ^ pseudo-random function (HMAC)
  -> BS.ByteString -- ^ password
  -> BS.ByteString -- ^ salt
  -> Word64        -- ^ iteration count
  -> Word32        -- ^ bytelength of derived key (max 0xffff_ffff * hlen)
  -> Maybe BS.ByteString -- ^ derived key
derive :: (ByteString -> ByteString -> ByteString)
-> ByteString -> ByteString -> Word64 -> Word32 -> Maybe ByteString
derive ByteString -> ByteString -> ByteString
prf ByteString
p ByteString
s Word64
c Word32
dklen = do
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Word32
dklen Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word32
0xffff_ffff Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fi Int
hlen)
    ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Builder -> Word32 -> ByteString
loop Builder
forall a. Monoid a => a
mempty Word32
1)
  where
    !hlen :: Int
hlen = ByteString -> Int
BS.length (ByteString -> ByteString -> ByteString
prf ByteString
forall a. Monoid a => a
mempty ByteString
forall a. Monoid a => a
mempty)
    !l :: Word32
l = Double -> Word32
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Word32 -> Double
forall a b. (Integral a, Num b) => a -> b
fi Word32
dklen Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fi Int
hlen :: Double) :: Word32
    !r :: Int
r = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fi (Word32
dklen Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- (Word32
l Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
1) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fi Int
hlen)

    f :: Word32 -> ByteString
f !Word32
i =
      let go :: Word64 -> ByteString -> ByteString -> ByteString
go Word64
j !ByteString
acc !ByteString
las
            | Word64
j Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
c = ByteString
acc
            | Bool
otherwise =
                let u :: ByteString
u = ByteString -> ByteString -> ByteString
prf ByteString
p ByteString
las
                    nacc :: ByteString
nacc = ByteString
acc ByteString -> ByteString -> ByteString
`xor` ByteString
u
                in  Word64 -> ByteString -> ByteString -> ByteString
go (Word64
j Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) ByteString
nacc ByteString
u

          org :: ByteString
org = ByteString -> ByteString -> ByteString
prf ByteString
p (ByteString
s ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word32 -> ByteString
ser32 Word32
i)

      in  Word64 -> ByteString -> ByteString -> ByteString
go Word64
1 ByteString
org ByteString
org
    {-# INLINE f #-}

    loop :: Builder -> Word32 -> ByteString
loop !Builder
acc !Word32
i
      | Word32
i Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
l =
          let t :: ByteString
t = Word32 -> ByteString
f Word32
i
              fin :: ByteString
fin = Int -> ByteString -> ByteString
BS.take Int
r ByteString
t
          in  ByteString -> ByteString
BS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
                if   Word32
dklen Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word32
128
                then AllocationStrategy -> ByteString -> Builder -> ByteString
BE.toLazyByteStringWith
                       (Int -> Int -> AllocationStrategy
BE.safeStrategy Int
128 Int
BE.smallChunkSize) ByteString
forall a. Monoid a => a
mempty (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
                       Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BSB.byteString ByteString
fin
                else Builder -> ByteString
BSB.toLazyByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
                       Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BSB.byteString ByteString
fin
      | Bool
otherwise =
          let t :: ByteString
t = Word32 -> ByteString
f Word32
i
              nacc :: Builder
nacc = Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BSB.byteString ByteString
t
          in  Builder -> Word32 -> ByteString
loop Builder
nacc (Word32
i Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
1)