{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module: Lightning.Protocol.BOLT4.Process
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Onion packet processing for BOLT4.

module Lightning.Protocol.BOLT4.Process (
    -- * Processing
    process

    -- * Rejection reasons
  , RejectReason(..)
  ) where

import Data.Bits (xor)
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Data.ByteString as BS
import Data.Word (Word8)
import GHC.Generics (Generic)
import Lightning.Protocol.BOLT4.Codec
import Lightning.Protocol.BOLT4.Prim
import Lightning.Protocol.BOLT4.Types

-- | Reasons for rejecting a packet.
data RejectReason
  = InvalidVersion !Word8       -- ^ Version byte is not 0x00
  | InvalidEphemeralKey         -- ^ Malformed public key
  | HmacMismatch                -- ^ HMAC verification failed
  | InvalidPayload !String      -- ^ Malformed hop payload
  deriving (RejectReason -> RejectReason -> Bool
(RejectReason -> RejectReason -> Bool)
-> (RejectReason -> RejectReason -> Bool) -> Eq RejectReason
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RejectReason -> RejectReason -> Bool
== :: RejectReason -> RejectReason -> Bool
$c/= :: RejectReason -> RejectReason -> Bool
/= :: RejectReason -> RejectReason -> Bool
Eq, Int -> RejectReason -> ShowS
[RejectReason] -> ShowS
RejectReason -> String
(Int -> RejectReason -> ShowS)
-> (RejectReason -> String)
-> ([RejectReason] -> ShowS)
-> Show RejectReason
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RejectReason -> ShowS
showsPrec :: Int -> RejectReason -> ShowS
$cshow :: RejectReason -> String
show :: RejectReason -> String
$cshowList :: [RejectReason] -> ShowS
showList :: [RejectReason] -> ShowS
Show, (forall x. RejectReason -> Rep RejectReason x)
-> (forall x. Rep RejectReason x -> RejectReason)
-> Generic RejectReason
forall x. Rep RejectReason x -> RejectReason
forall x. RejectReason -> Rep RejectReason x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RejectReason -> Rep RejectReason x
from :: forall x. RejectReason -> Rep RejectReason x
$cto :: forall x. Rep RejectReason x -> RejectReason
to :: forall x. Rep RejectReason x -> RejectReason
Generic)

-- | Process an incoming onion packet.
--
-- Takes the receiving node's private key, the incoming packet, and
-- associated data (typically the payment hash).
--
-- Returns either a rejection reason or the processing result
-- (forward to next hop or receive at final destination).
process
  :: BS.ByteString    -- ^ 32-byte secret key of this node
  -> OnionPacket      -- ^ incoming onion packet
  -> BS.ByteString    -- ^ associated data (payment hash)
  -> Either RejectReason ProcessResult
process :: ByteString
-> OnionPacket -> ByteString -> Either RejectReason ProcessResult
process !ByteString
secKey !OnionPacket
packet !ByteString
assocData = do
  -- Step 1: Validate version
  OnionPacket -> Either RejectReason ()
validateVersion OnionPacket
packet

  -- Step 2: Parse ephemeral public key
  ephemeral <- OnionPacket -> Either RejectReason Projective
parseEphemeralKey OnionPacket
packet

  -- Step 3: Compute shared secret
  ss <- case computeSharedSecret secKey ephemeral of
    Maybe SharedSecret
Nothing -> RejectReason -> Either RejectReason SharedSecret
forall a b. a -> Either a b
Left RejectReason
InvalidEphemeralKey
    Just SharedSecret
s  -> SharedSecret -> Either RejectReason SharedSecret
forall a b. b -> Either a b
Right SharedSecret
s

  -- Step 4: Derive keys
  let !muKey = SharedSecret -> DerivedKey
deriveMu SharedSecret
ss
      !rhoKey = SharedSecret -> DerivedKey
deriveRho SharedSecret
ss

  -- Step 5: Verify HMAC
  if not (verifyPacketHmac muKey packet assocData)
    then Left HmacMismatch
    else pure ()

  -- Step 6: Decrypt hop payloads
  let !decrypted = DerivedKey -> ByteString -> ByteString
decryptPayloads DerivedKey
rhoKey (OnionPacket -> ByteString
opHopPayloads OnionPacket
packet)

  -- Step 7: Extract payload
  (payloadBytes, nextHmac, remaining) <- extractPayload decrypted

  -- Step 8: Parse payload TLV
  hopPayload <- case decodeHopPayload payloadBytes of
    Maybe HopPayload
Nothing -> RejectReason -> Either RejectReason HopPayload
forall a b. a -> Either a b
Left (String -> RejectReason
InvalidPayload String
"failed to decode TLV")
    Just HopPayload
hp -> HopPayload -> Either RejectReason HopPayload
forall a b. b -> Either a b
Right HopPayload
hp

  -- Step 9: Check if final hop
  let SharedSecret ssBytes = ss
  if isFinalHop nextHmac
    then Right $! Receive $! ReceiveInfo
      { riPayload = hopPayload
      , riSharedSecret = ssBytes
      }
    else do
      -- Step 10: Prepare forward packet
      nextPacket <- case prepareForward ephemeral ss remaining nextHmac of
        Maybe OnionPacket
Nothing -> RejectReason -> Either RejectReason OnionPacket
forall a b. a -> Either a b
Left RejectReason
InvalidEphemeralKey
        Just OnionPacket
np -> OnionPacket -> Either RejectReason OnionPacket
forall a b. b -> Either a b
Right OnionPacket
np

      Right $! Forward $! ForwardInfo
        { fiNextPacket = nextPacket
        , fiPayload = hopPayload
        , fiSharedSecret = ssBytes
        }

-- | Validate packet version is 0x00.
validateVersion :: OnionPacket -> Either RejectReason ()
validateVersion :: OnionPacket -> Either RejectReason ()
validateVersion !OnionPacket
packet
  | OnionPacket -> Word8
opVersion OnionPacket
packet Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
versionByte = () -> Either RejectReason ()
forall a b. b -> Either a b
Right ()
  | Bool
otherwise = RejectReason -> Either RejectReason ()
forall a b. a -> Either a b
Left (Word8 -> RejectReason
InvalidVersion (OnionPacket -> Word8
opVersion OnionPacket
packet))
{-# INLINE validateVersion #-}

-- | Parse and validate ephemeral public key from packet.
parseEphemeralKey :: OnionPacket -> Either RejectReason Secp256k1.Projective
parseEphemeralKey :: OnionPacket -> Either RejectReason Projective
parseEphemeralKey !OnionPacket
packet =
  case ByteString -> Maybe Projective
Secp256k1.parse_point (OnionPacket -> ByteString
opEphemeralKey OnionPacket
packet) of
    Maybe Projective
Nothing  -> RejectReason -> Either RejectReason Projective
forall a b. a -> Either a b
Left RejectReason
InvalidEphemeralKey
    Just Projective
pub -> Projective -> Either RejectReason Projective
forall a b. b -> Either a b
Right Projective
pub
{-# INLINE parseEphemeralKey #-}

-- | Decrypt hop payloads by XORing with rho stream.
--
-- Generates a stream of 2*1300 bytes and XORs with hop_payloads
-- extended with 1300 zero bytes.
decryptPayloads
  :: DerivedKey      -- ^ rho key
  -> BS.ByteString   -- ^ hop_payloads (1300 bytes)
  -> BS.ByteString   -- ^ decrypted (2600 bytes, first 1300 useful)
decryptPayloads :: DerivedKey -> ByteString -> ByteString
decryptPayloads !DerivedKey
rhoKey !ByteString
payloads =
  let !streamLen :: Int
streamLen = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hopPayloadsSize  -- 2600 bytes
      !stream :: ByteString
stream = DerivedKey -> Int -> ByteString
generateStream DerivedKey
rhoKey Int
streamLen
      -- Extend payloads with zeros for the shift operation
      !extended :: ByteString
extended = ByteString
payloads ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate Int
hopPayloadsSize Word8
0
  in  ByteString -> ByteString -> ByteString
xorBytes ByteString
stream ByteString
extended
{-# INLINE decryptPayloads #-}

-- | XOR two bytestrings of equal length.
xorBytes :: BS.ByteString -> BS.ByteString -> BS.ByteString
xorBytes :: ByteString -> ByteString -> ByteString
xorBytes !ByteString
a !ByteString
b = [Word8] -> ByteString
BS.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a ByteString
b)
{-# INLINE xorBytes #-}

-- | Extract payload from decrypted buffer.
--
-- Parses BigSize length prefix, extracts payload bytes and next HMAC.
extractPayload
  :: BS.ByteString
  -> Either RejectReason (BS.ByteString, BS.ByteString, BS.ByteString)
     -- ^ (payload_bytes, next_hmac, remaining_hop_payloads)
extractPayload :: ByteString
-> Either RejectReason (ByteString, ByteString, ByteString)
extractPayload !ByteString
decrypted = do
  -- Parse length prefix
  (len, afterLen) <- case ByteString -> Maybe (Word64, ByteString)
decodeBigSize ByteString
decrypted of
    Maybe (Word64, ByteString)
Nothing -> RejectReason -> Either RejectReason (Int, ByteString)
forall a b. a -> Either a b
Left (String -> RejectReason
InvalidPayload String
"invalid length prefix")
    Just (Word64
l, ByteString
r) -> (Int, ByteString) -> Either RejectReason (Int, ByteString)
forall a b. b -> Either a b
Right (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l :: Int, ByteString
r)

  -- Validate length
  if len > BS.length afterLen
    then Left (InvalidPayload "payload length exceeds buffer")
    else if len == 0
      then Left (InvalidPayload "zero-length payload")
      else pure ()

  -- Extract payload bytes
  let !payloadBytes = Int -> ByteString -> ByteString
BS.take Int
len ByteString
afterLen
      !afterPayload = Int -> ByteString -> ByteString
BS.drop Int
len ByteString
afterLen

  -- Extract next HMAC (32 bytes)
  if BS.length afterPayload < hmacSize
    then Left (InvalidPayload "insufficient bytes for HMAC")
    else do
      let !nextHmac = Int -> ByteString -> ByteString
BS.take Int
hmacSize ByteString
afterPayload
          -- Remaining payloads: skip the HMAC, take first 1300 bytes
          -- This is already "shifted" by the payload extraction
          !remaining = Int -> ByteString -> ByteString
BS.drop Int
hmacSize ByteString
afterPayload

      Right (payloadBytes, nextHmac, remaining)

-- | Verify packet HMAC.
--
-- Computes HMAC over (hop_payloads || associated_data) using mu key
-- and compares with packet's HMAC using constant-time comparison.
verifyPacketHmac
  :: DerivedKey      -- ^ mu key
  -> OnionPacket     -- ^ packet with HMAC to verify
  -> BS.ByteString   -- ^ associated data
  -> Bool
verifyPacketHmac :: DerivedKey -> OnionPacket -> ByteString -> Bool
verifyPacketHmac !DerivedKey
muKey !OnionPacket
packet !ByteString
assocData =
  let !computed :: ByteString
computed = DerivedKey -> ByteString -> ByteString -> ByteString
computeHmac DerivedKey
muKey (OnionPacket -> ByteString
opHopPayloads OnionPacket
packet) ByteString
assocData
  in  ByteString -> ByteString -> Bool
verifyHmac (OnionPacket -> ByteString
opHmac OnionPacket
packet) ByteString
computed
{-# INLINE verifyPacketHmac #-}

-- | Prepare packet for forwarding to next hop.
--
-- Computes blinded ephemeral key and constructs next OnionPacket.
prepareForward
  :: Secp256k1.Projective  -- ^ current ephemeral key
  -> SharedSecret          -- ^ shared secret (for blinding)
  -> BS.ByteString         -- ^ remaining hop_payloads (after shift)
  -> BS.ByteString         -- ^ next HMAC
  -> Maybe OnionPacket
prepareForward :: Projective
-> SharedSecret -> ByteString -> ByteString -> Maybe OnionPacket
prepareForward !Projective
ephemeral !SharedSecret
ss !ByteString
remaining !ByteString
nextHmac = do
  -- Compute blinding factor and blind ephemeral key
  let !bf :: BlindingFactor
bf = Projective -> SharedSecret -> BlindingFactor
computeBlindingFactor Projective
ephemeral SharedSecret
ss
  newEphemeral <- Projective -> BlindingFactor -> Maybe Projective
blindPubKey Projective
ephemeral BlindingFactor
bf

  -- Serialize new ephemeral key
  let !newEphBytes = Projective -> ByteString
Secp256k1.serialize_point Projective
newEphemeral

  -- Truncate remaining to exactly 1300 bytes
  let !newPayloads = Int -> ByteString -> ByteString
BS.take Int
hopPayloadsSize ByteString
remaining

  -- Construct next packet
  pure $! OnionPacket
    { opVersion = versionByte
    , opEphemeralKey = newEphBytes
    , opHopPayloads = newPayloads
    , opHmac = nextHmac
    }

-- | Check if this is the final hop.
--
-- Final hop is indicated by next_hmac being all zeros.
isFinalHop :: BS.ByteString -> Bool
isFinalHop :: ByteString -> Bool
isFinalHop !ByteString
hmac = ByteString
hmac ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
BS.replicate Int
hmacSize Word8
0
{-# INLINE isFinalHop #-}