{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}

-- |
-- Module: Lightning.Protocol.BOLT9.Types
-- Copyright: (c) 2025 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Baseline types for BOLT #9 feature flags.

module Lightning.Protocol.BOLT9.Types (
    -- * Context
    Context(..)
  , isChannelContext
  , channelParity

    -- * Bit indices
  , BitIndex
  , unBitIndex
  , bitIndex

    -- * Required/optional level
  , FeatureLevel(..)

    -- * Required/optional bits
  , RequiredBit
  , unRequiredBit
  , requiredBit
  , requiredFromBitIndex

  , OptionalBit
  , unOptionalBit
  , optionalBit
  , optionalFromBitIndex

    -- * Feature vectors
  , FeatureVector
  , unFeatureVector
  , empty
  , fromByteString
  , set
  , clear
  , member
  ) where

import Control.DeepSeq (NFData)
import qualified Data.Bits as B
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word8, Word16)
import GHC.Generics (Generic)

-- Context ------------------------------------------------------------------

-- | Presentation context for feature flags.
--
-- Per BOLT #9, features are presented in different message contexts:
--
-- * 'Init' - the @init@ message
-- * 'NodeAnn' - @node_announcement@ messages
-- * 'ChanAnn' - @channel_announcement@ messages (normal)
-- * 'ChanAnnOdd' - @channel_announcement@, always odd (optional)
-- * 'ChanAnnEven' - @channel_announcement@, always even (required)
-- * 'Invoice' - BOLT 11 invoices
-- * 'Blinded' - @allowed_features@ field of a blinded path
-- * 'ChanType' - @channel_type@ field when opening channels
data Context
  = Init        -- ^ I: presented in the @init@ message
  | NodeAnn     -- ^ N: presented in @node_announcement@ messages
  | ChanAnn     -- ^ C: presented in @channel_announcement@ message
  | ChanAnnOdd  -- ^ C-: @channel_announcement@, always odd (optional)
  | ChanAnnEven -- ^ C+: @channel_announcement@, always even (required)
  | Invoice     -- ^ 9: presented in BOLT 11 invoices
  | Blinded     -- ^ B: @allowed_features@ field of a blinded path
  | ChanType    -- ^ T: @channel_type@ field when opening channels
  deriving (Context -> Context -> Bool
(Context -> Context -> Bool)
-> (Context -> Context -> Bool) -> Eq Context
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Context -> Context -> Bool
== :: Context -> Context -> Bool
$c/= :: Context -> Context -> Bool
/= :: Context -> Context -> Bool
Eq, Eq Context
Eq Context =>
(Context -> Context -> Ordering)
-> (Context -> Context -> Bool)
-> (Context -> Context -> Bool)
-> (Context -> Context -> Bool)
-> (Context -> Context -> Bool)
-> (Context -> Context -> Context)
-> (Context -> Context -> Context)
-> Ord Context
Context -> Context -> Bool
Context -> Context -> Ordering
Context -> Context -> Context
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Context -> Context -> Ordering
compare :: Context -> Context -> Ordering
$c< :: Context -> Context -> Bool
< :: Context -> Context -> Bool
$c<= :: Context -> Context -> Bool
<= :: Context -> Context -> Bool
$c> :: Context -> Context -> Bool
> :: Context -> Context -> Bool
$c>= :: Context -> Context -> Bool
>= :: Context -> Context -> Bool
$cmax :: Context -> Context -> Context
max :: Context -> Context -> Context
$cmin :: Context -> Context -> Context
min :: Context -> Context -> Context
Ord, Int -> Context -> ShowS
[Context] -> ShowS
Context -> String
(Int -> Context -> ShowS)
-> (Context -> String) -> ([Context] -> ShowS) -> Show Context
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Context -> ShowS
showsPrec :: Int -> Context -> ShowS
$cshow :: Context -> String
show :: Context -> String
$cshowList :: [Context] -> ShowS
showList :: [Context] -> ShowS
Show, (forall x. Context -> Rep Context x)
-> (forall x. Rep Context x -> Context) -> Generic Context
forall x. Rep Context x -> Context
forall x. Context -> Rep Context x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Context -> Rep Context x
from :: forall x. Context -> Rep Context x
$cto :: forall x. Rep Context x -> Context
to :: forall x. Rep Context x -> Context
Generic)

instance NFData Context

