diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 75492ff..fbef217 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -2,6 +2,7 @@ import abc import operator +from collections.abc import Callable from functools import reduce from typing import TYPE_CHECKING @@ -24,6 +25,7 @@ "compose", "where", "mask", + "transform", ] @@ -82,7 +84,7 @@ class clip(evm.ModifierBase): min_sf: float = eqx.field(static=True) max_sf: float = eqx.field(static=True) - def scale_factor(self, sumw: Array) -> evm.SF: + def scale_factor(self, sumw: Array) -> evm.custrom_types.SF: sf = self.modifier.scale_factor(sumw) return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) @@ -240,6 +242,43 @@ def _mask(true: Array, false: Array) -> Array: ) +class transform(ModifierBase): + """ + Transform the scale factors of a modifier. + + The `transform_fn` is a function that is applied to both, multiplicative and additive scale factors. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import evermore as evm + + hist = jnp.array([5, 20, 30]) + syst = evm.Parameter(value=0.1) + + norm = syst.lnN(jnp.array([0.9, 1.1])) + + transformed_norm = evm.modifier.transform(jnp.sqrt, norm) + + # apply + transformed_norm(hist) + # -> Array([ 5.024686, 20.098743, 30.148115], dtype=float32) + + # for comparison: + norm(hist) + # -> Array([ 5.049494, 20.197975, 30.296963], dtype=float32) + """ + + transform_fn: Callable = eqx.field(static=True) + modifier: Modifier + + def scale_factor(self, sumw: Array) -> SF: + sf = self.modifier.scale_factor(sumw) + return jtu.tree_map(self.transform_fn, sf) + + class compose(ModifierBase): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)`