From 7413bcf73eec929894659ee2d9254e0ce0ddaa7a Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 12:57:43 +0100 Subject: [PATCH] add ModifierBase for custom modifiers and use jax compatible pytree structure for scale factors --- src/evermore/__init__.py | 3 +- src/evermore/custom_types.py | 30 +++++----- src/evermore/effect.py | 46 +++++++--------- src/evermore/modifier.py | 104 ++++++++++++++++++++++++----------- src/evermore/util.py | 7 --- tests/test_parameter.py | 25 ++++----- 6 files changed, 117 insertions(+), 98 deletions(-) diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index 9b18ef0..cf9f0ca 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -27,6 +27,7 @@ # explicitely expose some classes "Parameter", "Modifier", + "ModifierBase", ] @@ -43,5 +44,5 @@ def __dir__(): sample, util, ) -from evermore.modifier import Modifier # noqa: E402 +from evermore.modifier import Modifier, ModifierBase # noqa: E402 from evermore.parameter import Parameter # noqa: E402 diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 4c2c27c..49b44db 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol from jaxtyping import Array @@ -9,8 +9,19 @@ from evermore.modifier import compose +__all__ = [ + "SF", + "AddOrMul", + "ModifierLike", +] + + AddOrMul = Callable[[Array, Array], Array] -AddOrMulSFs = dict[AddOrMul, Array] + + +class SF(NamedTuple): + multiplicative: Array + additive: Array class Sentinel: @@ -28,21 +39,8 @@ def __repr__(self) -> str: _NoValue: Any = Sentinel("") -@runtime_checkable class ModifierLike(Protocol): - def scale_factor(self, sumw: Array) -> AddOrMulSFs: - """ - Always return a dictionary of scale factors for the sumw array. - Dictionary has to look as follows: - - .. code-block:: python - - import operator - from jaxtyping import Array - - - {operator.mul: Array, operator.add: Array} - """ + def scale_factor(self, sumw: Array) -> SF: ... def __call__(self, sumw: Array) -> Array: diff --git a/src/evermore/effect.py b/src/evermore/effect.py index 8d348b4..36f8834 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -1,15 +1,14 @@ import abc -import operator from typing import TYPE_CHECKING import equinox as eqx import jax.numpy as jnp from jaxtyping import Array, Float -from evermore.custom_types import AddOrMulSFs +from evermore.custom_types import SF from evermore.parameter import Parameter from evermore.pdf import PDF, Flat, Gauss, Poisson -from evermore.util import as1darray, initSF +from evermore.util import as1darray if TYPE_CHECKING: pass @@ -37,7 +36,7 @@ def constraint(self, parameter: Parameter) -> PDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: ... @@ -45,10 +44,9 @@ class unconstrained(Effect): def constraint(self, parameter: Parameter) -> PDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = parameter.value - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = jnp.broadcast_to(parameter.value, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) DEFAULT_EFFECT = unconstrained() @@ -65,7 +63,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -83,9 +81,8 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: return (parameter.value * self.width) + 1 """ - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = (parameter.value * self.width) + 1 - return sf + sf = jnp.broadcast_to((parameter.value * self.width) + 1, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) class shape(Effect): @@ -124,11 +121,9 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - p = parameter.value - sf = initSF(shape=p.shape) - sf[operator.add] = self.vshift(sf=p, sumw=sumw) - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = self.vshift(sf=parameter.value, sumw=sumw) + return SF(multiplicative=jnp.ones_like(sumw), additive=sf) # shift = self.vshift(sf=sf, sumw=sumw) # # handle zeros, see: https://github.com/google/jax/issues/5039 # x = jnp.where(sumw == 0.0, 1.0, sumw) @@ -166,7 +161,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -184,11 +179,9 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) """ - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = jnp.exp( - parameter.value * self.interpolate(parameter=parameter) - ) - return sf + interp = self.interpolate(parameter=parameter) + sf = jnp.broadcast_to(jnp.exp(parameter.value * interp), sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) class poisson(Effect): @@ -201,7 +194,6 @@ def constraint(self, parameter: Parameter) -> PDF: assert parameter.value.shape == self.lamb.shape return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=parameter.value.shape) - sf[operator.add] = parameter.value + 1 - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = jnp.broadcast_to(parameter.value + 1, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 0e01476..f95b11e 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import operator from functools import reduce from typing import TYPE_CHECKING @@ -7,17 +8,18 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.tree_util as jtu from jaxtyping import Array -from evermore.custom_types import AddOrMul, AddOrMulSFs, ModifierLike +from evermore.custom_types import SF, AddOrMul, ModifierLike from evermore.effect import DEFAULT_EFFECT from evermore.parameter import Parameter -from evermore.util import initSF if TYPE_CHECKING: from evermore.effect import Effect __all__ = [ + "ModifierBase", "Modifier", "compose", "where", @@ -28,15 +30,67 @@ def __dir__(): return __all__ +class AbstractModifier(eqx.Module): + @abc.abstractmethod + def scale_factor(self: ModifierLike, sumw: Array) -> SF: + ... + + @abc.abstractmethod + def __call__(self: ModifierLike, sumw: Array) -> Array: + ... + + @abc.abstractmethod + def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: + ... + + class ApplyFn(eqx.Module): @jax.named_scope("evm.modifier.ApplyFn") def __call__(self: ModifierLike, sumw: Array) -> Array: sf = self.scale_factor(sumw=sumw) # apply - return sf[operator.mul] * (sumw + sf[operator.add]) + return sf.multiplicative * (sumw + sf.additive) + + +class MatMulCompose(eqx.Module): + def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: + return compose(self, other) + + +class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier): + """ + This serves as a base class for all modifiers. + It automatically implements the __call__ method to apply the scale factors to the sumw array + and the __matmul__ method to compose two modifiers. + + Custom modifiers should inherit from this class and implement the scale_factor method. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import jax.tree_util as jtu + import evermore as evm + class clip(evm.ModifierBase): + modifier: evm.ModifierBase + min_sf: float + max_sf: float -class Modifier(ApplyFn): + def scale_factor(self, sumw: jnp.ndarray) -> evm.SF: + sf = self.modifier.scale_factor(sumw) + return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) + + + parameter = evm.Parameter(value=1.1) + modifier = parameter.unconstrained() + + clipped_modifier = clip(modifier=modifier, min_sf=0.8, max_sf=1.2) + """ + + +class Modifier(ModifierBase): """ Create a new modifier for a given parameter and penalty. @@ -89,14 +143,11 @@ def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> Non constraint = self.effect.constraint(parameter=self.parameter) self.parameter._set_constraint(constraint, overwrite=False) - def scale_factor(self, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, sumw: Array) -> SF: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) - -class where(ApplyFn): +class where(ModifierBase): """ Combine two modifiers based on a condition. @@ -125,21 +176,17 @@ class where(ApplyFn): modifier_true: Modifier modifier_false: Modifier - def scale_factor(self, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=sumw.shape) - + def scale_factor(self, sumw: Array) -> SF: true_sf = self.modifier_true.scale_factor(sumw) false_sf = self.modifier_false.scale_factor(sumw) - for op in operator.mul, operator.add: - sf.update(jnp.where(self.condition, true_sf[op], false_sf[op])) - return sf + def _where(true: Array, false: Array) -> Array: + return jnp.where(self.condition, true, false) - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) + return jtu.tree_map(_where, true_sf, false_sf) -class compose(ApplyFn): +class compose(ModifierBase): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested. @@ -196,7 +243,7 @@ def __init__(self, *modifiers: ModifierLike) -> None: if isinstance(mod, compose): _modifiers.extend(mod.modifiers) else: - assert isinstance(mod, ModifierLike) + assert isinstance(mod, ModifierBase) _modifiers.append(mod) # by now all modifiers are either modifier or staterror self.modifiers = _modifiers @@ -204,23 +251,16 @@ def __init__(self, *modifiers: ModifierLike) -> None: def __len__(self) -> int: return len(self.modifiers) - def scale_factor(self, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, sumw: Array) -> SF: # collect all multiplicative and additive shifts sfs: dict[AddOrMul, list] = {operator.add: [], operator.mul: []} for m in range(len(self)): mod = self.modifiers[m] _sf = mod.scale_factor(sumw) - for op in operator.add, operator.mul: - sfs[op].append(_sf[op]) + sfs[operator.mul].append(_sf.multiplicative) + sfs[operator.add].append(_sf.additive) - sf = initSF(shape=sumw.shape) # calculate the product with for operator.mul and operator.add - for op, init_val in ( - (operator.mul, jnp.ones_like(sumw)), - (operator.add, jnp.zeros_like(sumw)), - ): - sf[op] = reduce(op, sfs[op], init_val) - return sf - - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) + multiplicative_sf = reduce(operator.mul, sfs[operator.mul], jnp.ones_like(sumw)) + additive_sf = reduce(operator.add, sfs[operator.add], jnp.zeros_like(sumw)) + return SF(multiplicative=multiplicative_sf, additive=additive_sf) diff --git a/src/evermore/util.py b/src/evermore/util.py index b151d70..7736af4 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -11,10 +11,7 @@ import jax.tree_util as jtu from jaxtyping import Array, ArrayLike, PyTree -from evermore.custom_types import AddOrMulSFs - __all__ = [ - "initSF", "is_parameter", "sum_leaves", "as1darray", @@ -27,10 +24,6 @@ def __dir__(): return __all__ -def initSF(shape: tuple) -> AddOrMulSFs: - return {operator.add: jnp.zeros(shape), operator.mul: jnp.ones(shape)} - - def is_parameter(leaf: Any) -> bool: from evermore import Parameter diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 277ecc0..5d6ca59 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,12 +1,10 @@ from __future__ import annotations -import operator - import jax.numpy as jnp import pytest import evermore as evm -from evermore.custom_types import _NoValue +from evermore.custom_types import SF, _NoValue from evermore.pdf import Flat, Gauss, Poisson @@ -24,10 +22,9 @@ def test_unconstrained(): u = evm.effect.unconstrained() assert u.constraint(p) == Flat() - assert u.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert u.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_gauss(): @@ -35,10 +32,9 @@ def test_gauss(): g = evm.effect.gauss(width=jnp.array(1.0)) assert g.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert g.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert g.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_lnN(): @@ -46,10 +42,9 @@ def test_lnN(): ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) assert ln.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert ln.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert ln.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_poisson():