{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module: Lightning.Protocol.BOLT4.Construct
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Onion packet construction for BOLT4.

module Lightning.Protocol.BOLT4.Construct (
    -- * Types
    Hop(..)
  , Error(..)

    -- * Packet construction
  , construct
  ) where

import Data.Bits (xor)
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Data.ByteString as BS
import Lightning.Protocol.BOLT4.Codec
import Lightning.Protocol.BOLT4.Prim
import Lightning.Protocol.BOLT4.Types

-- | Route information for a single hop.
data Hop = Hop
  { Hop -> Projective
hopPubKey  :: !Secp256k1.Projective  -- ^ node's public key
  , Hop -> HopPayload
hopPayload :: !HopPayload            -- ^ routing data for this hop
  } deriving (Hop -> Hop -> Bool
(Hop -> Hop -> Bool) -> (Hop -> Hop -> Bool) -> Eq Hop
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Hop -> Hop -> Bool
== :: Hop -> Hop -> Bool
$c/= :: Hop -> Hop -> Bool
/= :: Hop -> Hop -> Bool
Eq, Int -> Hop -> ShowS
[Hop] -> ShowS
Hop -> String
(Int -> Hop -> ShowS)
-> (Hop -> String) -> ([Hop] -> ShowS) -> Show Hop
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Hop -> ShowS
showsPrec :: Int -> Hop -> ShowS
$cshow :: Hop -> String
show :: Hop -> String
$cshowList :: [Hop] -> ShowS
showList :: [Hop] -> ShowS
Show)

-- | Errors during packet construction.
data Error
  = InvalidSessionKey
  | EmptyRoute
  | TooManyHops
  | PayloadTooLarge !Int
  | InvalidHopPubKey !Int
  deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Error -> ShowS
showsPrec :: Int -> Error -> ShowS
$cshow :: Error -> String
show :: Error -> String
$cshowList :: [Error] -> ShowS
showList :: [Error] -> ShowS
Show)

-- | Maximum number of hops in a route.
maxHops :: Int
maxHops :: Int
maxHops = Int
20
{-# INLINE maxHops #-}

-- | Construct an onion packet for a payment route.
--
-- Takes a session key (32 bytes random), list of hops, and associated
-- data (typically payment_hash).
--
-- Returns the onion packet and list of shared secrets (for error
-- attribution).
construct
  :: BS.ByteString       -- ^ 32-byte session key (random)
  -> [Hop]               -- ^ route (first hop to final destination)
  -> BS.ByteString       -- ^ associated data
  -> Either Error (OnionPacket, [SharedSecret])
construct :: ByteString
-> [Hop]
-> ByteString
-> Either Error (OnionPacket, [SharedSecret])
construct !ByteString
sessionKey ![Hop]
hops !ByteString
assocData
  | ByteString -> Int
BS.length ByteString
sessionKey Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Error -> Either Error (OnionPacket, [SharedSecret])
forall a b. a -> Either a b
Left Error
InvalidSessionKey
  | [Hop] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Hop]
hops = Error -> Either Error (OnionPacket, [SharedSecret])
forall a b. a -> Either a b
Left Error
EmptyRoute
  | [Hop] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Hop]
hops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxHops = Error -> Either Error (OnionPacket, [SharedSecret])
forall a b. a -> Either a b
Left Error
TooManyHops
  | Bool
otherwise = do
      -- Initialize ephemeral keypair from session key
      ephSec <- Either Error Wider
-> (Wider -> Either Error Wider)
-> Maybe Wider
-> Either Error Wider
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Error -> Either Error Wider
forall a b. a -> Either a b
Left Error
InvalidSessionKey) Wider -> Either Error Wider
forall a b. b -> Either a b
Right
                  (ByteString -> Maybe Wider
Secp256k1.roll32 ByteString
sessionKey)
      ephPub <- maybe (Left InvalidSessionKey) Right
                  (Secp256k1.derive_pub ephSec)

      -- Compute shared secrets and blinding factors for all hops
      let hopPubKeys = (Hop -> Projective) -> [Hop] -> [Projective]
