{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedNewtypes #-}

-- |
-- Module: Crypto.Hash.SHA256
-- Copyright: (c) 2024 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- SHA-256 and HMAC-SHA256 implementations for
-- strict and lazy ByteStrings, as specified by RFC's
-- [6234](https://datatracker.ietf.org/doc/html/rfc6234) and
-- [2104](https://datatracker.ietf.org/doc/html/rfc2104).
--
-- The 'hash' and 'hmac' functions will use primitive instructions from
-- the ARM cryptographic extensions via FFI if they're available, and
-- will otherwise use a pure Haskell implementation.

module Crypto.Hash.SHA256 (
  -- * SHA-256 message digest functions
    hash
  , Lazy.hash_lazy

  -- * SHA256-based MAC functions
  , MAC(..)
  , hmac
  , Lazy.hmac_lazy
  ) where

import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Data.Word (Word64)
import Crypto.Hash.SHA256.Arm
import Crypto.Hash.SHA256.Internal
import qualified Crypto.Hash.SHA256.Lazy as Lazy

-- utils ---------------------------------------------------------------------

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- hash ----------------------------------------------------------------------

-- | Compute a condensed representation of a strict bytestring via
--   SHA-256.
--
--   The 256-bit output digest is returned as a strict bytestring.
--
--   >>> hash "strict bytestring input"
--   "<strict 256-bit message digest>"
hash :: BS.ByteString -> BS.ByteString
hash :: ByteString -> ByteString
hash ByteString
m
  | Bool
sha256_arm_available = ByteString -> ByteString
hash_arm ByteString
m
  | Bool
otherwise            = Registers -> ByteString
cat (ByteString -> Registers
process ByteString
m)

-- process a message, given the specified iv
process_with :: Registers -> Word64 -> BS.ByteString -> Registers
process_with :: Registers -> Word64 -> ByteString -> Registers
process_with Registers
acc0 Word64
el m :: ByteString
m@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) = Registers -> Registers
finalize (Registers -> Int -> Registers
go Registers
acc0 Int
0) where
  go :: Registers -> Int -> Registers
go !Registers
acc !Int
j
    | Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l = Registers -> Int -> Registers
go (Registers -> Block -> Registers
block_hash Registers
acc (ByteString -> Int -> Block
parse_block ByteString
m Int
j)) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64)
    | Bool
otherwise   = Registers
acc

  finalize :: Registers -> Registers
finalize !Registers
acc
      | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
56  = Registers -> Block -> Registers
block_hash Registers
acc (ByteString -> Int -> Block
parse_block ByteString
padded Int
0)
      | Bool
otherwise = Registers -> Block -> Registers
block_hash
          (Registers -> Block -> Registers
block_hash Registers
acc (ByteString -> Int -> Block
parse_block ByteString
padded Int
0))
          (ByteString -> Int -> Block
parse_block ByteString
padded Int
64)
    where
      !remaining :: ByteString
remaining@(BI.PS ForeignPtr Word8
_ Int
_ Int
len) = Int -> ByteString -> ByteString
BU.unsafeDrop (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
64) ByteString
m
      !padded :: ByteString
padded = ByteString -> Word64 -> ByteString
unsafe_padding ByteString
remaining (Word64
el Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi Int
l)

process :: BS.ByteString -> Registers
process :: ByteString -> Registers
process = Registers -> Word64 -> ByteString -> Registers
process_with (() -> Registers
iv ()) Word64
0

-- hmac ----------------------------------------------------------------------

data KeyAndLen = KeyAndLen
  {-# UNPACK #-} !BS.ByteString
  {-# UNPACK #-} !Int

-- | Produce a message authentication code for a strict bytestring,
--   based on the provided (strict, bytestring) key, via SHA-256.
--
--   The 256-bit MAC is returned as a strict bytestring.
--
--   Per RFC 2104, the key /should/ be a minimum of 32 bytes long. Keys
--   exceeding 64 bytes in length will first be hashed (via SHA-256).
--
--   >>> hmac "strict bytestring key" "strict bytestring input"
--   "<strict 256-bit MAC>"
hmac
  :: BS.ByteString -- ^ key
  -> BS.ByteString -- ^ text
  -> MAC
hmac :: ByteString -> ByteString -> MAC
hmac mk :: ByteString
mk@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) ByteString
text
    | Bool
sha256_arm_available =
        let !inner :: ByteString
inner = ByteString -> Word64 -> ByteString -> ByteString
hash_arm_with ByteString
ipad Word64
64 ByteString
text
        in  ByteString -> MAC
MAC (ByteString -> ByteString
hash_arm (ByteString
opad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
inner))
    | Bool
otherwise =
        let !ipad_state :: Registers
ipad_state = Registers -> Block -> Registers
block_hash (() -> Registers
iv ()) (ByteString -> Int -> Block
parse_block ByteString
ipad Int
0)
            !inner :: ByteString
inner = Registers -> ByteString
cat (Registers -> Word64 -> ByteString -> Registers
process_with Registers
ipad_state Word64
64 ByteString
text)
        in  ByteString -> MAC
MAC (ByteString -> ByteString
hash (ByteString
opad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
inner))
  where
    !step1 :: ByteString
step1 = ByteString
k ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate (Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lk) Word8
0x00
    !ipad :: ByteString
ipad  = (Word8 -> Word8) -> ByteString -> ByteString
BS.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B.xor Word8
0x36) ByteString
step1
    !opad :: ByteString
opad  = (Word8 -> Word8) -> ByteString -> ByteString
BS.map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B.xor Word8
0x5C) ByteString
step1
    !(KeyAndLen ByteString
k Int
lk)
      | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
64    = ByteString -> Int -> KeyAndLen
KeyAndLen (ByteString -> ByteString
hash ByteString
mk) Int
32
      | Bool
otherwise = ByteString -> Int -> KeyAndLen
KeyAndLen ByteString
mk Int
l