From 005eef44bff7d6b1a0beafa65818cc5d16a5c72a Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 3 Apr 2024 16:09:55 +0200 Subject: [PATCH] fix staterrors, align with combine implementation --- src/evermore/effect.py | 6 +++ src/evermore/modifier.py | 12 ++--- src/evermore/staterror.py | 102 +++++++++++++++++++++++++++++++++----- 3 files changed, 101 insertions(+), 19 deletions(-) diff --git a/src/evermore/effect.py b/src/evermore/effect.py index dc97dca..7a85a8f 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -10,6 +10,7 @@ __all__ = [ "Effect", + "noop", "unconstrained", "normal", "log_normal", @@ -27,6 +28,11 @@ class Effect(eqx.Module): def scale_factor(self, parameter: Parameter, hist: Array) -> SF: ... +class noop(Effect): + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: + return SF(multiplicative=jnp.ones_like(hist), additive=jnp.zeros_like(hist)) + + class unconstrained(Effect): def scale_factor(self, parameter: Parameter, hist: Array) -> SF: sf = jnp.broadcast_to(parameter.value, hist.shape) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index cf3e7a7..0bf42e5 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -184,9 +184,9 @@ class where(ModifierBase): # -> Array([ 5.1593127, 20.281374 , 30.181376 ], dtype=float32) """ - condition: Array - modifier_true: Modifier - modifier_false: Modifier + condition: Array = eqx.field(static=True) + modifier_true: ModifierLike + modifier_false: ModifierLike def scale_factor(self, hist: Array) -> SF: true_sf = self.modifier_true.scale_factor(hist) @@ -225,8 +225,8 @@ class mask(ModifierBase): # -> Array([ 5.049494, 20. , 30.296963], dtype=float32) """ - where: Array - modifier: Modifier + where: Array = eqx.field(static=True) + modifier: ModifierLike def scale_factor(self, hist: Array) -> SF: sf = self.modifier.scale_factor(hist) @@ -270,7 +270,7 @@ class transform(ModifierBase): """ transform_fn: Callable = eqx.field(static=True) - modifier: Modifier + modifier: ModifierLike def scale_factor(self, hist: Array) -> SF: sf = self.modifier.scale_factor(hist) diff --git a/src/evermore/staterror.py b/src/evermore/staterror.py index 32b8c89..415f391 100644 --- a/src/evermore/staterror.py +++ b/src/evermore/staterror.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import cast import equinox as eqx import jax.numpy as jnp @@ -8,6 +9,8 @@ from jaxtyping import Array, PyTree from evermore.custom_types import ModifierLike +from evermore.effect import noop +from evermore.modifier import Modifier from evermore.modifier import where as modifier_where from evermore.parameter import NormalConstrained, PoissonConstrained from evermore.util import sum_leaves @@ -44,10 +47,14 @@ class StatErrors(eqx.Module): mod(hists["qcd"]) """ - params_global: PyTree - params_per_process: PyTree + gaussians_global: PyTree + gaussians_per_process: PyTree + poissons_per_process: PyTree + hists: PyTree = eqx.field(static=True) + histsw2: PyTree = eqx.field(static=True) ntot: Array = eqx.field(static=True) etot: Array = eqx.field(static=True) + threshold: float = eqx.field(static=True) mask: Array = eqx.field(static=True) def __init__( @@ -56,17 +63,86 @@ def __init__( histsw2: PyTree, threshold: float = 10.0, ) -> None: - leaf = jtu.tree_leaves(hists)[0] - self.params_global = NormalConstrained(value=jnp.zeros_like(leaf)) - self.params_per_process = jtu.tree_map( - lambda hist: PoissonConstrained(lamb=hist), hists - ) - self.ntot = sum_leaves(hists) - self.etot = jnp.sqrt(sum_leaves(histsw2)) + assert ( + jtu.tree_structure(hists) == jtu.tree_structure(histsw2) # type: ignore[operator] + ), "The PyTree structure of hists and histsw2 must be the same!" + self.hists = hists + self.histsw2 = histsw2 + self.threshold = threshold + + self.ntot = sum_leaves(self.hists) + self.etot = jnp.sqrt(sum_leaves(self.histsw2)) ntot_eff = jnp.round(self.ntot**2 / self.etot**2, decimals=0) - self.mask = ntot_eff > threshold + self.mask = ntot_eff > self.threshold + + # setup params + self.gaussians_global = NormalConstrained(value=jnp.zeros_like(self.ntot)) + self.gaussians_per_process = jtu.tree_map( + lambda hist: NormalConstrained(value=jnp.zeros_like(hist)), self.hists + ) + self.poissons_per_process = jtu.tree_map( + lambda hist: PoissonConstrained( + lamb=cast(Array, jnp.where(hist > 0.0, hist, 1.0)), + value=jnp.zeros_like(hist), + ), + self.hists, + ) def get(self, where: Callable) -> ModifierLike: - poisson_mod = where(self.params_per_process).poisson() - normal_mod = self.params_global.normal(width=self.etot / self.ntot) - return modifier_where(self.mask, normal_mod, poisson_mod) + # see: https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/pull/929 + # and: https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/latest/part2/bin-wise-stats/#usage-instructions + + # poisson case + # if w > 0.0, then poisson, else noop (no effect) + # since w <= 0 leads to NaNs in derivatives, we need to mask them + w = where(self.hists) + poisson_params = where(self.poissons_per_process) + poisson_noop_mod = Modifier(parameter=poisson_params, effect=noop()) + poisson_mod = modifier_where( + w > 0.0, poisson_params.poisson(), poisson_noop_mod + ) + + # gaussian case per process + # if w == 0.0, guard for division by zero + # gaussians with width 0 also lead to nans, so we need to guard this aswell + w2 = where(self.histsw2) + relerr = jnp.where(w == 0.0, 0.0, jnp.sqrt(w2) / jnp.where(w == 0.0, 1.0, w)) + mask = relerr == 0.0 + relerr = jnp.where(mask, relerr, 1.0) + gauss_params = where(self.gaussians_per_process) + gauss_noop_mod = Modifier(parameter=gauss_params, effect=noop()) + gauss_mod = modifier_where( + mask, gauss_noop_mod, gauss_params.normal(width=relerr) + ) + + # gaussian case global + gauss_global_mod = self.gaussians_global.normal(width=self.etot / self.ntot) + + # combine all, logic as here: https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/main/src/CMSHistErrorPropagator.cc#L320-L434 + # + # legend: + # - n_tot_eff: effective number of events summed over all processes per bin + # - e_tot: error summed over all processes per bin + # - n_tot: number of events summed over all processes per bin + # - n_i_eff: effective number of events for process i per bin + # - e_i: error for process i per bin + # - n_i: number of events for process i per bin + # - threshold: threshold for applying gaussian + # + # pseudo-code: + # + # if n_tot_eff > threshold: + # then apply global gaussian(width=e_tot/n_tot) + # else: + # if n_i_eff > threshold or e_i > n_i or n_i <= 0.0: + # apply per process gaussian(width=e_i/n_i) + # else: + # apply per process poisson + per_process_mask = ( + ((w**2 / w2**2) > self.threshold) | (jnp.sqrt(w2) > w) | (w <= 0) + ) + return modifier_where( + self.mask, + gauss_global_mod, + modifier_where(per_process_mask, gauss_mod, poisson_mod), + )