-- | Check if a context is a channel announcement context (C, C-, or C+).
isChannelContext :: Context -> Bool
isChannelContext :: Context -> Bool
isChannelContext Context
ChanAnn     = Bool
True
isChannelContext Context
ChanAnnOdd  = Bool
True
isChannelContext Context
ChanAnnEven = Bool
True
isChannelContext Context
_           = Bool
False
{-# INLINE isChannelContext #-}

-- | For channel contexts with forced parity, return 'Just' the required
-- parity: 'True' for even (C+), 'False' for odd (C-). Returns 'Nothing'
-- for contexts without forced parity.
channelParity :: Context -> Maybe Bool
channelParity :: Context -> Maybe Bool
channelParity Context
ChanAnnOdd  = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False  -- odd
channelParity Context
ChanAnnEven = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True   -- even
channelParity Context
_           = Maybe Bool
forall a. Maybe a
Nothing
{-# INLINE channelParity #-}

-- FeatureLevel -------------------------------------------------------------

-- | Whether a feature is set as required or optional.
--
-- Per BOLT #9, each feature has a pair of bits: the even bit indicates
-- required (compulsory) support, the odd bit indicates optional support.
data FeatureLevel
  = Required  -- ^ The feature is required (even bit set)
  | Optional  -- ^ The feature is optional (odd bit set)
  deriving (FeatureLevel -> FeatureLevel -> Bool
(FeatureLevel -> FeatureLevel -> Bool)
-> (FeatureLevel -> FeatureLevel -> Bool) -> Eq FeatureLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FeatureLevel -> FeatureLevel -> Bool
== :: FeatureLevel -> FeatureLevel -> Bool
$c/= :: FeatureLevel -> FeatureLevel -> Bool
/= :: FeatureLevel -> FeatureLevel -> Bool
Eq, Eq FeatureLevel
Eq FeatureLevel =>
(FeatureLevel -> FeatureLevel -> Ordering)
-> (FeatureLevel -> FeatureLevel -> Bool)
-> (FeatureLevel -> FeatureLevel -> Bool)
-> (FeatureLevel -> FeatureLevel -> Bool)
-> (FeatureLevel -> FeatureLevel -> Bool)
-> (FeatureLevel -> FeatureLevel -> FeatureLevel)
-> (FeatureLevel -> FeatureLevel -> FeatureLevel)
-> Ord FeatureLevel
FeatureLevel -> FeatureLevel -> Bool
FeatureLevel -> FeatureLevel -> Ordering
FeatureLevel -> FeatureLevel -> FeatureLevel
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FeatureLevel -> FeatureLevel -> Ordering
compare :: FeatureLevel -> FeatureLevel -> Ordering
$c< :: FeatureLevel -> FeatureLevel -> Bool
< :: FeatureLevel -> FeatureLevel -> Bool
$c<= :: FeatureLevel -> FeatureLevel -> Bool
<= :: FeatureLevel -> FeatureLevel -> Bool
$c> :: FeatureLevel -> FeatureLevel -> Bool
> :: FeatureLevel -> FeatureLevel -> Bool
$c>= :: FeatureLevel -> FeatureLevel -> Bool
>= :: FeatureLevel -> FeatureLevel -> Bool
$cmax :: FeatureLevel -> FeatureLevel -> FeatureLevel
max :: FeatureLevel -> FeatureLevel -> FeatureLevel
$cmin :: FeatureLevel -> FeatureLevel -> FeatureLevel
min :: FeatureLevel -> FeatureLevel -> FeatureLevel
Ord, Int -> FeatureLevel -> ShowS
[FeatureLevel] -> ShowS
FeatureLevel -> String
(Int -> FeatureLevel -> ShowS)
-> (FeatureLevel -> String)
-> ([FeatureLevel] -> ShowS)
-> Show FeatureLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FeatureLevel -> ShowS
showsPrec :: Int -> FeatureLevel -> ShowS
$cshow :: FeatureLevel -> String
show :: FeatureLevel -> String
$cshowList :: [FeatureLevel] -> ShowS
showList :: [FeatureLevel] -> ShowS
Show, (forall x. FeatureLevel -> Rep FeatureLevel x)
-> (forall x. Rep FeatureLevel x -> FeatureLevel)
-> Generic FeatureLevel
forall x. Rep FeatureLevel x -> FeatureLevel
forall x. FeatureLevel -> Rep FeatureLevel x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. FeatureLevel -> Rep FeatureLevel x
from :: forall x. FeatureLevel -> Rep FeatureLevel x
$cto :: forall x. Rep FeatureLevel x -> FeatureLevel
to :: forall x. Rep FeatureLevel x -> FeatureLevel
Generic)

instance NFData FeatureLevel

-- BitIndex -----------------------------------------------------------------

-- | A bit index into a feature vector. Bit 0 is the least significant bit.
--
-- Valid range: 0-65535 (sufficient for any practical feature flag).
newtype BitIndex = BitIndex { BitIndex -> Word16
unBitIndex :: Word16 }
  deriving (BitIndex -> BitIndex -> Bool
(BitIndex -> BitIndex -> Bool)
-> (BitIndex -> BitIndex -> Bool) -> Eq BitIndex
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitIndex -> BitIndex -> Bool
== :: BitIndex -> BitIndex -> Bool
$c/= :: BitIndex -> BitIndex -> Bool
/= :: BitIndex -> BitIndex -> Bool
Eq, Eq BitIndex
Eq BitIndex =>
(BitIndex -> BitIndex -> Ordering)
-> (BitIndex -> BitIndex -> Bool)
-> (BitIndex -> BitIndex -> Bool)
-> (BitIndex -> BitIndex -> Bool)
-> (BitIndex -> BitIndex -> Bool)
-> (BitIndex -> BitIndex -> BitIndex)
-> (BitIndex -> BitIndex -> BitIndex)
-> Ord BitIndex
BitIndex -> BitIndex -> Bool
BitIndex -> BitIndex -> Ordering
BitIndex -> BitIndex -> BitIndex
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: BitIndex -> BitIndex -> Ordering
compare :: BitIndex -> BitIndex -> Ordering
$c< :: BitIndex -> BitIndex -> Bool
< :: BitIndex -> BitIndex -> Bool
$c<= :: BitIndex -> BitIndex -> Bool
<= :: BitIndex -> BitIndex -> Bool
$c> :: BitIndex -> BitIndex -> Bool
> :: BitIndex -> BitIndex -> Bool
$c>= :: BitIndex -> BitIndex -> Bool
>= :: BitIndex -> BitIndex -> Bool
$cmax :: BitIndex -> BitIndex -> BitIndex
max :: BitIndex -> BitIndex -> BitIndex
$cmin :: BitIndex -> BitIndex -> BitIndex
min :: BitIndex -> BitIndex -> BitIndex
Ord, Int -> BitIndex -> ShowS
[BitIndex] -> ShowS
BitIndex -> String
(Int -> BitIndex -> ShowS)
-> (BitIndex -> String) -> ([BitIndex] -> ShowS) -> Show BitIndex
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BitIndex -> ShowS
showsPrec :: Int -> BitIndex -> ShowS
$cshow :: BitIndex -> String
show :: BitIndex -> String
$cshowList :: [BitIndex] -> ShowS
showList :: [BitIndex] -> ShowS
Show, (forall x. BitIndex -> Rep BitIndex x)
-> (forall x. Rep BitIndex x -> BitIndex) -> Generic BitIndex
forall x. Rep BitIndex x -> BitIndex
forall x. BitIndex -> Rep BitIndex x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. BitIndex -> Rep BitIndex x
from :: forall x. BitIndex -> Rep BitIndex x
$cto :: forall x. Rep BitIndex x -> BitIndex
to :: forall x. Rep BitIndex x -> BitIndex
Generic)

instance NFData BitIndex

-- | Smart constructor for 'BitIndex'. Always succeeds since all Word16
-- values are valid.
bitIndex :: Word16 -> BitIndex
bitIndex :: Word16 -> BitIndex
bitIndex = Word16 -> BitIndex
BitIndex
{-# INLINE bitIndex #-}

-- RequiredBit --------------------------------------------------------------

-- | A required (compulsory) feature bit. Required bits are always even.
newtype RequiredBit = RequiredBit { RequiredBit -> Word16
unRequiredBit :: Word16 }
  deriving (RequiredBit -> RequiredBit -> Bool
(RequiredBit -> RequiredBit -> Bool)
-> (RequiredBit -> RequiredBit -> Bool) -> Eq RequiredBit
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RequiredBit -> RequiredBit -> Bool
== :: RequiredBit -> RequiredBit -> Bool
$c/= :: RequiredBit -> RequiredBit -> Bool
/= :: RequiredBit -> RequiredBit -> Bool
Eq, Eq RequiredBit
Eq RequiredBit =>
(RequiredBit -> RequiredBit -> Ordering)
-> (RequiredBit -> RequiredBit -> Bool)
-> (RequiredBit -> RequiredBit -> Bool)
-> (RequiredBit -> RequiredBit -> Bool)
-> (RequiredBit -> RequiredBit -> Bool)
-> (RequiredBit -> RequiredBit -> RequiredBit)
-> (RequiredBit -> RequiredBit -> RequiredBit)
-> Ord RequiredBit
RequiredBit -> RequiredBit -> Bool
RequiredBit -> RequiredBit -> Ordering
RequiredBit -> RequiredBit -> RequiredBit
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: RequiredBit -> RequiredBit -> Ordering
compare :: RequiredBit -> RequiredBit -> Ordering
$c< :: RequiredBit -> RequiredBit -> Bool
< :: RequiredBit -> RequiredBit -> Bool
$c<= :: RequiredBit -> RequiredBit -> Bool
<= :: RequiredBit -> RequiredBit -> Bool
$c> :: RequiredBit -> RequiredBit -> Bool
> :: RequiredBit -> RequiredBit -> Bool
$c>= :: RequiredBit -> RequiredBit -> Bool
>= :: RequiredBit -> RequiredBit -> Bool
$cmax :: RequiredBit -> RequiredBit -> RequiredBit
max :: RequiredBit -> RequiredBit -> RequiredBit
$cmin :: RequiredBit -> RequiredBit -> RequiredBit
min :: RequiredBit -> RequiredBit -> RequiredBit
Ord, Int -> RequiredBit -> ShowS
[RequiredBit] -> ShowS
RequiredBit -> String
(Int -> RequiredBit -> ShowS)
-> (RequiredBit -> String)
-> ([RequiredBit] -> ShowS)
-> Show RequiredBit
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RequiredBit -> ShowS
showsPrec :: Int -> RequiredBit -> ShowS
$cshow :: RequiredBit -> String
show :: RequiredBit -> String
$cshowList :: [RequiredBit] -> ShowS
showList :: [RequiredBit] -> ShowS
Show, (forall x. RequiredBit -> Rep RequiredBit x)
-> (forall x. Rep RequiredBit x -> RequiredBit)
-> Generic RequiredBit
forall x. Rep RequiredBit x -> RequiredBit
forall x. RequiredBit -> Rep RequiredBit x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RequiredBit -> Rep RequiredBit x
from :: forall x. RequiredBit -> Rep RequiredBit x
$cto :: forall x. Rep RequiredBit x -> RequiredBit
to :: forall x. Rep RequiredBit x -> RequiredBit
Generic)

instance NFData RequiredBit

-- | Smart constructor for 'RequiredBit'. Returns 'Nothing' if the bit
--   index is odd.
--
--   >>> requiredBit 16
--   Just (RequiredBit {unRequiredBit = 16})
--   >>> requiredBit 17
--   Nothing
requiredBit :: Word16 -> Maybe RequiredBit
requiredBit :: Word16 -> Maybe RequiredBit
requiredBit !Word16
w
  | Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
B..&. Word16
1 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
0 = RequiredBit -> Maybe RequiredBit
forall a. a -> Maybe a
Just (Word16 -> RequiredBit
RequiredBit Word16
w)
  | Bool
otherwise      = Maybe RequiredBit
forall a. Maybe a
Nothing
{-# INLINE requiredBit #-}

-- | Convert a 'BitIndex' to a 'RequiredBit'. Returns 'Nothing' if odd.
requiredFromBitIndex :: BitIndex -> Maybe RequiredBit
requiredFromBitIndex :: BitIndex -> Maybe RequiredBit
requiredFromBitIndex (BitIndex Word16
w) = Word16 -> Maybe RequiredBit
requiredBit Word16
w
{-# INLINE requiredFromBitIndex #-}

-- OptionalBit --------------------------------------------------------------

-- | An optional feature bit. Optional bits are always odd.
newtype OptionalBit = OptionalBit { OptionalBit -> Word16
unOptionalBit :: Word16 }
  deriving (OptionalBit -> OptionalBit -> Bool
(OptionalBit -> OptionalBit -> Bool)
-> (OptionalBit -> OptionalBit -> Bool) -> Eq OptionalBit
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OptionalBit -> OptionalBit -> Bool
== :: OptionalBit -> OptionalBit -> Bool
$c/= :: OptionalBit -> OptionalBit -> Bool
/= :: OptionalBit -> OptionalBit -> Bool
Eq, Eq OptionalBit
Eq OptionalBit =>
(OptionalBit -> OptionalBit -> Ordering)
-> (OptionalBit -> OptionalBit -> Bool)
-> (OptionalBit -> OptionalBit -> Bool)
-> (OptionalBit -> OptionalBit -> Bool)
-> (OptionalBit -> OptionalBit -> Bool)
-> (OptionalBit -> OptionalBit -> OptionalBit)
-> (OptionalBit -> OptionalBit -> OptionalBit)
-> Ord OptionalBit
OptionalBit -> OptionalBit -> Bool
OptionalBit -> OptionalBit -> Ordering
OptionalBit -> OptionalBit -> OptionalBit
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: OptionalBit -> OptionalBit -> Ordering
compare :: OptionalBit -> OptionalBit -> Ordering
$c< :: OptionalBit -> OptionalBit -> Bool
< :: OptionalBit -> OptionalBit -> Bool
$c<= :: OptionalBit -> OptionalBit -> Bool
<= :: OptionalBit -> OptionalBit -> Bool
$c> :: OptionalBit -> OptionalBit -> Bool
> :: OptionalBit -> OptionalBit -> Bool
$c>= :: OptionalBit -> OptionalBit -> Bool
>= :: OptionalBit -> OptionalBit -> Bool
$cmax :: OptionalBit -> OptionalBit -> OptionalBit
max :: OptionalBit -> OptionalBit -> OptionalBit
$cmin :: OptionalBit -> OptionalBit -> OptionalBit
min :: OptionalBit -> OptionalBit -> OptionalBit
Ord, Int -> OptionalBit -> ShowS
[OptionalBit] -> ShowS
OptionalBit -> String
(Int -> OptionalBit -> ShowS)
-> (OptionalBit -> String)
-> ([OptionalBit] -> ShowS)
-> Show OptionalBit
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OptionalBit -> ShowS
showsPrec :: Int -> OptionalBit -> ShowS
$cshow :: OptionalBit -> String
show :: OptionalBit -> String
$cshowList :: [OptionalBit] -> ShowS
showList :: [OptionalBit] -> ShowS
Show, (forall x. OptionalBit -> Rep OptionalBit x)
-> (forall x. Rep OptionalBit x -> OptionalBit)
-> Generic OptionalBit
forall x. Rep OptionalBit x -> OptionalBit
forall x. OptionalBit -> Rep OptionalBit x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. OptionalBit -> Rep OptionalBit x
from :: forall x. OptionalBit -> Rep OptionalBit x
$cto :: forall x. Rep OptionalBit x -> OptionalBit
to :: forall x. Rep OptionalBit x -> OptionalBit
Generic)

instance NFData OptionalBit

-- | Smart constructor for 'OptionalBit'. Returns 'Nothing' if the bit
--   index is even.
--
--   >>> optionalBit 17
--   Just (OptionalBit {unOptionalBit = 17})
--   >>> optionalBit 16
--   Nothing
optionalBit :: Word16 -> Maybe OptionalBit
optionalBit :: Word16 -> Maybe OptionalBit
optionalBit !Word16
w
  | Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
B..&. Word16
1 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
1 = OptionalBit -> Maybe OptionalBit
forall a. a -> Maybe a
Just (Word16 -> OptionalBit
OptionalBit Word16
w)
  | Bool
otherwise      = Maybe OptionalBit
forall a. Maybe a
Nothing
{-# INLINE optionalBit #-}

-- | Convert a 'BitIndex' to an 'OptionalBit'. Returns 'Nothing' if even.
optionalFromBitIndex :: BitIndex -> Maybe OptionalBit
optionalFromBitIndex :: BitIndex -> Maybe OptionalBit
optionalFromBitIndex (BitIndex Word16
w) = Word16 -> Maybe OptionalBit
optionalBit Word16
w
{-# INLINE optionalFromBitIndex #-}

-- FeatureVector ------------------------------------------------------------

-- | A feature vector represented as a strict ByteString.
--
-- The vector is stored in big-endian byte order (most significant byte
-- first), with bits numbered from the least significant bit of the last
-- byte. Bit 0 is at position 0 of the last byte.
newtype FeatureVector = FeatureVector { FeatureVector -> ByteString
unFeatureVector :: ByteString }
  deriving (FeatureVector -> FeatureVector -> Bool
(FeatureVector -> FeatureVector -> Bool)
-> (FeatureVector -> FeatureVector -> Bool) -> Eq FeatureVector
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FeatureVector -> FeatureVector -> Bool
== :: FeatureVector -> FeatureVector -> Bool
$c/= :: FeatureVector -> FeatureVector -> Bool
/= :: FeatureVector -> FeatureVector -> Bool
Eq, Eq FeatureVector
Eq FeatureVector =>
(FeatureVector -> FeatureVector -> Ordering)
-> (FeatureVector -> FeatureVector -> Bool)
-> (FeatureVector -> FeatureVector -> Bool)
-> (FeatureVector -> FeatureVector -> Bool)
-> (FeatureVector -> FeatureVector -> Bool)
-> (FeatureVector -> FeatureVector -> FeatureVector)
-> (FeatureVector -> FeatureVector -> FeatureVector)
-> Ord FeatureVector
FeatureVector -> FeatureVector -> Bool
FeatureVector -> FeatureVector -> Ordering
FeatureVector -> FeatureVector -> FeatureVector
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FeatureVector -> FeatureVector -> Ordering
compare :: FeatureVector -> FeatureVector -> Ordering
$c< :: FeatureVector -> FeatureVector -> Bool
< :: FeatureVector -> FeatureVector -> Bool
$c<= :: FeatureVector -> FeatureVector -> Bool
<= :: FeatureVector -> FeatureVector -> Bool
$c> :: FeatureVector -> FeatureVector -> Bool
> :: FeatureVector -> FeatureVector -> Bool
$c>= :: FeatureVector -> FeatureVector -> Bool
>= :: FeatureVector -> FeatureVector -> Bool
$cmax :: FeatureVector -> FeatureVector -> FeatureVector
max :: FeatureVector -> FeatureVector -> FeatureVector
$cmin :: FeatureVector -> FeatureVector -> FeatureVector
min :: FeatureVector -> FeatureVector -> FeatureVector
Ord, Int -> FeatureVector -> ShowS
[FeatureVector] -> ShowS
FeatureVector -> String
(Int -> FeatureVector -> ShowS)
-> (FeatureVector -> String)
-> ([FeatureVector] -> ShowS)
-> Show FeatureVector
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FeatureVector -> ShowS
showsPrec :: Int -> FeatureVector -> ShowS
$cshow :: FeatureVector -> String
show :: FeatureVector -> String
$cshowList :: [FeatureVector] -> ShowS
showList :: [FeatureVector] -> ShowS
Show, (forall x. FeatureVector -> Rep FeatureVector x)
-> (forall x. Rep FeatureVector x -> FeatureVector)
-> Generic FeatureVector
forall x. Rep FeatureVector x -> FeatureVector
forall x. FeatureVector -> Rep FeatureVector x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. FeatureVector -> Rep FeatureVector x
from :: forall x. FeatureVector -> Rep FeatureVector x
$cto :: forall x. Rep FeatureVector x -> FeatureVector
to :: forall x. Rep FeatureVector x -> FeatureVector
Generic)

instance NFData FeatureVector

-- | The empty feature vector (no features set).
--
--   >>> empty
--   FeatureVector {unFeatureVector = ""}
empty :: FeatureVector
empty :: FeatureVector
empty = ByteString -> FeatureVector
FeatureVector ByteString
BS.empty
{-# INLINE empty #-}

-- | Wrap a ByteString as a FeatureVector.
fromByteString :: ByteString -> FeatureVector
fromByteString :: ByteString -> FeatureVector
fromByteString = ByteString -> FeatureVector
FeatureVector
{-# INLINE fromByteString #-}

-- | Set a bit in the feature vector.
--
--   >>> set (bitIndex 0) empty
--   FeatureVector {unFeatureVector = "\SOH"}
--   >>> set (bitIndex 8) empty
--   FeatureVector {unFeatureVector = "\SOH\NUL"}
set :: BitIndex -> FeatureVector -> FeatureVector
set :: BitIndex -> FeatureVector -> FeatureVector
set (BitIndex Word16
idx) (FeatureVector ByteString
bs) =
  let byteIdx :: Int
byteIdx    = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
      bitOffset :: Int
bitOffset  = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8
      len :: Int
len        = ByteString -> Int
BS.length ByteString
bs
      -- Number of bytes needed to hold this bit
      needed :: Int
needed     = Int
byteIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
      -- Pad with zeros if necessary (prepend to maintain big-endian)
      bs' :: ByteString
bs'        = if Int
needed Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
len
                   then Int -> Word8 -> ByteString
BS.replicate (Int
needed Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) Word8
0 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
                   else ByteString
bs
      len' :: Int
len'       = ByteString -> Int
BS.length ByteString
bs'
      -- Index from the end (big-endian: last byte has lowest bits)
      realIdx :: Int
realIdx    = Int
len' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
byteIdx
      oldByte :: Word8
oldByte    = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs' Int
realIdx
      newByte :: Word8
newByte    = Word8
oldByte Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B..|. Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
B.shiftL Word8
1 Int
bitOffset
  in  ByteString -> FeatureVector
FeatureVector (Int -> Word8 -> ByteString -> ByteString
updateByteAt Int
realIdx Word8
newByte ByteString
bs')
{-# INLINE set #-}

-- | Clear a bit in the feature vector.
clear :: BitIndex -> FeatureVector -> FeatureVector
clear :: BitIndex -> FeatureVector -> FeatureVector
clear (BitIndex Word16
idx) (FeatureVector ByteString
bs)
  | ByteString -> Bool
BS.null ByteString
bs = ByteString -> FeatureVector
FeatureVector ByteString
bs
  | Bool
otherwise  =
      let byteIdx :: Int
byteIdx   = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
          bitOffset :: Int
bitOffset = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8
          len :: Int
len       = ByteString -> Int
BS.length ByteString
bs
      in  if Int
byteIdx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
          then ByteString -> FeatureVector
FeatureVector ByteString
bs  -- bit not in range, already clear
          else
            let realIdx :: Int
realIdx = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
byteIdx
                oldByte :: Word8
oldByte = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
realIdx
                newByte :: Word8
newByte = Word8
oldByte Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B..&. Word8 -> Word8
forall a. Bits a => a -> a
B.complement (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
B.shiftL Word8
1 Int
bitOffset)
            in  ByteString -> FeatureVector
FeatureVector (ByteString -> ByteString
stripLeadingZeros (Int -> Word8 -> ByteString -> ByteString
updateByteAt Int
realIdx Word8
newByte ByteString
bs))
{-# INLINE clear #-}

-- | Test if a bit is set in the feature vector.
--
--   >>> member (bitIndex 0) (set (bitIndex 0) empty)
--   True
--   >>> member (bitIndex 1) (set (bitIndex 0) empty)
--   False
member :: BitIndex -> FeatureVector -> Bool
member :: BitIndex -> FeatureVector -> Bool
member (BitIndex Word16
idx) (FeatureVector ByteString
bs)
  | ByteString -> Bool
BS.null ByteString
bs = Bool
False
  | Bool
otherwise  =
      let byteIdx :: Int
byteIdx   = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
          bitOffset :: Int
bitOffset = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8
          len :: Int
len       = ByteString -> Int
BS.length ByteString
bs
      in  if Int
byteIdx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
          then Bool
False
          else
            let realIdx :: Int
realIdx = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
byteIdx
                byte :: Word8
byte    = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
realIdx
            in  Word8
byte Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
B..&. Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
B.shiftL Word8
1 Int
bitOffset Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0
{-# INLINE member #-}

-- Internal helpers ---------------------------------------------------------

-- | Update a single byte at the given index.
updateByteAt :: Int -> Word8 -> ByteString -> ByteString
updateByteAt :: Int -> Word8 -> ByteString -> ByteString
updateByteAt !Int
i !Word8
w !ByteString
bs =
  let (ByteString
before, ByteString
after) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
i ByteString
bs
  in  case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
after of
        Maybe (Word8, ByteString)
Nothing      -> ByteString
bs  -- shouldn't happen if i is valid
        Just (Word8
_, ByteString
rest) -> ByteString
before ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton Word8
w ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
rest
{-# INLINE updateByteAt #-}

-- | Remove leading zero bytes from a ByteString.
stripLeadingZeros :: ByteString -> ByteString
stripLeadingZeros :: ByteString -> ByteString
stripLeadingZeros = (Word8 -> Bool) -> ByteString -> ByteString
BS.dropWhile (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0)
{-# INLINE stripLeadingZeros #-}