Skip to content

Commit

Permalink
add ModifierBase for custom modifiers and use jax compatible pytree s…
Browse files Browse the repository at this point in the history
…tructure for scale factors
  • Loading branch information
pfackeldey committed Mar 9, 2024
1 parent e5f12a4 commit 7413bcf
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 98 deletions.
3 changes: 2 additions & 1 deletion src/evermore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# explicitely expose some classes
"Parameter",
"Modifier",
"ModifierBase",
]


Expand All @@ -43,5 +44,5 @@ def __dir__():
sample,
util,
)
from evermore.modifier import Modifier # noqa: E402
from evermore.modifier import Modifier, ModifierBase # noqa: E402
from evermore.parameter import Parameter # noqa: E402
30 changes: 14 additions & 16 deletions src/evermore/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol

from jaxtyping import Array

if TYPE_CHECKING:
from evermore.modifier import compose


__all__ = [
"SF",
"AddOrMul",
"ModifierLike",
]


AddOrMul = Callable[[Array, Array], Array]
AddOrMulSFs = dict[AddOrMul, Array]


class SF(NamedTuple):
multiplicative: Array
additive: Array


class Sentinel:
Expand All @@ -28,21 +39,8 @@ def __repr__(self) -> str:
_NoValue: Any = Sentinel("<NoValue>")


@runtime_checkable
class ModifierLike(Protocol):
def scale_factor(self, sumw: Array) -> AddOrMulSFs:
"""
Always return a dictionary of scale factors for the sumw array.
Dictionary has to look as follows:
.. code-block:: python
import operator
from jaxtyping import Array
{operator.mul: Array, operator.add: Array}
"""
def scale_factor(self, sumw: Array) -> SF:
...

def __call__(self, sumw: Array) -> Array:
Expand Down
46 changes: 19 additions & 27 deletions src/evermore/effect.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import abc
import operator
from typing import TYPE_CHECKING

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Float

from evermore.custom_types import AddOrMulSFs
from evermore.custom_types import SF
from evermore.parameter import Parameter
from evermore.pdf import PDF, Flat, Gauss, Poisson
from evermore.util import as1darray, initSF
from evermore.util import as1darray

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -37,18 +36,17 @@ def constraint(self, parameter: Parameter) -> PDF:
...

@abc.abstractmethod
def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
...


class unconstrained(Effect):
def constraint(self, parameter: Parameter) -> PDF:
return Flat()

def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
sf = initSF(shape=parameter.value.shape)
sf[operator.mul] = parameter.value
return sf
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
sf = jnp.broadcast_to(parameter.value, sumw.shape)
return SF(multiplicative=sf, additive=jnp.zeros_like(sumw))


DEFAULT_EFFECT = unconstrained()
Expand All @@ -65,7 +63,7 @@ def constraint(self, parameter: Parameter) -> PDF:
mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value)
)

def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
"""
Implementation with (inverse) CDFs is defined as follows:
Expand All @@ -83,9 +81,8 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
return (parameter.value * self.width) + 1
"""
sf = initSF(shape=parameter.value.shape)
sf[operator.mul] = (parameter.value * self.width) + 1
return sf
sf = jnp.broadcast_to((parameter.value * self.width) + 1, sumw.shape)
return SF(multiplicative=sf, additive=jnp.zeros_like(sumw))


class shape(Effect):
Expand Down Expand Up @@ -124,11 +121,9 @@ def constraint(self, parameter: Parameter) -> PDF:
mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value)
)

def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
p = parameter.value
sf = initSF(shape=p.shape)
sf[operator.add] = self.vshift(sf=p, sumw=sumw)
return sf
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
sf = self.vshift(sf=parameter.value, sumw=sumw)
return SF(multiplicative=jnp.ones_like(sumw), additive=sf)
# shift = self.vshift(sf=sf, sumw=sumw)
# # handle zeros, see: https://github.com/google/jax/issues/5039
# x = jnp.where(sumw == 0.0, 1.0, sumw)
Expand Down Expand Up @@ -166,7 +161,7 @@ def constraint(self, parameter: Parameter) -> PDF:
mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value)
)

def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
"""
Implementation with (inverse) CDFs is defined as follows:
Expand All @@ -184,11 +179,9 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
return jnp.exp(parameter.value * self.interpolate(parameter=parameter))
"""
sf = initSF(shape=parameter.value.shape)
sf[operator.mul] = jnp.exp(
parameter.value * self.interpolate(parameter=parameter)
)
return sf
interp = self.interpolate(parameter=parameter)
sf = jnp.broadcast_to(jnp.exp(parameter.value * interp), sumw.shape)
return SF(multiplicative=sf, additive=jnp.zeros_like(sumw))


class poisson(Effect):
Expand All @@ -201,7 +194,6 @@ def constraint(self, parameter: Parameter) -> PDF:
assert parameter.value.shape == self.lamb.shape
return Poisson(lamb=self.lamb)

def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs:
sf = initSF(shape=parameter.value.shape)
sf[operator.add] = parameter.value + 1
return sf
def scale_factor(self, parameter: Parameter, sumw: Array) -> SF:
sf = jnp.broadcast_to(parameter.value + 1, sumw.shape)
return SF(multiplicative=sf, additive=jnp.zeros_like(sumw))
104 changes: 72 additions & 32 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from __future__ import annotations

import abc
import operator
from functools import reduce
from typing import TYPE_CHECKING

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array

from evermore.custom_types import AddOrMul, AddOrMulSFs, ModifierLike
from evermore.custom_types import SF, AddOrMul, ModifierLike
from evermore.effect import DEFAULT_EFFECT
from evermore.parameter import Parameter
from evermore.util import initSF

if TYPE_CHECKING:
from evermore.effect import Effect

