{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module: Numeric.Eproc.Bounded
-- Copyright: (c) 2026 Jared Tobin
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
-- Two-sided bounded-mean anytime-valid test.
--
-- For samples @x_t@ in @[lo, hi]@, tests @H_0: E[x] = m@ against
-- @H_1: E[x] /= m@.
--
-- Internally two one-sided e-processes are run in parallel: a
-- /positive-direction/ process betting against the alternative
-- @E[x] > m@ (using centred observations @z = x - m@), and a
-- /negative-direction/ process betting against @E[x] < m@ (using
-- @-z@). Each maintains its own log-wealth and bettor state. The
-- test rejects when either side's wealth crosses @2 \/ alpha@; the
-- factor of 2 is the Bonferroni adjustment for the two-sided union.
--
-- The test is /anytime-valid/: under @H_0@ the wealth process is a
-- nonnegative supermartingale, so by Ville's inequality the
-- probability of ever crossing the threshold is at most @alpha@,
-- regardless of when the user decides to stop streaming samples.
--
-- == Example
--
-- Test @H_0: E[x] = 0.5@ for @x@ in @[0, 1]@ at level @alpha = 1e-3@
-- against a stream with empirical mean @0.8@:
--
-- >>> let cfg = config 0.5 0.0 1.0 1.0e-3 Newton
-- >>> let xs  = concat (replicate 30 [1, 1, 0, 1, 1, 0, 1, 1, 1, 1])
-- >>> decide cfg (foldl' (update cfg) (initial cfg) xs)
-- Reject

module Numeric.Eproc.Bounded (
  -- * Test configuration and state
    Config
  , State
  , Verdict(..)

  -- * Bettor strategies
  , Bettor(..)

  -- * Construction
  , config
  , initial

  -- * Streaming
  , update
  , decide

  -- * Inspection
  , log_wealth
  , samples
  ) where

import GHC.Exts (Double(D#))
import Numeric.Eproc.Common (Bettor(..), Verdict(..))

-- types ----------------------------------------------------------------------

-- here, the centred observation @z_t@ referenced in
-- "Numeric.Eproc.Common" is @x_t - m@; the per-direction safe-bet
-- ceilings @lambda_max@ are derived from the sample bounds (see
-- 'config').

-- per-direction bettor state. one constructor per 'Bettor' alternative;
-- the constructor used in a given 'State' matches the 'Bettor' chosen
-- in the enclosing 'Config'.
data BetState =
    SFixed
  | SAdaptive
      {-# UNPACK #-} !Double  -- sum of z (centred observation)
      {-# UNPACK #-} !Double  -- sum of z^2 (for online variance)
      {-# UNPACK #-} !Int     -- count
  | SNewton
      {-# UNPACK #-} !Double  -- current bet lambda
      {-# UNPACK #-} !Double  -- running sum of per-step squared gradients

-- | Bounded-mean test configuration. Build with 'config'.
--
--   Carries the bettor strategy, the null mean, the significance
--   level, the precomputed Bonferroni-adjusted log-wealth threshold,
--   and the per-direction safe-bet ceilings (see 'config' for how
--   the latter are derived from the sample bounds).
data Config = Config {
    -- ^ bettor strategy
    Config -> Bettor
cfg_bettor      :: !Bettor
    -- ^ positive-direction safe-bet ceiling
  , Config -> Double
cfg_lam_max_pos :: {-# UNPACK #-} !Double
    -- ^ negative-direction safe-bet ceiling
  , Config -> Double
cfg_lam_max_neg :: {-# UNPACK #-} !Double
    -- ^ null mean @m@
  , Config -> Double
cfg_null_mean   :: {-# UNPACK #-} !Double
    -- ^ significance level @alpha@
  , Config -> Double
cfg_alpha       :: {-# UNPACK #-} !Double
    -- ^ rejection threshold @log(2 \/ alpha)@
  , Config -> Double
cfg_log_thresh  :: {-# UNPACK #-} !Double
  }

-- | Streaming test state. Construct with 'initial' and fold
--   observations through 'update'.
--
--   The two log-wealth fields track the running log-wealth of the
--   positive- and negative-direction e-processes separately;
--   'decide' compares each to the threshold and 'log_wealth' returns
--   the larger of the two. The per-direction bettor states carry
--   whatever the chosen 'Bettor' needs (running sums, current bet,
--   etc.).
data State = State {
    State -> Int
st_n         :: {-# UNPACK #-} !Int       -- ^ sample count
  , State -> Double
st_log_w_pos :: {-# UNPACK #-} !Double    -- ^ log-wealth, pos-dir process
  , State -> Double
st_log_w_neg :: {-# UNPACK #-} !Double    -- ^ log-wealth, neg-dir process
  , State -> BetState
st_bet_pos   :: !BetState                 -- ^ bettor state, pos-direction
  , State -> BetState
st_bet_neg   :: !BetState                 -- ^ bettor state, neg-direction
  }

-- internal -------------------------------------------------------------------

-- floor for the wealth factor before taking a log; keeps the running
-- log-wealth finite when a step pushes the factor to (or below) zero.
-- NB. written via MagicHash because the fractional literal '1.0e-300'
--     compiles as 'fromRational (1.0e-300 :: Rational)', and GHC does
--     not constant-fold the conversion -- leaving a per-step
--     '$wrationalToDouble' call in the worker.
tiny :: Double
tiny :: Double
tiny = Double# -> Double
D# Double#
1.0e-300##
{-# INLINE tiny #-}

-- per-bettor initial state.
init_bet :: Bettor -> BetState
init_bet :: Bettor -> BetState
init_bet Bettor
b = case Bettor
b of
  Fixed Double
_  -> BetState
SFixed
  Bettor
Adaptive -> Double -> Double -> Int -> BetState
SAdaptive Double
0 Double
0 Int
0
  Bettor
Newton   -> Double -> Double -> BetState
SNewton Double
0 Double
1.0e-6  -- small acc seed avoids div-by-zero
{-# INLINE init_bet #-}

-- compute the next bet 'lambda' from the bettor and its current
-- state; 'lam_max' is the direction-specific safety bound. for
-- Adaptive we form a Kelly-style plug-in from the running sample
-- mean and variance; for Newton the bet is just the last lambda
-- chosen by the Newton step (updated during 'step_bet').
bet_lambda :: Bettor -> Double -> BetState -> Double
bet_lambda :: Bettor -> Double -> BetState -> Double
bet_lambda Bettor
b !Double
lam_max !BetState
s = case Bettor
b of
  Fixed Double
lam -> Double
lam
  Bettor
Adaptive -> case BetState
s of
    SAdaptive !Double
sm !Double
sm2 !Int
n
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    -> Double
0
      | Bool
otherwise ->
          let !nd :: Double
nd  = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
              !mu :: Double
mu  = Double
sm Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
nd
              !mu2 :: Double
mu2 = Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
mu
              !var :: Double
var = Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
0 (Double
sm2 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
nd Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
mu2)
              !den :: Double
den = Double
var Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mu2
              !raw :: Double
raw = if Double
den Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double
mu Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
den
          in  Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
0 (Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
lam_max Double
raw)
    BetState
_ -> Double
0
  Bettor
Newton -> case BetState
s of
    SNewton !Double
lam Double
_ -> Double
lam
    BetState
_              -> Double
0
{-# INLINE bet_lambda #-}

-- update bettor state with newly observed centred value 'z'. for
-- Adaptive this is just accumulating sums; for Newton we take one
-- Newton step on the per-step log-wealth loss '-log(1 + lambda * z)',
-- accumulating squared gradients for adaptive scaling.
step_bet :: Bettor -> Double -> BetState -> Double -> BetState
step_bet :: Bettor -> Double -> BetState -> Double -> BetState
step_bet Bettor
b !Double
lam_max !BetState
s !Double
z = case Bettor
b of
  Fixed Double
_ -> BetState
SFixed
  Bettor
Adaptive -> case BetState
s of
    SAdaptive !Double
sm !Double
sm2 !Int
n -> Double -> Double -> Int -> BetState
SAdaptive (Double
sm Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
z) (Double
sm2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
z Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
z) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    BetState
_                     -> Double -> Double -> Int -> BetState
SAdaptive Double
z (Double
z Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
z) Int
1
  Bettor
Newton -> case BetState
s of
    SNewton !Double
lam !Double
acc ->
      let !denom :: Double
denom = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lam Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
z
          !g :: Double
g     = if Double
denom Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double -> Double
forall a. Num a => a -> a
negate Double
z Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
denom
          !acc' :: Double
acc'  = Double
acc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
g Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
g
          !lam' :: Double
lam'  = Double
lam Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
g Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
acc'
          !clp :: Double
clp   = Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
0 (Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
lam_max Double
lam')
      in  Double -> Double -> BetState
SNewton Double
clp Double
acc'
    BetState
_ -> Double -> Double -> BetState
SNewton Double
0 Double
1.0e-6
{-# INLINE step_bet #-}

-- construction ---------------------------------------------------------------

-- | Build a 'Config' for the bounded-mean test.
--
--   Each per-direction safe-bet ceiling @lambda_max@ is set so that
--   the wealth factor stays nonnegative for every admissible
--   observation:
--
--   * The positive-direction factor is @1 + lambda_p * (x - m)@.
--     Since @x@ can dip to @lo@, @x - m@ can reach @lo - m@ (the
--     most negative value), so we need
--     @lambda_p <= 1 \/ (m - lo)@. The ceiling stored is half this
--     to leave numerical margin -- the WSR safety recommendation.
--
--   * The negative-direction factor is @1 - lambda_n * (x - m)@.
--     Since @x@ can rise to @hi@, @x - m@ can reach @hi - m@, so we
--     need @lambda_n <= 1 \/ (hi - m)@; again the ceiling is set to
--     half this.
--
--   The log-wealth rejection threshold is precomputed as
--   @log(2 \/ alpha)@; the 2 is the Bonferroni union-bound
--   adjustment for the two one-sided e-processes.
--
--   >>> let cfg = config 0.5 0.0 1.0 1.0e-3 Newton
config
  :: Double  -- ^ null mean @m@
  -> Double  -- ^ sample lower bound @lo@
  -> Double  -- ^ sample upper bound @hi@
  -> Double  -- ^ significance level @alpha@
  -> Bettor  -- ^ bettor strategy
  -> Config
config :: Double -> Double -> Double -> Double -> Bettor -> Config
config !Double
m !Double
lo !Double
hi !Double
alpha !Bettor
b = Config {
    cfg_bettor :: Bettor
cfg_bettor      = Bettor
b
  , cfg_lam_max_pos :: Double
cfg_lam_max_pos = Double
0.5 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo)
  , cfg_lam_max_neg :: Double
cfg_lam_max_neg = Double
0.5 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
m)
  , cfg_null_mean :: Double
cfg_null_mean   = Double
m
  , cfg_alpha :: Double
cfg_alpha       = Double
alpha
  , cfg_log_thresh :: Double
cfg_log_thresh  = Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
alpha)
  }
{-# INLINE config #-}

-- | The initial 'State' for a fresh streaming test.
--
--   Both directional log-wealths start at @0@ (i.e., wealth @1@) and
--   both bettors start in the per-strategy initial state appropriate
--   for the 'Bettor' chosen in the 'Config'.
--
--   >>> let s0 = initial cfg
initial :: Config -> State
initial :: Config -> State
initial Config{Double
Bettor
cfg_bettor :: Config -> Bettor
cfg_lam_max_pos :: Config -> Double
cfg_lam_max_neg :: Config -> Double
cfg_null_mean :: Config -> Double
cfg_alpha :: Config -> Double
cfg_log_thresh :: Config -> Double
cfg_bettor :: Bettor
cfg_lam_max_pos :: Double
cfg_lam_max_neg :: Double
cfg_null_mean :: Double
cfg_alpha :: Double
cfg_log_thresh :: Double
..} =
  let !s0 :: BetState
s0 = Bettor -> BetState
init_bet Bettor
cfg_bettor
  in  State {
        st_n :: Int
st_n         = Int
0
      , st_log_w_pos :: Double
st_log_w_pos = Double
0
      , st_log_w_neg :: Double
st_log_w_neg = Double
0
      , st_bet_pos :: BetState
st_bet_pos   = BetState
s0
      , st_bet_neg :: BetState
st_bet_neg   = BetState
s0
      }
{-# INLINE initial #-}

-- streaming ------------------------------------------------------------------

-- | Fold one observation into the running 'State'.
--
--   Computes the centred observation @z = x - m@, queries the two
--   directional bettors for their predictable bets, accumulates
--   per-direction log-wealth via
--
--       @log_w' = log_w + log (1 + lambda * z)@
--
--   (with the symmetric @-lambda@ for the negative direction), and
--   then steps the bettor states given the newly observed @z@. The
--   per-step wealth factor is floored at a tiny positive value to
--   keep the log finite when a marginal bet drives the factor to (or
--   below) zero.
--
--   >>> let s1 = update cfg s0 0.7
update :: Config -> State -> Double -> State
update :: Config -> State -> Double -> State
update Config{Double
Bettor
cfg_bettor :: Config -> Bettor
cfg_lam_max_pos :: Config -> Double
cfg_lam_max_neg :: Config -> Double
cfg_null_mean :: Config -> Double
cfg_alpha :: Config -> Double
cfg_log_thresh :: Config -> Double
cfg_bettor :: Bettor
cfg_lam_max_pos :: Double
cfg_lam_max_neg :: Double
cfg_null_mean :: Double
cfg_alpha :: Double
cfg_log_thresh :: Double
..} State{Double
Int
BetState
st_n :: State -> Int
st_log_w_pos :: State -> Double
st_log_w_neg :: State -> Double
st_bet_pos :: State -> BetState
st_bet_neg :: State -> BetState
st_n :: Int
st_log_w_pos :: Double
st_log_w_neg :: Double
st_bet_pos :: BetState
st_bet_neg :: BetState
..} !Double
x =
  let !z :: Double
z      = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
cfg_null_mean
      !lam_p :: Double
lam_p  = Bettor -> Double -> BetState -> Double
bet_lambda Bettor
cfg_bettor Double
cfg_lam_max_pos BetState
st_bet_pos
      !lam_n :: Double
lam_n  = Bettor -> Double -> BetState -> Double
bet_lambda Bettor
cfg_bettor Double
cfg_lam_max_neg BetState
st_bet_neg
      !fac_p :: Double
fac_p  = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
lam_p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
z
      !fac_n :: Double
fac_n  = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lam_n Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
z
      !logw_p :: Double
logw_p = Double
st_log_w_pos Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
tiny Double
fac_p)
      !logw_n :: Double
logw_n = Double
st_log_w_neg Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
tiny Double
fac_n)
      !sp :: BetState
sp     = Bettor -> Double -> BetState -> Double -> BetState
step_bet Bettor
cfg_bettor Double
cfg_lam_max_pos BetState
st_bet_pos Double
z
      !sn :: BetState
sn     = Bettor -> Double -> BetState -> Double -> BetState
step_bet Bettor
cfg_bettor Double
cfg_lam_max_neg BetState
st_bet_neg (Double -> Double
forall a. Num a => a -> a
negate Double
z)
  in  Int -> Double -> Double -> BetState -> BetState -> State
State (Int
st_n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Double
logw_p Double
logw_n BetState
sp BetState
sn
{-# INLINE update #-}

-- | Compute the current 'Verdict' from the running 'State'.
--
--   'Reject' iff either directional log-wealth has crossed the
--   Bonferroni-adjusted threshold @log(2 \/ alpha)@; equivalently,
--   the wealth process on either side has exceeded @2 \/ alpha@.
--   Under @H_0@, by Ville's inequality, the probability of this ever
--   happening is at most @alpha@ -- and crucially this bound holds
--   at /every/ sample size simultaneously, so the user is free to
--   peek at the verdict as often as they like and stop on the first
--   'Reject'.
--
--   >>> decide cfg s0
--   Continue
decide :: Config -> State -> Verdict
decide :: Config -> State -> Verdict
decide Config{Double
Bettor
cfg_bettor :: Config -> Bettor
cfg_lam_max_pos :: Config -> Double
cfg_lam_max_neg :: Config -> Double
cfg_null_mean :: Config -> Double
cfg_alpha :: Config -> Double
cfg_log_thresh :: Config -> Double
cfg_bettor :: Bettor
cfg_lam_max_pos :: Double
cfg_lam_max_neg :: Double
cfg_null_mean :: Double
cfg_alpha :: Double
cfg_log_thresh :: Double
..} State{Double
Int
BetState
st_n :: State -> Int
st_log_w_pos :: State -> Double
st_log_w_neg :: State -> Double
st_bet_pos :: State -> BetState
st_bet_neg :: State -> BetState
st_n :: Int
st_log_w_pos :: Double
st_log_w_neg :: Double
st_bet_pos :: BetState
st_bet_neg :: BetState
..}
  | Double
st_log_w_pos Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
cfg_log_thresh = Verdict
Reject
  | Double
st_log_w_neg Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
cfg_log_thresh = Verdict
Reject
  | Bool
otherwise                      = Verdict
Continue
{-# INLINE decide #-}

-- inspection -----------------------------------------------------------------

-- | The current log-wealth, taken as the maximum of the two
--   directional processes.
--
--   This is the natural \"test statistic\": it is monotone in the
--   evidence against @H_0@ accumulated so far, and the test rejects
--   exactly when it crosses @log(2 \/ alpha)@.
--
--   >>> log_wealth s0
--   0.0
log_wealth :: State -> Double
log_wealth :: State -> Double
log_wealth State{Double
Int
BetState
st_n :: State -> Int
st_log_w_pos :: State -> Double
st_log_w_neg :: State -> Double
st_bet_pos :: State -> BetState
st_bet_neg :: State -> BetState
st_n :: Int
st_log_w_pos :: Double
st_log_w_neg :: Double
st_bet_pos :: BetState
st_bet_neg :: BetState
..} = Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
st_log_w_pos Double
st_log_w_neg
{-# INLINE log_wealth #-}

-- | The number of samples consumed so far.
--
--   >>> samples s0
--   0
samples :: State -> Int
samples :: State -> Int
samples = State -> Int
st_n
{-# INLINE samples #-}