{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UnboxedTuples #-}

-- |
-- Module: Crypto.Cipher.ChaCha20
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- A pure ChaCha20 implementation, as specified by
-- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).

module Crypto.Cipher.ChaCha20 (
    -- * ChaCha20 stream cipher
    cipher

    -- * ChaCha20 block function
  , block

    -- testing
  , ChaCha(..)
  , _chacha
  , _parse_key
  , _parse_nonce
  , _quarter
  , _quarter_pure
  , _rounds
  ) where

import Control.Monad.ST
import qualified Data.Bits as B
import Data.Bits ((.|.), (.<<.), (.^.))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Foldable (for_)
import qualified Data.Primitive.PrimArray as PA
import Foreign.ForeignPtr
import GHC.Exts
import GHC.Word

-- utils ----------------------------------------------------------------------

-- keystroke saver
fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- parse strict ByteString in LE order to Word32 (verbatim from
-- Data.Binary)
unsafe_word32le :: BS.ByteString -> Word32
unsafe_word32le :: ByteString -> Word32
unsafe_word32le ByteString
s =
  (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fi (ByteString
s ByteString -> Int -> Word8
`BU.unsafeIndex` Int
3) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftL` Int
24) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
  (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fi (ByteString
s ByteString -> Int -> Word8
`BU.unsafeIndex` Int
2) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftL` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
  (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fi (ByteString
s ByteString -> Int -> Word8
`BU.unsafeIndex` Int
1) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftL`  Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
  (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fi (ByteString
s ByteString -> Int -> Word8
`BU.unsafeIndex` Int
0))
{-# INLINE unsafe_word32le #-}

data WSPair = WSPair
  {-# UNPACK #-} !Word32
  {-# UNPACK #-} !BS.ByteString

-- variant of Data.ByteString.splitAt that behaves like an incremental
-- Word32 parser
unsafe_parseWsPair :: BS.ByteString -> WSPair
unsafe_parseWsPair :: ByteString -> WSPair
unsafe_parseWsPair (BI.BS ForeignPtr Word8
x Int
l) =
  Word32 -> ByteString -> WSPair
WSPair (ByteString -> Word32
unsafe_word32le (ForeignPtr Word8 -> Int -> ByteString
BI.BS ForeignPtr Word8
x Int
4)) (ForeignPtr Word8 -> Int -> ByteString
BI.BS (ForeignPtr Word8 -> Int -> ForeignPtr Word8
forall a b. ForeignPtr a -> Int -> ForeignPtr b
plusForeignPtr ForeignPtr Word8
x Int
4) (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
4))
{-# INLINE unsafe_parseWsPair #-}

-- chacha quarter round -------------------------------------------------------

-- RFC8439 2.2
_quarter
  :: PrimMonad m
  => ChaCha (PrimState m)
  -> Int
  -> Int
  -> Int
  -> Int
  -> m ()
_quarter :: forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter (ChaCha MutablePrimArray (PrimState m) Word32
m) Int
i0 Int
i1 Int
i2 Int
i3 = do
  !(W32# Word32#
a) <- MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
i0
  !(W32# Word32#
b) <- MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
i1
  !(W32# Word32#
c) <- MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
i2
  !(W32# Word32#
d) <- MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
i3

  let !(# Word32#
a1, Word32#
b1, Word32#
c1, Word32#
d1 #) = Word32#
-> Word32#
-> Word32#
-> Word32#
-> (# Word32#, Word32#, Word32#, Word32# #)
quarter# Word32#
a Word32#
b Word32#
c Word32#
d

  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
m Int
i0 (Word32# -> Word32
W32# Word32#
a1)
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
m Int
i1 (Word32# -> Word32
W32# Word32#
b1)
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
m Int
i2 (Word32# -> Word32
W32# Word32#
c1)
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
m Int
i3 (Word32# -> Word32
W32# Word32#
d1)
{-# INLINEABLE _quarter #-}

_quarter_pure
  :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
_quarter_pure :: Word32
-> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
_quarter_pure (W32# Word32#
a) (W32# Word32#
b) (W32# Word32#
c) (W32# Word32#
d) =
  let !(# Word32#
a', Word32#
b', Word32#
c', Word32#
d' #) = Word32#
-> Word32#
-> Word32#
-> Word32#
-> (# Word32#, Word32#, Word32#, Word32# #)
quarter# Word32#
a Word32#
b Word32#
c Word32#
d
  in  (Word32# -> Word32
W32# Word32#
a', Word32# -> Word32
W32# Word32#
b', Word32# -> Word32
W32# Word32#
c', Word32# -> Word32
W32# Word32#
d')
{-# INLINE _quarter_pure #-}

-- RFC8439 2.1
quarter#
  :: Word32# -> Word32# -> Word32# -> Word32#
  -> (# Word32#, Word32#, Word32#, Word32# #)
quarter# :: Word32#
-> Word32#
-> Word32#
-> Word32#
-> (# Word32#, Word32#, Word32#, Word32# #)
quarter# Word32#
a Word32#
b Word32#
c Word32#
d =
  let a0 :: Word32#
a0 = Word32# -> Word32# -> Word32#
plusWord32# Word32#
a Word32#
b
      d0 :: Word32#
d0 = Word32# -> Word32# -> Word32#
xorWord32# Word32#
d Word32#
a0
      d1 :: Word32#
d1 = Word32# -> Int# -> Word32#
rotateL# Word32#
d0 Int#
16#

      c0 :: Word32#
c0 = Word32# -> Word32# -> Word32#
plusWord32# Word32#
c Word32#
d1
      b0 :: Word32#
b0 = Word32# -> Word32# -> Word32#
xorWord32# Word32#
b Word32#
c0
      b1 :: Word32#
b1 = Word32# -> Int# -> Word32#
rotateL# Word32#
b0 Int#
12#

      a1 :: Word32#
a1 = Word32# -> Word32# -> Word32#
plusWord32# Word32#
a0 Word32#
b1
      d2 :: Word32#
d2 = Word32# -> Word32# -> Word32#
xorWord32# Word32#
d1 Word32#
a1
      d3 :: Word32#
d3 = Word32# -> Int# -> Word32#
rotateL# Word32#
d2 Int#
8#

      c1 :: Word32#
c1 = Word32# -> Word32# -> Word32#
plusWord32# Word32#
c0 Word32#
d3
      b2 :: Word32#
b2 = Word32# -> Word32# -> Word32#
xorWord32# Word32#
b1 Word32#
c1
      b3 :: Word32#
b3 = Word32# -> Int# -> Word32#
rotateL# Word32#
b2 Int#
7#

  in  (# Word32#
a1, Word32#
b3, Word32#
c1, Word32#
d3 #)
{-# INLINE quarter# #-}

rotateL# :: Word32# -> Int# -> Word32#
rotateL# :: Word32# -> Int# -> Word32#
rotateL# Word32#
w Int#
i
  | Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
==# Int#
0#) = Word32#
w
  | Bool
otherwise = Word# -> Word32#
wordToWord32# (
            ((Word32# -> Word#
word32ToWord# Word32#
w) Word# -> Int# -> Word#
`uncheckedShiftL#` Int#
i)
      Word# -> Word# -> Word#
`or#` ((Word32# -> Word#
word32ToWord# Word32#
w) Word# -> Int# -> Word#
`uncheckedShiftRL#` (Int#
32# Int# -> Int# -> Int#
-# Int#
i)))
{-# INLINE rotateL# #-}

-- key and nonce parsing ------------------------------------------------------

data Key = Key {
    Key -> Word32
k0 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k1 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k2 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k3 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k4 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k5 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k6 :: {-# UNPACK #-} !Word32
  , Key -> Word32
k7 :: {-# UNPACK #-} !Word32
  }
  deriving (Key -> Key -> Bool
(Key -> Key -> Bool) -> (Key -> Key -> Bool) -> Eq Key
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Key -> Key -> Bool
== :: Key -> Key -> Bool
$c/= :: Key -> Key -> Bool
/= :: Key -> Key -> Bool
Eq, Int -> Key -> ShowS
[Key] -> ShowS
Key -> String
(Int -> Key -> ShowS)
-> (Key -> String) -> ([Key] -> ShowS) -> Show Key
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Key -> ShowS
showsPrec :: Int -> Key -> ShowS
$cshow :: Key -> String
show :: Key -> String
$cshowList :: [Key] -> ShowS
showList :: [Key] -> ShowS
Show)

-- parse strict 256-bit bytestring (length unchecked) to key
_parse_key :: BS.ByteString -> Key
_parse_key :: ByteString -> Key
_parse_key ByteString
bs =
  let !(WSPair Word32
k0 ByteString
t0) = ByteString -> WSPair
unsafe_parseWsPair ByteString
bs
      !(WSPair Word32
k1 ByteString
t1) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t0
      !(WSPair Word32
k2 ByteString
t2) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t1
      !(WSPair Word32
k3 ByteString
t3) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t2
      !(WSPair Word32
k4 ByteString
t4) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t3
      !(WSPair Word32
k5 ByteString
t5) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t4
      !(WSPair Word32
k6 ByteString
t6) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t5
      !(WSPair Word32
k7 ByteString
t7) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t6
  in  if   ByteString -> Bool
BS.null ByteString
t7
      then Key {Word32
k0 :: Word32
k1 :: Word32
k2 :: Word32
k3 :: Word32
k4 :: Word32
k5 :: Word32
k6 :: Word32
k7 :: Word32
k0 :: Word32
k1 :: Word32
k2 :: Word32
k3 :: Word32
k4 :: Word32
k5 :: Word32
k6 :: Word32
k7 :: Word32
..}
      else String -> Key
forall a. HasCallStack => String -> a
error String
"ppad-chacha (_parse_key): bytes remaining"

data Nonce = Nonce {
    Nonce -> Word32
n0 :: {-# UNPACK #-} !Word32
  , Nonce -> Word32
n1 :: {-# UNPACK #-} !Word32
  , Nonce -> Word32
n2 :: {-# UNPACK #-} !Word32
  }
  deriving (Nonce -> Nonce -> Bool
(Nonce -> Nonce -> Bool) -> (Nonce -> Nonce -> Bool) -> Eq Nonce
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Nonce -> Nonce -> Bool
== :: Nonce -> Nonce -> Bool
$c/= :: Nonce -> Nonce -> Bool
/= :: Nonce -> Nonce -> Bool
Eq, Int -> Nonce -> ShowS
[Nonce] -> ShowS
Nonce -> String
(Int -> Nonce -> ShowS)
-> (Nonce -> String) -> ([Nonce] -> ShowS) -> Show Nonce
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Nonce -> ShowS
showsPrec :: Int -> Nonce -> ShowS
$cshow :: Nonce -> String
show :: Nonce -> String
$cshowList :: [Nonce] -> ShowS
showList :: [Nonce] -> ShowS
Show)

-- parse strict 96-bit bytestring (length unchecked) to nonce
_parse_nonce :: BS.ByteString -> Nonce
_parse_nonce :: ByteString -> Nonce
_parse_nonce ByteString
bs =
  let !(WSPair Word32
n0 ByteString
t0) = ByteString -> WSPair
unsafe_parseWsPair ByteString
bs
      !(WSPair Word32
n1 ByteString
t1) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t0
      !(WSPair Word32
n2 ByteString
t2) = ByteString -> WSPair
unsafe_parseWsPair ByteString
t1
  in  if   ByteString -> Bool
BS.null ByteString
t2
      then Nonce {Word32
n0 :: Word32
n1 :: Word32
n2 :: Word32
n0 :: Word32
n1 :: Word32
n2 :: Word32
..}
      else String -> Nonce
forall a. HasCallStack => String -> a
error String
"ppad-chacha (_parse_nonce): bytes remaining"

-- chacha20 block function ----------------------------------------------------

newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
  deriving ChaCha s -> ChaCha s -> Bool
(ChaCha s -> ChaCha s -> Bool)
-> (ChaCha s -> ChaCha s -> Bool) -> Eq (ChaCha s)
forall s. ChaCha s -> ChaCha s -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall s. ChaCha s -> ChaCha s -> Bool
== :: ChaCha s -> ChaCha s -> Bool
$c/= :: forall s. ChaCha s -> ChaCha s -> Bool
/= :: ChaCha s -> ChaCha s -> Bool
Eq

_chacha
  :: PrimMonad m
  => Key
  -> Word32
  -> Nonce
  -> m (ChaCha (PrimState m))
_chacha :: forall (m :: * -> *).
PrimMonad m =>
Key -> Word32 -> Nonce -> m (ChaCha (PrimState m))
_chacha Key
key Word32
counter Nonce
nonce = do
  ChaCha (PrimState m)
state <- m (ChaCha (PrimState m))
forall (m :: * -> *). PrimMonad m => m (ChaCha (PrimState m))
_chacha_alloc
  ChaCha (PrimState m) -> Key -> Word32 -> Nonce -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Key -> Word32 -> Nonce -> m ()
_chacha_set ChaCha (PrimState m)
state Key
key Word32
counter Nonce
nonce
  ChaCha (PrimState m) -> m (ChaCha (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ChaCha (PrimState m)
state

-- allocate a new chacha state
_chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m))
_chacha_alloc :: forall (m :: * -> *). PrimMonad m => m (ChaCha (PrimState m))
_chacha_alloc = (MutablePrimArray (PrimState m) Word32 -> ChaCha (PrimState m))
-> m (MutablePrimArray (PrimState m) Word32)
-> m (ChaCha (PrimState m))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MutablePrimArray (PrimState m) Word32 -> ChaCha (PrimState m)
forall s. MutablePrimArray s Word32 -> ChaCha s
ChaCha (Int -> m (MutablePrimArray (PrimState m) Word32)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
PA.newPrimArray Int
16)
{-# INLINE _chacha_alloc #-}

-- set the values of a chacha state
_chacha_set
  :: PrimMonad m
  => ChaCha (PrimState m)
  -> Key
  -> Word32
  -> Nonce
  -> m ()
_chacha_set :: forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Key -> Word32 -> Nonce -> m ()
_chacha_set (ChaCha MutablePrimArray (PrimState m) Word32
arr) Key {Word32
k0 :: Key -> Word32
k1 :: Key -> Word32
k2 :: Key -> Word32
k3 :: Key -> Word32
k4 :: Key -> Word32
k5 :: Key -> Word32
k6 :: Key -> Word32
k7 :: Key -> Word32
k0 :: Word32
k1 :: Word32
k2 :: Word32
k3 :: Word32
k4 :: Word32
k5 :: Word32
k6 :: Word32
k7 :: Word32
..} Word32
counter Nonce {Word32
n0 :: Nonce -> Word32
n1 :: Nonce -> Word32
n2 :: Nonce -> Word32
n0 :: Word32
n1 :: Word32
n2 :: Word32
..}= do
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
00 Word32
0x61707865
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
01 Word32
0x3320646e
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
02 Word32
0x79622d32
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
03 Word32
0x6b206574
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
04 Word32
k0
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
05 Word32
k1
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
06 Word32
k2
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
07 Word32
k3
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
08 Word32
k4
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
09 Word32
k5
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
10 Word32
k6
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
11 Word32
k7
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
12 Word32
counter
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
13 Word32
n0
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
14 Word32
n1
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
15 Word32
n2
{-# INLINEABLE _chacha_set #-}

_chacha_counter
  :: PrimMonad m
  => ChaCha (PrimState m)
  -> Word32
  -> m ()
_chacha_counter :: forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Word32 -> m ()
_chacha_counter (ChaCha MutablePrimArray (PrimState m) Word32
arr) Word32
counter =
  MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
arr Int
12 Word32
counter

-- two full rounds (eight quarter rounds)
_rounds :: PrimMonad m => ChaCha (PrimState m) -> m ()
_rounds :: forall (m :: * -> *). PrimMonad m => ChaCha (PrimState m) -> m ()
_rounds ChaCha (PrimState m)
state = do
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
00 Int
04 Int
08 Int
12
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
01 Int
05 Int
09 Int
13
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
02 Int
06 Int
10 Int
14
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
03 Int
07 Int
11 Int
15
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
00 Int
05 Int
10 Int
15
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
01 Int
06 Int
11 Int
12
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
02 Int
07 Int
08 Int
13
  ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Int -> Int -> Int -> Int -> m ()
_quarter ChaCha (PrimState m)
state Int
03 Int
04 Int
09 Int
14
{-# INLINEABLE _rounds #-}

_block
  :: PrimMonad m
  => ChaCha (PrimState m)
  -> Word32
  -> m BS.ByteString
_block :: forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Word32 -> m ByteString
_block state :: ChaCha (PrimState m)
state@(ChaCha MutablePrimArray (PrimState m) Word32
s) Word32
counter = do
  ChaCha (PrimState m) -> Word32 -> m ()
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Word32 -> m ()
_chacha_counter ChaCha (PrimState m)
state Word32
counter
  PrimArray Word32
i <- MutablePrimArray (PrimState m) Word32
-> Int -> Int -> m (PrimArray Word32)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> m (PrimArray a)
PA.freezePrimArray MutablePrimArray (PrimState m) Word32
s Int
0 Int
16
  [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
1..Int
10 :: Int] (m () -> Int -> m ()
forall a b. a -> b -> a
const (ChaCha (PrimState m) -> m ()
forall (m :: * -> *). PrimMonad m => ChaCha (PrimState m) -> m ()
_rounds ChaCha (PrimState m)
state))
  [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
15 :: Int] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
idx -> do
    let iv :: Word32
iv = PrimArray Word32 -> Int -> Word32
forall a. Prim a => PrimArray a -> Int -> a
PA.indexPrimArray PrimArray Word32
i Int
idx
    Word32
sv <- MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
s Int
idx
    MutablePrimArray (PrimState m) Word32 -> Int -> Word32 -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray (PrimState m) Word32
s Int
idx (Word32
iv Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
sv)
  ChaCha (PrimState m) -> m ByteString
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> m ByteString
serialize ChaCha (PrimState m)
state

-- RFC8439 2.3

-- | The ChaCha20 block function. Useful for generating a keystream.
--
--   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
--   key must be exactly 256 bits, and the nonce exactly 96 bits.
block
  :: BS.ByteString    -- ^ 256-bit key
  -> Word32           -- ^ 32-bit counter
  -> BS.ByteString    -- ^ 96-bit nonce
  -> BS.ByteString    -- ^ 512-bit keystream
block :: ByteString -> Word32 -> ByteString -> ByteString
block key :: ByteString
key@(BI.PS ForeignPtr Word8
_ Int
_ Int
kl) Word32
counter nonce :: ByteString
nonce@(BI.PS ForeignPtr Word8
_ Int
_ Int
nl)
  | Int
kl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = String -> ByteString
forall a. HasCallStack => String -> a
error String
"ppad-chacha (block): invalid key"
  | Int
nl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = String -> ByteString
forall a. HasCallStack => String -> a
error String
"ppad-chacha (block): invalid nonce"
  | Bool
otherwise = (forall s. ST s ByteString) -> ByteString
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteString) -> ByteString)
-> (forall s. ST s ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ do
      let k :: Key
k = ByteString -> Key
_parse_key ByteString
key
          n :: Nonce
n = ByteString -> Nonce
_parse_nonce ByteString
nonce
      state :: ChaCha s
state@(ChaCha MutablePrimArray s Word32
s) <- Key -> Word32 -> Nonce -> ST s (ChaCha (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Key -> Word32 -> Nonce -> m (ChaCha (PrimState m))
_chacha Key
k Word32
counter Nonce
n
      PrimArray Word32
i <- MutablePrimArray (PrimState (ST s)) Word32
-> Int -> Int -> ST s (PrimArray Word32)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> m (PrimArray a)
PA.freezePrimArray MutablePrimArray s Word32
MutablePrimArray (PrimState (ST s)) Word32
s Int
0 Int
16
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
1..Int
10 :: Int] (ST s () -> Int -> ST s ()
forall a b. a -> b -> a
const (ChaCha (PrimState (ST s)) -> ST s ()
forall (m :: * -> *). PrimMonad m => ChaCha (PrimState m) -> m ()
_rounds ChaCha s
ChaCha (PrimState (ST s))
state))
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
15 :: Int] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
idx -> do
        let iv :: Word32
iv = PrimArray Word32 -> Int -> Word32
forall a. Prim a => PrimArray a -> Int -> a
PA.indexPrimArray PrimArray Word32
i Int
idx
        Word32
sv <- MutablePrimArray (PrimState (ST s)) Word32 -> Int -> ST s Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray s Word32
MutablePrimArray (PrimState (ST s)) Word32
s Int
idx
        MutablePrimArray (PrimState (ST s)) Word32
-> Int -> Word32 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PA.writePrimArray MutablePrimArray s Word32
MutablePrimArray (PrimState (ST s)) Word32
s Int
idx (Word32
iv Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
sv)
      ChaCha (PrimState (ST s)) -> ST s ByteString
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> m ByteString
serialize ChaCha s
ChaCha (PrimState (ST s))
state

serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString
serialize :: forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> m ByteString
serialize (ChaCha MutablePrimArray (PrimState m) Word32
m) = do
    Builder
w64_0 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
00 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
01
    Builder
w64_1 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
02 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
03
    Builder
w64_2 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
04 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
05
    Builder
w64_3 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
06 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
07
    Builder
w64_4 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
08 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
09
    Builder
w64_5 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
10 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
11
    Builder
w64_6 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
12 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
13
    Builder
w64_7 <- Word32 -> Word32 -> Builder
forall {a} {a}. (Integral a, Integral a) => a -> a -> Builder
w64 (Word32 -> Word32 -> Builder) -> m Word32 -> m (Word32 -> Builder)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
14 m (Word32 -> Builder) -> m Word32 -> m Builder
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MutablePrimArray (PrimState m) Word32 -> Int -> m Word32
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
PA.readPrimArray MutablePrimArray (PrimState m) Word32
m Int
15
    ByteString -> m ByteString
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> m ByteString)
-> ([Builder] -> ByteString) -> [Builder] -> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.toStrict (ByteString -> ByteString)
-> ([Builder] -> ByteString) -> [Builder] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
BSB.toLazyByteString (Builder -> ByteString)
-> ([Builder] -> Builder) -> [Builder] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat ([Builder] -> m ByteString) -> [Builder] -> m ByteString
forall a b. (a -> b) -> a -> b
$
      [Builder
w64_0, Builder
w64_1, Builder
w64_2, Builder
w64_3, Builder
w64_4, Builder
w64_5, Builder
w64_6, Builder
w64_7]
  where
    w64 :: a -> a -> Builder
w64 a
a a
b = Word64 -> Builder
BSB.word64LE (a -> Word64
forall a b. (Integral a, Num b) => a -> b
fi a
a Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (a -> Word64
forall a b. (Integral a, Num b) => a -> b
fi a
b Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
.<<. Int
32))

-- chacha20 encryption --------------------------------------------------------

-- RFC8439 2.4

-- | The ChaCha20 stream cipher. Generates a keystream and then XOR's
--   the supplied input with it; use it both to encrypt plaintext and
--   decrypt ciphertext.
--
--   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
--   key must be exactly 256 bits, and the nonce exactly 96 bits.
--
--   >>> let key = "don't tell anyone my secret key!"
--   >>> let non = "or my nonce!"
--   >>> let cip = cipher key 1 non "but you can share the plaintext"
--   >>> cip
--   "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201"
--   >>> cipher key 1 non cip
--   "but you can share the plaintext"
cipher
  :: BS.ByteString    -- ^ 256-bit key
  -> Word32           -- ^ 32-bit counter
  -> BS.ByteString    -- ^ 96-bit nonce
  -> BS.ByteString    -- ^ arbitrary-length plaintext
  -> BS.ByteString    -- ^ ciphertext
cipher :: ByteString -> Word32 -> ByteString -> ByteString -> ByteString
cipher raw_key :: ByteString
raw_key@(BI.PS ForeignPtr Word8
_ Int
_ Int
kl) Word32
counter raw_nonce :: ByteString
raw_nonce@(BI.PS ForeignPtr Word8
_ Int
_ Int
nl) ByteString
plaintext
  | Int
kl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32  = String -> ByteString
forall a. HasCallStack => String -> a
error String
"ppad-chacha (cipher): invalid key"
  | Int
nl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12  = String -> ByteString
forall a. HasCallStack => String -> a
error String
"ppad-chacha (cipher): invalid nonce"
  | Bool
otherwise = (forall s. ST s ByteString) -> ByteString
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteString) -> ByteString)
-> (forall s. ST s ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ do
      let key :: Key
key = ByteString -> Key
_parse_key ByteString
raw_key
          non :: Nonce
non = ByteString -> Nonce
_parse_nonce ByteString
raw_nonce
      Key -> Word32 -> Nonce -> ByteString -> ST s ByteString
forall (m :: * -> *).
PrimMonad m =>
Key -> Word32 -> Nonce -> ByteString -> m ByteString
_cipher Key
key Word32
counter Nonce
non ByteString
plaintext

_cipher
  :: PrimMonad m
  => Key
  -> Word32
  -> Nonce
  -> BS.ByteString
  -> m BS.ByteString
_cipher :: forall (m :: * -> *).
PrimMonad m =>
Key -> Word32 -> Nonce -> ByteString -> m ByteString
_cipher Key
key Word32
counter Nonce
nonce ByteString
plaintext = do
  ChaCha MutablePrimArray (PrimState m) Word32
initial <- Key -> Word32 -> Nonce -> m (ChaCha (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
Key -> Word32 -> Nonce -> m (ChaCha (PrimState m))
_chacha Key
key Word32
counter Nonce
nonce
  state :: ChaCha (PrimState m)
state@(ChaCha MutablePrimArray (PrimState m) Word32
s) <- m (ChaCha (PrimState m))
forall (m :: * -> *). PrimMonad m => m (ChaCha (PrimState m))
_chacha_alloc

  let loop :: Builder -> Word32 -> ByteString -> f ByteString
loop Builder
acc !Word32
j ByteString
bs = case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
64 ByteString
bs of
        (chunk :: ByteString
chunk@(BI.PS ForeignPtr Word8
_ Int
_ Int
l), ByteString
etc)
          | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& ByteString -> Int
BS.length ByteString
etc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 -> ByteString -> f ByteString
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> f ByteString) -> ByteString -> f ByteString
forall a b. (a -> b) -> a -> b
$ -- XX
              ByteString -> ByteString
BS.toStrict (Builder -> ByteString
BSB.toLazyByteString Builder
acc)
          | Bool
otherwise -> do
              MutablePrimArray (PrimState f) Word32
-> Int
-> MutablePrimArray (PrimState f) Word32
-> Int
-> Int
-> f ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a
-> Int -> MutablePrimArray (PrimState m) a -> Int -> Int -> m ()
PA.copyMutablePrimArray MutablePrimArray (PrimState m) Word32
MutablePrimArray (PrimState f) Word32
s Int
0 MutablePrimArray (PrimState m) Word32
MutablePrimArray (PrimState f) Word32
initial Int
0 Int
16
              ByteString
stream <- ChaCha (PrimState f) -> Word32 -> f ByteString
forall (m :: * -> *).
PrimMonad m =>
ChaCha (PrimState m) -> Word32 -> m ByteString
_block ChaCha (PrimState m)
ChaCha (PrimState f)
state Word32
j
              let cip :: ByteString
cip = (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString
BS.packZipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
(.^.) ByteString
chunk ByteString
stream
              Builder -> Word32 -> ByteString -> f ByteString
loop (Builder
acc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BSB.byteString ByteString
cip) (Word32
j Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
1) ByteString
etc

  Builder -> Word32 -> ByteString -> m ByteString
forall {f :: * -> *}.
(PrimState f ~ PrimState m, PrimMonad f) =>
Builder -> Word32 -> ByteString -> f ByteString
loop Builder
forall a. Monoid a => a
mempty Word32
counter ByteString
plaintext
{-# INLINE _cipher #-}