Skip to content

Commit

Permalink
fix staterrors, align with combine implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Apr 3, 2024
1 parent 5212812 commit 005eef4
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 19 deletions.
6 changes: 6 additions & 0 deletions src/evermore/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"Effect",
"noop",
"unconstrained",
"normal",
"log_normal",
Expand All @@ -27,6 +28,11 @@ class Effect(eqx.Module):
def scale_factor(self, parameter: Parameter, hist: Array) -> SF: ...


class noop(Effect):
def scale_factor(self, parameter: Parameter, hist: Array) -> SF:
return SF(multiplicative=jnp.ones_like(hist), additive=jnp.zeros_like(hist))


class unconstrained(Effect):
def scale_factor(self, parameter: Parameter, hist: Array) -> SF:
sf = jnp.broadcast_to(parameter.value, hist.shape)
Expand Down
12 changes: 6 additions & 6 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ class where(ModifierBase):
# -> Array([ 5.1593127, 20.281374 , 30.181376 ], dtype=float32)
"""

condition: Array
modifier_true: Modifier
modifier_false: Modifier
condition: Array = eqx.field(static=True)
modifier_true: ModifierLike
modifier_false: ModifierLike

def scale_factor(self, hist: Array) -> SF:
true_sf = self.modifier_true.scale_factor(hist)
Expand Down Expand Up @@ -225,8 +225,8 @@ class mask(ModifierBase):
# -> Array([ 5.049494, 20. , 30.296963], dtype=float32)
"""

where: Array
modifier: Modifier
where: Array = eqx.field(static=True)
modifier: ModifierLike

def scale_factor(self, hist: Array) -> SF:
sf = self.modifier.scale_factor(hist)
Expand Down Expand Up @@ -270,7 +270,7 @@ class transform(ModifierBase):
"""

transform_fn: Callable = eqx.field(static=True)
modifier: Modifier
modifier: ModifierLike

def scale_factor(self, hist: Array) -> SF:
sf = self.modifier.scale_factor(hist)
Expand Down
102 changes: 89 additions & 13 deletions src/evermore/staterror.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from collections.abc import Callable
from typing import cast

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, PyTree

from evermore.custom_types import ModifierLike
from evermore.effect import noop
from evermore.modifier import Modifier
from evermore.modifier import where as modifier_where
from evermore.parameter import NormalConstrained, PoissonConstrained
from evermore.util import sum_leaves
Expand Down Expand Up @@ -44,10 +47,14 @@ class StatErrors(eqx.Module):
mod(hists["qcd"])
"""

params_global: PyTree
params_per_process: PyTree
gaussians_global: PyTree
gaussians_per_process: PyTree
poissons_per_process: PyTree
hists: PyTree = eqx.field(static=True)
histsw2: PyTree = eqx.field(static=True)
ntot: Array = eqx.field(static=True)
etot: Array = eqx.field(static=True)
threshold: float = eqx.field(static=True)
mask: Array = eqx.field(static=True)

def __init__(
Expand All @@ -56,17 +63,86 @@ def __init__(
histsw2: PyTree,
threshold: float = 10.0,
) -> None:
leaf = jtu.tree_leaves(hists)[0]
self.params_global = NormalConstrained(value=jnp.zeros_like(leaf))
self.params_per_process = jtu.tree_map(
lambda hist: PoissonConstrained(lamb=hist), hists
)
self.ntot = sum_leaves(hists)
self.etot = jnp.sqrt(sum_leaves(histsw2))
assert (
jtu.tree_structure(hists) == jtu.tree_structure(histsw2) # type: ignore[operator]
), "The PyTree structure of hists and histsw2 must be the same!"
self.hists = hists
self.histsw2 = histsw2
self.threshold = threshold

self.ntot = sum_leaves(self.hists)
self.etot = jnp.sqrt(sum_leaves(self.histsw2))
ntot_eff = jnp.round(self.ntot**2 / self.etot**2, decimals=0)
self.mask = ntot_eff > threshold
self.mask = ntot_eff > self.threshold

# setup params
self.gaussians_global = NormalConstrained(value=jnp.zeros_like(self.ntot))
self.gaussians_per_process = jtu.tree_map(
lambda hist: NormalConstrained(value=jnp.zeros_like(hist)), self.hists
)
self.poissons_per_process = jtu.tree_map(
lambda hist: PoissonConstrained(
lamb=cast(Array, jnp.where(hist > 0.0, hist, 1.0)),
value=jnp.zeros_like(hist),
),
self.hists,
)

def get(self, where: Callable) -> ModifierLike:
poisson_mod = where(self.params_per_process).poisson()
normal_mod = self.params_global.normal(width=self.etot / self.ntot)
return modifier_where(self.mask, normal_mod, poisson_mod)
# see: https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/pull/929
# and: https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/latest/part2/bin-wise-stats/#usage-instructions

# poisson case
# if w > 0.0, then poisson, else noop (no effect)
# since w <= 0 leads to NaNs in derivatives, we need to mask them
w = where(self.hists)
poisson_params = where(self.poissons_per_process)
poisson_noop_mod = Modifier(parameter=poisson_params, effect=noop())
poisson_mod = modifier_where(
w > 0.0, poisson_params.poisson(), poisson_noop_mod
)

# gaussian case per process
# if w == 0.0, guard for division by zero
# gaussians with width 0 also lead to nans, so we need to guard this aswell
w2 = where(self.histsw2)
relerr = jnp.where(w == 0.0, 0.0, jnp.sqrt(w2) / jnp.where(w == 0.0, 1.0, w))
mask = relerr == 0.0
relerr = jnp.where(mask, relerr, 1.0)
gauss_params = where(self.gaussians_per_process)
gauss_noop_mod = Modifier(parameter=gauss_params, effect=noop())
gauss_mod = modifier_where(
mask, gauss_noop_mod, gauss_params.normal(width=relerr)
)

# gaussian case global
gauss_global_mod = self.gaussians_global.normal(width=self.etot / self.ntot)

# combine all, logic as here: https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/main/src/CMSHistErrorPropagator.cc#L320-L434
#
# legend:
# - n_tot_eff: effective number of events summed over all processes per bin
# - e_tot: error summed over all processes per bin
# - n_tot: number of events summed over all processes per bin
# - n_i_eff: effective number of events for process i per bin
# - e_i: error for process i per bin
# - n_i: number of events for process i per bin
# - threshold: threshold for applying gaussian
#
# pseudo-code:
#
# if n_tot_eff > threshold:
# then apply global gaussian(width=e_tot/n_tot)
# else:
# if n_i_eff > threshold or e_i > n_i or n_i <= 0.0:
# apply per process gaussian(width=e_i/n_i)
# else:
# apply per process poisson
per_process_mask = (
((w**2 / w2**2) > self.threshold) | (jnp.sqrt(w2) > w) | (w <= 0)
)
return modifier_where(
self.mask,
gauss_global_mod,
modifier_where(per_process_mask, gauss_mod, poisson_mod),
)

0 comments on commit 005eef4

Please sign in to comment.