From b0c2e709e94bad572a7099c0033b14949c56cd5e Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Fri, 24 Nov 2023 12:55:10 +0100 Subject: [PATCH] calculated mask only once --- src/dilax/modifier.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/dilax/modifier.py b/src/dilax/modifier.py index b74cb6e..de2c8eb 100644 --- a/src/dilax/modifier.py +++ b/src/dilax/modifier.py @@ -349,18 +349,15 @@ def prepare( """ import equinox as eqx - # create parameters parameters: dict[str, dict[str, Parameter]] = {} staterrors: dict[str, dict[str, eqx.Partial]] = {} for process, _sumw in self.sumw.items(): key = self.key_template.format(process=process) process_parameters = parameters[key] = {} + mask = _sumw < self.threshold for i in range(len(_sumw)): pkey = f"{process}_{i}" - if ( - self.mode == self.Mode.barlow_beeston_lite - and _sumw[i] > self.threshold - ): + if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: # we merge all processes into one parameter # for the barlow-beeston-lite approach where # the bin content is above a certain treshold @@ -375,8 +372,6 @@ def prepare( if self.mode == self.Mode.poisson: kwargs["threshold"] = jnp.inf # inf -> always poisson elif self.mode == self.Mode.barlow_beeston_lite: - mask = _sumw < self.threshold - kwargs["sumw"] = jnp.where( mask, _sumw, sum(jax.tree_util.tree_leaves(self.sumw)) ) @@ -385,6 +380,5 @@ def prepare( self.sumw2[process], sum(jax.tree_util.tree_leaves(self.sumw2)), ) - staterrors[key] = eqx.Partial(staterror, **kwargs) return parameters, staterrors