{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE BangPatterns #-}

-- |
-- Module: Crypto.Hash.SHA256.Arm
-- Copyright: (c) 2024 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- ARM crypto extension support for SHA-256.

module Crypto.Hash.SHA256.Arm (
    sha256_arm_available
  , hash_arm
  , hash_arm_with
  ) where

import Control.Monad (unless, when)
import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Data.Word (Word8, Word32, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke, peek)
import Crypto.Hash.SHA256.Internal (unsafe_padding)
import System.IO.Unsafe (unsafePerformIO)

-- ffi -----------------------------------------------------------------------

foreign import ccall unsafe "sha256_block_arm"
  c_sha256_block :: Ptr Word32 -> Ptr Word8 -> IO ()

foreign import ccall unsafe "sha256_arm_available"
  c_sha256_arm_available :: IO Int

-- utilities -----------------------------------------------------------------

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- api -----------------------------------------------------------------------

sha256_arm_available :: Bool
sha256_arm_available :: Bool
sha256_arm_available = IO Int -> Int
forall a. IO a -> a
unsafePerformIO IO Int
c_sha256_arm_available Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
{-# NOINLINE sha256_arm_available #-}

hash_arm :: BS.ByteString -> BS.ByteString
hash_arm :: ByteString -> ByteString
hash_arm = ByteString -> Word64 -> ByteString -> ByteString
hash_arm_with ByteString
forall a. Monoid a => a
mempty Word64
0

-- | Hash with optional 64-byte prefix and extra length for padding.
hash_arm_with
  :: BS.ByteString  -- ^ optional 64-byte prefix (or empty)
  -> Word64         -- ^ extra length to add for padding
  -> BS.ByteString  -- ^ message
  -> BS.ByteString
hash_arm_with :: ByteString -> Word64 -> ByteString -> ByteString
hash_arm_with ByteString
prefix Word64
el m :: ByteString
m@(BI.PS ForeignPtr Word8
fp Int
off Int
l) = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
    Int -> (Ptr Word32 -> IO ByteString) -> IO ByteString
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
32 ((Ptr Word32 -> IO ByteString) -> IO ByteString)
-> (Ptr Word32 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word32
state -> do
      Ptr Word32 -> IO ()
poke_iv Ptr Word32
state
      -- process prefix block if provided
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
prefix) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        let BI.PS ForeignPtr Word8
pfp Int
poff Int
_ = ByteString
prefix
        ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
pfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
          Ptr Word32 -> Ptr Word8 -> IO ()
c_sha256_block Ptr Word32
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
poff)

      Ptr Word32 -> Int -> IO ()
go Ptr Word32
state Int
0

      let !remaining :: ByteString
remaining@(BI.PS ForeignPtr Word8
_ Int
_ Int
rlen) = Int -> ByteString -> ByteString
BU.unsafeDrop (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
64) ByteString
m
          BI.PS ForeignPtr Word8
padfp Int
padoff Int
_ = ByteString -> Word64 -> ByteString
unsafe_padding ByteString
remaining (Word64
el Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi Int
l)
      ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
padfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> do
        Ptr Word32 -> Ptr Word8 -> IO ()
c_sha256_block Ptr Word32
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
padoff)
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rlen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
56) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          Ptr Word32 -> Ptr Word8 -> IO ()
c_sha256_block Ptr Word32
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
padoff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64))

      Ptr Word32 -> IO ByteString
read_state Ptr Word32
state
  where
    go :: Ptr Word32 -> Int -> IO ()
go !Ptr Word32
state !Int
j
      | Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l = do
          ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
            Ptr Word32 -> Ptr Word8 -> IO ()
c_sha256_block Ptr Word32
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j))
          Ptr Word32 -> Int -> IO ()
go Ptr Word32
state (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
64)
      | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- arm helpers ---------------------------------------------------------------

poke_iv :: Ptr Word32 -> IO ()
poke_iv :: Ptr Word32 -> IO ()
poke_iv !Ptr Word32
state = do
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word32
state                (Word32
0x6a09e667 :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
4)  (Word32
0xbb67ae85 :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8)  (Word32
0x3c6ef372 :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
12) (Word32
0xa54ff53a :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
16) (Word32
0x510e527f :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
20) (Word32
0x9b05688c :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
24) (Word32
0x1f83d9ab :: Word32)
  Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word32
state Ptr Word32 -> Int -> Ptr Word32
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
28) (Word32
0x5be0cd19 :: Word32)

read_state :: Ptr Word32 -> IO BS.ByteString
read_state :: Ptr Word32 -> IO ByteString
read_state !Ptr Word32
state = Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BI.create Int
32 ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
out -> do
  h0 <- Ptr Word32 -> IO Word32
forall a. Storable a => Ptr a -> IO a
peek Ptr Word32
state                :: IO Word32
  h1 <- peek (state `plusPtr` 4)  :: IO Word32
  h2 <- peek (state `plusPtr` 8)  :: IO Word32
  h3 <- peek (state `plusPtr` 12) :: IO Word32
  h4 <- peek (state `plusPtr` 16) :: IO Word32
  h5 <- peek (state `plusPtr` 20) :: IO Word32
  h6 <- peek (state `plusPtr` 24) :: IO Word32
  h7 <- peek (state `plusPtr` 28) :: IO Word32
  poke_word32be out 0 h0
  poke_word32be out 4 h1
  poke_word32be out 8 h2
  poke_word32be out 12 h3
  poke_word32be out 16 h4
  poke_word32be out 20 h5
  poke_word32be out 24 h6
  poke_word32be out 28 h7

poke_word32be :: Ptr Word8 -> Int -> Word32 -> IO ()
poke_word32be :: Ptr Word8 -> Int -> Word32 -> IO ()
poke_word32be !Ptr Word8
p !Int
off !Word32
w = do
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)       (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
24) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
16) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word32
w Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
8) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3)) (Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi Word32
w :: Word8)