{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Crypto.AEAD.ChaCha20Poly1305
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure AEAD-ChaCha20-Poly1305 implementation, as specified by
-- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).

module Crypto.AEAD.ChaCha20Poly1305 (
    -- * AEAD construction
    encrypt
  , decrypt

    -- * Error information
  , Error(..)

    -- testing
  , _poly1305_key_gen
  ) where

import qualified Crypto.Cipher.ChaCha20 as ChaCha20
import qualified Crypto.MAC.Poly1305 as Poly1305
import Data.Bits ((.>>.))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import Data.Word (Word64)

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- little-endian bytestring encoding
unroll :: Word64 -> BS.ByteString
unroll :: Word64 -> ByteString
unroll Word64
i = case Word64
i of
    Word64
0 -> Word8 -> ByteString
BS.singleton Word8
0
    Word64
_ -> (Word64 -> Maybe (Word8, Word64)) -> Word64 -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
BS.unfoldr Word64 -> Maybe (Word8, Word64)
coalg Word64
i
  where
    coalg :: Word64 -> Maybe (Word8, Word64)
coalg = \case
      Word64
0 -> Maybe (Word8, Word64)
forall a. Maybe a
Nothing
      Word64
m -> (Word8, Word64) -> Maybe (Word8, Word64)
forall a. a -> Maybe a
Just ((Word8, Word64) -> Maybe (Word8, Word64))
-> (Word8, Word64) -> Maybe (Word8, Word64)
forall a b. (a -> b) -> a -> b
$! (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi Word64
m, Word64
m Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
.>>. Int
8)
{-# INLINE unroll #-}

-- little-endian bytestring encoding for 64-bit ints, right-padding with zeros
unroll8 :: Word64 -> BS.ByteString
unroll8 :: Word64 -> ByteString
unroll8 (Word64 -> ByteString
unroll -> u :: ByteString
u@(BI.PS ForeignPtr Word8
_ Int
_ Int
l))
  | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
8 = ByteString
u ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
BS.replicate (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) Word8
0
  | Bool
otherwise = ByteString
u
{-# INLINE unroll8 #-}

-- RFC8439 2.6

_poly1305_key_gen
  :: BS.ByteString -- ^ 256-bit initial keying material
  -> BS.ByteString -- ^ 96-bit nonce
  -> Either Error BS.ByteString -- ^ 256-bit key (suitable for poly1305)
_poly1305_key_gen :: ByteString -> ByteString -> Either Error ByteString
_poly1305_key_gen ByteString
key ByteString
nonce = case ByteString -> Word32 -> ByteString -> Either Error ByteString
ChaCha20.block ByteString
key Word32
0 ByteString
nonce of
  Left Error
ChaCha20.InvalidKey -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidKey
  Left Error
ChaCha20.InvalidNonce -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidNonce
  Right ByteString
k -> ByteString -> Either Error ByteString
forall a. a -> Either Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
32 ByteString
k)
{-# INLINEABLE _poly1305_key_gen #-}

pad16 :: BS.ByteString -> BS.ByteString
pad16 :: ByteString -> ByteString
pad16 (BI.PS ForeignPtr Word8
_ Int
_ Int
l)
  | Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
16 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString
forall a. Monoid a => a
mempty
  | Bool
otherwise = Int -> Word8 -> ByteString
BS.replicate (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
16) Word8
0
{-# INLINE pad16 #-}

data Error =
    InvalidKey
  | InvalidNonce
  | InvalidMAC
  deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
/= :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Error -> ShowS
showsPrec :: Int -> Error -> ShowS
$cshow :: Error -> String
show :: Error -> String
$cshowList :: [Error] -> ShowS
showList :: [Error] -> ShowS
Show)

-- RFC8439 2.8

-- | Perform authenticated encryption on a plaintext and some additional
--   authenticated data, given a 256-bit key and 96-bit nonce, using
--   AEAD-ChaCha20-Poly1305.
--
--   Produces a ciphertext and 128-bit message authentication code pair.
--
--   >>> let key = "don't tell anyone my secret key!"
--   >>> let non = "or my nonce!"
--   >>> let pan = "and here's my plaintext"
--   >>> let aad = "i approve this message"
--   >>> let Right (cip, mac) = encrypt aad key nonce pan
--   >>> (cip, mac)
--   <(ciphertext, 128-bit MAC)>
encrypt
  :: BS.ByteString -- ^ arbitrary-length additional authenticated data
  -> BS.ByteString -- ^ 256-bit key
  -> BS.ByteString -- ^ 96-bit nonce
  -> BS.ByteString -- ^ arbitrary-length plaintext
  -> Either Error (BS.ByteString, BS.ByteString) -- ^ (ciphertext, 128-bit MAC)
encrypt :: ByteString
-> ByteString
-> ByteString
-> ByteString
-> Either Error (ByteString, ByteString)
encrypt ByteString
aad ByteString
key ByteString
nonce ByteString
plaintext
  | ByteString -> Int
BS.length ByteString
key  Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32  = Error -> Either Error (ByteString, ByteString)
forall a b. a -> Either a b
Left Error
InvalidKey
  | ByteString -> Int
BS.length ByteString
nonce Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = Error -> Either Error (ByteString, ByteString)
forall a b. a -> Either a b
Left Error
InvalidNonce
  | Bool
otherwise = do
      ByteString
otk <- ByteString -> ByteString -> Either Error ByteString
_poly1305_key_gen ByteString
key ByteString
nonce
      case ByteString
-> Word32 -> ByteString -> ByteString -> Either Error ByteString
ChaCha20.cipher ByteString
key Word32
1 ByteString
nonce ByteString
plaintext of
        Left Error
ChaCha20.InvalidKey -> Error -> Either Error (ByteString, ByteString)
forall a b. a -> Either a b
Left Error
InvalidKey     -- impossible, but..
        Left Error
ChaCha20.InvalidNonce -> Error -> Either Error (ByteString, ByteString)
forall a b. a -> Either a b
Left Error
InvalidNonce -- ditto
        Right ByteString
cip -> do
          let md0 :: ByteString
md0 = ByteString
aad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
aad
              md1 :: ByteString
md1 = ByteString
md0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cip ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
cip
              md2 :: ByteString
md2 = ByteString
md1 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
aad))
              md3 :: ByteString
