{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE BangPatterns #-}

-- |
-- Module: Crypto.Hash.SHA512.Arm
-- Copyright: (c) 2024 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- ARM crypto extension support for SHA-512.

module Crypto.Hash.SHA512.Arm (
    sha512_arm_available
  , hash_arm
  , hash_arm_with
  ) where

import Control.Monad (unless, when)
import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Data.Word (Word8, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (poke, peek)
import Crypto.Hash.SHA512.Internal (unsafe_padding)
import System.IO.Unsafe (unsafePerformIO)

-- ffi -----------------------------------------------------------------------

foreign import ccall unsafe "sha512_block_arm"
  c_sha512_block :: Ptr Word64 -> Ptr Word8 -> IO ()

foreign import ccall unsafe "sha512_arm_available"
  c_sha512_arm_available :: IO Int

-- utilities -----------------------------------------------------------------

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- api -----------------------------------------------------------------------

sha512_arm_available :: Bool
sha512_arm_available :: Bool
sha512_arm_available = IO Int -> Int
forall a. IO a -> a
unsafePerformIO IO Int
c_sha512_arm_available Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
{-# NOINLINE sha512_arm_available #-}

hash_arm :: BS.ByteString -> BS.ByteString
hash_arm :: ByteString -> ByteString
hash_arm = ByteString -> Word64 -> ByteString -> ByteString
hash_arm_with ByteString
forall a. Monoid a => a
mempty Word64
0

-- | Hash with optional 128-byte prefix and extra length for padding.
hash_arm_with
  :: BS.ByteString  -- ^ optional 128-byte prefix (or empty)
  -> Word64         -- ^ extra length to add for padding
  -> BS.ByteString  -- ^ message
  -> BS.ByteString
hash_arm_with :: ByteString -> Word64 -> ByteString -> ByteString
hash_arm_with ByteString
prefix Word64
el m :: ByteString
m@(BI.PS ForeignPtr Word8
fp Int
off Int
l) = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
    Int -> (Ptr Word64 -> IO ByteString) -> IO ByteString
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
64 ((Ptr Word64 -> IO ByteString) -> IO ByteString)
-> (Ptr Word64 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word64
state -> do
      Ptr Word64 -> IO ()
poke_iv Ptr Word64
state
      -- process prefix block if provided
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
prefix) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        let BI.PS ForeignPtr Word8
pfp Int
poff Int
_ = ByteString
prefix
        ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
pfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
          Ptr Word64 -> Ptr Word8 -> IO ()
c_sha512_block Ptr Word64
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
poff)

      Ptr Word64 -> Int -> IO ()
go Ptr Word64
state Int
0

      let !remaining :: ByteString
remaining@(BI.PS ForeignPtr Word8
_ Int
_ Int
rlen) = Int -> ByteString -> ByteString
BU.unsafeDrop (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
128) ByteString
m
          BI.PS ForeignPtr Word8
padfp Int
padoff Int
_ = ByteString -> Word64 -> ByteString
unsafe_padding ByteString
remaining (Word64
el Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fi Int
l)
      ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
padfp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> do
        Ptr Word64 -> Ptr Word8 -> IO ()
c_sha512_block Ptr Word64
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
padoff)
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rlen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
112) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          Ptr Word64 -> Ptr Word8 -> IO ()
c_sha512_block Ptr Word64
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
padoff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
128))

      Ptr Word64 -> IO ByteString
read_state Ptr Word64
state
  where
    go :: Ptr Word64 -> Int -> IO ()
go !Ptr Word64
state !Int
j
      | Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
128 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l = do
          ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
BI.unsafeWithForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
            Ptr Word64 -> Ptr Word8 -> IO ()
c_sha512_block Ptr Word64
state (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j))
          Ptr Word64 -> Int -> IO ()
go Ptr Word64
state (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
128)
      | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- arm helpers ---------------------------------------------------------------

poke_iv :: Ptr Word64 -> IO ()
poke_iv :: Ptr Word64 -> IO ()
poke_iv !Ptr Word64
state = do
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word64
state                (Word64
0x6a09e667f3bcc908 :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8)  (Word64
0xbb67ae8584caa73b :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
16) (Word64
0x3c6ef372fe94f82b :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
24) (Word64
0xa54ff53a5f1d36f1 :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
32) (Word64
0x510e527fade682d1 :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
40) (Word64
0x9b05688c2b3e6c1f :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
48) (Word64
0x1f83d9abfb41bd6b :: Word64)
  Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word64
state Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
56) (Word64
0x5be0cd19137e2179 :: Word64)

read_state :: Ptr Word64 -> IO BS.ByteString
read_state :: Ptr Word64 -> IO ByteString
read_state !Ptr Word64
state = Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BI.create Int
64 ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
out -> do
  h0 <- Ptr Word64 -> IO Word64
forall a. Storable a => Ptr a -> IO a
peek Ptr Word64
state                :: IO Word64
  h1 <- peek (state `plusPtr` 8)  :: IO Word64
  h2 <- peek (state `plusPtr` 16) :: IO Word64
  h3 <- peek (state `plusPtr` 24) :: IO Word64
  h4 <- peek (state `plusPtr` 32) :: IO Word64
  h5 <- peek (state `plusPtr` 40) :: IO Word64
  h6 <- peek (state `plusPtr` 48) :: IO Word64
  h7 <- peek (state `plusPtr` 56) :: IO Word64
  poke_word64be out 0 h0
  poke_word64be out 8 h1
  poke_word64be out 16 h2
  poke_word64be out 24 h3
  poke_word64be out 32 h4
  poke_word64be out 40 h5
  poke_word64be out 48 h6
  poke_word64be out 56 h7

poke_word64be :: Ptr Word8 -> Int -> Word64 -> IO ()
poke_word64be :: Ptr Word8 -> Int -> Word64 -> IO ()
poke_word64be !Ptr Word8
p !Int
off !Word64
w = do
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)       (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
56) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
48) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
40) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
32) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
24) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
5)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
16) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi (Word64
w Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`B.unsafeShiftR` Int
8) :: Word8)
  Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7)) (Word64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fi Word64
w :: Word8)