Skip to content

Commit

Permalink
add custom_types and additive shits
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 15, 2023
1 parent 2bb2255 commit 507c886
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 71 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ port.exclude_lines = [

[tool.mypy]
files = ["src", "tests"]
python_version = "3.8"
python_version = "3.10"
show_error_codes = true
warn_unreachable = true
disallow_untyped_defs = false
Expand Down Expand Up @@ -123,6 +123,7 @@ extend-ignore = [
"PLR", # Design related pylint codes
"E501", # Line too long
# "B006", # converts default args to 'None'
"I002", # isort: "from __future__ import annotations"
]

src = ["src"]
Expand Down
21 changes: 21 additions & 0 deletions src/dilax/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Callable

import jax

ArrayLike = jax.typing.ArrayLike
AddOrMul = Callable[[ArrayLike, ArrayLike], jax.Array]


class Sentinel:
__slots__ = ("repr",)

def __init__(self, repr: str) -> None:
self.repr = repr

def __repr__(self) -> str:
return self.repr

__str__ = __repr__


_NoValue: Any = Sentinel("<NoValue>")
38 changes: 28 additions & 10 deletions src/dilax/effect.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING
import operator
from typing import TYPE_CHECKING, ClassVar

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.custom_types import AddOrMul, ArrayLike
from dilax.parameter import Parameter
from dilax.pdf import Flat, Gauss, HashablePDF, Poisson
from dilax.util import as1darray

ArrayLike = jax.typing.ArrayLike

if TYPE_CHECKING:
from dilax.parameter import Parameter
from typing import ClassVar as AbstractClassVar
else:
from equinox import AbstractClassVar


__all__ = [
"Effect",
Expand All @@ -30,6 +32,8 @@ def __dir__():


class Effect(eqx.Module):
apply_op: AbstractClassVar[AddOrMul]

@property
@abc.abstractmethod
def constraint(self) -> HashablePDF:
Expand All @@ -41,6 +45,8 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:


class unconstrained(Effect):
apply_op: ClassVar[AddOrMul] = operator.mul

@property
def constraint(self) -> HashablePDF:
return Flat()
Expand All @@ -55,6 +61,8 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
class gauss(Effect):
width: ArrayLike = eqx.field(static=True, converter=as1darray)

apply_op: ClassVar[AddOrMul] = operator.mul

def __init__(self, width: ArrayLike) -> None:
self.width = width

Expand Down Expand Up @@ -87,6 +95,8 @@ class shape(Effect):
up: jax.Array = eqx.field(converter=as1darray)
down: jax.Array = eqx.field(converter=as1darray)

apply_op: ClassVar[AddOrMul] = operator.add

def __init__(
self,
up: jax.Array,
Expand Down Expand Up @@ -120,19 +130,25 @@ def constraint(self) -> HashablePDF:

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
sf = parameter.value
shift = self.vshift(sf=sf, sumw=sumw)
# handle zeros, see: https://github.com/google/jax/issues/5039
x = jnp.where(sumw == 0.0, 1.0, sumw)
return jnp.where(sumw == 0.0, 1.0, (x + shift) / x)
return self.vshift(sf=sf, sumw=sumw)
# shift = self.vshift(sf=sf, sumw=sumw)
# # handle zeros, see: https://github.com/google/jax/issues/5039
# x = jnp.where(sumw == 0.0, 1.0, sumw)
# return jnp.where(sumw == 0.0, shift, (x + shift) / x)


class lnN(Effect):
width: tuple[ArrayLike, ArrayLike] = eqx.field(static=True)

apply_op: ClassVar[AddOrMul] = operator.mul

def __init__(
self,
width: tuple[ArrayLike, ArrayLike],
) -> None:
# given as (down, up)
assert isinstance(width, tuple)
assert len(width) == 2
self.width = width

def interpolate(self, parameter: Parameter) -> jax.Array:
Expand Down Expand Up @@ -179,6 +195,8 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
class poisson(Effect):
lamb: jax.Array = eqx.field(static=True, converter=as1darray)

apply_op: ClassVar[AddOrMul] = operator.mul

def __init__(self, lamb: jax.Array) -> None:
self.lamb = lamb

Expand Down
6 changes: 2 additions & 4 deletions src/dilax/ipy_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp

from dilax.custom_types import ArrayLike
from dilax.model import Model

__all__ = ["interactive"]
Expand All @@ -18,7 +16,7 @@ def interactive(model: Model) -> None:
import ipywidgets as widgets
import matplotlib.pyplot as plt

def slider(v: float | jax.Array) -> widgets.FloatSlider:
def slider(v: ArrayLike) -> widgets.FloatSlider:
return widgets.FloatSlider(min=v - 2, max=v + 2, step=0.01, value=v)

fig, ax = plt.subplots()
Expand Down
4 changes: 1 addition & 3 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.custom_types import Sentinel, _NoValue
from dilax.model import Model
from dilax.util import Sentinel, _NoValue

__all__ = [
"NLL",
Expand Down
3 changes: 2 additions & 1 deletion src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import jax.numpy as jnp
import jax.tree_util as jtu

from dilax.custom_types import Sentinel, _NoValue
from dilax.parameter import Parameter
from dilax.util import Sentinel, _NoValue, deep_update
from dilax.util import deep_update

__all__ = [
"Result",
Expand Down
80 changes: 49 additions & 31 deletions src/dilax/modifier.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import abc
import operator
from functools import reduce
from typing import TYPE_CHECKING

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.custom_types import AddOrMul
from dilax.effect import (
DEFAULT_EFFECT,
gauss,
Expand All @@ -31,11 +34,8 @@ def __dir__():

class ModifierBase(eqx.Module):
@abc.abstractmethod
def scale_factor(self, sumw: jax.Array) -> jax.Array:
...

def __call__(self, sumw: jax.Array) -> jax.Array:
return jnp.atleast_1d(self.scale_factor(sumw=sumw)) * sumw
...


class modifier(ModifierBase):
Expand Down Expand Up @@ -90,6 +90,12 @@ def __init__(
def scale_factor(self, sumw: jax.Array) -> jax.Array:
return self.effect.scale_factor(parameter=self.parameter, sumw=sumw)

def __call__(self, sumw: jax.Array) -> jax.Array:
op = self.effect.apply_op
shift = jnp.atleast_1d(self.scale_factor(sumw=sumw))
shift = jnp.broadcast_to(shift, sumw.shape)
return op(shift, sumw) # type: ignore[call-arg]


class compose(ModifierBase):
"""
Expand All @@ -109,7 +115,7 @@ class compose(ModifierBase):
# create a new parameter and a composition of modifiers
composition = dlx.compose(
dlx.modifier(name="mu", parameter=mu),
dlx.modifier(name="sigma1", parameter=sigma, effect=dlx.effect.lnN(0.1)),
dlx.modifier(name="sigma1", parameter=sigma, effect=dlx.effect.lnN((0.9, 1.1))),
)
# apply the composition
Expand All @@ -118,7 +124,7 @@ class compose(ModifierBase):
# nest compositions
composition = dlx.compose(
composition,
dlx.modifier(name="sigma2", parameter=sigma, effect=dlx.effect.lnN(0.2)),
dlx.modifier(name="sigma2", parameter=sigma, effect=dlx.effect.lnN((0.8, 1.2))),
)
# jit
Expand All @@ -127,46 +133,52 @@ class compose(ModifierBase):
eqx.filter_jit(composition)(jnp.array([10, 20, 30]))
"""

modifiers: tuple[modifier, ...]
names: list[str] = eqx.field(static=True)
modifiers: list[ModifierBase]

def __init__(self, *modifiers: modifier) -> None:
self.modifiers = modifiers

# set names
self.names = []
for m in range(len(self)):
modifier = self.modifiers[m]
if isinstance(modifier, compose):
self.names.extend(modifier.names)
self.modifiers = list(modifiers)
# unroll nested compositions
_modifiers = []
for mod in self.modifiers:
if isinstance(mod, compose):
_modifiers.extend(mod.modifiers)
else:
self.names.append(modifier.name)
assert isinstance(mod, modifier)
_modifiers.append(mod)
self.modifiers = _modifiers

def __check_init__(self):
# check for duplicate names
duplicates = [name for name in self.names if self.names.count(name) > 1]
names = [m.name for m in self.modifiers]
duplicates = {name for name in names if names.count(name) > 1}
if duplicates:
msg = f"Modifiers need to have unique names, got: {duplicates}"
raise ValueError(msg)

def __len__(self) -> int:
return len(self.modifiers)

def scale_factors(self, sumw: jax.Array) -> dict[str, jax.Array]:
sfs = {}
def __call__(self, sumw: jax.Array) -> jax.Array:
def _prep_shift(modifier: ModifierBase, sumw: jax.Array) -> jax.Array:
shift = modifier.scale_factor(sumw=sumw)
shift = jnp.atleast_1d(shift)
return jnp.broadcast_to(shift, sumw.shape)

# collect all multiplicative and additive shifts
shifts: dict[AddOrMul, list] = {operator.mul: [], operator.add: []}
for m in range(len(self)):
modifier = self.modifiers[m]
if isinstance(modifier, compose):
sfs.update(modifier.scale_factors(sumw=sumw))
else:
sf = jnp.atleast_1d(modifier.scale_factor(sumw=sumw))
sfs[modifier.name] = jnp.broadcast_to(sf, sumw.shape)
return sfs

def scale_factor(self, sumw: jax.Array) -> jax.Array:
sfs = jnp.stack(list(self.scale_factors(sumw=sumw).values()))
# calculate the product in log-space for numerical precision
return jnp.exp(jnp.sum(jnp.log(sfs), axis=0))
if modifier.effect.apply_op is operator.mul:
shifts[operator.mul].append(_prep_shift(modifier, sumw))
elif modifier.effect.apply_op is operator.add:
shifts[operator.add].append(_prep_shift(modifier, sumw))

# calculate the product with for operator.mul
_mult_fact = reduce(operator.mul, shifts[operator.mul], 1.0)
# calculate the sum for operator.add
_add_shift = reduce(operator.add, shifts[operator.add], 0.0)
# apply
return _mult_fact * (sumw + _add_shift)


class staterror(ModifierBase):
Expand Down Expand Up @@ -281,6 +293,12 @@ def _mod(
jax.vmap(_gauss_mod)(values, _widths, idxs),
)

def __call__(self, sumw: jax.Array) -> jax.Array:
# both gauss and poisson behave multiplicative
op = operator.mul
sf = self.scale_factor(sumw=sumw)
return op(jnp.atleast_1d(sf), sumw)


class autostaterrors(eqx.Module):
class Mode(eqx.Enumeration):
Expand Down
2 changes: 1 addition & 1 deletion src/dilax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax
import jaxopt

from dilax.util import Sentinel, _NoValue
from dilax.custom_types import Sentinel, _NoValue

__all__ = [
"JaxOptimizer",
Expand Down
17 changes: 1 addition & 16 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax
import jax.numpy as jnp

ArrayLike = jax.typing.ArrayLike
from dilax.custom_types import ArrayLike, Sentinel, _NoValue

__all__ = [
"HistDB",
Expand All @@ -24,21 +24,6 @@ def __dir__():
return __all__


class Sentinel:
__slots__ = ("repr",)

def __init__(self, repr: str) -> None:
self.repr = repr

def __repr__(self) -> str:
return self.repr

__str__ = __repr__


_NoValue: Sentinel = Sentinel("<NoValue>")


class FrozenKeysView(collections.abc.KeysView):
"""FrozenKeysView that does not print values when repr'ing."""

Expand Down
Loading

0 comments on commit 507c886

Please sign in to comment.