{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UnboxedTuples #-}

-- |
-- Module: Crypto.MAC.Poly1305
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure Poly1305 MAC implementation, as specified by
-- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).

module Crypto.MAC.Poly1305 (
    -- * Poly1305 message authentication code
    mac

    -- testing
  , _poly1305_loop
  , _roll16
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Data.Word (Word8)
import Data.Word.Limb (Limb(..))
import qualified Data.Word.Limb as L
import Data.Word.Wider (Wider(..))
import qualified Data.Word.Wider as W
import qualified Foreign.Storable as Storable (pokeByteOff)
import qualified GHC.Exts as Exts
import qualified GHC.Word (Word8(..))

-- utilities ------------------------------------------------------------------

-- convert a Word8 to a Limb
limb :: Word8 -> Limb
limb :: Word8 -> Limb
limb (GHC.Word.W8# (Word8# -> Word#
Exts.word8ToWord# -> Word#
w)) = Word# -> Limb
Limb Word#
w
{-# INLINABLE limb #-}

-- convert a Limb to a Word8
word8 :: Limb -> Word8
word8 :: Limb -> Word8
word8 (Limb Word#
w) = Word8# -> Word8
GHC.Word.W8# (Word# -> Word8#
Exts.wordToWord8# Word#
w)
{-# INLINABLE word8 #-}

-- convert a Limb to a Word8 after right-shifting
word8s :: Limb -> Exts.Int# -> Word8
word8s :: Limb -> Int# -> Word8
word8s Limb
l Int#
s =
  let !(Limb Word#
w) = Limb -> Int# -> Limb
L.shr# Limb
l Int#
s
  in  Word8# -> Word8
GHC.Word.W8# (Word# -> Word8#
Exts.wordToWord8# Word#
w)
{-# INLINABLE word8s #-}

-- 128-bit little-endian bytestring decoding
_roll16 :: BS.ByteString -> Wider
_roll16 :: ByteString -> Wider
_roll16 bs :: ByteString
bs@(BI.PS ForeignPtr Word8
_ Int
_ Int
l) =
  let byte :: Int -> Limb
      byte :: Int -> Limb
byte Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l     = Word8 -> Limb
limb (ByteString -> Int -> Word8
BU.unsafeIndex ByteString
bs Int
i)
        | Bool
otherwise = Word# -> Limb
Limb Word#
0##
      {-# INLINE byte #-}
      !w0 :: Limb
w0 =     (Int -> Limb
byte Int
07 Limb -> Int# -> Limb
`L.shl#` Int#
56#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
06 Limb -> Int# -> Limb
`L.shl#` Int#
48#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
05 Limb -> Int# -> Limb
`L.shl#` Int#
40#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
04 Limb -> Int# -> Limb
`L.shl#` Int#
32#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
03 Limb -> Int# -> Limb
`L.shl#` Int#
24#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
02 Limb -> Int# -> Limb
`L.shl#` Int#
16#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
01 Limb -> Int# -> Limb
`L.shl#` Int#
08#)
        Limb -> Limb -> Limb
`L.or#` Int -> Limb
byte Int
00
      !w1 :: Limb
w1 =     (Int -> Limb
byte Int
15 Limb -> Int# -> Limb
`L.shl#` Int#
56#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
14 Limb -> Int# -> Limb
`L.shl#` Int#
48#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
13 Limb -> Int# -> Limb
`L.shl#` Int#
40#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
12 Limb -> Int# -> Limb
`L.shl#` Int#
32#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
11 Limb -> Int# -> Limb
`L.shl#` Int#
24#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
10 Limb -> Int# -> Limb
`L.shl#` Int#
16#)
        Limb -> Limb -> Limb
`L.or#` (Int -> Limb
byte Int
09 Limb -> Int# -> Limb
`L.shl#` Int#
08#)
        Limb -> Limb -> Limb
`L.or#` Int -> Limb
byte Int
08
  in  Limb4 -> Wider
Wider (# Limb
w0, Limb
w1, Word# -> Limb
Limb Word#
0##, Word# -> Limb
Limb Word#
0## #)
{-# INLINE _roll16 #-}

-- 128-bit little-endian bytestring encoding
unroll16 :: Wider -> BS.ByteString
unroll16 :: Wider -> ByteString
unroll16 (Wider (# Limb
w0, Limb
w1, Limb
_, Limb
_ #)) =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BI.unsafeCreate Int
16 ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    -- w0
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
00 (Limb -> Word8
word8 Limb
w0)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
01 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
08#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
02 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
16#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
03 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
24#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
04 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
32#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
05 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
40#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
06 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
48#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
07 (Limb -> Int# -> Word8
word8s Limb
w0 Int#
56#)
    -- w1
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
08 (Limb -> Word8
word8 Limb
w1)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
09 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
08#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
10 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
16#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
11 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
24#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
12 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
32#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
13 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
40#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
14 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
48#)
    Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
Storable.pokeByteOff Ptr Word8
ptr Int
15 (Limb -> Int# -> Word8
word8s Limb
w1 Int#
56#)
{-# INLINABLE unroll16 #-}

-- set high bit for chunk of length l (max 16)
set_hi :: Int -> Wider
set_hi :: Int -> Wider
set_hi Int
l
  | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
8     = Wider -> Int -> Wider
W.shl_limb Wider
1 (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
l)
  | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
16    = Limb4 -> Wider
Wider (# Word# -> Limb
Limb Word#
0##, Limb -> Int# -> Limb
L.shl# (Word# -> Limb
Limb Word#
1##) Int#
s, Word# -> Limb
Limb Word#
0##, Word# -> Limb
Limb Word#
0## #)
  | Bool
otherwise = Limb4 -> Wider
Wider (# Word# -> Limb
Limb Word#
0##, Word# -> Limb
Limb Word#
0##, Word# -> Limb
Limb Word#
1##, Word# -> Limb
Limb Word#
0## #)
  where
    !(Exts.I# Int#
s) = Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8)
{-# INLINE set_hi #-}

-- bespoke constant-time 130-bit right shift
shr130 :: Wider -> Wider
shr130 :: Wider -> Wider
shr130 (Wider (# Limb
_, Limb
_, Limb
l2, Limb
l3 #)) =
  let !r0 :: Limb
r0 = Limb -> Limb -> Limb
L.or# (Limb -> Int# -> Limb
L.shr# Limb
l2 Int#
2#) (Limb -> Int# -> Limb
L.shl# Limb
l3 Int#
62#)
      !r1 :: Limb
r1 = Limb -> Int# -> Limb
L.shr# Limb
l3 Int#
2#
  in  Limb4 -> Wider
Wider (# Limb
r0, Limb
r1, Word# -> Limb
Limb Word#
0##, Word# -> Limb
Limb Word#
0## #)
{-# INLINE shr130 #-}

-------------------------------------------------------------------------------

clamp :: Wider -> Wider
clamp :: Wider -> Wider
clamp Wider
r = Wider
r Wider -> Wider -> Wider
`W.and` Wider
0x0ffffffc0ffffffc0ffffffc0fffffff
{-# INLINE clamp #-}

-- | Produce a Poly1305 MAC for the provided message, given the provided
--   key.
--
--   Per RFC8439: the key, which is essentially a /one-time/ key, should
--   be unique, and MUST be unpredictable for each invocation.
--
--   The key must be exactly 256 bits in length.
--
--   >>> mac "i'll never use this key again!!!" "a message needing authentication"
--   Just "O'\231Z\224\149\148\246\203[}\210\203\b\200\207"
mac
  :: BS.ByteString -- ^ 256-bit one-time key
  -> BS.ByteString -- ^ arbitrary-length message
  -> Maybe BS.ByteString -- ^ 128-bit message authentication code
mac :: ByteString -> ByteString -> Maybe ByteString
mac key :: ByteString
key@(BI.PS ForeignPtr Word8
_ Int
_ Int
kl) ByteString
msg
  | Int
kl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32  = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise =
      let (Wider -> Wider
clamp (Wider -> Wider) -> (ByteString -> Wider) -> ByteString -> Wider
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Wider
_roll16 -> Wider
r, ByteString -> Wider
_roll16 -> Wider
s) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
16 ByteString
key
      in  ByteString -> Maybe ByteString
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Wider -> Wider -> ByteString -> ByteString
_poly1305_loop Wider
r Wider
s ByteString
msg)

-- p = 2^130 - 5
--
-- mask for the low 130 bits
mask130 :: Wider
mask130 :: Wider
mask130 = Wider
0x3ffffffffffffffffffffffffffffffff
{-# INLINE mask130 #-}

-- partial reduction to [0, 2 ^ 131)
reduce_partial :: Wider -> Wider
reduce_partial :: Wider -> Wider
reduce_partial Wider
x =
  let !lo :: Wider
lo = Wider
x Wider -> Wider -> Wider
`W.and` Wider
mask130
      !hi :: Wider
hi = Wider -> Wider
shr130 Wider
x
  in  Wider
lo Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
+ Wider
5 Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
* Wider
hi
{-# INLINE reduce_partial #-}

-- [0, 2 ^ 131) -> [0, p)
reduce_full :: Wider -> Wider
reduce_full :: Wider -> Wider
reduce_full Wider
h =
  let !lo :: Wider
lo = Wider
h Wider -> Wider -> Wider
`W.and` Wider
mask130
      !hi :: Wider
hi  = Wider -> Wider
shr130 Wider
h
      !h' :: Wider
h'  = Wider
lo Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
+ Wider
5 Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
* Wider
hi
      !h_5 :: Wider
h_5 = Wider
h' Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
+ Wider
5
      !reduced :: Wider
reduced = Wider
h_5 Wider -> Wider -> Wider
`W.and` Wider
mask130
      !carry :: Wider
carry   = Wider -> Wider
shr130 Wider
h_5
      !gte :: Choice
gte     = Wider -> Wider -> Choice
W.lt Wider
0 Wider
carry
  in  Wider -> Wider -> Choice -> Wider
W.select Wider
h' Wider
reduced Choice
gte
{-# INLINE reduce_full #-}

_poly1305_loop :: Wider -> Wider -> BS.ByteString -> BS.ByteString
_poly1305_loop :: Wider -> Wider -> ByteString -> ByteString
_poly1305_loop !Wider
r !Wider
s !ByteString
msg =
    let loop :: Wider -> ByteString -> ByteString
loop !Wider
acc !ByteString
bs = case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
16 ByteString
bs of
          (chunk :: ByteString
chunk@(BI.PS ForeignPtr Word8
_ Int
_ Int
l), ByteString
etc)
            | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
                let !final :: Wider
final = Wider -> Wider
reduce_full (Wider -> Wider
reduce_partial Wider
acc)
                in  Wider -> ByteString
unroll16 (Wider
final Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
+ Wider
s)
            | Bool
otherwise ->
                let !n :: Wider
n = ByteString -> Wider
_roll16 ByteString
chunk Wider -> Wider -> Wider
`W.or` Int -> Wider
set_hi Int
l
                    !prod :: Wider
prod = Wider
r Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
* (Wider
acc Wider -> Wider -> Wider
forall a. Num a => a -> a -> a
+ Wider
n)
                    !nacc :: Wider
nacc = Wider -> Wider
reduce_partial (Wider -> Wider
reduce_partial Wider
prod)
                in  Wider -> ByteString -> ByteString
loop Wider
nacc ByteString
etc
    in  Wider -> ByteString -> ByteString
loop Wider
0 ByteString
msg
{-# INLINE _poly1305_loop #-}