__all__ = [
"ModifierBase",
"Modifier",
"compose",
"where",
Expand All @@ -28,15 +30,67 @@ def __dir__():
return __all__


class AbstractModifier(eqx.Module):
@abc.abstractmethod
def scale_factor(self: ModifierLike, sumw: Array) -> SF:
...

@abc.abstractmethod
def __call__(self: ModifierLike, sumw: Array) -> Array:
...

@abc.abstractmethod
def __matmul__(self: ModifierLike, other: ModifierLike) -> compose:
...


class ApplyFn(eqx.Module):
@jax.named_scope("evm.modifier.ApplyFn")
def __call__(self: ModifierLike, sumw: Array) -> Array:
sf = self.scale_factor(sumw=sumw)
# apply
return sf[operator.mul] * (sumw + sf[operator.add])
return sf.multiplicative * (sumw + sf.additive)


class MatMulCompose(eqx.Module):
def __matmul__(self: ModifierLike, other: ModifierLike) -> compose:
return compose(self, other)


class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier):
"""
This serves as a base class for all modifiers.
It automatically implements the __call__ method to apply the scale factors to the sumw array
and the __matmul__ method to compose two modifiers.
Custom modifiers should inherit from this class and implement the scale_factor method.
Example:
.. code-block:: python
import jax.numpy as jnp
import jax.tree_util as jtu
import evermore as evm
class clip(evm.ModifierBase):
modifier: evm.ModifierBase
min_sf: float
max_sf: float
class Modifier(ApplyFn):
def scale_factor(self, sumw: jnp.ndarray) -> evm.SF:
sf = self.modifier.scale_factor(sumw)
return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf)
parameter = evm.Parameter(value=1.1)
modifier = parameter.unconstrained()
clipped_modifier = clip(modifier=modifier, min_sf=0.8, max_sf=1.2)
"""


class Modifier(ModifierBase):
"""
Create a new modifier for a given parameter and penalty.
Expand Down Expand Up @@ -89,14 +143,11 @@ def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> Non
constraint = self.effect.constraint(parameter=self.parameter)
self.parameter._set_constraint(constraint, overwrite=False)

def scale_factor(self, sumw: Array) -> AddOrMulSFs:
def scale_factor(self, sumw: Array) -> SF:
return self.effect.scale_factor(parameter=self.parameter, sumw=sumw)

def __matmul__(self, other: ModifierLike) -> compose:
return compose(self, other)


class where(ApplyFn):
class where(ModifierBase):
"""
Combine two modifiers based on a condition.
Expand Down Expand Up @@ -125,21 +176,17 @@ class where(ApplyFn):
modifier_true: Modifier
modifier_false: Modifier

def scale_factor(self, sumw: Array) -> AddOrMulSFs:
sf = initSF(shape=sumw.shape)

def scale_factor(self, sumw: Array) -> SF:
true_sf = self.modifier_true.scale_factor(sumw)
false_sf = self.modifier_false.scale_factor(sumw)

for op in operator.mul, operator.add:
sf.update(jnp.where(self.condition, true_sf[op], false_sf[op]))
return sf
def _where(true: Array, false: Array) -> Array:
return jnp.where(self.condition, true, false)

def __matmul__(self, other: ModifierLike) -> compose:
return compose(self, other)
return jtu.tree_map(_where, true_sf, false_sf)


class compose(ApplyFn):
class compose(ModifierBase):
"""
Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)`
It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested.
Expand Down Expand Up @@ -196,31 +243,24 @@ def __init__(self, *modifiers: ModifierLike) -> None:
if isinstance(mod, compose):
_modifiers.extend(mod.modifiers)
else:
assert isinstance(mod, ModifierLike)
assert isinstance(mod, ModifierBase)
_modifiers.append(mod)
# by now all modifiers are either modifier or staterror
self.modifiers = _modifiers

def __len__(self) -> int:
return len(self.modifiers)

def scale_factor(self, sumw: Array) -> AddOrMulSFs:
def scale_factor(self, sumw: Array) -> SF:
# collect all multiplicative and additive shifts
sfs: dict[AddOrMul, list] = {operator.add: [], operator.mul: []}
for m in range(len(self)):
mod = self.modifiers[m]
_sf = mod.scale_factor(sumw)
for op in operator.add, operator.mul:
sfs[op].append(_sf[op])
sfs[operator.mul].append(_sf.multiplicative)
sfs[operator.add].append(_sf.additive)

sf = initSF(shape=sumw.shape)
# calculate the product with for operator.mul and operator.add
for op, init_val in (
(operator.mul, jnp.ones_like(sumw)),
(operator.add, jnp.zeros_like(sumw)),
):
sf[op] = reduce(op, sfs[op], init_val)
return sf

def __matmul__(self, other: ModifierLike) -> compose:
return compose(self, other)
multiplicative_sf = reduce(operator.mul, sfs[operator.mul], jnp.ones_like(sumw))
additive_sf = reduce(operator.add, sfs[operator.add], jnp.zeros_like(sumw))
return SF(multiplicative=multiplicative_sf, additive=additive_sf)
7 changes: 0 additions & 7 deletions src/evermore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, PyTree

from evermore.custom_types import AddOrMulSFs

__all__ = [
"initSF",
"is_parameter",
"sum_leaves",
"as1darray",
Expand All @@ -27,10 +24,6 @@ def __dir__():
return __all__


def initSF(shape: tuple) -> AddOrMulSFs:
return {operator.add: jnp.zeros(shape), operator.mul: jnp.ones(shape)}


def is_parameter(leaf: Any) -> bool:
from evermore import Parameter

Expand Down
Loading

0 comments on commit 7413bcf

Please sign in to comment.