{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module: Lightning.Protocol.BOLT4.Error
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Error packet construction and unwrapping for BOLT4 onion routing.
--
-- Failing nodes construct error packets that are wrapped at each
-- intermediate hop on the return path. The origin node unwraps
-- layers to attribute the error to a specific hop.

module Lightning.Protocol.BOLT4.Error (
    -- * Types
    ErrorPacket(..)
  , AttributionResult(..)
  , minErrorPacketSize

    -- * Error construction (failing node)
  , constructError

    -- * Error forwarding (intermediate node)
  , wrapError

    -- * Error unwrapping (origin node)
  , unwrapError
  ) where

import Data.Bits (xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Lazy as BL
import qualified Crypto.Hash.SHA256 as SHA256
import Data.Word (Word8, Word16)
import Lightning.Protocol.BOLT4.Codec (encodeFailureMessage, decodeFailureMessage)
import Lightning.Protocol.BOLT4.Prim
import Lightning.Protocol.BOLT4.Types (FailureMessage)

-- | Wrapped error packet ready for return to origin.
newtype ErrorPacket = ErrorPacket BS.ByteString
  deriving (ErrorPacket -> ErrorPacket -> Bool
(ErrorPacket -> ErrorPacket -> Bool)
-> (ErrorPacket -> ErrorPacket -> Bool) -> Eq ErrorPacket
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ErrorPacket -> ErrorPacket -> Bool
== :: ErrorPacket -> ErrorPacket -> Bool
$c/= :: ErrorPacket -> ErrorPacket -> Bool
/= :: ErrorPacket -> ErrorPacket -> Bool
Eq, Int -> ErrorPacket -> ShowS
[ErrorPacket] -> ShowS
ErrorPacket -> String
(Int -> ErrorPacket -> ShowS)
-> (ErrorPacket -> String)
-> ([ErrorPacket] -> ShowS)
-> Show ErrorPacket
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ErrorPacket -> ShowS
showsPrec :: Int -> ErrorPacket -> ShowS
$cshow :: ErrorPacket -> String
show :: ErrorPacket -> String
$cshowList :: [ErrorPacket] -> ShowS
showList :: [ErrorPacket] -> ShowS
Show)

-- | Result of error attribution.
data AttributionResult
  = Attributed {-# UNPACK #-} !Int !FailureMessage
    -- ^ (hop index, failure)
  | UnknownOrigin !BS.ByteString
    -- ^ Could not attribute to any hop
  deriving (AttributionResult -> AttributionResult -> Bool
(AttributionResult -> AttributionResult -> Bool)
-> (AttributionResult -> AttributionResult -> Bool)
-> Eq AttributionResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AttributionResult -> AttributionResult -> Bool
== :: AttributionResult -> AttributionResult -> Bool
$c/= :: AttributionResult -> AttributionResult -> Bool
/= :: AttributionResult -> AttributionResult -> Bool
Eq, Int -> AttributionResult -> ShowS
[AttributionResult] -> ShowS
AttributionResult -> String
(Int -> AttributionResult -> ShowS)
-> (AttributionResult -> String)
-> ([AttributionResult] -> ShowS)
-> Show AttributionResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AttributionResult -> ShowS
showsPrec :: Int -> AttributionResult -> ShowS
$cshow :: AttributionResult -> String
show :: AttributionResult -> String
$cshowList :: [AttributionResult] -> ShowS
showList :: [AttributionResult] -> ShowS
Show)

-- | Minimum error packet size (256 bytes per spec).
minErrorPacketSize :: Int
minErrorPacketSize :: Int
minErrorPacketSize = Int
256
{-# INLINE minErrorPacketSize #-}

-- Error construction ---------------------------------------------------------

-- | Construct an error packet at a failing node.
--
-- Takes the shared secret (from processing) and failure message,
-- and wraps it for return to origin.
constructError
  :: SharedSecret      -- ^ from packet processing
  -> FailureMessage    -- ^ failure details
  -> ErrorPacket
constructError :: SharedSecret -> FailureMessage -> ErrorPacket
constructError !SharedSecret
ss !FailureMessage
failure =
  let !um :: DerivedKey
um = SharedSecret -> DerivedKey
deriveUm SharedSecret
ss
      !ammag :: DerivedKey
ammag = SharedSecret -> DerivedKey
deriveAmmag SharedSecret
ss
      !inner :: ByteString
inner = DerivedKey -> FailureMessage -> ByteString
buildErrorMessage DerivedKey
um FailureMessage
failure
      !obfuscated :: ByteString
obfuscated = DerivedKey -> ByteString -> ByteString
obfuscateError DerivedKey
ammag ByteString
inner
  in  ByteString -> ErrorPacket
ErrorPacket ByteString
obfuscated
{-# INLINE constructError #-}

-- | Wrap an existing error packet for forwarding back.
--
-- Each intermediate node wraps the error with its own layer.
wrapError
  :: SharedSecret      -- ^ this node's shared secret
  -> ErrorPacket       -- ^ error from downstream
  -> ErrorPacket
wrapError :: SharedSecret -> ErrorPacket -> ErrorPacket
wrapError !SharedSecret
ss (ErrorPacket !ByteString
packet) =
  let !ammag :: DerivedKey
ammag = SharedSecret -> DerivedKey
deriveAmmag SharedSecret
ss
      !wrapped :: ByteString
wrapped = DerivedKey -> ByteString -> ByteString
obfuscateError DerivedKey
ammag ByteString
packet
  in  ByteString -> ErrorPacket
ErrorPacket ByteString
wrapped
{-# INLINE wrapError #-}

-- Error unwrapping -----------------------------------------------------------

-- | Attempt to attribute an error to a specific hop.
--
-- Takes the shared secrets from original packet construction
-- (in order from first hop to final) and the error packet.
--
-- Tries each hop's keys until HMAC verifies, revealing origin.
unwrapError
  :: [SharedSecret]    -- ^ secrets from construction, in route order
  -> ErrorPacket       -- ^ received error
  -> AttributionResult
unwrapError :: [SharedSecret] -> ErrorPacket -> AttributionResult
unwrapError [SharedSecret]
secrets (ErrorPacket !ByteString
initialPacket) = Int -> ByteString -> [SharedSecret] -> AttributionResult
go Int
0 ByteString
initialPacket [SharedSecret]
secrets
  where
    go :: Int -> BS.ByteString -> [SharedSecret] -> AttributionResult
    go :: Int -> ByteString -> [SharedSecret] -> AttributionResult
go !Int
_ !ByteString
packet [] = ByteString -> AttributionResult
UnknownOrigin ByteString
packet
    go !Int
idx !ByteString
packet (SharedSecret
ss:[SharedSecret]
rest) =
      let !ammag :: DerivedKey
ammag = SharedSecret -> DerivedKey
deriveAmmag SharedSecret
ss
          !um :: DerivedKey
um = SharedSecret -> DerivedKey
deriveUm SharedSecret
ss
          !deobfuscated :: ByteString
deobfuscated = DerivedKey -> ByteString -> ByteString
deobfuscateError DerivedKey
ammag ByteString
packet
      in  if DerivedKey -> ByteString -> Bool
verifyErrorHmac DerivedKey
um ByteString
deobfuscated
            then case ByteString -> Maybe FailureMessage
parseErrorMessage (Int -> ByteString -> ByteString
BS.drop Int
32 ByteString
deobfuscated) of
                   Just FailureMessage
msg -> Int -> FailureMessage -> AttributionResult
Attributed Int
idx FailureMessage
msg
                   Maybe FailureMessage
Nothing  -> ByteString -> AttributionResult
UnknownOrigin ByteString
deobfuscated
            else Int -> ByteString -> [SharedSecret] -> AttributionResult
go (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ByteString
deobfuscated [SharedSecret]
rest

-- Internal functions ---------------------------------------------------------

-- | Build the inner error message structure.
--
-- Format: HMAC (32) || len (2) || message || pad_len (2) || padding
-- Total must be >= 256 bytes.
buildErrorMessage
  :: DerivedKey        -- ^ um key
  -> FailureMessage    -- ^ failure to encode
  -> BS.ByteString     -- ^ complete message with HMAC
buildErrorMessage :: DerivedKey -> FailureMessage -> ByteString
buildErrorMessage (DerivedKey !ByteString
umKey) !FailureMessage
failure =
  let !encoded :: ByteString
encoded = FailureMessage -> ByteString
encodeFailureMessage FailureMessage
failure
      !msgLen :: Int
msgLen = ByteString -> Int
BS.length ByteString
encoded
      -- Total payload: len(2) + msg + pad_len(2) + padding = 256 - 32 = 224
      -- padding = 224 - 2 - msgLen - 2 = 220 - msgLen
      !padLen :: Int
padLen = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
minErrorPacketSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
msgLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)
      !padding :: ByteString
padding = Int -> Word8 -> ByteString
BS.replicate Int
padLen Word8
0
      -- Build: len || message || pad_len || padding
      !payload :: ByteString
payload = Builder -> ByteString
toStrict (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$
        Word16 -> Builder
B.word16BE (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
msgLen) Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
        ByteString -> Builder
B.byteString ByteString
encoded Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
        Word16 -> Builder
B.word16BE (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
padLen) Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<>
        ByteString -> Builder
B.byteString ByteString
padding
      -- HMAC over the payload
      SHA256.MAC !ByteString
hmac = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
umKey ByteString
payload
  in  ByteString
hmac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
payload
{-# INLINE buildErrorMessage #-}

-- | Obfuscate error packet with ammag stream.
--
-- XORs the entire packet with pseudo-random stream.
obfuscateError
  :: DerivedKey        -- ^ ammag key
  -> BS.ByteString     -- ^ error packet
  -> BS.ByteString     -- ^ obfuscated packet
obfuscateError :: DerivedKey -> ByteString -> ByteString
obfuscateError !DerivedKey
ammag !ByteString
packet =
  let !stream :: ByteString
stream = DerivedKey -> Int -> ByteString
generateStream DerivedKey
ammag (ByteString -> Int
BS.length ByteString
packet)
  in  ByteString -> ByteString -> ByteString
xorBytes ByteString
packet ByteString
stream
{-# INLINE obfuscateError #-}

-- | Remove one layer of obfuscation from error packet.
--
-- XOR is its own inverse, so same as obfuscation.
deobfuscateError
  :: DerivedKey        -- ^ ammag key
  -> BS.ByteString     -- ^ obfuscated packet
  -> BS.ByteString     -- ^ deobfuscated packet
deobfuscateError :: DerivedKey -> ByteString -> ByteString
deobfuscateError = DerivedKey -> ByteString -> ByteString
obfuscateError
{-# INLINE deobfuscateError #-}

-- | Verify error HMAC after deobfuscation.
verifyErrorHmac
  :: DerivedKey        -- ^ um key
  -> BS.ByteString     -- ^ deobfuscated packet (HMAC || rest)
  -> Bool
verifyErrorHmac :: DerivedKey -> ByteString -> Bool
verifyErrorHmac (DerivedKey !ByteString
umKey) !ByteString
packet
  | ByteString -> Int
BS.length ByteString
packet Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
32 = Bool
False
  | Bool
otherwise =
      let !receivedHmac :: ByteString
receivedHmac = Int -> ByteString -> ByteString
BS.take Int
32 ByteString
packet
          !payload :: ByteString
payload = Int -> ByteString -> ByteString
BS.drop Int
32 ByteString
packet
          SHA256.MAC !ByteString
computedHmac = ByteString -> ByteString -> MAC
SHA256.hmac ByteString
umKey ByteString
payload
      in  ByteString -> ByteString -> Bool
constantTimeEq ByteString
receivedHmac ByteString
computedHmac
{-# INLINE verifyErrorHmac #-}

-- | Parse error message from deobfuscated packet (after HMAC).
parseErrorMessage
  :: BS.ByteString     -- ^ packet after HMAC (len || msg || pad_len || pad)
  -> Maybe FailureMessage
parseErrorMessage :: ByteString -> Maybe FailureMessage
parseErrorMessage !ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4 = Maybe FailureMessage
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let !msgLen :: Int
msgLen = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word16
word16BE (Int -> ByteString -> ByteString
BS.take Int
2 ByteString
bs))
      in  if ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
msgLen
            then Maybe FailureMessage
forall a. Maybe a
Nothing
            else ByteString -> Maybe FailureMessage
decodeFailureMessage (Int -> ByteString -> ByteString
BS.take Int
msgLen (Int -> ByteString -> ByteString
BS.drop Int
2 ByteString
bs))
{-# INLINE parseErrorMessage #-}

-- Helper functions -----------------------------------------------------------

-- | XOR two ByteStrings of equal length.
xorBytes :: BS.ByteString -> BS.ByteString -> BS.ByteString
xorBytes :: ByteString -> ByteString -> ByteString
xorBytes !ByteString
a !ByteString
b = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a ByteString
b
{-# INLINE xorBytes #-}

-- | Constant-time equality comparison.
constantTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constantTimeEq :: ByteString -> ByteString -> Bool
constantTimeEq !ByteString
a !ByteString
b
  | ByteString -> Int
BS.length ByteString
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
BS.length ByteString
b = Bool
False
  | Bool
otherwise = Word8 -> [(Word8, Word8)] -> Bool
go Word8
0 (ByteString -> ByteString -> [(Word8, Word8)]
BS.zip ByteString
a ByteString
b)
  where
    go :: Word8 -> [(Word8, Word8)] -> Bool
    go :: Word8 -> [(Word8, Word8)] -> Bool
go !Word8
acc [] = Word8
acc Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0
    go !Word8
acc ((Word8
x, Word8
y):[(Word8, Word8)]
rest) = Word8 -> [(Word8, Word8)] -> Bool
go (Word8
acc Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
y)) [(Word8, Word8)]
rest
{-# INLINE constantTimeEq #-}

-- | Decode big-endian Word16.
word16BE :: BS.ByteString -> Word16
word16BE :: ByteString -> Word16
word16BE !ByteString
bs =
  let !b0 :: Word16
b0 = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
0) :: Word16
      !b1 :: Word16
b1 = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
1) :: Word16
  in  (Word16
b0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Word16
256) Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
b1
{-# INLINE word16BE #-}

-- | Convert Builder to strict ByteString.
toStrict :: B.Builder -> BS.ByteString
toStrict :: Builder -> ByteString
toStrict = LazyByteString -> ByteString
BL.toStrict (LazyByteString -> ByteString)
-> (Builder -> LazyByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> LazyByteString
B.toLazyByteString
{-# INLINE toStrict #-}