forall a b. (a -> b) -> [a] -> [b]
map Hop -> Projective
hopPubKey [Hop]
hops
      (secrets, _) <- computeAllSecrets sessionKey ephPub hopPubKeys

      -- Validate payload sizes
      let payloadBytes = (Hop -> ByteString) -> [Hop] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (HopPayload -> ByteString
encodeHopPayload (HopPayload -> ByteString)
-> (Hop -> HopPayload) -> Hop -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hop -> HopPayload
hopPayload) [Hop]
hops
          payloadSizes = (ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int
payloadShiftSize [ByteString]
payloadBytes
          totalSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
payloadSizes
      if totalSize > hopPayloadsSize
        then Left (PayloadTooLarge totalSize)
        else do
          -- Generate filler using secrets for all but final hop
          let numHops = [Hop] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Hop]
hops
              secretsExceptFinal = Int -> [SharedSecret] -> [SharedSecret]
forall a. Int -> [a] -> [a]
take (Int
numHops Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [SharedSecret]
secrets
              sizesExceptFinal = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (Int
numHops Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Int]
payloadSizes
              filler = [SharedSecret] -> [Int] -> ByteString
generateFiller [SharedSecret]
secretsExceptFinal [Int]
sizesExceptFinal

          -- Initialize hop_payloads with deterministic padding
          let DerivedKey padKey = derivePad (SharedSecret sessionKey)
              initialPayloads = DerivedKey -> Int -> ByteString
generateStream (ByteString -> DerivedKey
DerivedKey ByteString
padKey)
                                  Int
hopPayloadsSize

          -- Wrap payloads in reverse order (final hop first)
          let (finalPayloads, finalHmac) = wrapAllHops
                secrets payloadBytes filler assocData initialPayloads

          -- Build the final packet
          let ephPubBytes = Projective -> ByteString
Secp256k1.serialize_point Projective
ephPub
              packet = OnionPacket
                { opVersion :: Word8
opVersion = Word8
versionByte
                , opEphemeralKey :: ByteString
opEphemeralKey = ByteString
ephPubBytes
                , opHopPayloads :: ByteString
opHopPayloads = ByteString
finalPayloads
                , opHmac :: ByteString
opHmac = ByteString
finalHmac
                }

          Right (packet, secrets)

-- | Compute the total shift size for a payload.
payloadShiftSize :: BS.ByteString -> Int
payloadShiftSize :: ByteString -> Int
payloadShiftSize !ByteString
payload =
  let !len :: Int
len = ByteString -> Int
BS.length ByteString
payload
      !bsLen :: Int
bsLen = Word64 -> Int
bigSizeLen (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
  in  Int
bsLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
hmacSize
{-# INLINE payloadShiftSize #-}

-- | Compute shared secrets for all hops.
computeAllSecrets
  :: BS.ByteString
  -> Secp256k1.Projective
  -> [Secp256k1.Projective]
  -> Either Error ([SharedSecret], Secp256k1.Projective)
computeAllSecrets :: ByteString
-> Projective
-> [Projective]
-> Either Error ([SharedSecret], Projective)
computeAllSecrets !ByteString
initSec !Projective
initPub = ByteString
-> Projective
-> Int
-> [SharedSecret]
-> [Projective]
-> Either Error ([SharedSecret], Projective)
go ByteString
initSec Projective
initPub Int
0 []
  where
    go :: ByteString
-> Projective
-> Int
-> [SharedSecret]
-> [Projective]
-> Either Error ([SharedSecret], Projective)
go !ByteString
_ephSec !Projective
ephPub !Int
_ ![SharedSecret]
acc [] = ([SharedSecret], Projective)
-> Either Error ([SharedSecret], Projective)
forall a b. b -> Either a b
Right ([SharedSecret] -> [SharedSecret]
forall a. [a] -> [a]
reverse [SharedSecret]
acc, Projective
ephPub)
    go !ByteString
ephSec !Projective
ephPub !Int
idx ![SharedSecret]
acc (Projective
hopPub:[Projective]
rest) = do
      ss <- Either Error SharedSecret
-> (SharedSecret -> Either Error SharedSecret)
-> Maybe SharedSecret
-> Either Error SharedSecret
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Error -> Either Error SharedSecret
forall a b. a -> Either a b
Left (Int -> Error
InvalidHopPubKey Int
idx)) SharedSecret -> Either Error SharedSecret
forall a b. b -> Either a b
Right
              (ByteString -> Projective -> Maybe SharedSecret
computeSharedSecret ByteString
ephSec Projective
hopPub)
      let !bf = Projective -> SharedSecret -> BlindingFactor
computeBlindingFactor Projective
ephPub SharedSecret
ss
      newEphSec <- maybe (Left (InvalidHopPubKey idx)) Right
                     (blindSecKey ephSec bf)
      newEphPub <- maybe (Left (InvalidHopPubKey idx)) Right
                     (blindPubKey ephPub bf)
      go newEphSec newEphPub (idx + 1) (ss : acc) rest

-- | Generate filler bytes.
generateFiller :: [SharedSecret] -> [Int] -> BS.ByteString
generateFiller :: [SharedSecret] -> [Int] -> ByteString
generateFiller ![SharedSecret]
secrets ![Int]
sizes = ByteString -> [SharedSecret] -> [Int] -> ByteString
go ByteString
BS.empty [SharedSecret]
secrets [Int]
sizes
  where
    go :: ByteString -> [SharedSecret] -> [Int] -> ByteString
go !ByteString
filler [] [] = ByteString
filler
    go !ByteString
filler (SharedSecret
ss:[SharedSecret]
sss) (Int
sz:[Int]
szs) =
      let !extended :: ByteString
extended = ByteString
filler ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate Int
sz Word8
0
          !rhoKey :: DerivedKey
rhoKey = SharedSecret -> DerivedKey
deriveRho SharedSecret
ss
          !stream :: ByteString
stream = DerivedKey -> Int -> ByteString
generateStream DerivedKey
rhoKey (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hopPayloadsSize)
          !streamOffset :: Int
streamOffset = Int
hopPayloadsSize
          !streamPart :: ByteString
streamPart = Int -> ByteString -> ByteString
BS.take (ByteString -> Int
BS.length ByteString
extended)
                          (Int -> ByteString -> ByteString
BS.drop Int
streamOffset ByteString
stream)
          !newFiller :: ByteString
newFiller = ByteString -> ByteString -> ByteString
xorBytes ByteString
extended ByteString
streamPart
      in  ByteString -> [SharedSecret] -> [Int] -> ByteString
go ByteString
newFiller [SharedSecret]
sss [Int]
szs
    go !ByteString
filler [SharedSecret]
_ [Int]
_ = ByteString
filler
{-# INLINE generateFiller #-}

-- | Wrap all hops in reverse order.
wrapAllHops
  :: [SharedSecret]
  -> [BS.ByteString]
  -> BS.ByteString
  -> BS.ByteString
  -> BS.ByteString
  -> (BS.ByteString, BS.ByteString)
wrapAllHops :: [SharedSecret]
-> [ByteString]
-> ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
wrapAllHops ![SharedSecret]
secrets ![ByteString]
payloads !ByteString
filler !ByteString
assocData !ByteString
initPayloads =
  let !paired :: [(SharedSecret, ByteString)]
paired = [(SharedSecret, ByteString)] -> [(SharedSecret, ByteString)]
forall a. [a] -> [a]
reverse ([SharedSecret] -> [ByteString] -> [(SharedSecret, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SharedSecret]
secrets [ByteString]
payloads)
      !numHops :: Int
numHops = [(SharedSecret, ByteString)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SharedSecret, ByteString)]
paired
      !initHmac :: ByteString
initHmac = Int -> Word8 -> ByteString
BS.replicate Int
hmacSize Word8
0
  in  Int
-> ByteString
-> ByteString
-> [(SharedSecret, ByteString)]
-> (ByteString, ByteString)
go Int
numHops ByteString
initPayloads ByteString
initHmac [(SharedSecret, ByteString)]
paired
  where
    go :: Int
-> ByteString
-> ByteString
-> [(SharedSecret, ByteString)]
-> (ByteString, ByteString)
go !Int
_ !ByteString
hopPayloads !ByteString
hmac [] = (ByteString
hopPayloads, ByteString
hmac)
    go !Int
remaining !ByteString
hopPayloads !ByteString
hmac ((SharedSecret
ss, ByteString
payload):[(SharedSecret, ByteString)]
rest) =
      let !isLastHop :: Bool
isLastHop = Int
remaining Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [(SharedSecret, ByteString)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([(SharedSecret, ByteString)] -> [(SharedSecret, ByteString)]
forall a. [a] -> [a]
reverse ([SharedSecret] -> [ByteString] -> [(SharedSecret, ByteString)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SharedSecret]
secrets [ByteString]
payloads))
          (!ByteString
newPayloads, !ByteString
newHmac) = SharedSecret
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> Bool
-> (ByteString, ByteString)
wrapHop SharedSecret
ss ByteString
payload ByteString
hmac ByteString
hopPayloads
                                       ByteString
assocData ByteString
filler Bool
isLastHop
      in  Int
-> ByteString
-> ByteString
-> [(SharedSecret, ByteString)]
-> (ByteString, ByteString)
go (Int
remaining Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
newPayloads ByteString
newHmac [(SharedSecret, ByteString)]
rest

-- | Wrap a single hop's payload.
wrapHop
  :: SharedSecret
  -> BS.ByteString
  -> BS.ByteString
  -> BS.ByteString
  -> BS.ByteString
  -> BS.ByteString
  -> Bool
  -> (BS.ByteString, BS.ByteString)
wrapHop :: SharedSecret
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> Bool
-> (ByteString, ByteString)
wrapHop !SharedSecret
ss !ByteString
payload !ByteString
hmac !ByteString
hopPayloads !ByteString
assocData !ByteString
filler !Bool
isFinalHop =
  let !payloadLen :: Int
payloadLen = ByteString -> Int
BS.length ByteString
payload
      !lenBytes :: ByteString
lenBytes = Word64 -> ByteString
encodeBigSize (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
payloadLen)
      !shiftSize :: Int
shiftSize = ByteString -> Int
BS.length ByteString
lenBytes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
payloadLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
hmacSize
      !shifted :: ByteString
shifted = Int -> ByteString -> ByteString
BS.take (Int
hopPayloadsSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
shiftSize) ByteString
hopPayloads
      !prepended :: ByteString
prepended = ByteString
lenBytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
payload ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
hmac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
shifted
      !rhoKey :: DerivedKey
rhoKey = SharedSecret -> DerivedKey
deriveRho SharedSecret
ss
      !stream :: ByteString
stream = DerivedKey -> Int -> ByteString
generateStream DerivedKey
rhoKey Int
hopPayloadsSize
      !obfuscated :: ByteString
obfuscated = ByteString -> ByteString -> ByteString
xorBytes ByteString
prepended ByteString
stream
      !withFiller :: ByteString
withFiller = if Bool
isFinalHop Bool -> Bool -> Bool
&& Bool -> Bool
not (ByteString -> Bool
BS.null ByteString
filler)
                      then ByteString -> ByteString -> ByteString
applyFiller ByteString
obfuscated ByteString
filler
                      else ByteString
obfuscated
      !muKey :: DerivedKey
muKey = SharedSecret -> DerivedKey
deriveMu SharedSecret
ss
      !newHmac :: ByteString
newHmac = DerivedKey -> ByteString -> ByteString -> ByteString
computeHmac DerivedKey
muKey ByteString
withFiller ByteString
assocData
  in  (ByteString
withFiller, ByteString
newHmac)
{-# INLINE wrapHop #-}

-- | Apply filler to the tail of hop_payloads.
applyFiller :: BS.ByteString -> BS.ByteString -> BS.ByteString
applyFiller :: ByteString -> ByteString -> ByteString
applyFiller !ByteString
hopPayloads !ByteString
filler =
  let !fillerLen :: Int
fillerLen = ByteString -> Int
BS.length ByteString
filler
      !prefix :: ByteString
prefix = Int -> ByteString -> ByteString
BS.take (Int
hopPayloadsSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fillerLen) ByteString
hopPayloads
  in  ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
filler
{-# INLINE applyFiller #-}

-- | XOR two ByteStrings.
xorBytes :: BS.ByteString -> BS.ByteString -> BS.ByteString
xorBytes :: ByteString -> ByteString -> ByteString
xorBytes !ByteString
a !ByteString
b = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a ByteString
b
{-# INLINE xorBytes #-}