{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module: Lightning.Protocol.BOLT4.Blinding
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Route blinding for BOLT4 onion routing.

module Lightning.Protocol.BOLT4.Blinding (
    -- * Types
    BlindedPath(..)
  , BlindedHop(..)
  , BlindedHopData(..)
  , PaymentRelay(..)
  , PaymentConstraints(..)
  , BlindingError(..)

    -- * Path creation
  , createBlindedPath

    -- * Hop processing
  , processBlindedHop

    -- * Key derivation (exported for testing)
  , deriveBlindingRho
  , deriveBlindedNodeId
  , nextEphemeral

    -- * TLV encoding (exported for testing)
  , encodeBlindedHopData
  , decodeBlindedHopData

    -- * Encryption (exported for testing)
  , encryptHopData
  , decryptHopData
  ) where

import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as B
import Data.Word (Word16, Word32, Word64)
import qualified Numeric.Montgomery.Secp256k1.Scalar as S
import Lightning.Protocol.BOLT4.Codec
  ( encodeShortChannelId, decodeShortChannelId
  , encodeTlvStream, decodeTlvStream
  , toStrict, word16BE, word32BE
  , encodeWord64TU, decodeWord64TU
  , encodeWord32TU, decodeWord32TU
  )
import Lightning.Protocol.BOLT4.Prim (SharedSecret(..), DerivedKey(..))
import Lightning.Protocol.BOLT4.Types (ShortChannelId(..), TlvRecord(..))

-- Types ---------------------------------------------------------------------

-- | A blinded route provided by recipient.
data BlindedPath = BlindedPath
  { BlindedPath -> Projective
bpIntroductionNode :: !Secp256k1.Projective  -- ^ First node (unblinded)
  , BlindedPath -> Projective
bpBlindingKey      :: !Secp256k1.Projective  -- ^ E_0, initial ephemeral
  , BlindedPath -> [BlindedHop]
bpBlindedHops      :: ![BlindedHop]
  } deriving (BlindedPath -> BlindedPath -> Bool
(BlindedPath -> BlindedPath -> Bool)
-> (BlindedPath -> BlindedPath -> Bool) -> Eq BlindedPath
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BlindedPath -> BlindedPath -> Bool
== :: BlindedPath -> BlindedPath -> Bool
$c/= :: BlindedPath -> BlindedPath -> Bool
/= :: BlindedPath -> BlindedPath -> Bool
Eq, Int -> BlindedPath -> ShowS
[BlindedPath] -> ShowS
BlindedPath -> [Char]
(Int -> BlindedPath -> ShowS)
-> (BlindedPath -> [Char])
-> ([BlindedPath] -> ShowS)
-> Show BlindedPath
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BlindedPath -> ShowS
showsPrec :: Int -> BlindedPath -> ShowS
$cshow :: BlindedPath -> [Char]
show :: BlindedPath -> [Char]
$cshowList :: [BlindedPath] -> ShowS
showList :: [BlindedPath] -> ShowS
Show)

-- | A single hop in a blinded path.
data BlindedHop = BlindedHop
  { BlindedHop -> ByteString
bhBlindedNodeId :: !BS.ByteString  -- ^ 33 bytes, blinded pubkey
  , BlindedHop -> ByteString
bhEncryptedData :: !BS.ByteString  -- ^ Encrypted routing data
  } deriving (BlindedHop -> BlindedHop -> Bool
(BlindedHop -> BlindedHop -> Bool)
-> (BlindedHop -> BlindedHop -> Bool) -> Eq BlindedHop
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BlindedHop -> BlindedHop -> Bool
== :: BlindedHop -> BlindedHop -> Bool
$c/= :: BlindedHop -> BlindedHop -> Bool
/= :: BlindedHop -> BlindedHop -> Bool
Eq, Int -> BlindedHop -> ShowS
[BlindedHop] -> ShowS
BlindedHop -> [Char]
(Int -> BlindedHop -> ShowS)
-> (BlindedHop -> [Char])
-> ([BlindedHop] -> ShowS)
-> Show BlindedHop
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BlindedHop -> ShowS
showsPrec :: Int -> BlindedHop -> ShowS
$cshow :: BlindedHop -> [Char]
show :: BlindedHop -> [Char]
$cshowList :: [BlindedHop] -> ShowS
showList :: [BlindedHop] -> ShowS
Show)

-- | Data encrypted for each blinded hop (before encryption).
data BlindedHopData = BlindedHopData
  { BlindedHopData -> Maybe ByteString
bhdPadding             :: !(Maybe BS.ByteString)  -- ^ TLV 1
  , BlindedHopData -> Maybe ShortChannelId
bhdShortChannelId      :: !(Maybe ShortChannelId) -- ^ TLV 2
  , BlindedHopData -> Maybe ByteString
bhdNextNodeId          :: !(Maybe BS.ByteString)  -- ^ TLV 4, 33-byte pubkey
  , BlindedHopData -> Maybe ByteString
bhdPathId              :: !(Maybe BS.ByteString)  -- ^ TLV 6
  , BlindedHopData -> Maybe ByteString
bhdNextPathKeyOverride :: !(Maybe BS.ByteString)  -- ^ TLV 8
  , BlindedHopData -> Maybe PaymentRelay
bhdPaymentRelay        :: !(Maybe PaymentRelay)   -- ^ TLV 10
  , BlindedHopData -> Maybe PaymentConstraints
bhdPaymentConstraints  :: !(Maybe PaymentConstraints) -- ^ TLV 12
  , BlindedHopData -> Maybe ByteString
bhdAllowedFeatures     :: !(Maybe BS.ByteString)  -- ^ TLV 14
  } deriving (BlindedHopData -> BlindedHopData -> Bool
(BlindedHopData -> BlindedHopData -> Bool)
-> (BlindedHopData -> BlindedHopData -> Bool) -> Eq BlindedHopData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BlindedHopData -> BlindedHopData -> Bool
== :: BlindedHopData -> BlindedHopData -> Bool
$c/= :: BlindedHopData -> BlindedHopData -> Bool
/= :: BlindedHopData -> BlindedHopData -> Bool
Eq, Int -> BlindedHopData -> ShowS
[BlindedHopData] -> ShowS
BlindedHopData -> [Char]
(Int -> BlindedHopData -> ShowS)
-> (BlindedHopData -> [Char])
-> ([BlindedHopData] -> ShowS)
-> Show BlindedHopData
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BlindedHopData -> ShowS
showsPrec :: Int -> BlindedHopData -> ShowS
$cshow :: BlindedHopData -> [Char]
show :: BlindedHopData -> [Char]
$cshowList :: [BlindedHopData] -> ShowS
showList :: [BlindedHopData] -> ShowS
Show)

-- | Payment relay parameters (TLV 10).
data PaymentRelay = PaymentRelay
  { PaymentRelay -> Word16
prCltvExpiryDelta  :: {-# UNPACK #-} !Word16
  , PaymentRelay -> Word32
prFeeProportional  :: {-# UNPACK #-} !Word32  -- ^ Fee in millionths
  , PaymentRelay -> Word32
prFeeBaseMsat      :: {-# UNPACK #-} !Word32
  } deriving (PaymentRelay -> PaymentRelay -> Bool
(PaymentRelay -> PaymentRelay -> Bool)
-> (PaymentRelay -> PaymentRelay -> Bool) -> Eq PaymentRelay
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PaymentRelay -> PaymentRelay -> Bool
== :: PaymentRelay -> PaymentRelay -> Bool
$c/= :: PaymentRelay -> PaymentRelay -> Bool
/= :: PaymentRelay -> PaymentRelay -> Bool
Eq, Int -> PaymentRelay -> ShowS
[PaymentRelay] -> ShowS
PaymentRelay -> [Char]
(Int -> PaymentRelay -> ShowS)
-> (PaymentRelay -> [Char])
-> ([PaymentRelay] -> ShowS)
-> Show PaymentRelay
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PaymentRelay -> ShowS
showsPrec :: Int -> PaymentRelay -> ShowS
$cshow :: PaymentRelay -> [Char]
show :: PaymentRelay -> [Char]
$cshowList :: [PaymentRelay] -> ShowS
showList :: [PaymentRelay] -> ShowS
Show)

-- | Payment constraints (TLV 12).
data PaymentConstraints = PaymentConstraints
  { PaymentConstraints -> Word32
pcMaxCltvExpiry   :: {-# UNPACK #-} !Word32
  , PaymentConstraints -> Word64
pcHtlcMinimumMsat :: {-# UNPACK #-} !Word64
  } deriving (PaymentConstraints -> PaymentConstraints -> Bool
(PaymentConstraints -> PaymentConstraints -> Bool)
-> (PaymentConstraints -> PaymentConstraints -> Bool)
-> Eq PaymentConstraints
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PaymentConstraints -> PaymentConstraints -> Bool
== :: PaymentConstraints -> PaymentConstraints -> Bool
$c/= :: PaymentConstraints -> PaymentConstraints -> Bool
/= :: PaymentConstraints -> PaymentConstraints -> Bool
Eq, Int -> PaymentConstraints -> ShowS
[PaymentConstraints] -> ShowS
PaymentConstraints -> [Char]
(Int -> PaymentConstraints -> ShowS)
-> (PaymentConstraints -> [Char])
-> ([PaymentConstraints] -> ShowS)
-> Show PaymentConstraints
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PaymentConstraints -> ShowS
showsPrec :: Int -> PaymentConstraints -> ShowS
$cshow :: PaymentConstraints -> [Char]
show :: PaymentConstraints -> [Char]
$cshowList :: [PaymentConstraints] -> ShowS
showList :: [PaymentConstraints] -> ShowS
Show)

-- | Errors during blinding operations.
data BlindingError
  = InvalidSeed
  | EmptyPath
  | InvalidNodeKey Int
  | DecryptionFailed
  | InvalidPathKey
  deriving (BlindingError -> BlindingError -> Bool
(BlindingError -> BlindingError -> Bool)
-> (BlindingError -> BlindingError -> Bool) -> Eq BlindingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BlindingError -> BlindingError -> Bool
== :: BlindingError -> BlindingError -> Bool
$c/= :: BlindingError -> BlindingError -> Bool
/= :: BlindingError -> BlindingError -> Bool
Eq, Int -> BlindingError -> ShowS
[BlindingError] -> ShowS
BlindingError -> [Char]
(Int -> BlindingError -> ShowS)
-> (BlindingError -> [Char])
-> ([BlindingError] -> ShowS)
-> Show BlindingError
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BlindingError -> ShowS
showsPrec :: Int -> BlindingError -> ShowS
$cshow :: BlindingError -> [Char]
show :: BlindingError -> [Char]
$cshowList :: [BlindingError] -> ShowS
showList :: [BlindingError] -> ShowS
Show)

-- Key derivation ------------------------------------------------------------

-- | Derive rho key for encrypting hop data.
--
-- @rho = HMAC-SHA256(key="rho", data=shared_secret)@
deriveBlindingRho :: SharedSecret -> DerivedKey
deriveBlindingRho :: SharedSecret -> DerivedKey
deriveBlindingRho (SharedSecret !ByteString
ss) =
  let SHA256.MAC !ByteString
result = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
"rho" ByteString
ss
  in  ByteString -> DerivedKey
DerivedKey ByteString
result
{-# INLINE deriveBlindingRho #-}

-- | Derive blinded node ID from shared secret and node pubkey.
--
-- @B_i = HMAC256("blinded_node_id", ss_i) * N_i@
deriveBlindedNodeId
  :: SharedSecret
  -> Secp256k1.Projective
  -> Maybe BS.ByteString
deriveBlindedNodeId :: SharedSecret -> Projective -> Maybe ByteString
deriveBlindedNodeId (SharedSecret !ByteString
ss) !Projective
nodePub = do
  let SHA256.MAC !ByteString
hmacResult = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
"blinded_node_id" ByteString
ss
  sk <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
hmacResult
  blindedPub <- Secp256k1.mul nodePub sk
  pure $! Secp256k1.serialize_point blindedPub
{-# INLINE deriveBlindedNodeId #-}

-- | Compute next ephemeral key pair.
--
-- @e_{i+1} = SHA256(E_i || ss_i) * e_i@
-- @E_{i+1} = SHA256(E_i || ss_i) * E_i@
nextEphemeral
  :: BS.ByteString        -- ^ e_i (32-byte secret key)
  -> Secp256k1.Projective -- ^ E_i
  -> SharedSecret         -- ^ ss_i
  -> Maybe (BS.ByteString, Secp256k1.Projective)  -- ^ (e_{i+1}, E_{i+1})
nextEphemeral :: ByteString
-> Projective -> SharedSecret -> Maybe (ByteString, Projective)
nextEphemeral !ByteString
secKey !Projective
pubKey (SharedSecret !ByteString
ss) = do
  let !pubBytes :: ByteString
pubBytes = Projective -> ByteString
Secp256k1.serialize_point Projective
pubKey
      !blindingFactor :: ByteString
blindingFactor = ByteString -> ByteString
SHA256.hash (ByteString
pubBytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
ss)
  bfInt <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
blindingFactor
  -- Compute e_{i+1} = e_i * blindingFactor (mod q)
  let !newSecKey = ByteString -> ByteString -> ByteString
mulSecKey ByteString
secKey ByteString
blindingFactor
  -- Compute E_{i+1} = E_i * blindingFactor
  newPubKey <- Secp256k1.mul pubKey bfInt
  pure (newSecKey, newPubKey)
{-# INLINE nextEphemeral #-}

-- | Compute blinding factor for next path key (public key only).
nextPathKey
  :: Secp256k1.Projective -- ^ E_i
  -> SharedSecret         -- ^ ss_i
  -> Maybe Secp256k1.Projective  -- ^ E_{i+1}
nextPathKey :: Projective -> SharedSecret -> Maybe Projective
nextPathKey !Projective
pubKey (SharedSecret !ByteString
ss) = do
  let !pubBytes :: ByteString
pubBytes = Projective -> ByteString
Secp256k1.serialize_point Projective
pubKey
      !blindingFactor :: ByteString
blindingFactor = ByteString -> ByteString
SHA256.hash (ByteString
pubBytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
ss)
  bfInt <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
blindingFactor
  Secp256k1.mul pubKey bfInt
{-# INLINE nextPathKey #-}

-- Encryption/Decryption -----------------------------------------------------

-- | Encrypt hop data with ChaCha20-Poly1305.
--
-- Uses rho key and 12-byte zero nonce, empty AAD.
encryptHopData :: DerivedKey -> BlindedHopData -> BS.ByteString
encryptHopData :: DerivedKey -> BlindedHopData -> ByteString
encryptHopData (DerivedKey !ByteString
rho) !BlindedHopData
hopData =
  let !plaintext :: ByteString
plaintext = BlindedHopData -> ByteString
encodeBlindedHopData BlindedHopData
hopData
      !nonce :: ByteString
nonce = Int -> Word8 -> ByteString
BS.replicate Int
12 Word8
0
  in  case ByteString
-> ByteString
-> ByteString
-> ByteString
-> Either Error (ByteString, ByteString)
AEAD.encrypt ByteString
BS.empty ByteString
rho ByteString
nonce ByteString
plaintext of
        Left Error
e -> [Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char]
"encryptHopData: unexpected AEAD error: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Error -> [Char]
forall a. Show a => a -> [Char]
show Error
e
        Right (!ByteString
ciphertext, !ByteString
mac) -> ByteString
ciphertext ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
mac
{-# INLINE encryptHopData #-}

-- | Decrypt hop data with ChaCha20-Poly1305.
decryptHopData :: DerivedKey -> BS.ByteString -> Maybe BlindedHopData
decryptHopData :: DerivedKey -> ByteString -> Maybe BlindedHopData
decryptHopData (DerivedKey !ByteString
rho) !ByteString
encData
  | ByteString -> Int
BS.length ByteString
encData Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
16 = Maybe BlindedHopData
forall a. Maybe a
Nothing
  | Bool
otherwise = do
      let !ciphertext :: ByteString
ciphertext = Int -> ByteString -> ByteString
BS.take (ByteString -> Int
BS.length ByteString
encData Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
encData
          !mac :: ByteString
mac = Int -> ByteString -> ByteString
BS.drop (ByteString -> Int
BS.length ByteString
encData Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
encData
          !nonce :: ByteString
nonce = Int -> Word8 -> ByteString
BS.replicate Int
12 Word8
0
      case ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
-> Either Error ByteString
AEAD.decrypt ByteString
BS.empty ByteString
rho ByteString
nonce (ByteString
ciphertext, ByteString
mac) of
        Left Error
_ -> Maybe BlindedHopData
forall a. Maybe a
Nothing
        Right !ByteString
plaintext -> ByteString -> Maybe BlindedHopData
decodeBlindedHopData ByteString
plaintext
{-# INLINE decryptHopData #-}

-- TLV Encoding/Decoding -----------------------------------------------------

-- | Encode BlindedHopData to TLV stream.
encodeBlindedHopData :: BlindedHopData -> BS.ByteString
encodeBlindedHopData :: BlindedHopData -> ByteString
encodeBlindedHopData !BlindedHopData
bhd = [TlvRecord] -> ByteString
encodeTlvStream (BlindedHopData -> [TlvRecord]
buildTlvs BlindedHopData
bhd)
  where
    buildTlvs :: BlindedHopData -> [TlvRecord]
    buildTlvs :: BlindedHopData -> [TlvRecord]
buildTlvs (BlindedHopData Maybe ByteString
pad Maybe ShortChannelId
sci Maybe ByteString
nid Maybe ByteString
pid Maybe ByteString
pko Maybe PaymentRelay
pr Maybe PaymentConstraints
pc Maybe ByteString
af) =
      let pad' :: [TlvRecord]
pad'  = [TlvRecord]
-> (ByteString -> [TlvRecord]) -> Maybe ByteString -> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ByteString
p -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
1 ByteString
p]) Maybe ByteString
pad
          sci' :: [TlvRecord]
sci'  = [TlvRecord]
-> (ShortChannelId -> [TlvRecord])
-> Maybe ShortChannelId
-> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ShortChannelId
s -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
2 (ShortChannelId -> ByteString
encodeShortChannelId ShortChannelId
s)]) Maybe ShortChannelId
sci
          nid' :: [TlvRecord]
nid'  = [TlvRecord]
-> (ByteString -> [TlvRecord]) -> Maybe ByteString -> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ByteString
n -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
4 ByteString
n]) Maybe ByteString
nid
          pid' :: [TlvRecord]
pid'  = [TlvRecord]
-> (ByteString -> [TlvRecord]) -> Maybe ByteString -> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ByteString
p -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
6 ByteString
p]) Maybe ByteString
pid
          pko' :: [TlvRecord]
pko'  = [TlvRecord]
-> (ByteString -> [TlvRecord]) -> Maybe ByteString -> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ByteString
k -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
8 ByteString
k]) Maybe ByteString
pko
          pr' :: [TlvRecord]
pr'   = [TlvRecord]
-> (PaymentRelay -> [TlvRecord])
-> Maybe PaymentRelay
-> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\PaymentRelay
r -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
10 (PaymentRelay -> ByteString
encodePaymentRelay PaymentRelay
r)]) Maybe PaymentRelay
pr
          pc' :: [TlvRecord]
pc'   = [TlvRecord]
-> (PaymentConstraints -> [TlvRecord])
-> Maybe PaymentConstraints
-> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\PaymentConstraints
c -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
12 (PaymentConstraints -> ByteString
encodePaymentConstraints PaymentConstraints
c)]) Maybe PaymentConstraints
pc
          af' :: [TlvRecord]
