{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE BangPatterns #-}

-- |
-- Module: Crypto.Cipher.ChaCha20.Arm
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- ARM NEON support for the ChaCha20 stream cipher.

module Crypto.Cipher.ChaCha20.Arm (
    chacha20_arm_available
  , block
  , cipher
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import Data.Word (Word8, Word32)
import Foreign.C.Types (CInt(..), CSize(..))
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, plusPtr)
import System.IO.Unsafe (unsafeDupablePerformIO)

-- ffi ------------------------------------------------------------------------

foreign import ccall unsafe "chacha20_block_arm"
  c_chacha20_block
    :: Ptr Word8 -> Word32 -> Ptr Word8 -> Ptr Word8 -> IO ()

foreign import ccall unsafe "chacha20_cipher_arm"
  c_chacha20_cipher
    :: Ptr Word8 -> Word32 -> Ptr Word8
    -> Ptr Word8 -> Ptr Word8 -> CSize -> IO ()

foreign import ccall unsafe "chacha20_arm_available"
  c_chacha20_arm_available :: IO CInt

-- utilities ------------------------------------------------------------------

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- api ------------------------------------------------------------------------

-- | Are ARM NEON extensions available?
chacha20_arm_available :: Bool
chacha20_arm_available :: Bool
chacha20_arm_available =
  IO CInt -> CInt
forall a. IO a -> a
unsafeDupablePerformIO IO CInt
c_chacha20_arm_available CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0
{-# NOINLINE chacha20_arm_available #-}

-- | One 64-byte ChaCha20 keystream block for the given (already-
--   validated) key, counter, and nonce.
block :: BS.ByteString -> Word32 -> BS.ByteString -> BS.ByteString
block :: ByteString -> Word32 -> ByteString -> ByteString
block (BI.PS ForeignPtr Word8
kfp Int
koff Int
_) Word32
counter (BI.PS ForeignPtr Word8
nfp Int
noff Int
_) =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BI.unsafeCreate Int
64 ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
kfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
kp0 ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
nfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
np0 ->
      Ptr Word8 -> Word32 -> Ptr Word8 -> Ptr Word8 -> IO ()
c_chacha20_block (Ptr Word8
kp0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
koff)
                       Word32
counter
                       (Ptr Word8
np0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
noff)
                       Ptr Word8
dst

-- | XOR the plaintext with the ChaCha20 keystream derived from the
--   given (already-validated) key, counter, and nonce.
cipher
  :: BS.ByteString -> Word32 -> BS.ByteString -> BS.ByteString
  -> BS.ByteString
cipher :: ByteString -> Word32 -> ByteString -> ByteString -> ByteString
cipher (BI.PS ForeignPtr Word8
kfp Int
koff Int
_) Word32
counter (BI.PS ForeignPtr Word8
nfp Int
noff Int
_)
       (BI.PS ForeignPtr Word8
pfp Int
poff Int
plen) =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BI.unsafeCreate Int
plen ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
kfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
kp0 ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
nfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
np0 ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
pfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
pp0 ->
      Ptr Word8
-> Word32 -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> CSize -> IO ()
c_chacha20_cipher (Ptr Word8
kp0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
koff)
                        Word32
counter
                        (Ptr Word8
np0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
noff)
                        (Ptr Word8
pp0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
poff)
                        Ptr Word8
dst
                        (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fi Int
plen)