{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Lightning.Protocol.BOLT8
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Encrypted and authenticated transport for the Lightning Network, per
-- [BOLT #8](https://github.com/lightning/bolts/blob/master/08-transport.md).
--
-- This module implements the Noise_XK_secp256k1_ChaChaPoly_SHA256
-- handshake and subsequent encrypted message transport.
--
-- = Handshake
--
-- A BOLT #8 handshake consists of three acts. The /initiator/ knows the
-- responder's static public key in advance and initiates the connection:
--
-- @
-- (msg1, state) <- act1 i_sec i_pub r_pub entropy
-- -- send msg1 (50 bytes) to responder
-- -- receive msg2 (50 bytes) from responder
-- (msg3, result) <- act3 state msg2
-- -- send msg3 (66 bytes) to responder
-- let session = 'session' result
-- @
--
-- The /responder/ receives the connection and authenticates the initiator:
--
-- @
-- -- receive msg1 (50 bytes) from initiator
-- (msg2, state) <- act2 r_sec r_pub entropy msg1
-- -- send msg2 (50 bytes) to initiator
-- -- receive msg3 (66 bytes) from initiator
-- result <- finalize state msg3
-- let session = 'session' result
-- @
--
-- = Message Transport
--
-- After a successful handshake, use 'encrypt' and 'decrypt' to exchange
-- messages. Each returns an updated 'Session' that must be used for the
-- next operation (keys rotate every 1000 messages):
--
-- @
-- -- sender
-- (ciphertext, session') <- 'encrypt' session plaintext
--
-- -- receiver
-- (plaintext, session') <- 'decrypt' session ciphertext
-- @
--
-- = Message Framing
--
-- BOLT #8 runs over a byte stream, so callers often need to deal with
-- partial buffers. Use 'decrypt_frame' when you have exactly one frame,
-- or 'decrypt_frame_partial' to handle incremental reads and return how
-- many bytes are still needed.
--
-- Maximum plaintext size is 65535 bytes.

module Lightning.Protocol.BOLT8 (
    -- * Keys
    Sec
  , Pub
  , keypair
  , parse_pub
  , serialize_pub

    -- * Handshake (initiator)
  , act1
  , act3

    -- * Handshake (responder)
  , act2
  , finalize

    -- * Session
  , Session
  , HandshakeState
  , Handshake(..)
  , encrypt
  , decrypt
  , decrypt_frame
  , decrypt_frame_partial
  , FrameResult(..)

    -- * Errors
  , Error(..)
  ) where

import Control.Monad (guard, unless)
import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.KDF.HMAC as HKDF
import Data.Bits (unsafeShiftR, (.&.))
import qualified Data.ByteString as BS
import Data.Word (Word16, Word64)
import GHC.Generics (Generic)

-- types ---------------------------------------------------------------------

-- | Secret key (32 bytes).
newtype Sec = Sec BS.ByteString
  deriving (Sec -> Sec -> Bool
(Sec -> Sec -> Bool) -> (Sec -> Sec -> Bool) -> Eq Sec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Sec -> Sec -> Bool
== :: Sec -> Sec -> Bool
$c/= :: Sec -> Sec -> Bool
/= :: Sec -> Sec -> Bool
Eq, (forall x. Sec -> Rep Sec x)
-> (forall x. Rep Sec x -> Sec) -> Generic Sec
forall x. Rep Sec x -> Sec
forall x. Sec -> Rep Sec x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Sec -> Rep Sec x
from :: forall x. Sec -> Rep Sec x
$cto :: forall x. Rep Sec x -> Sec
to :: forall x. Rep Sec x -> Sec
Generic)

-- | Compressed public key.
newtype Pub = Pub Secp256k1.Projective

instance Eq Pub where
  (Pub Projective
a) == :: Pub -> Pub -> Bool
== (Pub Projective
b) =
    Projective -> ByteString
Secp256k1.serialize_point Projective
a ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Projective -> ByteString
Secp256k1.serialize_point Projective
b

instance Show Pub where
  show :: Pub -> String
show (Pub Projective
p) = String
"Pub " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show (Projective -> ByteString
Secp256k1.serialize_point Projective
p)

-- | Handshake errors.
data Error =
    InvalidKey
  | InvalidPub
  | InvalidMAC
  | InvalidVersion
  | InvalidLength
  | DecryptionFailed
  deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Error -> ShowS
showsPrec :: Int -> Error -> ShowS
$cshow :: Error -> String
show :: Error -> String
$cshowList :: [Error] -> ShowS
showList :: [Error] -> ShowS
Show, (forall x. Error -> Rep Error x)
-> (forall x. Rep Error x -> Error) -> Generic Error
forall x. Rep Error x -> Error
forall x. Error -> Rep Error x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Error -> Rep Error x
from :: forall x. Error -> Rep Error x
$cto :: forall x. Rep Error x -> Error
to :: forall x. Rep Error x -> Error
Generic)

-- | Result of attempting to decrypt a frame from a partial buffer.
data FrameResult =
    NeedMore {-# UNPACK #-} !Int
    -- ^ More bytes needed; the 'Int' is the minimum additional bytes required.
  | FrameOk !BS.ByteString !BS.ByteString !Session
    -- ^ Successfully decrypted: plaintext, remainder, updated session.
  | FrameError !Error
    -- ^ Decryption failed with the given error.
  deriving (forall x. FrameResult -> Rep FrameResult x)
-> (forall x. Rep FrameResult x -> FrameResult)
-> Generic FrameResult
forall x. Rep FrameResult x -> FrameResult
forall x. FrameResult -> Rep FrameResult x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. FrameResult -> Rep FrameResult x
from :: forall x. FrameResult -> Rep FrameResult x
$cto :: forall x. Rep FrameResult x -> FrameResult
to :: forall x. Rep FrameResult x -> FrameResult
Generic

-- | Post-handshake session state.
data Session = Session {
    Session -> ByteString
sess_sk  :: {-# UNPACK #-} !BS.ByteString  -- ^ send key (32 bytes)
  , Session -> Word64
sess_sn  :: {-# UNPACK #-} !Word64         -- ^ send nonce
  , Session -> ByteString
sess_sck :: {-# UNPACK #-} !BS.ByteString  -- ^ send chaining key
  , Session -> ByteString
sess_rk  :: {-# UNPACK #-} !BS.ByteString  -- ^ receive key (32 bytes)
  , Session -> Word64
sess_rn  :: {-# UNPACK #-} !Word64         -- ^ receive nonce
  , Session -> ByteString
sess_rck :: {-# UNPACK #-} !BS.ByteString  -- ^ receive chaining key
  }
  deriving (forall x. Session -> Rep Session x)
-> (forall x. Rep Session x -> Session) -> Generic Session
forall x. Rep Session x -> Session
forall x. Session -> Rep Session x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Session -> Rep Session x
from :: forall x. Session -> Rep Session x
$cto :: forall x. Rep Session x -> Session
to :: forall x. Rep Session x -> Session
Generic

-- | Result of a successful handshake.
data Handshake = Handshake {
    Handshake -> Session
session       :: !Session  -- ^ session state
  , Handshake -> Pub
remote_static :: !Pub      -- ^ authenticated remote static pubkey
  }
  deriving (forall x. Handshake -> Rep Handshake x)
-> (forall x. Rep Handshake x -> Handshake) -> Generic Handshake
forall x. Rep Handshake x -> Handshake
forall x. Handshake -> Rep Handshake x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Handshake -> Rep Handshake x
from :: forall x. Handshake -> Rep Handshake x
$cto :: forall x. Rep Handshake x -> Handshake
to :: forall x. Rep Handshake x -> Handshake
Generic

-- | Internal handshake state (exported for benchmarking).
data HandshakeState = HandshakeState {
    HandshakeState -> ByteString
hs_h      :: {-# UNPACK #-} !BS.ByteString  -- handshake hash (32 bytes)
  , HandshakeState -> ByteString
hs_ck     :: {-# UNPACK #-} !BS.ByteString  -- chaining key (32 bytes)
  , HandshakeState -> ByteString
hs_temp_k :: {-# UNPACK #-} !BS.ByteString  -- temp key (32 bytes)
  , HandshakeState -> Sec
hs_e_sec  :: !Sec                           -- ephemeral secret
  , HandshakeState -> Pub
hs_e_pub  :: !Pub                           -- ephemeral public
  , HandshakeState -> Sec
hs_s_sec  :: !Sec                           -- static secret
  , HandshakeState -> Pub
hs_s_pub  :: !Pub                           -- static public
  , HandshakeState -> Maybe Pub
hs_re     :: !(Maybe Pub)                   -- remote ephemeral
  , HandshakeState -> Maybe Pub
hs_rs     :: !(Maybe Pub)                   -- remote static
  }
  deriving (forall x. HandshakeState -> Rep HandshakeState x)
-> (forall x. Rep HandshakeState x -> HandshakeState)
-> Generic HandshakeState
forall x. Rep HandshakeState x -> HandshakeState
forall x. HandshakeState -> Rep HandshakeState x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. HandshakeState -> Rep HandshakeState x
from :: forall x. HandshakeState -> Rep HandshakeState x
$cto :: forall x. Rep HandshakeState x -> HandshakeState
to :: forall x. Rep HandshakeState x -> HandshakeState
Generic

-- protocol constants --------------------------------------------------------

_PROTOCOL_NAME :: BS.ByteString
_PROTOCOL_NAME :: ByteString
_PROTOCOL_NAME = ByteString
"Noise_XK_secp256k1_ChaChaPoly_SHA256"

_PROLOGUE :: BS.ByteString
_PROLOGUE :: ByteString
_PROLOGUE = ByteString
"lightning"

-- key operations ------------------------------------------------------------

-- | Derive a keypair from 32 bytes of entropy.
--
--   Returns Nothing if the entropy is invalid (zero or >= curve order).
--
--   >>> let ent = BS.replicate 32 0x11
--   >>> case keypair ent of { Just _ -> "ok"; Nothing -> "fail" }
--   "ok"
--   >>> keypair (BS.replicate 31 0x11) -- wrong length
--   Nothing
keypair :: BS.ByteString -> Maybe (Sec, Pub)
keypair :: ByteString -> Maybe (Sec, Pub)
keypair ByteString
ent = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> Int
BS.length ByteString
ent Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32)
  k <- ByteString -> Maybe Wider
Secp256k1.parse_int256 ByteString
ent
  p <- Secp256k1.derive_pub k
  pure (Sec ent, Pub p)

-- | Parse a 33-byte compressed public key.
--
--   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
--   >>> let bytes = serialize_pub pub
--   >>> case parse_pub bytes of { Just _ -> "ok"; Nothing -> "fail" }
--   "ok"
--   >>> parse_pub (BS.replicate 32 0x00) -- wrong length
--   Nothing
parse_pub :: BS.ByteString -> Maybe Pub
parse_pub :: ByteString -> Maybe Pub
parse_pub ByteString
bs = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
33)
  p <- ByteString -> Maybe Projective
Secp256k1.parse_point ByteString
bs
  pure (Pub p)

-- | Serialize a public key to 33-byte compressed form.
--
--   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
--   >>> BS.length (serialize_pub pub)
--   33
serialize_pub :: Pub -> BS.ByteString
serialize_pub :: Pub -> ByteString
serialize_pub (Pub Projective
p) = Projective -> ByteString
Secp256k1.serialize_point Projective
p

-- cryptographic primitives --------------------------------------------------

-- bolt8-style ECDH
ecdh :: Sec -> Pub -> Maybe BS.ByteString
ecdh :: Sec -> Pub -> Maybe ByteString
ecdh (Sec ByteString
sec) (Pub Projective
pub) = do
  k <- ByteString -> Maybe Wider
Secp256k1.parse_int256 ByteString
sec
  pt <- Secp256k1.mul pub k
  let compressed = Projective -> ByteString
Secp256k1.serialize_point Projective
pt
  pure (SHA256.hash compressed)

-- h' = SHA256(h || data)
mix_hash :: BS.ByteString -> BS.ByteString -> BS.ByteString
mix_hash :: ByteString -> ByteString -> ByteString
mix_hash ByteString
h ByteString
dat = ByteString -> ByteString
SHA256.hash (ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
dat)

-- Mix key: (ck', k) = HKDF(ck, input_key_material)
--
-- NB HKDF limits output to 255 * hashlen bytes. For SHA256 that's 8160,
-- well above the 64 bytes requested here, so 'Nothing' is impossible.
mix_key :: BS.ByteString -> BS.ByteString -> (BS.ByteString, BS.ByteString)
mix_key :: ByteString -> ByteString -> (ByteString, ByteString)
mix_key ByteString
ck ByteString
ikm = case (ByteString -> ByteString -> ByteString)
-> ByteString
-> ByteString
-> Word64
-> ByteString
-> Maybe ByteString
HKDF.derive ByteString -> ByteString -> ByteString
hmac ByteString
ck ByteString
forall a. Monoid a => a
mempty Word64
64 ByteString
ikm of
    Maybe ByteString
Nothing -> String -> (ByteString, ByteString)
forall a. HasCallStack => String -> a
error String
"ppad-bolt8: internal error, please report a bug!"
    Just ByteString
output -> Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
32 ByteString
output
  where
    hmac :: ByteString -> ByteString -> ByteString
hmac ByteString
k ByteString
b = case ByteString -> ByteString -> MAC
SHA256.hmac ByteString
k ByteString
b of
      SHA256.MAC ByteString
mac -> ByteString
mac

-- Encrypt with associated data using ChaCha20-Poly1305
encrypt_with_ad
  :: BS.ByteString       -- ^ key (32 bytes)
  -> Word64              -- ^ nonce
  -> BS.ByteString       -- ^ associated data
  -> BS.ByteString       -- ^ plaintext
  -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
encrypt_with_ad :: ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
encrypt_with_ad ByteString
key Word64
n ByteString
ad ByteString
pt =
  case ByteString
-> ByteString
-> ByteString
-> ByteString
-> Either Error (ByteString, ByteString)
AEAD.encrypt ByteString
ad ByteString
key (Word64 -> ByteString
encode_nonce Word64
n) ByteString
pt of
    Left Error
_ -> Maybe ByteString
forall a. Maybe a
Nothing
    Right (ByteString
ct, ByteString
mac) -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString
ct ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
mac)

-- Decrypt with associated data using ChaCha20-Poly1305
decrypt_with_ad
  :: BS.ByteString       -- ^ key (32 bytes)
  -> Word64              -- ^ nonce
  -> BS.ByteString       -- ^ associated data
  -> BS.ByteString       -- ^ ciphertext || mac
  -> Maybe BS.ByteString -- ^ plaintext
decrypt_with_ad :: ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
decrypt_with_ad ByteString
key Word64
n ByteString
ad ByteString
ctmac
  | ByteString -> Int
BS.length ByteString
ctmac Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
16 = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let (ByteString
ct, ByteString
mac) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
ctmac Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
ctmac
      in case ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
-> Either Error ByteString
AEAD.decrypt ByteString
ad ByteString
key (Word64 -> ByteString
encode_nonce Word64
n) (ByteString
ct, ByteString
mac) of
           Left Error
_ -> Maybe ByteString
forall a. Maybe a
Nothing
           Right ByteString
pt -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
pt

-- Encode nonce as 96-bit value: 4 zero bytes + 8-byte little-endian
encode_nonce :: Word64 -> BS.ByteString
encode_nonce :: Word64 -> ByteString
encode_nonce Word64
n = Int -> Word8 -> ByteString
BS.replicate Int
4 Word8
0x00 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encode_le64 Word64
n

-- Little-endian 64-bit encoding
encode_le64 :: Word64 -> BS.ByteString
encode_le64 :: Word64 -> ByteString
encode_le64 Word64
n = [Word8] -> ByteString
BS.pack [
    Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
n Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
8  Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
16 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
24 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
32 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
40 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
48 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  , Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
n Int
56 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xff)
  ]

-- Big-endian 16-bit encoding
encode_be16 :: Word16 -> BS.ByteString
encode_be16 :: Word16 -> ByteString
encode_be16 Word16
n = [Word8] -> ByteString
BS.pack [Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
unsafeShiftR Word16
n Int
8), Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word16
n Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0xff)]

-- Big-endian 16-bit decoding
decode_be16 :: BS.ByteString -> Maybe Word16
decode_be16 :: ByteString -> Maybe Word16
decode_be16 ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
2 = Maybe Word16
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let !b0 :: Word8
b0 = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
0
          !b1 :: Word8
b1 = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
1
      in Word16 -> Maybe Word16
forall a. a -> Maybe a
Just (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fi Word8
b0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Word16
0x100 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fi Word8
b1)

-- handshake -----------------------------------------------------------------

-- Initialize handshake state
--
-- h = SHA256(protocol_name)
-- ck = h
-- h = SHA256(h || prologue)
-- h = SHA256(h || responder_static_pubkey)
init_handshake
  :: Sec                -- ^ local static secret
  -> Pub                -- ^ local static public
  -> Sec                -- ^ ephemeral secret
  -> Pub                -- ^ ephemeral public
  -> Maybe Pub          -- ^ remote static (initiator knows, responder doesn't)
  -> Bool               -- ^ True if initiator
  -> HandshakeState
init_handshake :: Sec -> Pub -> Sec -> Pub -> Maybe Pub -> Bool -> HandshakeState
init_handshake Sec
s_sec Pub
s_pub Sec
e_sec Pub
e_pub Maybe Pub
m_rs Bool
is_initiator =
  let !h0 :: ByteString
h0 = ByteString -> ByteString
SHA256.hash ByteString
_PROTOCOL_NAME
      !ck :: ByteString
ck = ByteString
h0
      !h1 :: ByteString
h1 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h0 ByteString
_PROLOGUE
      -- Mix in responder's static pubkey
      !h2 :: ByteString
h2 = case (Bool
is_initiator, Maybe Pub
m_rs) of
        (Bool
True, Just Pub
rs)  -> ByteString -> ByteString -> ByteString
mix_hash ByteString
h1 (Pub -> ByteString
serialize_pub Pub
rs)
        (Bool
False, Maybe Pub
Nothing) -> ByteString -> ByteString -> ByteString
mix_hash ByteString
h1 (Pub -> ByteString
serialize_pub Pub
s_pub)
        (Bool, Maybe Pub)
_ -> ByteString
h1  -- shouldn't happen
  in HandshakeState {
       hs_h :: ByteString
hs_h      = ByteString
h2
     , hs_ck :: ByteString
hs_ck     = ByteString
ck
     , hs_temp_k :: ByteString
hs_temp_k = Int -> Word8 -> ByteString
BS.replicate Int
32 Word8
0x00
     , hs_e_sec :: Sec
hs_e_sec  = Sec
e_sec
     , hs_e_pub :: Pub
hs_e_pub  = Pub
e_pub
     , hs_s_sec :: Sec
hs_s_sec  = Sec
s_sec
     , hs_s_pub :: Pub
hs_s_pub  = Pub
s_pub
     , hs_re :: Maybe Pub
hs_re     = Maybe Pub
forall a. Maybe a
Nothing
     , hs_rs :: Maybe Pub
hs_rs     = Maybe Pub
m_rs
     }

-- | Initiator: generate Act 1 message (50 bytes).
--
--   Takes local static key, remote static pubkey, and 32 bytes of
--   entropy for ephemeral key generation.
--
--   Returns the 50-byte Act 1 message and handshake state for Act 3.
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let eph_ent = BS.replicate 32 0x12
--   >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--   50
act1
  :: Sec                -- ^ local static secret
  -> Pub                -- ^ local static public
  -> Pub                -- ^ remote static public (responder's)
  -> BS.ByteString      -- ^ 32 bytes entropy for ephemeral
  -> Either Error (BS.ByteString, HandshakeState)
act1 :: Sec
-> Pub
-> Pub
-> ByteString
-> Either Error (ByteString, HandshakeState)
act1 Sec
s_sec Pub
s_pub Pub
rs ByteString
ent = do
  (e_sec, e_pub) <- Error -> Maybe (Sec, Pub) -> Either Error (Sec, Pub)
forall e a. e -> Maybe a -> Either e a
note Error
InvalidKey (ByteString -> Maybe (Sec, Pub)
keypair ByteString
ent)
  let !hs0 = Sec -> Pub -> Sec -> Pub -> Maybe Pub -> Bool -> HandshakeState
init_handshake Sec
s_sec Pub
s_pub Sec
e_sec Pub
e_pub (Pub -> Maybe Pub
forall a. a -> Maybe a
Just Pub
rs) Bool
True
      !e_pub_bytes = Pub -> ByteString
serialize_pub Pub
e_pub
      !h1 = ByteString -> ByteString -> ByteString
mix_hash (HandshakeState -> ByteString
hs_h HandshakeState
hs0) ByteString
e_pub_bytes
  es <- note InvalidKey (ecdh e_sec rs)
  let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
  c <- note InvalidMAC (encrypt_with_ad temp_k1 0 h1 BS.empty)
  let !h2 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h1 ByteString
c
      !msg = Word8 -> ByteString
BS.singleton Word8
0x00 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
e_pub_bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
c
      !hs1 = HandshakeState
hs0 {
        hs_h      = h2
      , hs_ck     = ck1
      , hs_temp_k = temp_k1
      }
  pure (msg, hs1)

-- | Responder: process Act 1 and generate Act 2 message (50 bytes).
--
--   Takes local static key and 32 bytes of entropy for ephemeral key,
--   plus the 50-byte Act 1 message from initiator.
--
--   Returns the 50-byte Act 2 message and handshake state for finalize.
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--   50
act2
  :: Sec                -- ^ local static secret
  -> Pub                -- ^ local static public
  -> BS.ByteString      -- ^ 32 bytes entropy for ephemeral
  -> BS.ByteString      -- ^ Act 1 message (50 bytes)
  -> Either Error (BS.ByteString, HandshakeState)
act2 :: Sec
-> Pub
-> ByteString
-> ByteString
-> Either Error (ByteString, HandshakeState)
act2 Sec
s_sec Pub
s_pub ByteString
ent ByteString
msg1 = do
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (ByteString -> Int
BS.length ByteString
msg1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
50) Error
InvalidLength
  let !version :: Word8
version = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
msg1 Int
0
      !re_bytes :: ByteString
re_bytes = Int -> ByteString -> ByteString
BS.take Int
33 (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
msg1)
      !c :: ByteString
c = Int -> ByteString -> ByteString
BS.drop Int
34 ByteString
msg1
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (Word8
version Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x00) Error
InvalidVersion
  re <- Error -> Maybe Pub -> Either Error Pub
forall e a. e -> Maybe a -> Either e a
note Error
InvalidPub (ByteString -> Maybe Pub
parse_pub ByteString
re_bytes)
  (e_sec, e_pub) <- note InvalidKey (keypair ent)
  let !hs0 = Sec -> Pub -> Sec -> Pub -> Maybe Pub -> Bool -> HandshakeState
init_handshake Sec
s_sec Pub
s_pub Sec
e_sec Pub
e_pub Maybe Pub
forall a. Maybe a
Nothing Bool
False
      !h1 = ByteString -> ByteString -> ByteString
mix_hash (HandshakeState -> ByteString
hs_h HandshakeState
hs0) ByteString
re_bytes
  es <- note InvalidKey (ecdh s_sec re)
  let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
  _ <- note InvalidMAC (decrypt_with_ad temp_k1 0 h1 c)
  let !h2 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h1 ByteString
c
      !e_pub_bytes = Pub -> ByteString
serialize_pub Pub
e_pub
      !h3 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h2 ByteString
e_pub_bytes
  ee <- note InvalidKey (ecdh e_sec re)
  let !(ck2, temp_k2) = mix_key ck1 ee
  c2 <- note InvalidMAC (encrypt_with_ad temp_k2 0 h3 BS.empty)
  let !h4 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h3 ByteString
c2
      !msg = Word8 -> ByteString
BS.singleton Word8
0x00 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
e_pub_bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
c2
      !hs1 = HandshakeState
hs0 {
        hs_h      = h4
      , hs_ck     = ck2
      , hs_temp_k = temp_k2
      , hs_re     = Just re
      }
  pure (msg, hs1)

-- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing
--   the handshake.
--
--   Returns the 66-byte Act 3 message and the handshake result.
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--   >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--   66
act3
  :: HandshakeState     -- ^ state after Act 1
  -> BS.ByteString      -- ^ Act 2 message (50 bytes)
  -> Either Error (BS.ByteString, Handshake)
act3 :: HandshakeState
-> ByteString -> Either Error (ByteString, Handshake)
act3 HandshakeState
hs ByteString
msg2 = do
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (ByteString -> Int
BS.length ByteString
msg2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
50) Error
InvalidLength
  let !version :: Word8
version = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
msg2 Int
0
      !re_bytes :: ByteString
re_bytes = Int -> ByteString -> ByteString
BS.take Int
33 (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
msg2)
      !c :: ByteString
c = Int -> ByteString -> ByteString
BS.drop Int
34 ByteString
msg2
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (Word8
version Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x00) Error
InvalidVersion
  re <- Error -> Maybe Pub -> Either Error Pub
forall e a. e -> Maybe a -> Either e a
note Error
InvalidPub (ByteString -> Maybe Pub
parse_pub ByteString
re_bytes)
  let !h1 = ByteString -> ByteString -> ByteString
mix_hash (HandshakeState -> ByteString
hs_h HandshakeState
hs) ByteString
re_bytes
  ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
  let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
  _ <- note InvalidMAC (decrypt_with_ad temp_k2 0 h1 c)
  let !h2 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h1 ByteString
c
      !s_pub_bytes = Pub -> ByteString
serialize_pub (HandshakeState -> Pub
hs_s_pub HandshakeState
hs)
  c3 <- note InvalidMAC (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
  let !h3 = ByteString -> ByteString -> ByteString
mix_hash ByteString
h2 ByteString
c3
  se <- note InvalidKey (ecdh (hs_s_sec hs) re)
  let !(ck2, temp_k3) = mix_key ck1 se
  t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty)
  let !(sk, rk) = mix_key ck2 BS.empty
      !msg = Word8 -> ByteString
BS.singleton Word8
0x00 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
c3 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
t
      !sess = Session {
        sess_sk :: ByteString
sess_sk  = ByteString
sk
      , sess_sn :: Word64
sess_sn  = Word64
0
      , sess_sck :: ByteString
sess_sck = ByteString
ck2
      , sess_rk :: ByteString
sess_rk  = ByteString
rk
      , sess_rn :: Word64
sess_rn  = Word64
0
      , sess_rck :: ByteString
sess_rck = ByteString
ck2
      }
  rs <- note InvalidPub (hs_rs hs)
  let !result = Handshake {
        session :: Session
session       = Session
sess
      , remote_static :: Pub
remote_static = Pub
rs
      }
  pure (msg, result)

-- | Responder: process Act 3 (66 bytes) and complete the handshake.
--
--   Returns the handshake result with authenticated remote static pubkey.
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--   >>> let Right (msg3, _) = act3 i_hs msg2
--   >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
--   "ok"
finalize
  :: HandshakeState     -- ^ state after Act 2
  -> BS.ByteString      -- ^ Act 3 message (66 bytes)
  -> Either Error Handshake
finalize :: HandshakeState -> ByteString -> Either Error Handshake
finalize HandshakeState
hs ByteString
msg3 = do
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (ByteString -> Int
BS.length ByteString
msg3 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
66) Error
InvalidLength
  let !version :: Word8
version = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
msg3 Int
0
      !c :: ByteString
c = Int -> ByteString -> ByteString
BS.take Int
49 (Int -> ByteString -> ByteString
BS.drop Int
1 ByteString
msg3)
      !t :: ByteString
t = Int -> ByteString -> ByteString
BS.drop Int
50 ByteString
msg3
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (Word8
version Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x00) Error
InvalidVersion
  rs_bytes <- Error -> Maybe ByteString -> Either Error ByteString
forall e a. e -> Maybe a -> Either e a
note Error
InvalidMAC (ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
decrypt_with_ad (HandshakeState -> ByteString
hs_temp_k HandshakeState
hs) Word64
1 (HandshakeState -> ByteString
hs_h HandshakeState
hs) ByteString
c)
  rs <- note InvalidPub (parse_pub rs_bytes)
  let !h1 = ByteString -> ByteString -> ByteString
mix_hash (HandshakeState -> ByteString
hs_h HandshakeState
hs) ByteString
c
  se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
  let !(ck1, temp_k3) = mix_key (hs_ck hs) se
  _ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t)
  -- responder swaps order (receives what initiator sends)
  let !(rk, sk) = mix_key ck1 BS.empty
      !sess = Session {
        sess_sk :: ByteString
sess_sk  = ByteString
sk
      , sess_sn :: Word64
sess_sn  = Word64
0
      , sess_sck :: ByteString
sess_sck = ByteString
ck1
      , sess_rk :: ByteString
sess_rk  = ByteString
rk
      , sess_rn :: Word64
sess_rn  = Word64
0
      , sess_rck :: ByteString
sess_rck = ByteString
ck1
      }
      !result = Handshake {
        session :: Session
session       = Session
sess
      , remote_static :: Pub
remote_static = Pub
rs
      }
  pure result

-- message encryption --------------------------------------------------------

-- | Encrypt a message (max 65535 bytes).
--
--   Returns the encrypted packet and updated session. Key rotation
--   is handled automatically at nonce 1000.
--
--   Wire format: encrypted_length (2) || MAC (16) || encrypted_body || MAC (16)
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--   >>> let Right (_, i_result) = act3 i_hs msg2
--   >>> let sess = session i_result
--   >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
--   39
encrypt
  :: Session
  -> BS.ByteString          -- ^ plaintext (max 65535 bytes)
  -> Either Error (BS.ByteString, Session)
encrypt :: Session -> ByteString -> Either Error (ByteString, Session)
encrypt Session
sess ByteString
pt = do
  let !len :: Int
len = ByteString -> Int
BS.length ByteString
pt
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
65535) Error
InvalidLength
  let !len_bytes :: ByteString
len_bytes = Word16 -> ByteString
encode_be16 (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fi Int
len)
  lc <- Error -> Maybe ByteString -> Either Error ByteString
forall e a. e -> Maybe a -> Either e a
note Error
InvalidMAC (ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
encrypt_with_ad (Session -> ByteString
sess_sk Session
sess) (Session -> Word64
sess_sn Session
sess)
                           ByteString
BS.empty ByteString
len_bytes)
  let !(sn1, sck1, sk1) = step_nonce (sess_sn sess) (sess_sck sess) (sess_sk sess)
  bc <- note InvalidMAC (encrypt_with_ad sk1 sn1 BS.empty pt)
  let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
      !packet = ByteString
lc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bc
      !sess' = Session
sess {
        sess_sk  = sk2
      , sess_sn  = sn2
      , sess_sck = sck2
      }
  pure (packet, sess')

-- | Decrypt a message, requiring an exact packet with no trailing bytes.
--
--   Returns the plaintext and updated session. Key rotation
--   is handled automatically at nonce 1000.
--
--   This is a strict variant that rejects any trailing data. For
--   streaming use cases where you need to handle multiple frames in a
--   buffer, use 'decrypt_frame' instead.
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--   >>> let Right (msg3, i_result) = act3 i_hs msg2
--   >>> let Right r_result = finalize r_hs msg3
--   >>> let Right (ct, _) = encrypt (session i_result) "hello"
--   >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
--   "hello"
decrypt
  :: Session
  -> BS.ByteString          -- ^ encrypted packet (exact length required)
  -> Either Error (BS.ByteString, Session)
decrypt :: Session -> ByteString -> Either Error (ByteString, Session)
decrypt Session
sess ByteString
packet = do
  (pt, remainder, sess') <- Session
-> ByteString -> Either Error (ByteString, ByteString, Session)
decrypt_frame Session
sess ByteString
packet
  require (BS.null remainder) InvalidLength
  pure (pt, sess')

-- | Decrypt a single frame from a buffer, returning the remainder.
--
--   Returns the plaintext, any unconsumed bytes, and the updated session.
--   Key rotation is handled automatically every 1000 messages.
--
--   This is useful for streaming scenarios where multiple messages may
--   be buffered together. The remainder can be passed to the next call
--   to 'decrypt_frame'.
--
--   Wire format consumed: encrypted_length (18) || encrypted_body (len + 16)
--
--   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--   >>> let Right (msg3, i_result) = act3 i_hs msg2
--   >>> let Right r_result = finalize r_hs msg3
--   >>> let Right (ct, _) = encrypt (session i_result) "hello"
--   >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
--   ("hello",True)
decrypt_frame
  :: Session
  -> BS.ByteString          -- ^ buffer containing at least one encrypted frame
  -> Either Error (BS.ByteString, BS.ByteString, Session)
decrypt_frame :: Session
-> ByteString -> Either Error (ByteString, ByteString, Session)
decrypt_frame Session
sess ByteString
packet = do
  Bool -> Error -> Either Error ()
forall e. Bool -> e -> Either e ()
require (ByteString -> Int
BS.length ByteString
packet Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
34) Error
InvalidLength
  let !lc :: ByteString
lc = Int -> ByteString -> ByteString
BS.take Int
18 ByteString
packet
      !rest :: ByteString
rest = Int -> ByteString -> ByteString
BS.drop Int
18 ByteString
packet
  len_bytes <- Error -> Maybe ByteString -> Either Error ByteString
forall e a. e -> Maybe a -> Either e a
note Error
InvalidMAC (ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
decrypt_with_ad (Session -> ByteString
sess_rk Session
sess) (Session -> Word64
sess_rn Session
sess)
                                  ByteString
BS.empty ByteString
lc)
  len <- note InvalidLength (decode_be16 len_bytes)
  let !(rn1, rck1, rk1) = step_nonce (sess_rn sess) (sess_rck sess) (sess_rk sess)
      !body_len = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fi Word16
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16
  require (BS.length rest >= body_len) InvalidLength
  let !bc = Int -> ByteString -> ByteString
BS.take Int
body_len ByteString
rest
      !remainder = Int -> ByteString -> ByteString
BS.drop Int
body_len ByteString
rest
  pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc)
  let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
      !sess' = Session
sess {
        sess_rk  = rk2
      , sess_rn  = rn2
      , sess_rck = rck2
      }
  pure (pt, remainder, sess')

-- | Decrypt a frame from a partial buffer, indicating when more data needed.
--
--   Unlike 'decrypt_frame', this function handles incomplete buffers
--   gracefully by returning 'NeedMore' with the number of additional
--   bytes required to make progress.
--
--   * If the buffer has fewer than 18 bytes (encrypted length + MAC),
--     returns @'NeedMore' n@ where @n@ is the bytes still needed.
--   * If the length header is complete but the body is incomplete,
--     returns @'NeedMore' n@ with bytes needed for the full frame.
--   * MAC or decryption failures return 'FrameError'.
--   * A complete, valid frame returns 'FrameOk' with plaintext,
--     remainder, and updated session.
--
--   This is useful for non-blocking I/O where data arrives incrementally.
decrypt_frame_partial
  :: Session
  -> BS.ByteString  -- ^ buffer (possibly incomplete)
  -> FrameResult
decrypt_frame_partial :: Session -> ByteString -> FrameResult
decrypt_frame_partial Session
sess ByteString
buf
  | Int
buflen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
18 = Int -> FrameResult
NeedMore (Int
18 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
buflen)
  | Bool
otherwise =
      let !lc :: ByteString
lc = Int -> ByteString -> ByteString
BS.take Int
18 ByteString
buf
          !rest :: ByteString
rest = Int -> ByteString -> ByteString
BS.drop Int
18 ByteString
buf
      in case ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
decrypt_with_ad (Session -> ByteString
sess_rk Session
sess) (Session -> Word64
sess_rn Session
sess) ByteString
BS.empty ByteString
lc of
           Maybe ByteString
Nothing -> Error -> FrameResult
FrameError Error
InvalidMAC
           Just ByteString
len_bytes -> case ByteString -> Maybe Word16
decode_be16 ByteString
len_bytes of
             Maybe Word16
Nothing -> Error -> FrameResult
FrameError Error
InvalidLength
             Just Word16
len ->
               let !body_len :: Int
body_len = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fi Word16
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16
                   !(Word64
rn1, ByteString
rck1, ByteString
rk1) = Word64
-> ByteString -> ByteString -> (Word64, ByteString, ByteString)
step_nonce (Session -> Word64
sess_rn Session
sess)
                                        (Session -> ByteString
sess_rck Session
sess) (Session -> ByteString
sess_rk Session
sess)
               in if ByteString -> Int
BS.length ByteString
rest Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
body_len
                    then Int -> FrameResult
NeedMore (Int
body_len Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
rest)
                    else
                      let !bc :: ByteString
bc = Int -> ByteString -> ByteString
BS.take Int
body_len ByteString
rest
                          !remainder :: ByteString
remainder = Int -> ByteString -> ByteString
BS.drop Int
body_len ByteString
rest
                      in case ByteString
-> Word64 -> ByteString -> ByteString -> Maybe ByteString
decrypt_with_ad ByteString
rk1 Word64
rn1 ByteString
BS.empty ByteString
bc of
                           Maybe ByteString
Nothing -> Error -> FrameResult
FrameError Error
InvalidMAC
                           Just ByteString
pt ->
                             let !(Word64
rn2, ByteString
rck2, ByteString
rk2) = Word64
-> ByteString -> ByteString -> (Word64, ByteString, ByteString)
step_nonce Word64
rn1 ByteString
rck1 ByteString
rk1
                                 !sess' :: Session
sess' = Session
sess {
                                   sess_rk  = rk2
                                 , sess_rn  = rn2
                                 , sess_rck = rck2
                                 }
                             in ByteString -> ByteString -> Session -> FrameResult
FrameOk ByteString
pt ByteString
remainder Session
sess'
  where
    !buflen :: Int
buflen = ByteString -> Int
BS.length ByteString
buf

-- key rotation --------------------------------------------------------------

-- Key rotation occurs after nonce reaches 1000 (i.e., before using 1000)
-- (ck', k') = HKDF(ck, k), reset nonce to 0
step_nonce
  :: Word64
  -> BS.ByteString
  -> BS.ByteString
  -> (Word64, BS.ByteString, BS.ByteString)
step_nonce :: Word64
-> ByteString -> ByteString -> (Word64, ByteString, ByteString)
step_nonce Word64
n ByteString
ck ByteString
k
  | Word64
n Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1 Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
1000 =
      let !(ByteString
ck', ByteString
k') = ByteString -> ByteString -> (ByteString, ByteString)
mix_key ByteString
ck ByteString
k
      in (Word64
0, ByteString
ck', ByteString
k')
  | Bool
otherwise = (Word64
n Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1, ByteString
ck, ByteString
k)

-- utilities -----------------------------------------------------------------

-- Lift Maybe to Either
note :: e -> Maybe a -> Either e a
note :: forall e a. e -> Maybe a -> Either e a
note e
e = Either e a -> (a -> Either e a) -> Maybe a -> Either e a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (e -> Either e a
forall a b. a -> Either a b
Left e
e) a -> Either e a
forall a b. b -> Either a b
Right
{-# INLINE note #-}

-- Require condition or fail
require :: Bool -> e -> Either e ()
require :: forall e. Bool -> e -> Either e ()
require Bool
cond e
e = Bool -> Either e () -> Either e ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cond (e -> Either e ()
forall a b. a -> Either a b
Left e
e)
{-# INLINE require #-}

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}