{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Crypto.Hash.SHA512.Lazy
-- Copyright: (c) 2024 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Pure SHA-512 and HMAC-SHA512 implementations for lazy ByteStrings,
-- as specified by RFC's
-- [6234](https://datatracker.ietf.org/doc/html/rfc6234) and
-- [2104](https://datatracker.ietf.org/doc/html/rfc2104).

module Crypto.Hash.SHA512.Lazy (
  -- * SHA-512 message digest functions
    hash_lazy

  -- * SHA512-based MAC functions
  , hmac_lazy
  ) where

import Crypto.Hash.SHA512.Internal
import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Builder.Extra as BE
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Internal as BLI
import Data.Word (Word64)
import Foreign.ForeignPtr (plusForeignPtr)

-- preliminary utils

-- keystroke saver
fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- utility types for more efficient ByteString management

data SSPair = SSPair
  {-# UNPACK #-} !BS.ByteString
  {-# UNPACK #-} !BS.ByteString

data SLPair = SLPair {-# UNPACK #-} !BS.ByteString !BL.ByteString

-- unsafe version of splitAt that does no bounds checking
--
-- invariant:
--   0 <= n <= l
unsafe_splitAt :: Int -> BS.ByteString -> SSPair
unsafe_splitAt :: Int -> ByteString -> SSPair
unsafe_splitAt Int
n (BI.BS ForeignPtr Word8
x Int
l) =
  ByteString -> ByteString -> SSPair
SSPair (ForeignPtr Word8 -> Int -> ByteString
BI.BS ForeignPtr Word8
x Int
n) (ForeignPtr Word8 -> Int -> ByteString
BI.BS (ForeignPtr Word8 -> Int -> ForeignPtr Word8
forall a b. ForeignPtr a -> Int -> ForeignPtr b
plusForeignPtr ForeignPtr Word8
x Int
n) (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n))

-- variant of Data.ByteString.Lazy.splitAt that returns the initial
-- component as a strict, unboxed ByteString
splitAt128 :: BL.ByteString -> SLPair
splitAt128 :: ByteString -> SLPair
splitAt128 = Int -> ByteString -> SLPair
splitAt' (Int
128 :: Int) where
  splitAt' :: Int -> ByteString -> SLPair
splitAt' Int
_ ByteString
BLI.Empty        = ByteString -> ByteString -> SLPair
SLPair ByteString
forall a. Monoid a => a
mempty ByteString
BLI.Empty
  splitAt' Int
n (BLI.Chunk c :: ByteString
c@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) ByteString
cs) =
    if    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l
    then
      -- n < BS.length c, so unsafe_splitAt is safe
      let !(SSPair ByteString
c0 ByteString
c1) = Int -> ByteString -> SSPair
unsafe_splitAt Int
n ByteString
c
      in  ByteString -> ByteString -> SLPair
SLPair ByteString
c0 (ByteString -> ByteString -> ByteString
BLI.Chunk ByteString
c1 ByteString
cs)
    else
      let SLPair ByteString
cs' ByteString
cs'' = Int -> ByteString -> SLPair
splitAt' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) ByteString
cs
      in  ByteString -> ByteString -> SLPair
SLPair (ByteString
c ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cs') ByteString
cs''

-- builder realization strategies

to_strict :: BSB.Builder -> BS.ByteString
to_strict :: Builder -> ByteString
to_strict = ByteString -> ByteString
BL.toStrict (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BSB.toLazyByteString

-- message padding and parsing
-- https://datatracker.ietf.org/doc/html/rfc6234#section-4.1

-- k such that (l + 1 + k) mod 128 = 112
sol :: Word64 -> Word64
sol :: Word64 -> Word64
sol Word64
l =
  let r :: Integer
r = Integer
112 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Word64 -> Integer
forall a b. (Integral a, Num b) => a -> b
fi Word64
l Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
128 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 :: Integer -- fi prevents underflow
  in  Integer -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (if Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 then Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
128 else Integer
r)

-- RFC 6234 4.1 (lazy)
pad_lazy :: BL.ByteString -> BL.ByteString
pad_lazy :: ByteString -> ByteString
pad_lazy (ByteString -> [ByteString]
BL.toChunks -> [ByteString]
m) = [ByteString] -> ByteString
BL.fromChunks (Word64 -> [ByteString] -> [ByteString]
walk Word64
0 [ByteString]
m) where
  walk :: Word64 -> [ByteString] -> [ByteString]
walk !Word64
l [ByteString]
bs = case [ByteString]
bs of
    (ByteString
c:[ByteString]
cs) -> ByteString
c ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Word64 -> [ByteString] -> [ByteString]
walk (Word64
l Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
c)) [ByteString]
cs
    [] -> Word64 -> Word64 -> Builder -> [ByteString]
forall {t} {f :: * -> *}.
(Eq t, Num t, Applicative f, Enum t) =>
Word64 -> t -> Builder -> f ByteString
padding Word64
l (Word64 -> Word64
sol Word64
l) (Word8 -> Builder
BSB.word8 Word8
0x80)

  padding :: Word64 -> t -> Builder -> f ByteString
padding Word64
l t
k Builder
bs
    | t
k t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 =
          ByteString -> f ByteString
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        (ByteString -> f ByteString)
-> (Builder -> ByteString) -> Builder -> f ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
to_strict
          -- more efficient for small builder
        (Builder -> f ByteString) -> Builder -> f ByteString
forall a b. (a -> b) -> a -> b
$ Builder
bs Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
BSB.word64BE Word64
0x00 Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
BSB.word64BE (Word64
l Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
8)
    | Bool
otherwise =
        let nacc :: Builder
nacc = Builder
bs Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word8 -> Builder
BSB.word8 Word8
0x00
        in  Word64 -> t -> Builder -> f ByteString
padding Word64
l (t -> t
forall a. Enum a => a -> a
pred t
k) Builder
nacc

-- | Compute a condensed representation of a lazy bytestring via
--   SHA-512.
--
--   The 512-bit output digest is returned as a strict bytestring.
--
--   >>> hash_lazy "lazy bytestring input"
--   "<strict 512-bit message digest>"
hash_lazy :: BL.ByteString -> BS.ByteString
hash_lazy :: ByteString -> ByteString
hash_lazy ByteString
bl = Registers -> ByteString
cat (Registers -> ByteString -> Registers
go Registers
iv (ByteString -> ByteString
pad_lazy ByteString
bl)) where
  go :: Registers -> BL.ByteString -> Registers
  go :: Registers -> ByteString -> Registers
go !Registers
acc ByteString
bs
    | ByteString -> Bool
BL.null ByteString
bs = Registers
acc
    | Bool
otherwise = case ByteString -> SLPair
splitAt128 ByteString
bs of
        SLPair ByteString
c ByteString
r -> Registers -> ByteString -> Registers
go (Registers -> ByteString -> Registers
unsafe_hash_alg Registers
acc ByteString
c) ByteString
r

-- HMAC -----------------------------------------------------------------------
-- https://datatracker.ietf.org/doc/html/rfc2104#section-2

data KeyAndLen = KeyAndLen
  {-# UNPACK #-} !BS.ByteString
  {-# UNPACK #-} !Int

-- | Produce a message authentication code for a lazy bytestring, based
--   on the provided (strict, bytestring) key, via SHA-512.
--
--   The 512-bit MAC is returned as a strict bytestring.
--
--   Per RFC 2104, the key /should/ be a minimum of 64 bytes long. Keys
--   exceeding 128 bytes in length will first be hashed (via SHA-512).
--
--   >>> hmac_lazy "strict bytestring key" "lazy bytestring input"
--   "<strict 512-bit MAC>"
hmac_lazy
  :: BS.ByteString -- ^ key
  -> BL.ByteString -- ^ text
  -> MAC
hmac_lazy :: ByteString -> ByteString -> MAC
hmac_lazy mk :: ByteString
mk@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) ByteString
text =
    let step1 :: ByteString
step1 = ByteString
k ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate (Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lk) Word8
0x00
        step2 :: ByteString
step2 = (Word8 -> Word8) -> ByteString -> ByteString
BS.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B.xor Word8
0x36) ByteString
step1
        step3 :: ByteString
step3 = ByteString -> ByteString
BL.fromStrict ByteString
step2 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
text
        step4 :: ByteString
step4 = ByteString -> ByteString
hash_lazy ByteString
step3
        step5 :: ByteString
step5 = (Word8 -> Word8) -> ByteString -> ByteString
BS.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B.xor Word8
0x5C) ByteString
step1
        step6 :: ByteString
step6 = ByteString
step5 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
step4
    in  ByteString -> MAC
MAC (ByteString -> ByteString
hash ByteString
step6)
  where
    hash :: ByteString -> ByteString
hash ByteString
bs = Registers -> ByteString
cat (Registers -> ByteString -> Registers
go Registers
iv (ByteString -> ByteString
pad ByteString
bs)) where
      go :: Registers -> BS.ByteString -> Registers
      go :: Registers -> ByteString -> Registers
go !Registers
acc ByteString
b
        | ByteString -> Bool
BS.null ByteString
b = Registers
acc
        | Bool
otherwise = case Int -> ByteString -> SSPair
unsafe_splitAt Int
128 ByteString
b of
            SSPair ByteString
c ByteString
r -> Registers -> ByteString -> Registers
go (Registers -> ByteString -> Registers
unsafe_hash_alg Registers
acc ByteString
c) ByteString
r

      pad :: ByteString -> ByteString
pad m :: ByteString
m@(BI.PS ForeignPtr Word8
_ Int
_ (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi -> Word64
len))
          | Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
256 = Builder -> ByteString
to_strict_small Builder
padded
          | Bool
otherwise = Builder -> ByteString
to_strict Builder
padded
        where
          padded :: Builder
padded = ByteString -> Builder
BSB.byteString ByteString
m
                Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder -> Builder
forall {t}. Integral t => t -> Builder -> Builder
fill (Word64 -> Word64
sol Word64
len) (Word8 -> Builder
BSB.word8 Word8
0x80)
                Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
BSB.word64BE Word64
0x00
                Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
BSB.word64BE (Word64
len Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
8)

          to_strict_small :: Builder -> ByteString
to_strict_small = ByteString -> ByteString
BL.toStrict (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocationStrategy -> ByteString -> Builder -> ByteString
BE.toLazyByteStringWith
            (Int -> Int -> AllocationStrategy
BE.safeStrategy Int
256 Int
BE.smallChunkSize) ByteString
forall a. Monoid a => a
mempty

          fill :: t -> Builder -> Builder
fill t
j !Builder
acc
            | t
j t -> t -> t
forall a. Integral a => a -> a -> a
`rem` t
8 t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 = t -> Builder -> Builder
forall {t}. (Eq t, Num t) => t -> Builder -> Builder
loop64 t
j Builder
acc
            | Bool
otherwise = t -> Builder -> Builder
forall {t}. (Eq t, Num t, Enum t) => t -> Builder -> Builder
loop8 t
j Builder
acc

          loop64 :: t -> Builder -> Builder
loop64 t
j !Builder
acc
            | t
j t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 = Builder
acc
            | Bool
otherwise = t -> Builder -> Builder
loop64 (t
j t -> t -> t
forall a. Num a => a -> a -> a
- t
8) (Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
BSB.word64BE Word64
0x00)

          loop8 :: t -> Builder -> Builder
loop8 t
j !Builder
acc
            | t
j t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 = Builder
acc
            | Bool
otherwise = t -> Builder -> Builder
loop8 (t -> t
forall a. Enum a => a -> a
pred t
j) (Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word8 -> Builder
BSB.word8 Word8
0x00)

    !(KeyAndLen ByteString
k Int
lk)
      | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
128   = ByteString -> Int -> KeyAndLen
KeyAndLen (ByteString -> ByteString
hash ByteString
mk) Int
64
      | Bool
otherwise = ByteString -> Int -> KeyAndLen
KeyAndLen ByteString
mk Int
l