{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UnliftedNewtypes #-}

-- |
-- Module: Data.Word.Limb
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- The primitive 'Limb' type, as well as operations on it.

module Data.Word.Limb (
  -- * Limb
    Limb(..)
  , render

  -- * Bit manipulation and representation
  , and#
  , or#
  , not#
  , xor#
  , bits#
  , shl#
  , shl1#
  , shr#
  , shr1#

  -- * Comparison
  , eq#
  , ne#
  , eq_vartime#
  , ne_vartime#
  , nonzero#
  , lt#
  , gt#

  -- * Selection
  , select#
  , cswap#

  -- * Negation

  , neg#

  -- * Arithmetic
  , add_o#
  , add_c#
  , add_w#
  , add_s#

  , sub_b#
  , sub_w#
  , sub_s#

  , mul_c#
  , mul_w#
  , mul_s#

  , mac#
  ) where

import qualified Data.Bits as B
import qualified Data.Choice as C
import GHC.Exts (Word#)
import qualified GHC.Exts as Exts

-- | A 'Limb' is the smallest component of a wider word.
newtype Limb = Limb Word#

-- | Return a 'Limb' value as a 'String'.
render :: Limb -> String
render :: Limb -> String
render (Limb Word#
a) = Word -> String
forall a. Show a => a -> String
show (Word# -> Word
Exts.W# Word#
a)

-- comparison -----------------------------------------------------------------

-- | Equality comparison.
eq#
  :: Limb
  -> Limb
  -> C.Choice
eq# :: Limb -> Limb -> Choice
eq# (Limb Word#
a) (Limb Word#
b) = Word# -> Word# -> Choice
C.eq_word# Word#
a Word#
b
{-# INLINE eq# #-}

eq_vartime#
  :: Limb
  -> Limb
  -> Bool
eq_vartime# :: Limb -> Limb -> Bool
eq_vartime# (Limb Word#
a) (Limb Word#
b) = Int# -> Bool
Exts.isTrue# (Word# -> Word# -> Int#
Exts.eqWord# Word#
a Word#
b)
{-# INLINE eq_vartime# #-}

-- | Inequality comparison.
ne#
  :: Limb
  -> Limb
  -> C.Choice
ne# :: Limb -> Limb -> Choice
ne# Limb
a Limb
b = Choice -> Choice
C.not# (Limb -> Limb -> Choice
eq# Limb
a Limb
b)
{-# INLINE ne# #-}

ne_vartime#
  :: Limb
  -> Limb
  -> Bool
ne_vartime# :: Limb -> Limb -> Bool
ne_vartime# Limb
a Limb
b = Bool -> Bool
not (Limb -> Limb -> Bool
eq_vartime# Limb
a Limb
b)
{-# INLINE ne_vartime# #-}

-- | Comparison to zero.
nonzero#
  :: Limb
  -> C.Choice
nonzero# :: Limb -> Choice
nonzero# (Limb Word#
a) = Word# -> Choice
C.from_word_nonzero# Word#
a
{-# INLINE nonzero# #-}

-- | Less than.
lt#
  :: Limb
  -> Limb
  -> C.Choice
lt# :: Limb -> Limb -> Choice
lt# (Limb Word#
a) (Limb Word#
b) = Word# -> Word# -> Choice
C.from_word_lt# Word#
a Word#
b
{-# INLINE lt# #-}

-- | Greater than.
gt#
  :: Limb
  -> Limb
  -> C.Choice
gt# :: Limb -> Limb -> Choice
gt# (Limb Word#
a) (Limb Word#
b) = Word# -> Word# -> Choice
C.from_word_gt# Word#
a Word#
b
{-# INLINE gt# #-}

-- selection ------------------------------------------------------------------

-- | Return a if c is truthy, otherwise return b.
select#
  :: Limb     -- ^ a
  -> Limb     -- ^ b
  -> C.Choice -- ^ c
  -> Limb     -- ^ result
select# :: Limb -> Limb -> Choice -> Limb
select# (Limb Word#
a) (Limb Word#
b) Choice
c = Word# -> Limb
Limb (Word# -> Word# -> Choice -> Word#
C.select_word# Word#
a Word#
b Choice
c)
{-# INLINE select# #-}

-- | Return (# b, a #) if c is truthy, otherwise return (# a, b #).
cswap#
  :: Limb             -- ^ a
  -> Limb             -- ^ b
  -> C.Choice         -- ^ c
  -> (# Limb, Limb #) -- ^ result
cswap# :: Limb -> Limb -> Choice -> (# Limb, Limb #)
cswap# (Limb Word#
a) (Limb Word#
b) Choice
c =
  let !l :: Word#
l = Word# -> Word# -> Choice -> Word#
C.select_word# Word#
a Word#
b Choice
c
      !r :: Word#
r = Word# -> Word# -> Choice -> Word#
C.select_word# Word#
b Word#
a Choice
c
  in  (# Word# -> Limb
Limb Word#
l, Word# -> Limb
Limb Word#
r #)
{-# INLINE cswap# #-}

-- bit manipulation -----------------------------------------------------------

-- | Bitwise and.
and#
  :: Limb -- ^ a
  -> Limb -- ^ b
  -> Limb -- ^ a & b
and# :: Limb -> Limb -> Limb
and# (Limb Word#
a) (Limb Word#
b) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.and# Word#
a Word#
b)
{-# INLINE and# #-}

-- | Bitwise or.
or#
  :: Limb -- ^ a
  -> Limb -- ^ b
  -> Limb -- ^ a | b
or# :: Limb -> Limb -> Limb
or# (Limb Word#
a) (Limb Word#
b) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.or# Word#
a Word#
b)
{-# INLINE or# #-}

-- | Bitwise not.
not#
  :: Limb -- ^ a
  -> Limb -- ^ not a
not# :: Limb -> Limb
not# (Limb Word#
a) = Word# -> Limb
Limb (Word# -> Word#
Exts.not# Word#
a)
{-# INLINE not# #-}

-- | Bitwise exclusive or.
xor#
  :: Limb -- ^ a
  -> Limb -- ^ b
  -> Limb -- ^ a ^ b
xor# :: Limb -> Limb -> Limb
xor# (Limb Word#
a) (Limb Word#
b) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.xor# Word#
a Word#
b)
{-# INLINE xor# #-}

-- | Number of bits required to represent this limb.
bits#
  :: Limb -- ^ limb
  -> Int  -- ^ bits required to represent limb
bits# :: Limb -> Int
bits# (Limb Word#
a) =
  let !_BITS :: Int
_BITS = Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word)
      !zs :: Int
