Skip to content

Commit

Permalink
calculated mask only once
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 24, 2023
1 parent 9eabc95 commit b0c2e70
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/dilax/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,15 @@ def prepare(
"""
import equinox as eqx

# create parameters
parameters: dict[str, dict[str, Parameter]] = {}
staterrors: dict[str, dict[str, eqx.Partial]] = {}
for process, _sumw in self.sumw.items():
key = self.key_template.format(process=process)
process_parameters = parameters[key] = {}
mask = _sumw < self.threshold
for i in range(len(_sumw)):
pkey = f"{process}_{i}"
if (
self.mode == self.Mode.barlow_beeston_lite
and _sumw[i] > self.threshold
):
if self.mode == self.Mode.barlow_beeston_lite and not mask[i]:
# we merge all processes into one parameter
# for the barlow-beeston-lite approach where
# the bin content is above a certain treshold
Expand All @@ -375,8 +372,6 @@ def prepare(
if self.mode == self.Mode.poisson:
kwargs["threshold"] = jnp.inf # inf -> always poisson
elif self.mode == self.Mode.barlow_beeston_lite:
mask = _sumw < self.threshold

kwargs["sumw"] = jnp.where(
mask, _sumw, sum(jax.tree_util.tree_leaves(self.sumw))
)
Expand All @@ -385,6 +380,5 @@ def prepare(
self.sumw2[process],
sum(jax.tree_util.tree_leaves(self.sumw2)),
)

staterrors[key] = eqx.Partial(staterror, **kwargs)
return parameters, staterrors

0 comments on commit b0c2e70

Please sign in to comment.