{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Crypto.AEAD.ChaCha20Poly1305
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure AEAD-ChaCha20-Poly1305 implementation, as specified by
-- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).

module Crypto.AEAD.ChaCha20Poly1305 (
    -- * AEAD construction
    encrypt
  , decrypt

    -- testing
  , _poly1305_key_gen
  ) where

import qualified Crypto.Cipher.ChaCha20 as ChaCha20
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Bits ((.>>.))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import Data.Word (Word64)

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- little-endian bytestring encoding
unroll :: Word64 -> BS.ByteString
unroll :: Word64 -> ByteString
unroll Word64
i = case Word64
i of
    Word64
0 -> Word8 -> ByteString
BS.singleton Word8
0
    Word64
_ -> (Word64 -> Maybe (Word8, Word64)) -> Word64 -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
BS.unfoldr Word64 -> Maybe (Word8, Word64)
coalg Word64
i
  where
    coalg :: Word64 -> Maybe (Word8, Word64)
coalg = \case
      Word64
0 -> Maybe (Word8, Word64)
forall a. Maybe a
Nothing
      Word64
m -> (Word8, Word64) -> Maybe (Word8, Word64)
forall a. a -> Maybe a
Just ((Word8, Word64) -> Maybe (Word8, Word64))
-> (Word8, Word64) -> Maybe (Word8, Word64)
forall a b. (a -> b) -> a -> b
$! (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi Word64
m, Word64
m Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
.>>. Int
8)
{-# INLINE unroll #-}

-- little-endian bytestring encoding for 64-bit ints, right-padding with zeros
unroll8 :: Word64 -> BS.ByteString
unroll8 :: Word64 -> ByteString
unroll8 (Word64 -> ByteString
unroll -> u :: ByteString
u@(BI.PS ForeignPtr Word8
_ Int
_ Int
l))
  | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
8 = ByteString
u ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) Word8
0
  | Bool
otherwise = ByteString
u
{-# INLINE unroll8 #-}

-- RFC8439 2.6

_poly1305_key_gen
  :: BS.ByteString -- ^ 256-bit initial keying material
  -> BS.ByteString -- ^ 96-bit nonce
  -> BS.ByteString -- ^ 256-bit key (suitable for poly1305)
_poly1305_key_gen :: ByteString -> ByteString -> ByteString
_poly1305_key_gen key :: ByteString
key@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) ByteString
nonce
  | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32   = [Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error [Char]
"ppad-aead (poly1305_key_gen): invalid key"
  | Bool
otherwise = Int -> ByteString -> ByteString
BS.take Int
32 (ByteString -> Word32 -> ByteString -> ByteString
ChaCha20.block ByteString
key Word32
0 ByteString
nonce)
{-# INLINEABLE _poly1305_key_gen #-}

pad16 :: BS.ByteString -> BS.ByteString
pad16 :: ByteString -> ByteString
pad16 (BI.PS ForeignPtr Word8
_ Int
_ Int
l)
  | Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
16 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString
forall a. Monoid a => a
mempty
  | Bool
otherwise = Int -> Word8 -> ByteString
BS.replicate (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
16) Word8
0
{-# INLINE pad16 #-}

-- RFC8439 2.8

-- | Perform authenticated encryption on a plaintext and some additional
--   authenticated data, given a 256-bit key and 96-bit nonce, using
--   AEAD-ChaCha20-Poly1305.
--
--   Produces a ciphertext and 128-bit message authentication code pair.
--
--   Providing an invalid key or nonce will result in an 'ErrorCall'
--   exception being thrown.
--
--   >>> let key = "don't tell anyone my secret key!"
--   >>> let non = "or my nonce!"
--   >>> let pan = "and here's my plaintext"
--   >>> let aad = "i approve this message"
--   >>> let (cip, mac) = encrypt aad key nonce pan
--   >>> (cip, mac)
--   <(ciphertext, 128-bit MAC)>
encrypt
  :: BS.ByteString -- ^ arbitrary-length additional authenticated data
  -> BS.ByteString -- ^ 256-bit key
  -> BS.ByteString -- ^ 96-bit nonce
  -> BS.ByteString -- ^ arbitrary-length plaintext
  -> (BS.ByteString, BS.ByteString) -- ^ (ciphertext, 128-bit MAC)
encrypt :: ByteString
-> ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
encrypt ByteString
aad ByteString
key ByteString
nonce ByteString
plaintext
  | ByteString -> Int
BS.length ByteString
key  Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32  = [Char] -> (ByteString, ByteString)
forall a. HasCallStack => [Char] -> a
error [Char]
"ppad-aead (encrypt): invalid key"
  | ByteString -> Int
BS.length ByteString
nonce Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = [Char] -> (ByteString, ByteString)
forall a. HasCallStack => [Char] -> a
error [Char]
"ppad-aead (encrypt): invalid nonce"
  | Bool
otherwise =
      let otk :: ByteString
otk = ByteString -> ByteString -> ByteString
_poly1305_key_gen ByteString
key ByteString
nonce
          cip :: ByteString
cip = ByteString -> Word32 -> ByteString -> ByteString -> ByteString
ChaCha20.cipher ByteString
key Word32
1 ByteString
nonce ByteString
plaintext
          md0 :: ByteString
md0 = ByteString
aad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
aad
          md1 :: ByteString
md1 = ByteString
md0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cip ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
cip
          md2 :: ByteString
md2 = ByteString
md1 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
aad))
          md3 :: ByteString
md3 = ByteString
md2 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
cip))
          tag :: ByteString