af'   = [TlvRecord]
-> (ByteString -> [TlvRecord]) -> Maybe ByteString -> [TlvRecord]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\ByteString
f -> [Word64 -> ByteString -> TlvRecord
TlvRecord Word64
14 ByteString
f]) Maybe ByteString
af
      in  [TlvRecord]
pad' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
sci' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
nid' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
pid' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
pko' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
pr' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
pc' [TlvRecord] -> [TlvRecord] -> [TlvRecord]
forall a. [a] -> [a] -> [a]
++ [TlvRecord]
af'
{-# INLINE encodeBlindedHopData #-}

-- | Decode TLV stream to BlindedHopData.
decodeBlindedHopData :: BS.ByteString -> Maybe BlindedHopData
decodeBlindedHopData :: ByteString -> Maybe BlindedHopData
decodeBlindedHopData !ByteString
bs = do
  tlvs <- ByteString -> Maybe [TlvRecord]
decodeTlvStream ByteString
bs
  parseBlindedHopData tlvs

parseBlindedHopData :: [TlvRecord] -> Maybe BlindedHopData
parseBlindedHopData :: [TlvRecord] -> Maybe BlindedHopData
parseBlindedHopData = BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
emptyHopData
  where
    emptyHopData :: BlindedHopData
    emptyHopData :: BlindedHopData
emptyHopData = Maybe ByteString
-> Maybe ShortChannelId
-> Maybe ByteString
-> Maybe ByteString
-> Maybe ByteString
-> Maybe PaymentRelay
-> Maybe PaymentConstraints
-> Maybe ByteString
-> BlindedHopData
BlindedHopData
      Maybe ByteString
forall a. Maybe a
Nothing Maybe ShortChannelId
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing Maybe PaymentRelay
forall a. Maybe a
Nothing Maybe PaymentConstraints
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing

    go :: BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
    go :: BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go !BlindedHopData
bhd [] = BlindedHopData -> Maybe BlindedHopData
forall a. a -> Maybe a
Just BlindedHopData
bhd
    go !BlindedHopData
bhd (TlvRecord Word64
typ ByteString
val : [TlvRecord]
rest) = case Word64
typ of
      Word64
1  -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd { bhdPadding = Just val } [TlvRecord]
rest
      Word64
2  -> do
        sci <- ByteString -> Maybe ShortChannelId
decodeShortChannelId ByteString
val
        go bhd { bhdShortChannelId = Just sci } rest
      Word64
4  -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd { bhdNextNodeId = Just val } [TlvRecord]
rest
      Word64
6  -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd { bhdPathId = Just val } [TlvRecord]
rest
      Word64
8  -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd { bhdNextPathKeyOverride = Just val } [TlvRecord]
rest
      Word64
10 -> do
        pr <- ByteString -> Maybe PaymentRelay
decodePaymentRelay ByteString
val
        go bhd { bhdPaymentRelay = Just pr } rest
      Word64
12 -> do
        pc <- ByteString -> Maybe PaymentConstraints
decodePaymentConstraints ByteString
val
        go bhd { bhdPaymentConstraints = Just pc } rest
      Word64
14 -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd { bhdAllowedFeatures = Just val } [TlvRecord]
rest
      Word64
_  -> BlindedHopData -> [TlvRecord] -> Maybe BlindedHopData
go BlindedHopData
bhd [TlvRecord]
rest  -- Skip unknown TLVs

-- PaymentRelay encoding/decoding --------------------------------------------

-- | Encode PaymentRelay.
--
-- Format: 2-byte cltv_delta BE, 4-byte fee_prop BE, tu32 fee_base
encodePaymentRelay :: PaymentRelay -> BS.ByteString
encodePaymentRelay :: PaymentRelay -> ByteString
encodePaymentRelay (PaymentRelay !Word16
cltv !Word32
feeProp !Word32
feeBase) = Builder -> ByteString
toStrict (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
  Word16 -> Builder
B.word16BE Word16
cltv Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
  Word32 -> Builder
B.word32BE Word32
feeProp Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
  ByteString -> Builder
B.byteString (Word32 -> ByteString
encodeWord32TU Word32
feeBase)
{-# INLINE encodePaymentRelay #-}

-- | Decode PaymentRelay.
decodePaymentRelay :: BS.ByteString -> Maybe PaymentRelay
decodePaymentRelay :: ByteString -> Maybe PaymentRelay
decodePaymentRelay !ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
6 = Maybe PaymentRelay
forall a. Maybe a
Nothing
  | Bool
otherwise = do
      let !cltv :: Word16
cltv = ByteString -> Word16
word16BE (Int -> ByteString -> ByteString
BS.take Int
2 ByteString
bs)
          !feeProp :: Word32
feeProp = ByteString -> Word32
word32BE (Int -> ByteString -> ByteString
BS.take Int
4 (Int -> ByteString -> ByteString
BS.drop Int
2 ByteString
bs))
          !feeBaseBytes :: ByteString
feeBaseBytes = Int -> ByteString -> ByteString
BS.drop Int
6 ByteString
bs
      feeBase <- ByteString -> Maybe Word32
decodeWord32TU ByteString
feeBaseBytes
      Just (PaymentRelay cltv feeProp feeBase)
{-# INLINE decodePaymentRelay #-}

-- PaymentConstraints encoding/decoding --------------------------------------

-- | Encode PaymentConstraints.
--
-- Format: 4-byte max_cltv BE, tu64 htlc_min
encodePaymentConstraints :: PaymentConstraints -> BS.ByteString
encodePaymentConstraints :: PaymentConstraints -> ByteString
encodePaymentConstraints (PaymentConstraints !Word32
maxCltv !Word64
htlcMin) = Builder -> ByteString
toStrict (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
  Word32 -> Builder
B.word32BE Word32
maxCltv Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
  ByteString -> Builder
B.byteString (Word64 -> ByteString
encodeWord64TU Word64
htlcMin)
{-# INLINE encodePaymentConstraints #-}

-- | Decode PaymentConstraints.
decodePaymentConstraints :: BS.ByteString -> Maybe PaymentConstraints
decodePaymentConstraints :: ByteString -> Maybe PaymentConstraints
decodePaymentConstraints !ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4 = Maybe PaymentConstraints
forall a. Maybe a
Nothing
  | Bool
otherwise = do
      let !maxCltv :: Word32
maxCltv = ByteString -> Word32
word32BE (Int -> ByteString -> ByteString
BS.take Int
4 ByteString
bs)
          !htlcMinBytes :: ByteString
htlcMinBytes = Int -> ByteString -> ByteString
BS.drop Int
4 ByteString
bs
      htlcMin <- ByteString -> Maybe Word64
decodeWord64TU ByteString
htlcMinBytes
      Just (PaymentConstraints maxCltv htlcMin)
{-# INLINE decodePaymentConstraints #-}

-- Shared secret computation -------------------------------------------------

-- | Compute shared secret from ECDH.
computeSharedSecret
  :: BS.ByteString         -- ^ 32-byte secret key
  -> Secp256k1.Projective  -- ^ Public key
  -> Maybe SharedSecret
computeSharedSecret :: ByteString -> Projective -> Maybe SharedSecret
computeSharedSecret !ByteString
secBs !Projective
pub = do
  sec <- ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
secBs
  ecdhPoint <- Secp256k1.mul pub sec
  let !compressed = Projective -> ByteString
Secp256k1.serialize_point Projective
ecdhPoint
      !ss = ByteString -> ByteString
SHA256.hash ByteString
compressed
  pure $! SharedSecret ss
{-# INLINE computeSharedSecret #-}

-- Path creation -------------------------------------------------------------

-- | Create a blinded path from a seed and list of nodes with their data.
createBlindedPath
  :: BS.ByteString  -- ^ 32-byte random seed for ephemeral key
  -> [(Secp256k1.Projective, BlindedHopData)]  -- ^ Nodes with their data
  -> Either BlindingError BlindedPath
createBlindedPath :: ByteString
-> [(Projective, BlindedHopData)]
-> Either BlindingError BlindedPath
createBlindedPath !ByteString
seed ![(Projective, BlindedHopData)]
nodes
  | ByteString -> Int
BS.length ByteString
seed Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = BlindingError -> Either BlindingError BlindedPath
forall a b. a -> Either a b
Left BlindingError
InvalidSeed
  | Bool
otherwise = case [(Projective, BlindedHopData)]
nodes of
      [] -> BlindingError -> Either BlindingError BlindedPath
forall a b. a -> Either a b
Left BlindingError
EmptyPath
      ((Projective
introNode, BlindedHopData
_) : [(Projective, BlindedHopData)]
_) -> do
        -- (e_0, E_0) = keypair from seed
        e0 <- Either BlindingError Wider
-> (Wider -> Either BlindingError Wider)
-> Maybe Wider
-> Either BlindingError Wider
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (BlindingError -> Either BlindingError Wider
forall a b. a -> Either a b
Left BlindingError
InvalidSeed) Wider -> Either BlindingError Wider
forall a b. b -> Either a b
Right (ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
seed)
        e0Pub <- maybe (Left InvalidSeed) Right
                   (Secp256k1.mul Secp256k1._CURVE_G e0)
        -- Process all hops
        hops <- processHops seed e0Pub nodes 0
        Right (BlindedPath introNode e0Pub hops)

processHops
  :: BS.ByteString  -- ^ Current e_i
  -> Secp256k1.Projective  -- ^ Current E_i
  -> [(Secp256k1.Projective, BlindedHopData)]
  -> Int  -- ^ Index for error reporting
  -> Either BlindingError [BlindedHop]
processHops :: ByteString
-> Projective
-> [(Projective, BlindedHopData)]
-> Int
-> Either BlindingError [BlindedHop]
processHops ByteString
_ Projective
_ [] Int
_ = [BlindedHop] -> Either BlindingError [BlindedHop]
forall a b. b -> Either a b
Right []
processHops !ByteString
eKey !Projective
ePub ((Projective
nodePub, BlindedHopData
hopData) : [(Projective, BlindedHopData)]
rest) !Int
idx = do
  -- ss_i = SHA256(ECDH(e_i, N_i))
  ss <- Either BlindingError SharedSecret
-> (SharedSecret -> Either BlindingError SharedSecret)
-> Maybe SharedSecret
-> Either BlindingError SharedSecret
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (BlindingError -> Either BlindingError SharedSecret
forall a b. a -> Either a b
Left (Int -> BlindingError
InvalidNodeKey Int
idx)) SharedSecret -> Either BlindingError SharedSecret
forall a b. b -> Either a b
Right
          (ByteString -> Projective -> Maybe SharedSecret
computeSharedSecret ByteString
eKey Projective
nodePub)
  -- rho_i = deriveBlindingRho(ss_i)
  let !rho = SharedSecret -> DerivedKey
deriveBlindingRho SharedSecret
ss
  -- B_i = deriveBlindedNodeId(ss_i, N_i)
  blindedId <- maybe (Left (InvalidNodeKey idx)) Right
                 (deriveBlindedNodeId ss nodePub)
  -- encrypted_i = encryptHopData(rho_i, data_i)
  let !encData = DerivedKey -> BlindedHopData -> ByteString
encryptHopData DerivedKey
rho BlindedHopData
hopData
      !hop = ByteString -> ByteString -> BlindedHop
BlindedHop ByteString
blindedId ByteString
encData
  -- (e_{i+1}, E_{i+1}) = nextEphemeral(e_i, E_i, ss_i)
  (nextE, nextEPub) <- maybe (Left (InvalidNodeKey idx)) Right
                         (nextEphemeral eKey ePub ss)
  -- Process remaining hops
  restHops <- processHops nextE nextEPub rest (idx + 1)
  Right (hop : restHops)

-- Hop processing ------------------------------------------------------------

-- | Process a blinded hop, returning decrypted data and next path key.
processBlindedHop
  :: BS.ByteString        -- ^ Node's 32-byte private key
  -> Secp256k1.Projective -- ^ E_i, current path key (blinding point)
  -> BS.ByteString        -- ^ encrypted_data from onion payload
  -> Either BlindingError (BlindedHopData, Secp256k1.Projective)
processBlindedHop :: ByteString
-> Projective
-> ByteString
-> Either BlindingError (BlindedHopData, Projective)
processBlindedHop !ByteString
nodeSecKey !Projective
pathKey !ByteString
encData = do
  -- ss = SHA256(ECDH(node_seckey, path_key))
  ss <- Either BlindingError SharedSecret
-> (SharedSecret -> Either BlindingError SharedSecret)
-> Maybe SharedSecret
-> Either BlindingError SharedSecret
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (BlindingError -> Either BlindingError SharedSecret
forall a b. a -> Either a b
Left BlindingError
InvalidPathKey) SharedSecret -> Either BlindingError SharedSecret
forall a b. b -> Either a b
Right
          (ByteString -> Projective -> Maybe SharedSecret
computeSharedSecret ByteString
nodeSecKey Projective
pathKey)
  -- rho = deriveBlindingRho(ss)
  let !rho = SharedSecret -> DerivedKey
deriveBlindingRho SharedSecret
ss
  -- hop_data = decryptHopData(rho, encrypted_data)
  hopData <- maybe (Left DecryptionFailed) Right
               (decryptHopData rho encData)
  -- Compute next path key
  nextKey <- case bhdNextPathKeyOverride hopData of
    Just ByteString
override -> do
      -- Parse override as compressed point
      Either BlindingError Projective
-> (Projective -> Either BlindingError Projective)
-> Maybe Projective
-> Either BlindingError Projective
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (BlindingError -> Either BlindingError Projective
forall a b. a -> Either a b
Left BlindingError
InvalidPathKey) Projective -> Either BlindingError Projective
forall a b. b -> Either a b
Right (ByteString -> Maybe Projective
Secp256k1.parse_point ByteString
override)
    Maybe ByteString
Nothing -> do
      -- E_next = SHA256(path_key || ss) * path_key
      Either BlindingError Projective
-> (Projective -> Either BlindingError Projective)
-> Maybe Projective
-> Either BlindingError Projective
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (BlindingError -> Either BlindingError Projective
forall a b. a -> Either a b
Left BlindingError
InvalidPathKey) Projective -> Either BlindingError Projective
forall a b. b -> Either a b
Right (Projective -> SharedSecret -> Maybe Projective
nextPathKey Projective
pathKey SharedSecret
ss)
  Right (hopData, nextKey)

-- Scalar multiplication -----------------------------------------------------

-- | Multiply two 32-byte scalars mod curve order q.
--
-- Uses Montgomery multiplication from ppad-fixed for efficiency.
mulSecKey :: BS.ByteString -> BS.ByteString -> BS.ByteString
mulSecKey :: ByteString -> ByteString -> ByteString
mulSecKey !ByteString
a !ByteString
b =
  let !aW :: Wider
aW = ByteString -> Wider
Secp256k1.unsafe_roll32 ByteString
a
      !bW :: Wider
bW = ByteString -> Wider
Secp256k1.unsafe_roll32 ByteString
b
      !aM :: Montgomery
aM = Wider -> Montgomery
S.to Wider
aW
      !bM :: Montgomery
bM = Wider -> Montgomery
S.to Wider
bW
      !resultM :: Montgomery
resultM = Montgomery -> Montgomery -> Montgomery
S.mul Montgomery
aM Montgomery
bM
      !resultW :: Wider
resultW = Montgomery -> Wider
S.retr Montgomery
resultM
  in  Wider -> ByteString
Secp256k1.unroll32 Wider
resultW
{-# INLINE mulSecKey #-}