Skip to content

Commit

Permalink
better API, add __check_init__ eqx.Module checks
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 24, 2023
1 parent 1a2c7e3 commit fe9f162
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 212 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(dlx.model.Model):
class MyModel(dlx.Model):
def __call__(
self, processes: dict, parameters: dict[str, dlx.parameter.Parameter]
) -> dlx.model.Result:
res = dlx.model.Result()
self, processes: dict, parameters: dict[str, dlx.Parameter]
) -> dlx.Result:
res = dlx.Result()

# signal
mu_mod = dlx.parameter.modifier(
name="mu", parameter=parameters["mu"], effect=dlx.parameter.unconstrained()
mu_mod = dlx.modifier(
name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained()
)
res.add(process="signal", expectation=mu_mod(processes["signal"]))

# background
bkg_mod = dlx.parameter.modifier(
name="sigma", parameter=parameters["sigma"], effect=dlx.parameter.gauss(0.2)
bkg_mod = dlx.modifier(
name="sigma", parameter=parameters["sigma"], effect=dlx.effect.gauss(0.2)
)
res.add(process="background", expectation=bkg_mod(processes["background"]))
return res
Expand All @@ -61,8 +61,8 @@ class MyModel(dlx.model.Model):
# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
"mu": dlx.parameter.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": dlx.parameter.Parameter(value=jnp.array([0.0])),
"mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": dlx.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

Expand Down
15 changes: 11 additions & 4 deletions src/dilax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
# expose public API

__all__ = [
"effect",
"ipy_util",
"likelihood",
"model",
"optimizer",
"parameter",
# "pdf", # this should not be needed in public API
"util",
"__version__",
# explicitely expose some classes
"Model",
"Result",
"Parameter",
"modifier",
"staterror",
"compose",
]


Expand All @@ -32,11 +38,12 @@ def __dir__():


from dilax import ( # noqa: E402
effect,
ipy_util,
likelihood,
model,
optimizer,
parameter,
# pdf, # this should not be needed in public API
util,
)
from dilax.model import Model, Result # noqa: E402
from dilax.parameter import Parameter, compose, modifier, staterror # noqa: E402
179 changes: 179 additions & 0 deletions src/dilax/effect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING

import equinox as eqx
import jax
import jax.numpy as jnp

from dilax.pdf import Flat, Gauss, HashablePDF, Poisson
from dilax.util import as1darray

if TYPE_CHECKING:
from dilax.parameter import Parameter

__all__ = [
"Effect",
"unconstrained",
"gauss",
"lnN",
"poisson",
"shape",
]


def __dir__():
return __all__


class Effect(eqx.Module):
@property
@abc.abstractmethod
def constraint(self) -> HashablePDF:
...

@abc.abstractmethod
def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
...


class unconstrained(Effect):
@property
def constraint(self) -> HashablePDF:
return Flat()

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
return parameter.value


DEFAULT_EFFECT = unconstrained()


class gauss(Effect):
width: jax.Array = eqx.field(static=True, converter=as1darray)

def __init__(self, width: jax.Array) -> None:
self.width = width

@property
def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
"""
Implementation with (inverse) CDFs is defined as follows:
.. code-block:: python
gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type]
g1 = Gauss(mean=1.0, width=1.0)
return gx.inv_cdf(g1.cdf(parameter.value + 1))
But we can use the fast analytical solution instead:
.. code-block:: python
return (parameter.value * self.width) + 1
"""
return (parameter.value * self.width) + 1


class shape(Effect):
up: jax.Array = eqx.field(converter=as1darray)
down: jax.Array = eqx.field(converter=as1darray)

def __init__(
self,
up: jax.Array,
down: jax.Array,
) -> None:
self.up = up # +1 sigma
self.down = down # -1 sigma

@eqx.filter_jit
def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array:
factor = sf
dx_sum = self.up + self.down - 2 * sumw
dx_diff = self.up - self.down

# taken from https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L173C6-L192
_asym_poly = jnp.array([3.0, -10.0, 15.0, 0.0]) / 8.0

abs_value = jnp.abs(factor)
return 0.5 * (
dx_diff * factor
+ dx_sum
* jnp.where(
abs_value > 1.0,
abs_value,
jnp.polyval(_asym_poly, factor * factor),
)
)

@property
def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
sf = parameter.value
# clip, no negative values are allowed
return jnp.maximum((sumw + self.vshift(sf=sf, sumw=sumw)) / sumw, 0.0)


class lnN(Effect):
width: jax.Array | tuple[jax.Array, jax.Array] = eqx.field(static=True)

def __init__(
self,
width: jax.Array | tuple[jax.Array, jax.Array],
) -> None:
self.width = width

def scale(self, parameter: Parameter) -> jax.Array:
if isinstance(self.width, tuple):
down, up = self.width
scale = jnp.where(parameter.value > 0, up, down)
else:
scale = self.width
return scale

@property
def constraint(self) -> HashablePDF:
return Gauss(mean=0.0, width=1.0)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
"""
Implementation with (inverse) CDFs is defined as follows:
.. code-block:: python
gx = Gauss(mean=jnp.exp(parameter.value), width=width) # type: ignore[arg-type]
g1 = Gauss(mean=1.0, width=1.0)
return gx.inv_cdf(g1.cdf(parameter.value + 1))
But we can use the fast analytical solution instead:
.. code-block:: python
return jnp.exp(parameter.value * self.scale(parameter=parameter))
"""
return jnp.exp(parameter.value * self.scale(parameter=parameter))


class poisson(Effect):
lamb: jax.Array = eqx.field(static=True, converter=as1darray)

def __init__(self, lamb: jax.Array) -> None:
self.lamb = lamb

@property
def constraint(self) -> HashablePDF:
return Poisson(lamb=self.lamb)

def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
return parameter.value + 1
25 changes: 12 additions & 13 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,39 @@ def expectation(self) -> jax.Array:
class Model(eqx.Module):
"""
A model describing nuisance parameters, templates (histograms), and how they interact.
It is requires to implement the `evaluate` method, which returns an `EvaluationResult` object.
It is requires to implement the `evaluate` method, which returns an `Result` object.
Example:
.. code-block:: python
import equinox as eqx
import jax
import jax.numpy as jnp
import equinox as eqx
from dilax.model import Model, Result
from dilax.parameter import Parameter, lnN, modifier, unconstrained
import dilax as dlx
# Define a simple model with two processes and two parameters
class MyModel(Model):
def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
res = Result()
class MyModel(dlx.Model):
def __call__(self, processes: dict, parameters: dict[str, dlx.Parameter]) -> dlx.Result:
res = dlx.Result()
# signal
mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
mu_mod = dlx.modifier(name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained())
res.add(process="signal", expectation=mu_mod(processes["signal"]))
# background
bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=lnN(0.2))
bkg_mod = dlx.modifier(name="sigma", parameter=parameters["sigma"], effect=dlx.effect.lnN(0.2))
res.add(process="background", expectation=bkg_mod(processes["background"]))
return res
# Setup model
processes = {"signal": jnp.array([10]), "background": jnp.array([50])}
parameters = {
"mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": Parameter(value=jnp.array([0.0])),
"mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": dlx.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)
Expand All @@ -77,7 +76,7 @@ def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
# -> Array([60.], dtype=float32)
%timeit model.evaluate().expectation()
# -> 3.05 ms ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# -> 485 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# evaluate the expectation *fast*
@eqx.filter_jit
Expand All @@ -89,7 +88,7 @@ def eval(model) -> jax.Array:
# -> Array([60.], dtype=float32)
%timeit eqx.filter_jit(eval)(model).block_until_ready()
# -> 114 µs ± 327 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# -> 202 µs ± 4.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
"""

processes: dict
Expand Down
Loading

0 comments on commit fe9f162

Please sign in to comment.