tag = ByteString -> ByteString -> ByteString
Poly1305.mac ByteString
otk ByteString
md3
      in  (ByteString
cip, ByteString
tag)

-- | Decrypt an authenticated ciphertext, given a message authentication
--   code and some additional authenticated data, via a 256-bit key and
--   96-bit nonce.
--
--   Returns 'Nothing' if the MAC fails to validate.
--
--   Providing an invalid key or nonce will result in an 'ErrorCall'
--   exception being thrown.
--
--   >>> decrypt aad key non (cip, mac)
--   Just "and here's my plaintext"
--   >>> decrypt aad key non (cip, "it's a valid mac")
--   Nothing
decrypt
  :: BS.ByteString                  -- ^ arbitrary-length AAD
  -> BS.ByteString                  -- ^ 256-bit key
  -> BS.ByteString                  -- ^ 96-bit nonce
  -> (BS.ByteString, BS.ByteString) -- ^ (arbitrary-length ciphertext, 128-bit MAC)
  -> Maybe BS.ByteString
decrypt :: ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
-> Maybe ByteString
decrypt ByteString
aad ByteString
key ByteString
nonce (ByteString
cip, ByteString
mac)
  | ByteString -> Int
BS.length ByteString
key Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32   = [Char] -> Maybe ByteString
forall a. HasCallStack => [Char] -> a
error [Char]
"ppad-aead (decrypt): invalid key"
  | ByteString -> Int
BS.length ByteString
nonce Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = [Char] -> Maybe ByteString
forall a. HasCallStack => [Char] -> a
error [Char]
"ppad-aead (decrypt): invalid nonce"
  | ByteString -> Int
BS.length ByteString
mac Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16   = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let otk :: ByteString
otk = ByteString -> ByteString -> ByteString
_poly1305_key_gen ByteString
key ByteString
nonce
          md0 :: ByteString
md0 = ByteString
aad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
aad
          md1 :: ByteString
md1 = ByteString
md0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cip ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
cip
          md2 :: ByteString
md2 = ByteString
md1 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
aad))
          md3 :: ByteString
md3 = ByteString
md2 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
cip))
          tag :: ByteString
tag = ByteString -> ByteString -> ByteString
Poly1305.mac ByteString
otk ByteString
md3
      in  if   ByteString
mac ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
          then ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Word32 -> ByteString -> ByteString -> ByteString
ChaCha20.cipher ByteString
key Word32
1 ByteString
nonce ByteString
cip)
          else Maybe ByteString
forall a. Maybe a
Nothing