{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Crypto.KDF.HMAC
-- Copyright: (c) 2024 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure HKDF implementation, as specified by
-- [RFC5869](https://datatracker.ietf.org/doc/html/rfc5869).

module Crypto.KDF.HMAC (
    -- * HMAC synonym
    HMAC

    -- * HMAC-based KDF
  , derive

    -- internals
  , extract
  , expand
  , HMACEnv
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Internal as BI
import Data.Word (Word64)

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- NB following synonym really only exists to make haddocks more
--    readable

-- | A HMAC function, taking a key as the first argument and the input
--   value as the second, producing a MAC digest.
--
--   >>> import qualified Crypto.Hash.SHA256 as SHA256
--   >>> :t SHA256.hmac
--   SHA256.hmac :: BS.ByteString -> BS.ByteString -> BS.ByteString
--   >>> SHA256.hmac "my HMAC key" "my HMAC input"
--   <256-bit MAC>
type HMAC = BS.ByteString -> BS.ByteString -> BS.ByteString

-- HMAC function and its associated outlength
data HMACEnv = HMACEnv
                 !HMAC
  {-# UNPACK #-} !Int

extract
  :: HMACEnv
  -> BS.ByteString  -- ^ salt
  -> BS.ByteString  -- ^ input keying material
  -> BS.ByteString  -- ^ pseudorandom key
extract :: HMACEnv -> ByteString -> ByteString -> ByteString
extract (HMACEnv ByteString -> ByteString -> ByteString
hmac Int
hashlen) salt :: ByteString
salt@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) ByteString
ikm
  | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = ByteString -> ByteString -> ByteString
hmac (Int -> Word8 -> ByteString
BS.replicate Int
hashlen Word8
0x00) ByteString
ikm
  | Bool
otherwise = ByteString -> ByteString -> ByteString
hmac ByteString
salt ByteString
ikm
{-# INLINE extract #-}

expand
  :: HMACEnv
  -> BS.ByteString  -- ^ optional context and application-specific info
  -> Word64         -- ^ bytelength of output keying material
  -> BS.ByteString  -- ^ pseudorandom key
  -> Maybe BS.ByteString  -- ^ output keying material
expand :: HMACEnv -> ByteString -> Word64 -> ByteString -> Maybe ByteString
expand (HMACEnv ByteString -> ByteString -> ByteString
hmac Int
hashlen) ByteString
info (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fi -> Int
len) ByteString
prk
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
255 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hashlen = Maybe ByteString
forall a. Maybe a
Nothing
    | Bool
otherwise = ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
len (Int -> Builder -> ByteString -> ByteString
forall {t}. Integral t => t -> Builder -> ByteString -> ByteString
go (Int
1 :: Int) Builder
forall a. Monoid a => a
mempty ByteString
forall a. Monoid a => a
mempty))
  where
    n :: Int
n = Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling ((Int -> Double
forall a b. (Integral a, Num b) => a -> b
fi Int
len :: Double) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fi Int
hashlen :: Double)) :: Int
    go :: t -> Builder -> ByteString -> ByteString
go !t
j Builder
t !ByteString
tl
      | t
j t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> t
forall a b. (Integral a, Num b) => a -> b
fi Int
n = ByteString -> ByteString
BS.toStrict (Builder -> ByteString
BSB.toLazyByteString Builder
t)
      | Bool
otherwise =
          let nt :: ByteString
nt = ByteString -> ByteString -> ByteString
hmac ByteString
prk (ByteString
tl ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
info ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton (t -> Word8
forall a b. (Integral a, Num b) => a -> b
fi t
j))
          in  t -> Builder -> ByteString -> ByteString
go (t -> t
forall a. Enum a => a -> a
succ t
j) (Builder
t Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BSB.byteString ByteString
nt) ByteString
nt
{-# INLINE expand #-}

-- | Derive a key from a secret, via a HMAC-based key derivation
--   function.
--
--   The /salt/ and /info/ arguments are optional to the KDF, and may
--   be simply passed as 'mempty'. An empty salt will be replaced by
--   /hashlen/ zero bytes.
--
--   >>> import qualified Crypto.Hash.SHA256 as SHA256
--   >>> derive SHA256.hmac "my public salt" mempty 64 "my secret input"
--   <64-byte output keying material>
derive
  :: HMAC          -- ^ HMAC function
  -> BS.ByteString -- ^ salt
  -> BS.ByteString -- ^ optional context and application-specific info
  -> Word64        -- ^ bytelength of output keying material (<= 255 * hashlen)
  -> BS.ByteString -- ^ input keying material
  -> Maybe BS.ByteString -- ^ output keying material
derive :: (ByteString -> ByteString -> ByteString)
-> ByteString
-> ByteString
-> Word64
-> ByteString
-> Maybe ByteString
derive ByteString -> ByteString -> ByteString
hmac ByteString
salt ByteString
info Word64
len = HMACEnv -> ByteString -> Word64 -> ByteString -> Maybe ByteString
expand HMACEnv
env ByteString
info Word64
len (ByteString -> Maybe ByteString)
-> (ByteString -> ByteString) -> ByteString -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HMACEnv -> ByteString -> ByteString -> ByteString
extract HMACEnv
env ByteString
salt where
  env :: HMACEnv
env = (ByteString -> ByteString -> ByteString) -> Int -> HMACEnv
HMACEnv ByteString -> ByteString -> ByteString
hmac (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length (ByteString -> ByteString -> ByteString
hmac ByteString
forall a. Monoid a => a
mempty ByteString
forall a. Monoid a => a
mempty)))