zs = Word -> Int
forall b. FiniteBits b => b -> Int
B.countLeadingZeros (Word# -> Word
Exts.W# Word#
a)
  in  Int
_BITS Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zs -- XX unbox?
{-# INLINE bits# #-}

-- | Bit-shift left.
shl#
  :: Limb       -- ^ limb
  -> Exts.Int#  -- ^ shift amount
  -> Limb       -- ^ result
shl# :: Limb -> Int# -> Limb
shl# (Limb Word#
w) Int#
s = Word# -> Limb
Limb (Word# -> Int# -> Word#
Exts.uncheckedShiftL# Word#
w Int#
s)
{-# INLINE shl# #-}

-- | Bit-shift left by 1, returning the result and carry.
shl1#
  :: Limb
  -> (# Limb, Limb #)
shl1# :: Limb -> (# Limb, Limb #)
shl1# (Limb Word#
w) =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of Exts.I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !r :: Word#
r = Word# -> Int# -> Word#
Exts.uncheckedShiftL# Word#
w Int#
1#
      !c :: Word#
c = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# Word#
w Int#
s
  in  (# Word# -> Limb
Limb Word#
r, Word# -> Limb
Limb Word#
c #)
{-# INLINE shl1# #-}

-- | Bit-shift right.
shr#
  :: Limb       -- ^ limb
  -> Exts.Int#  -- ^ shift amount
  -> Limb       -- ^ result
shr# :: Limb -> Int# -> Limb
shr# (Limb Word#
w) Int#
s = Word# -> Limb
Limb (Word# -> Int# -> Word#
Exts.uncheckedShiftRL# Word#
w Int#
s)
{-# INLINE shr# #-}

-- | Bit-shift right by 1, returning the result and carry.
shr1#
  :: Limb
  -> (# Limb, Limb #)
shr1# :: Limb -> (# Limb, Limb #)
shr1# (Limb Word#
w) =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of Exts.I# Int#
m -> Int#
m Int# -> Int# -> Int#
Exts.-# Int#
1#
      !r :: Word#
r = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# Word#
w Int#
1#
      !c :: Word#
c = Word# -> Int# -> Word#
Exts.uncheckedShiftL# Word#
w Int#
s
  in  (# Word# -> Limb
Limb Word#
r, Word# -> Limb
Limb Word#
c #)
{-# INLINE shr1# #-}

-- negation -------------------------------------------------------------------

-- | Wrapping (two's complement) negation.
neg#
  :: Limb
  -> Limb
neg# :: Limb -> Limb
neg# (Limb Word#
x) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.plusWord# (Word# -> Word#
Exts.not# Word#
x) Word#
1##)
{-# INLINE neg# #-}

-- addition -------------------------------------------------------------------

-- | Overflowing addition, computing augend + addend, returning the
--   sum and carry.
add_o#
  :: Limb             -- ^ augend
  -> Limb             -- ^ addend
  -> (# Limb, Limb #) -- ^ (# sum, carry #)
add_o# :: Limb -> Limb -> (# Limb, Limb #)
add_o# (Limb Word#
a) (Limb Word#
b) = case Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
a Word#
b of
  (# Word#
c, Word#
s #) -> (# Word# -> Limb
Limb Word#
s, Word# -> Limb
Limb Word#
c #)
{-# INLINE add_o# #-}

-- | Carrying addition, computing augend + addend + carry, returning
--   the sum and new carry.
add_c#
  :: Limb             -- ^ augend
  -> Limb             -- ^ addend
  -> Limb             -- ^ carry
  -> (# Limb, Limb #) -- ^ (# sum, new carry #)
add_c# :: Limb -> Limb -> Limb -> (# Limb, Limb #)
add_c# (Limb Word#
a) (Limb Word#
b) (Limb Word#
c) =
  let !(# Word#
c0, Word#
s0 #) = Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
a Word#
b
      !(# Word#
c1,  Word#
s #) = Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
s0 Word#
c
  in  (# Word# -> Limb
Limb Word#
s, Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.or# Word#
c0 Word#
c1) #)
{-# INLINE add_c# #-}

-- | Wrapping addition, computing augend + addend, returning the sum
--   (discarding overflow).
add_w#
  :: Limb -- ^ augend
  -> Limb -- ^ addend
  -> Limb -- ^ sum
add_w# :: Limb -> Limb -> Limb
add_w# (Limb Word#
a) (Limb Word#
b) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.plusWord# Word#
a Word#
b)
{-# INLINE add_w# #-}

-- | Saturating addition, computing augend + addend, returning the
--   sum (clamping to the maximum representable value in the case of
--   overflow).
add_s#
  :: Limb
  -> Limb
  -> Limb
add_s# :: Limb -> Limb -> Limb
add_s# (Limb Word#
a) (Limb Word#
b) = case Word# -> Word# -> (# Word#, Int# #)
Exts.addWordC# Word#
a Word#
b of
  (# Word#
s, Int#
0# #) -> Word# -> Limb
Limb Word#
s
  (# Word#, Int# #)
_ -> case Word
forall a. Bounded a => a
maxBound :: Word of
    Exts.W# Word#
m -> Word# -> Limb
Limb Word#
m
{-# INLINE add_s# #-}

-- subtraction ----------------------------------------------------------------

-- | Borrowing subtraction, computing minuend - (subtrahend + borrow),
--   returning the difference and new borrow mask.
sub_b#
  :: Limb              -- ^ minuend
  -> Limb              -- ^ subtrahend
  -> Limb              -- ^ borrow
  -> (# Limb, Limb #)  -- ^ (# difference, new borrow #)
sub_b# :: Limb -> Limb -> Limb -> (# Limb, Limb #)
sub_b# (Limb Word#
m) (Limb Word#
n) (Limb Word#
a) =
  let !s :: Int#
s = case Word -> Int
forall b. FiniteBits b => b -> Int
B.finiteBitSize (Word
0 :: Word) of Exts.I# Int#
bs -> Int#
bs Int# -> Int# -> Int#
Exts.-# Int#
1#
      !b :: Word#
b = Word# -> Int# -> Word#
Exts.uncheckedShiftRL# Word#
a Int#
s
      !(# Word#
d0, Int#
b0 #) = Word# -> Word# -> (# Word#, Int# #)
Exts.subWordC# Word#
m Word#
n
      !(#  Word#
d, Int#
b1 #) = Word# -> Word# -> (# Word#, Int# #)
Exts.subWordC# Word#
d0 Word#
b
      !c :: Word#
c = Int# -> Word#
Exts.int2Word# (Int# -> Int#
Exts.negateInt# (Int# -> Int# -> Int#
Exts.orI# Int#
b0 Int#
b1))
  in  (# Word# -> Limb
Limb Word#
d, Word# -> Limb
Limb Word#
c #)
{-# INLINE sub_b# #-}

-- | Saturating subtraction, computing minuend - subtrahend, returning the
--   difference (and clamping to zero in the case of underflow).
sub_s#
  :: Limb -- ^ minuend
  -> Limb -- ^ subtrahend
  -> Limb -- ^ difference
sub_s# :: Limb -> Limb -> Limb
sub_s# (Limb Word#
m) (Limb Word#
n) = case Word# -> Word# -> (# Word#, Int# #)
Exts.subWordC# Word#
m Word#
n of
  (# Word#
d, Int#
0# #) -> Word# -> Limb
Limb Word#
d
  (# Word#, Int# #)
_ -> Word# -> Limb
Limb Word#
0##
{-# INLINE sub_s# #-}

-- | Wrapping subtraction, computing minuend - subtrahend, returning the
--   difference (and discarding underflow).
sub_w#
  :: Limb -- ^ minuend
  -> Limb -- ^ subtrahend
  -> Limb -- ^ difference
sub_w# :: Limb -> Limb -> Limb
sub_w# (Limb Word#
m) (Limb Word#
n) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.minusWord# Word#
m Word#
n)
{-# INLINE sub_w# #-}

-- multiplication -------------------------------------------------------------

-- | Widening multiplication, returning low and high words of the product.
mul_c#
  :: Limb             -- ^ multiplicand
  -> Limb             -- ^ multiplier
  -> (# Limb, Limb #) -- ^ (# low, high #) product
mul_c# :: Limb -> Limb -> (# Limb, Limb #)
mul_c# (Limb Word#
a) (Limb Word#
b) =
  let !(# Word#
h, Word#
l #) = Word# -> Word# -> (# Word#, Word# #)
Exts.timesWord2# Word#
a Word#
b
  in  (# Word# -> Limb
Limb Word#
l, Word# -> Limb
Limb Word#
h #)
{-# INLINE mul_c# #-}

-- | Wrapping multiplication, returning only the low word of the product.
mul_w#
  :: Limb -- ^ multiplicand
  -> Limb -- ^ multiplier
  -> Limb -- ^ low word of product
mul_w# :: Limb -> Limb -> Limb
mul_w# (Limb Word#
a) (Limb Word#
b) = Word# -> Limb
Limb (Word# -> Word# -> Word#
Exts.timesWord# Word#
a Word#
b)
{-# INLINE mul_w# #-}

-- | Saturating multiplication, returning only the low word of the product,
--   and clamping to the maximum value in the case of overflow.
mul_s#
  :: Limb -- ^ multiplicand
  -> Limb -- ^ multiplier
  -> Limb -- ^ clamped low word of product
mul_s# :: Limb -> Limb -> Limb
mul_s# (Limb Word#
a) (Limb Word#
b) = case Word# -> Word# -> (# Word#, Word# #)
Exts.timesWord2# Word#
a Word#
b of
  (# Word#
0##, Word#
l #) -> Word# -> Limb
Limb Word#
l
  (# Word#, Word# #)
_ -> Word# -> Limb
Limb (Word# -> Word#
Exts.not# Word#
0##)
{-# INLINE mul_s# #-}

-- | Multiply-add-carry, computing a * b + m + c, returning the
--   result along with the new carry.
mac#
  :: Limb              -- ^ a (multiplicand)
  -> Limb              -- ^ b (multiplier)
  -> Limb              -- ^ m (addend)
  -> Limb              -- ^ c (carry)
  -> (# Limb, Limb #)  -- ^ a * b + m + c
mac# :: Limb -> Limb -> Limb -> Limb -> (# Limb, Limb #)
mac# (Limb Word#
a) (Limb Word#
b) (Limb Word#
m) (Limb Word#
c) =
    let !(# Word#
h, Word#
l #) = Word# -> Word# -> (# Word#, Word# #)
Exts.timesWord2# Word#
a Word#
b
        !(# Word#
l_0, Word#
h_0 #) = (# Word#, Word# #) -> Word# -> (# Word#, Word# #)
wadd_w# (# Word#
l, Word#
h #) Word#
m
        !(# Word#
d, Word#
l_1 #) = Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
l_0 Word#
c
        !h_1 :: Word#
h_1 = Word# -> Word# -> Word#
Exts.plusWord# Word#
h_0 Word#
d
    in  (# Word# -> Limb
Limb Word#
l_1, Word# -> Limb
Limb Word#
h_1 #)
  where
    -- wide wrapping addition
    wadd_w# :: (# Word#, Word# #) -> Word# -> (# Word#, Word# #)
    wadd_w# :: (# Word#, Word# #) -> Word# -> (# Word#, Word# #)
wadd_w# (# Word#
x_lo, Word#
x_hi #) Word#
y_lo =
      let !(# Word#
c0, Word#
s0 #) = Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
x_lo Word#
y_lo
          !(# Word#
_, Word#
s1 #) = Word# -> Word# -> (# Word#, Word# #)
Exts.plusWord2# Word#
x_hi Word#
c0
      in  (# Word#
s0, Word#
s1 #)
    {-# INLINE wadd_w# #-}
{-# INLINE mac# #-}