{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedNewtypes #-}

-- |
-- Module: Data.Word.Wide
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Wide words, consisting of two 'Limb's.

module Data.Word.Wide (
  -- * Wide Words
    Wide(..)

  -- * Construction, Conversion
  , wide
  , to
  , from

  -- * Bit Manipulation
  , or
  , or#
  , and
  , and#
  , xor
  , xor#
  , not
  , not#

  -- * Comparison
  , eq_vartime

  -- * Arithmetic
  , add
  , add_o
  , sub
  , mul
  , neg

  -- * Unboxed Arithmetic
  , add_o#
  , add_w#
  , sub_b#
  , sub_w#
  , mul_w#
  , neg#
  ) where

import Control.DeepSeq
import Data.Bits ((.|.), (.&.), (.<<.), (.>>.))
import qualified Data.Bits as B
import Data.Word.Limb (Limb(..))
import qualified Data.Word.Limb as L
import GHC.Exts
import Prelude hiding (div, mod, or, and, not, quot, rem, recip)

-- utilities ------------------------------------------------------------------

fi :: (Integral a, Num b) => a -> b
fi :: forall a b. (Integral a, Num b) => a -> b
fi = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE fi #-}

-- wide words -----------------------------------------------------------------

-- | Little-endian wide words.
data Wide = Wide !(# Limb, Limb #)

instance Show Wide where
  show :: Wide -> String
show = Integer -> String
forall a. Show a => a -> String
show (Integer -> String) -> (Wide -> Integer) -> Wide -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Wide -> Integer
from

instance Num Wide where
  + :: Wide -> Wide -> Wide
(+) = Wide -> Wide -> Wide
add
  (-) = Wide -> Wide -> Wide
sub
  * :: Wide -> Wide -> Wide
(*) = Wide -> Wide -> Wide
mul
  abs :: Wide -> Wide
abs = Wide -> Wide
forall a. a -> a
id
  fromInteger :: Integer -> Wide
fromInteger = Integer -> Wide
to
  negate :: Wide -> Wide
negate = Wide -> Wide
neg
  signum :: Wide -> Wide
signum Wide
a = case Wide
a of
    Wide (# Limb Word#
0##, Limb Word#
0## #) -> Wide
0
    Wide
_ -> Wide
1

instance NFData Wide where
  rnf :: Wide -> ()
rnf (Wide (# Limb, Limb #)
a) = case (# Limb, Limb #)
a of (# Limb
_, Limb
_ #) -> ()

-- construction / conversion --------------------------------------------------

-- | Construct a 'Wide' word from low and high 'Word's.
wide :: Word -> Word -> Wide
wide :: Word -> Word -> Wide
wide (W# Word#
l) (W# Word#
h) = (# Limb, Limb #) -> Wide
Wide (# Word# -> Limb
Limb Word#
l, Word# -> Limb
Limb Word#
h #)

-- | Convert an 'Integer' to a 'Wide' word.
to :: Integer -> Wide
to :: Integer -> Wide
to Integer
n =
  let !size :: Int
size = Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word)
      !mask :: Integer
mask = Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fi (Word
forall a. Bounded a => a
maxBound :: Word) :: Integer
      !(W# Word#
w0) = Integer -> Word
forall a b. (Integral a, Num b) => a -> b
fi (Integer
n Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
mask)
      !(W# Word#
w1) = Integer -> Word
forall a b. (Integral a, Num b) => a -> b
fi ((Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
.>>. Int
size) Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
mask)
  in  (# Limb, Limb #) -> Wide
Wide (# Word# -> Limb
Limb Word#
w0, Word# -> Limb
Limb Word#
w1 #)

-- | Convert a 'Wide' word to an 'Integer'.
from :: Wide -> Integer
from :: Wide -> Integer
from (Wide (# Limb Word#
a, Limb Word#
b #)) =
      Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fi (Word# -> Word
W# Word#
b) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
.<<. (Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word))
  Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.|. Word -> Integer
forall a b. (Integral a, Num b) => a -> b
fi (Word# -> Word
W# Word#
a)

-- comparison -----------------------------------------------------------------

-- | Compare 'Wide' words for equality in variable time.
eq_vartime :: Wide -> Wide -> Bool
eq_vartime :: Wide -> Wide -> Bool
eq_vartime (Wide (# Limb Word#
a0, Limb Word#
b0 #)) (Wide (# Limb Word#
a1, Limb Word#
b1 #)) =
  Int# -> Bool
isTrue# (Int# -> Int# -> Int#
andI# (Word# -> Word# -> Int#
eqWord# Word#
a0 Word#
a1) (Word# -> Word# -> Int#
eqWord# Word#
b0 Word#
b1))

-- bits -----------------------------------------------------------------------

or_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
or_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
or_w# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) = (# Limb -> Limb -> Limb
L.or# Limb
a0 Limb
b0, Limb -> Limb -> Limb
L.or# Limb
a1 Limb
b1 #)
{-# INLINE or_w# #-}

or :: Wide -> Wide -> Wide
or :: Wide -> Wide -> Wide
or (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
or_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)

and_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
and_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
and_w# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) = (# Limb -> Limb -> Limb
L.and# Limb
a0 Limb
b0, Limb -> Limb -> Limb
L.and# Limb
a1 Limb
b1 #)
{-# INLINE and_w# #-}

and :: Wide -> Wide -> Wide
and :: Wide -> Wide -> Wide
and (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
and_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)

xor_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
xor_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
xor_w# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) = (# Limb -> Limb -> Limb
L.xor# Limb
a0 Limb
b0, Limb -> Limb -> Limb
L.xor# Limb
a1 Limb
b1 #)
{-# INLINE xor_w# #-}

xor :: Wide -> Wide -> Wide
xor :: Wide -> Wide -> Wide
xor (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
xor_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)

not_w# :: (# Limb, Limb #) -> (# Limb, Limb #)
not_w# :: (# Limb, Limb #) -> (# Limb, Limb #)
not_w# (# Limb
a0, Limb
a1 #) = (# Limb -> Limb
L.not# Limb
a0, Limb -> Limb
L.not# Limb
a1 #)
{-# INLINE not_w# #-}

not :: Wide -> Wide
not :: Wide -> Wide
not (Wide (# Limb, Limb #)
w) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #)
not_w# (# Limb, Limb #)
w)
{-# INLINE not #-}

-- negation -------------------------------------------------------------------

neg#
  :: (# Limb, Limb #) -- ^ argument
  -> (# Limb, Limb #) -- ^ (wrapping) additive inverse
neg# :: (# Limb, Limb #) -> (# Limb, Limb #)
neg# (# Limb, Limb #)
w = (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
add_w# ((# Limb, Limb #) -> (# Limb, Limb #)
not_w# (# Limb, Limb #)
w) (# Word# -> Limb
Limb Word#
1##, Word# -> Limb
Limb Word#
0## #)
{-# INLINE neg# #-}

neg
  :: Wide -- ^ argument
  -> Wide -- ^ (wrapping) additive inverse
neg :: Wide -> Wide
neg (Wide (# Limb, Limb #)
w) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #)
neg# (# Limb, Limb #)
w)

-- addition, subtraction ------------------------------------------------------

-- | Overflowing addition, computing 'a + b', returning the sum and a
--   carry bit.
add_o#
  :: (# Limb, Limb #)              -- ^ augend
  -> (# Limb, Limb #)              -- ^ addend
  -> (# (# Limb, Limb #), Limb #)  -- ^ (# sum, carry bit #)
add_o# :: (# Limb, Limb #)
-> (# Limb, Limb #) -> (# (# Limb, Limb #), Limb #)
add_o# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) =
  let !(# Limb
s0, Limb
c0 #) = Limb -> Limb -> (# Limb, Limb #)
L.add_o# Limb
a0 Limb
b0
      !(# Limb
s1, Limb
c1 #) = Limb -> Limb -> Limb -> (# Limb, Limb #)
L.add_c# Limb
a1 Limb
b1 Limb
c0
  in  (# (# Limb
s0, Limb
s1 #), Limb
c1 #)
{-# INLINE add_o# #-}

-- | Overflowing addition on 'Wide' words, computing 'a + b', returning
--   the sum and carry.
add_o
  :: Wide         -- ^ augend
  -> Wide         -- ^ addend
  -> (Wide, Word) -- ^ (sum, carry)
add_o :: Wide -> Wide -> (Wide, Word)
add_o (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) =
  let !(# (# Limb, Limb #)
s, Limb Word#
c #) = (# Limb, Limb #)
-> (# Limb, Limb #) -> (# (# Limb, Limb #), Limb #)
add_o# (# Limb, Limb #)
a (# Limb, Limb #)
b
  in  ((# Limb, Limb #) -> Wide
Wide (# Limb, Limb #)
s, Word# -> Word
W# Word#
c)

-- | Wrapping addition, computing 'a + b'.
add_w#
  :: (# Limb, Limb #) -- ^ augend
  -> (# Limb, Limb #) -- ^ addend
  -> (# Limb, Limb #) -- ^ sum
add_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
add_w# (# Limb, Limb #)
a (# Limb, Limb #)
b =
  let !(# (# Limb, Limb #)
c, Limb
_ #) = (# Limb, Limb #)
-> (# Limb, Limb #) -> (# (# Limb, Limb #), Limb #)
add_o# (# Limb, Limb #)
a (# Limb, Limb #)
b
  in  (# Limb, Limb #)
c
{-# INLINE add_w# #-}

-- | Wrapping addition on 'Wide' words, computing 'a + b'.
add :: Wide -> Wide -> Wide
add :: Wide -> Wide -> Wide
add (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
add_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)

-- | Borrowing subtraction, computing 'a - b' and returning the
--   difference with a borrow mask.
sub_b#
  :: (# Limb, Limb #)              -- ^ minuend
  -> (# Limb, Limb #)              -- ^ subtrahend
  -> (# (# Limb, Limb #), Limb #) -- ^ (# difference, borrow mask #)
sub_b# :: (# Limb, Limb #)
-> (# Limb, Limb #) -> (# (# Limb, Limb #), Limb #)
sub_b# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) =
  let !(# Limb
s0, Limb
c0 #) = Limb -> Limb -> Limb -> (# Limb, Limb #)
L.sub_b# Limb
a0 Limb
b0 (Word# -> Limb
Limb Word#
0##)
      !(# Limb
s1, Limb
c1 #) = Limb -> Limb -> Limb -> (# Limb, Limb #)
L.sub_b# Limb
a1 Limb
b1 Limb
c0
  in  (# (# Limb
s0, Limb
s1 #), Limb
c1 #)
{-# INLINE sub_b# #-}

-- | Wrapping subtraction, computing 'a - b'.
sub_w#
  :: (# Limb, Limb #) -- ^ minuend
  -> (# Limb, Limb #) -- ^ subtrahend
  -> (# Limb, Limb #) -- ^ difference
sub_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
sub_w# (# Limb, Limb #)
a (# Limb, Limb #)
b =
  let !(# (# Limb, Limb #)
c, Limb
_ #) = (# Limb, Limb #)
-> (# Limb, Limb #) -> (# (# Limb, Limb #), Limb #)
sub_b# (# Limb, Limb #)
a (# Limb, Limb #)
b
  in  (# Limb, Limb #)
c
{-# INLINE sub_w# #-}

-- | Wrapping subtraction on 'Wide' words, computing 'a - b'.
sub :: Wide -> Wide -> Wide
sub :: Wide -> Wide -> Wide
sub (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
sub_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)

-- multiplication -------------------------------------------------------------

-- | Wrapping multiplication, computing 'a b'.
mul_w#
  :: (# Limb, Limb #) -- ^ multiplicand
  -> (# Limb, Limb #) -- ^ multiplier
  -> (# Limb, Limb #) -- ^ product
mul_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
mul_w# (# Limb
a0, Limb
a1 #) (# Limb
b0, Limb
b1 #) =
  let !(# Limb
p0_lo, Limb
p0_hi #) = Limb -> Limb -> (# Limb, Limb #)
L.mul_c# Limb
a0 Limb
b0
      !(# Limb
p1_lo, Limb
_ #) = Limb -> Limb -> (# Limb, Limb #)
L.mul_c# Limb
a0 Limb
b1
      !(# Limb
p2_lo, Limb
_ #) = Limb -> Limb -> (# Limb, Limb #)
L.mul_c# Limb
a1 Limb
b0
      !(# Limb
s0, Limb
_ #) = Limb -> Limb -> (# Limb, Limb #)
L.add_o# Limb
p0_hi Limb
p1_lo
      !(# Limb
s1, Limb
_ #) = Limb -> Limb -> (# Limb, Limb #)
L.add_o# Limb
s0 Limb
p2_lo
  in  (# Limb
p0_lo, Limb
s1 #)
{-# INLINE mul_w# #-}

-- | Wrapping multiplication on 'Wide' words.
mul :: Wide -> Wide -> Wide
mul :: Wide -> Wide -> Wide
mul (Wide (# Limb, Limb #)
a) (Wide (# Limb, Limb #)
b) = (# Limb, Limb #) -> Wide
Wide ((# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
mul_w# (# Limb, Limb #)
a (# Limb, Limb #)
b)