md3 = ByteString
md2 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
cip))
          case ByteString -> ByteString -> Maybe ByteString
Poly1305.mac ByteString
otk ByteString
md3 of
            Maybe ByteString
Nothing -> Error -> Either Error (ByteString, ByteString)
forall a b. a -> Either a b
Left Error
InvalidKey
            Just ByteString
tag -> (ByteString, ByteString) -> Either Error (ByteString, ByteString)
forall a. a -> Either Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
cip, ByteString
tag)

-- | Decrypt an authenticated ciphertext, given a message authentication
--   code and some additional authenticated data, via a 256-bit key and
--   96-bit nonce.
--
--   >>> decrypt aad key non (cip, mac)
--   Right "and here's my plaintext"
--   >>> decrypt aad key non (cip, "it's a valid mac")
--   Left InvalidMAC
decrypt
  :: BS.ByteString                  -- ^ arbitrary-length AAD
  -> BS.ByteString                  -- ^ 256-bit key
  -> BS.ByteString                  -- ^ 96-bit nonce
  -> (BS.ByteString, BS.ByteString) -- ^ (arbitrary-length ciphertext, 128-bit MAC)
  -> Either Error BS.ByteString
decrypt :: ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
-> Either Error ByteString
decrypt ByteString
aad ByteString
key ByteString
nonce (ByteString
cip, ByteString
mac)
  | ByteString -> Int
BS.length ByteString
key Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32   = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidKey
  | ByteString -> Int
BS.length ByteString
nonce Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidNonce
  | ByteString -> Int
BS.length ByteString
mac Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16   = Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidMAC
  | Bool
otherwise = do
      ByteString
otk <- ByteString -> ByteString -> Either Error ByteString
_poly1305_key_gen ByteString
key ByteString
nonce
      let md0 :: ByteString
md0 = ByteString
aad ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
aad
          md1 :: ByteString
md1 = ByteString
md0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cip ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
pad16 ByteString
cip
          md2 :: ByteString
md2 = ByteString
md1 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
aad))
          md3 :: ByteString
md3 = ByteString
md2 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
unroll8 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi (ByteString -> Int
BS.length ByteString
cip))
      case ByteString -> ByteString -> Maybe ByteString
Poly1305.mac ByteString
otk ByteString
md3 of
        Maybe ByteString
Nothing -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidKey
        Just ByteString
tag
          | ByteString
mac ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag -> case ByteString
-> Word32 -> ByteString -> ByteString -> Either Error ByteString
ChaCha20.cipher ByteString
key Word32
1 ByteString
nonce ByteString
cip of
              Left Error
ChaCha20.InvalidKey -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidKey
              Left Error
ChaCha20.InvalidNonce -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidNonce
              Right ByteString
v -> ByteString -> Either Error ByteString
forall a. a -> Either Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
v
          | Bool
otherwise ->
              Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
InvalidMAC