{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module: Lightning.Protocol.BOLT4.Prim
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Low-level cryptographic primitives for BOLT4 onion routing.

module Lightning.Protocol.BOLT4.Prim (
    -- * Types
    SharedSecret(..)
  , DerivedKey(..)
  , BlindingFactor(..)

    -- * Key derivation
  , deriveRho
  , deriveMu
  , deriveUm
  , derivePad
  , deriveAmmag

    -- * Shared secret computation
  , computeSharedSecret

    -- * Blinding factor computation
  , computeBlindingFactor

    -- * Key blinding
  , blindPubKey
  , blindSecKey

    -- * Stream generation
  , generateStream

    -- * HMAC operations
  , computeHmac
  , verifyHmac
  ) where

import qualified Crypto.Cipher.ChaCha20 as ChaCha
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Crypto.Hash.SHA256 as SHA256
import Data.Bits (xor)
import qualified Data.ByteString as BS
import qualified Data.List as L
import Data.Word (Word8, Word32)
import qualified Numeric.Montgomery.Secp256k1.Scalar as S

-- | 32-byte shared secret derived from ECDH.
newtype SharedSecret = SharedSecret BS.ByteString
  deriving (SharedSecret -> SharedSecret -> Bool
(SharedSecret -> SharedSecret -> Bool)
-> (SharedSecret -> SharedSecret -> Bool) -> Eq SharedSecret
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SharedSecret -> SharedSecret -> Bool
== :: SharedSecret -> SharedSecret -> Bool
$c/= :: SharedSecret -> SharedSecret -> Bool
/= :: SharedSecret -> SharedSecret -> Bool
Eq, Int -> SharedSecret -> ShowS
[SharedSecret] -> ShowS
SharedSecret -> String
(Int -> SharedSecret -> ShowS)
-> (SharedSecret -> String)
-> ([SharedSecret] -> ShowS)
-> Show SharedSecret
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SharedSecret -> ShowS
showsPrec :: Int -> SharedSecret -> ShowS
$cshow :: SharedSecret -> String
show :: SharedSecret -> String
$cshowList :: [SharedSecret] -> ShowS
showList :: [SharedSecret] -> ShowS
Show)

-- | 32-byte derived key (rho, mu, um, pad, ammag).
newtype DerivedKey = DerivedKey BS.ByteString
  deriving (DerivedKey -> DerivedKey -> Bool
(DerivedKey -> DerivedKey -> Bool)
-> (DerivedKey -> DerivedKey -> Bool) -> Eq DerivedKey
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DerivedKey -> DerivedKey -> Bool
== :: DerivedKey -> DerivedKey -> Bool
$c/= :: DerivedKey -> DerivedKey -> Bool
/= :: DerivedKey -> DerivedKey -> Bool
Eq, Int -> DerivedKey -> ShowS
[DerivedKey] -> ShowS
DerivedKey -> String
(Int -> DerivedKey -> ShowS)
-> (DerivedKey -> String)
-> ([DerivedKey] -> ShowS)
-> Show DerivedKey
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DerivedKey -> ShowS
showsPrec :: Int -> DerivedKey -> ShowS
$cshow :: DerivedKey -> String
show :: DerivedKey -> String
$cshowList :: [DerivedKey] -> ShowS
showList :: [DerivedKey] -> ShowS
Show)

-- | 32-byte blinding factor for ephemeral key updates.
newtype BlindingFactor = BlindingFactor BS.ByteString
  deriving (BlindingFactor -> BlindingFactor -> Bool
(BlindingFactor -> BlindingFactor -> Bool)
-> (BlindingFactor -> BlindingFactor -> Bool) -> Eq BlindingFactor
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BlindingFactor -> BlindingFactor -> Bool
== :: BlindingFactor -> BlindingFactor -> Bool
$c/= :: BlindingFactor -> BlindingFactor -> Bool
/= :: BlindingFactor -> BlindingFactor -> Bool
Eq, Int -> BlindingFactor -> ShowS
[BlindingFactor] -> ShowS
BlindingFactor -> String
(Int -> BlindingFactor -> ShowS)
-> (BlindingFactor -> String)
-> ([BlindingFactor] -> ShowS)
-> Show BlindingFactor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BlindingFactor -> ShowS
showsPrec :: Int -> BlindingFactor -> ShowS
$cshow :: BlindingFactor -> String
show :: BlindingFactor -> String
$cshowList :: [BlindingFactor] -> ShowS
showList :: [BlindingFactor] -> ShowS
Show)

-- Key derivation ------------------------------------------------------------

-- | Derive rho key for obfuscation stream generation.
--
-- @rho = HMAC-SHA256(key="rho", data=shared_secret)@
deriveRho :: SharedSecret -> DerivedKey
deriveRho :: SharedSecret -> DerivedKey
deriveRho = ByteString -> SharedSecret -> DerivedKey
deriveKey ByteString
"rho"
{-# INLINE deriveRho #-}

-- | Derive mu key for HMAC computation.
--
-- @mu = HMAC-SHA256(key="mu", data=shared_secret)@
deriveMu :: SharedSecret -> DerivedKey
deriveMu :: SharedSecret -> DerivedKey
deriveMu = ByteString -> SharedSecret -> DerivedKey
deriveKey ByteString
"mu"
{-# INLINE deriveMu #-}

-- | Derive um key for return error HMAC.
--
-- @um = HMAC-SHA256(key="um", data=shared_secret)@
deriveUm :: SharedSecret -> DerivedKey
deriveUm :: SharedSecret -> DerivedKey
deriveUm = ByteString -> SharedSecret -> DerivedKey
deriveKey ByteString
"um"
{-# INLINE deriveUm #-}

-- | Derive pad key for filler generation.
--
-- @pad = HMAC-SHA256(key="pad", data=shared_secret)@
derivePad :: SharedSecret -> DerivedKey
derivePad :: SharedSecret -> DerivedKey
derivePad = ByteString -> SharedSecret -> DerivedKey
deriveKey ByteString
"pad"
{-# INLINE derivePad #-}

-- | Derive ammag key for error obfuscation.
--
-- @ammag = HMAC-SHA256(key="ammag", data=shared_secret)@
deriveAmmag :: SharedSecret -> DerivedKey
deriveAmmag :: SharedSecret -> DerivedKey
deriveAmmag = ByteString -> SharedSecret -> DerivedKey
deriveKey ByteString
"ammag"
{-# INLINE deriveAmmag #-}

-- Internal helper for key derivation.
deriveKey :: BS.ByteString -> SharedSecret -> DerivedKey
deriveKey :: ByteString -> SharedSecret -> DerivedKey
deriveKey !ByteString
keyType (SharedSecret !ByteString
ss) =
  let SHA256.MAC !ByteString
result = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
keyType ByteString
ss
  in  ByteString -> DerivedKey
DerivedKey ByteString
result
{-# INLINE deriveKey #-}

-- Shared secret computation -------------------------------------------------

-- | Compute shared secret from ECDH.
--
-- Takes a 32-byte secret key and a public key.
-- Returns SHA256 of the compressed ECDH point (33 bytes).
computeSharedSecret
  :: BS.ByteString         -- ^ 32-byte secret key
  -> Secp256k1.Projective  -- ^ public key
  -> Maybe SharedSecret
computeSharedSecret :: ByteString -> Projective -> Maybe SharedSecret
computeSharedSecret !ByteString
secBs !Projective
pub = do
  sec <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
secBs
  ecdhPoint <- Secp256k1.mul pub sec
  let !compressed = Projective -> ByteString
Secp256k1.serialize_point Projective
ecdhPoint
      !ss = ByteString -> ByteString
SHA256.hash ByteString
compressed
  pure $! SharedSecret ss
{-# INLINE computeSharedSecret #-}

-- Blinding factor -----------------------------------------------------------

-- | Compute blinding factor for ephemeral key updates.
--
-- @blinding_factor = SHA256(ephemeral_pubkey || shared_secret)@
computeBlindingFactor
  :: Secp256k1.Projective  -- ^ ephemeral public key
  -> SharedSecret          -- ^ shared secret
  -> BlindingFactor
computeBlindingFactor :: Projective -> SharedSecret -> BlindingFactor
computeBlindingFactor !Projective
pub (SharedSecret !ByteString
ss) =
  let !pubBytes :: ByteString
pubBytes = Projective -> ByteString
Secp256k1.serialize_point Projective
pub
      !combined :: ByteString
combined = ByteString
pubBytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
ss
      !hashed :: ByteString
hashed = ByteString -> ByteString
SHA256.hash ByteString
combined
  in  ByteString -> BlindingFactor
BlindingFactor ByteString
hashed
{-# INLINE computeBlindingFactor #-}

-- Key blinding --------------------------------------------------------------

-- | Blind a public key by multiplying with blinding factor.
--
-- @new_pubkey = pubkey * blinding_factor@
blindPubKey
  :: Secp256k1.Projective
  -> BlindingFactor
  -> Maybe Secp256k1.Projective
blindPubKey :: Projective -> BlindingFactor -> Maybe Projective
blindPubKey !Projective
pub (BlindingFactor !ByteString
bf) = do
  sk <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
bf
  Secp256k1.mul pub sk
{-# INLINE blindPubKey #-}

-- | Blind a secret key by multiplying with blinding factor (mod curve order).
--
-- @new_seckey = seckey * blinding_factor (mod q)@
--
-- Uses Montgomery multiplication from ppad-fixed for efficiency.
-- Takes a 32-byte secret key and returns a 32-byte blinded secret key.
blindSecKey
  :: BS.ByteString     -- ^ 32-byte secret key
  -> BlindingFactor    -- ^ blinding factor
  -> Maybe BS.ByteString  -- ^ 32-byte blinded secret key
blindSecKey :: ByteString -> BlindingFactor -> Maybe ByteString
blindSecKey !ByteString
secBs (BlindingFactor !ByteString
bf)
  | ByteString -> Int
BS.length ByteString
secBs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe ByteString
forall a. Maybe a
Nothing
  | ByteString -> Int
BS.length ByteString
bf Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let !secW :: Wider
secW = ByteString -> Wider
Secp256k1.unsafe_roll32 ByteString
secBs
          !bfW :: Wider
bfW = ByteString -> Wider
Secp256k1.unsafe_roll32 ByteString
bf
          !secM :: Montgomery
secM = Wider -> Montgomery
S.to Wider
secW
          !bfM :: Montgomery
bfM = Wider -> Montgomery
S.to Wider
bfW
          !resultM :: Montgomery
resultM = Montgomery -> Montgomery -> Montgomery
S.mul Montgomery
secM Montgomery
bfM
          !resultW :: Wider
resultW = Montgomery -> Wider
S.retr Montgomery
resultM
      in  ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$! Wider -> ByteString
Secp256k1.unroll32 Wider
resultW
{-# INLINE blindSecKey #-}

-- Stream generation ---------------------------------------------------------

-- | Generate pseudo-random byte stream using ChaCha20.
--
-- Uses derived key as ChaCha20 key, 96-bit zero nonce, counter=0.
-- Encrypts zeros to produce keystream.
generateStream
  :: DerivedKey     -- ^ rho or ammag key
  -> Int            -- ^ desired length
  -> BS.ByteString
generateStream :: DerivedKey -> Int -> ByteString
generateStream (DerivedKey !ByteString
key) !Int
len =
  let !nonce :: ByteString
nonce = Int -> Word8 -> ByteString
BS.replicate Int
12 Word8
0
      !zeros :: ByteString
zeros = Int -> Word8 -> ByteString
BS.replicate Int
len Word8
0
  in  (Error -> ByteString)
-> (ByteString -> ByteString)
-> Either Error ByteString
-> ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ByteString -> Error -> ByteString
forall a b. a -> b -> a
const (Int -> Word8 -> ByteString
BS.replicate Int
len Word8
0)) ByteString -> ByteString
forall a. a -> a
id
        (ByteString
-> Word32 -> ByteString -> ByteString -> Either Error ByteString
ChaCha.cipher ByteString
key (Word32
0 :: Word32) ByteString
nonce ByteString
zeros)
{-# INLINE generateStream #-}

-- HMAC operations -----------------------------------------------------------

-- | Compute HMAC-SHA256 for packet integrity.
computeHmac
  :: DerivedKey      -- ^ mu key
  -> BS.ByteString   -- ^ hop_payloads
  -> BS.ByteString   -- ^ associated_data
  -> BS.ByteString   -- ^ 32-byte HMAC
computeHmac :: DerivedKey -> ByteString -> ByteString -> ByteString
computeHmac (DerivedKey !ByteString
key) !ByteString
payloads !ByteString
assocData =
  let SHA256.MAC !ByteString
result = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
key (ByteString
payloads ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
assocData)
  in  ByteString
result
{-# INLINE computeHmac #-}

-- | Constant-time HMAC comparison.
verifyHmac
  :: BS.ByteString  -- ^ expected
  -> BS.ByteString  -- ^ computed
  -> Bool
verifyHmac :: ByteString -> ByteString -> Bool
verifyHmac !ByteString
expected !ByteString
computed
  | ByteString -> Int
BS.length ByteString
expected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
computed = Bool
False
  | Bool
otherwise = ByteString -> ByteString -> Bool
constantTimeEq ByteString
expected ByteString
computed
{-# INLINE verifyHmac #-}

-- Constant-time equality comparison.
constantTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constantTimeEq :: ByteString -> ByteString -> Bool
constantTimeEq !ByteString
a !ByteString
b =
  let !diff :: Word8
diff = (Word8 -> (Word8, Word8) -> Word8)
-> Word8 -> [(Word8, Word8)] -> Word8
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\Word8
acc (Word8
x, Word8
y) -> Word8
acc Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
y)) (Word8
0 :: Word8)
                       (ByteString -> ByteString -> [(Word8, Word8)]
BS.zip ByteString
a ByteString
b)
  in  Word8
diff Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
{-# INLINE constantTimeEq #-}