From 5aff06660664c6429ad9015feae97312b1800373 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 13:11:59 +0100 Subject: [PATCH] start of restructure: be more close to equinox' philosophy --- docs/conf.py | 2 + examples/dnn_weights_constraint.py | 36 ++ examples/grad_nll.py | 31 -- examples/model.py | 131 +++---- examples/nll_fit.py | 18 - examples/nll_profiling.py | 67 ---- examples/toy_generation.py | 52 +-- pyproject.toml | 3 +- src/evermore/__init__.py | 24 +- src/evermore/custom_types.py | 5 +- src/evermore/effect.py | 44 ++- src/evermore/ipy_util.py | 50 --- src/evermore/likelihood.py | 128 ------- src/evermore/loss.py | 74 ++++ src/evermore/model.py | 205 ----------- src/evermore/modifier.py | 545 ++++++++++++++--------------- src/evermore/optimizer.py | 99 ------ src/evermore/parameter.py | 58 ++- src/evermore/pdf.py | 80 ++++- src/evermore/sample.py | 36 ++ src/evermore/util.py | 268 ++------------ tests/test_optimizer.py | 31 -- tests/test_parameter.py | 27 +- tests/test_util.py | 41 +-- 24 files changed, 690 insertions(+), 1365 deletions(-) create mode 100644 examples/dnn_weights_constraint.py delete mode 100644 src/evermore/ipy_util.py delete mode 100644 src/evermore/likelihood.py create mode 100644 src/evermore/loss.py delete mode 100644 src/evermore/model.py delete mode 100644 src/evermore/optimizer.py create mode 100644 src/evermore/sample.py delete mode 100644 tests/test_optimizer.py diff --git a/docs/conf.py b/docs/conf.py index 7558ac7..0dcab2f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,8 @@ "use_edit_page_button": True, } html_context = {"default_mode": "light"} +html_logo = "../assets/favicon.png" +html_favicon = "../assets/favicon.png" extensions = [ "sphinx.ext.autodoc", diff --git a/examples/dnn_weights_constraint.py b/examples/dnn_weights_constraint.py new file mode 100644 index 0000000..f18f145 --- /dev/null +++ b/examples/dnn_weights_constraint.py @@ -0,0 +1,36 @@ +import equinox as eqx +import jax +import jax.numpy as jnp + +import evermore as evm + + +class LinearConstrained(eqx.Module): + weights: evm.Parameter + biases: jax.Array + + def __init__(self, in_size, out_size, key): + self.biases = jax.random.normal(key, (out_size,)) + self.weights = evm.Parameter(value=jax.random.normal(key, (out_size, in_size))) + self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5)) + + def __call__(self, x: jax.Array): + return self.weights.value @ x + self.biases + + +@eqx.filter_jit +def loss_fn(model, x, y): + pred_y = jax.vmap(model)(x) + mse = jax.numpy.mean((y - pred_y) ** 2) + constraints = evm.loss.get_param_constraints(model) + # sum them all up for each weight + constraints = jax.tree_util.tree_map(jnp.sum, constraints) + return mse + evm.util.sum_leaves(constraints) + + +batch_size, in_size, out_size = 32, 2, 3 +model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0)) +x = jax.numpy.zeros((batch_size, in_size)) +y = jax.numpy.zeros((batch_size, out_size)) +loss_val = loss_fn(model, x, y) +grads = eqx.filter_grad(loss_fn)(model, x, y) diff --git a/examples/grad_nll.py b/examples/grad_nll.py index c5a3662..e69de29 100644 --- a/examples/grad_nll.py +++ b/examples/grad_nll.py @@ -1,31 +0,0 @@ -from __future__ import annotations - -import equinox as eqx -from jax import config -from model import init_values, model, observation, optimizer - -import evermore as evm - -config.update("jax_enable_x64", True) - -# create negative log likelihood -nll = evm.likelihood.NLL(model=model, observation=observation) - -# fit -params, state = optimizer.fit(fun=nll, init_values=init_values) - -# gradients of nll of fitted model -fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll)) -grads = fast_grad_nll(params) -# gradients of nll of fitted model only wrt to `mu` -# basically: pass the parameters dict of which you want the gradients -params_ = {k: v for k, v in params.items() if k == "mu"} -grad_mu = fast_grad_nll(params_) - -# hessian + cov_matrix of nll of fitted model -hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))() - -# covariance matrix of fitted model -covmatrix = eqx.filter_jit( - evm.likelihood.CovMatrix(model=model, observation=observation) -)() diff --git a/examples/model.py b/examples/model.py index f7756b7..be416a0 100644 --- a/examples/model.py +++ b/examples/model.py @@ -1,94 +1,79 @@ from __future__ import annotations +from typing import Any + +import equinox as eqx +import jax import jax.numpy as jnp import evermore as evm -class SPlusBModel(evm.Model): - def __call__(self, processes: dict, parameters: dict) -> evm.Result: - res = evm.Result() +class SPlusBModel(eqx.Module): + mu: evm.Parameter + norm1: evm.Parameter + norm2: evm.Parameter + shape1: evm.Parameter - mu_modifier = evm.modifier( - name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained() - ) - res.add( - process="signal", - expectation=mu_modifier(processes[("signal", "nominal")]), - ) + def __init__(self) -> None: + self.mu = evm.Parameter(value=jnp.array([1.0])) + self.norm1 = evm.Parameter() + self.norm2 = evm.Parameter() + self.shape1 = evm.Parameter() - bkg1_modifier = evm.compose( - evm.modifier( - name="lnN1", - parameter=parameters["norm1"], - effect=evm.effect.lnN((0.9, 1.1)), - ), - evm.modifier( - name="shape1_bkg1", - parameter=parameters["shape1"], - effect=evm.effect.shape( - up=processes[("background1", "shape_up")], - down=processes[("background1", "shape_down")], - ), - ), - ) - res.add( - process="background1", - expectation=bkg1_modifier(processes[("background1", "nominal")]), - ) + def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: + expectations = {} - bkg2_modifier = evm.compose( - evm.modifier( - name="lnN2", - parameter=parameters["norm2"], - effect=evm.effect.lnN((0.95, 1.05)), - ), - evm.modifier( - name="shape1_bkg2", - parameter=parameters["shape1"], - effect=evm.effect.shape( - up=processes[("background2", "shape_up")], - down=processes[("background2", "shape_down")], - ), - ), + # signal process + sig_mod = self.mu.unconstrained() + expectations["signal"] = sig_mod(hists[("signal", "nominal")]) + + # bkg1 process + bkg1_mod = self.norm1.lnN(width=jnp.array([0.9, 1.1])) @ self.shape1.shape( + up=hists[("bkg1", "shape_up")], + down=hists[("bkg1", "shape_down")], ) - res.add( - process="background2", - expectation=bkg2_modifier(processes[("background2", "nominal")]), + expectations["bkg1"] = bkg1_mod(hists[("bkg1", "nominal")]) + + # bkg2 process + bkg2_mod = self.norm2.lnN(width=jnp.array([0.95, 1.05])) @ self.shape1.shape( + up=hists[("bkg2", "shape_up")], + down=hists[("bkg2", "shape_down")], ) - return res + expectations["bkg2"] = bkg2_mod(hists[("bkg2", "nominal")]) + # return the modified expectations + return expectations -def create_model(): - processes = { - ("signal", "nominal"): jnp.array([3]), - ("background1", "nominal"): jnp.array([10]), - ("background2", "nominal"): jnp.array([20]), - ("background1", "shape_up"): jnp.array([12]), - ("background1", "shape_down"): jnp.array([8]), - ("background2", "shape_up"): jnp.array([23]), - ("background2", "shape_down"): jnp.array([19]), - } - parameters = { - "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "norm1": evm.Parameter(value=jnp.array([0.0])), - "norm2": evm.Parameter(value=jnp.array([0.0])), - "shape1": evm.Parameter(value=jnp.array([0.0])), - } - # return model - return SPlusBModel(processes=processes, parameters=parameters) +model = SPlusBModel() -model = create_model() +hists = { + ("signal", "nominal"): jnp.array([3]), + ("bkg1", "nominal"): jnp.array([10]), + ("bkg2", "nominal"): jnp.array([20]), + ("bkg1", "shape_up"): jnp.array([12]), + ("bkg1", "shape_down"): jnp.array([8]), + ("bkg2", "shape_up"): jnp.array([23]), + ("bkg2", "shape_down"): jnp.array([19]), +} -init_values = model.parameter_values observation = jnp.array([37]) -asimov = model.evaluate().expectation() + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + return nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + constraint=evm.util.sum_leaves(constraints), + ) -# create optimizer (from `jaxopt`) -optimizer = evm.optimizer.JaxOptimizer.make( - name="LBFGS", - settings={"maxiter": 5, "jit": True, "unroll": True}, -) +loss_val = loss(model, hists, observation) +grads = eqx.filter_grad(loss)(model, hists, observation) diff --git a/examples/nll_fit.py b/examples/nll_fit.py index c983205..e69de29 100644 --- a/examples/nll_fit.py +++ b/examples/nll_fit.py @@ -1,18 +0,0 @@ -from __future__ import annotations - -from jax import config -from model import init_values, model, observation, optimizer - -from evermore.likelihood import NLL - -config.update("jax_enable_x64", True) - - -# create negative log likelihood -nll = NLL(model=model, observation=observation) - -# fit -values, state = optimizer.fit(fun=nll, init_values=init_values) - -# update model with fitted values -fitted_model = model.update(values=values) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index 45bf296..e69de29 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -1,67 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import equinox as eqx -import jax -import jax.numpy as jnp -from jax import config -from model import asimov, model, optimizer - -from evermore import Model -from evermore.likelihood import NLL -from evermore.optimizer import JaxOptimizer - -config.update("jax_enable_x64", True) - - -def nll_profiling( - value_name: str, - scan_points: jax.Array, - model: Model, - observation: jax.Array, - optimizer: JaxOptimizer, - fit: bool, -) -> jax.Array: - # define single fit for a fixed parameter of interest (poi) - @partial(jax.jit, static_argnames=("value_name", "optimizer", "fit")) - def fixed_poi_fit( - value_name: str, - scan_point: jax.Array, - model: Model, - observation: jax.Array, - optimizer: JaxOptimizer, - fit: bool, - ) -> jax.Array: - # fix theta into the model - model = model.update(values={value_name: scan_point}) - init_values = model.parameter_values - init_values.pop(value_name, 1) - # minimize - nll = eqx.filter_jit(NLL(model=model, observation=observation)) - if fit: - values, _ = optimizer.fit(fun=nll, init_values=init_values) - else: - values = model.parameter_values - return nll(values=values) - - # vectorise for multiple fixed values (scan points) - fixed_poi_fit_vec = jax.vmap( - fixed_poi_fit, in_axes=(None, 0, None, None, None, None) - ) - return fixed_poi_fit_vec( - value_name, scan_points, model, observation, optimizer, fit - ) - - -# profile the NLL around starting point of `0` -scan_points = jnp.r_[-1.9:2.0:0.1] - -profile_postfit = nll_profiling( - value_name="norm1", - scan_points=scan_points, - model=model, - observation=asimov, - optimizer=optimizer, - fit=True, -) diff --git a/examples/toy_generation.py b/examples/toy_generation.py index 06045ef..e7038d9 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -1,31 +1,43 @@ -from __future__ import annotations +from typing import Any import equinox as eqx import jax -from jax import config -from model import init_values, model, observation, optimizer +from jaxtyping import Array, PRNGKeyArray +from model import hists, model, observation -from evermore.likelihood import NLL, SampleToy +import evermore as evm -config.update("jax_enable_x64", True) +key = jax.random.PRNGKey(0) +# generate a new model with sampled parameters according to their constraint pdfs +toymodel = evm.sample.toy_module(model, key) -# create negative log likelihood -nll = NLL(model=model, observation=observation) -# fit -values, state = optimizer.fit(fun=nll, init_values=init_values) +# generate new expectation based on the toy model +def toy_expectation( + key: PRNGKeyArray, + module: eqx.Module, + hists: dict[Any, Array], +) -> Array: + toymodel = evm.sample.toy_module(model, key) + expectations = toymodel(hists) + return evm.util.sum_leaves(expectations) -# create sampling method -sample_toy = SampleToy(model=model, observation=observation) -# vectorise and jit -sample_toys = eqx.filter_vmap(in_axes=(None, 0))(eqx.filter_jit(sample_toy)) -sample_toy(values, jax.random.PRNGKey(1234)) +expectation = toy_expectation(key, model, hists) -# sample 10 toys based on fitted parameters -keys = jax.random.split(jax.random.PRNGKey(1234), num=10) -# postfit toys -toys_postfit = sample_toys(values, keys) -# prefit toys -toys_prefit = sample_toys(init_values, keys) + +# generate a new expectations vectorized over many keys +keys = jax.random.split(key, 1000) + +# vectorized toy expectation +toy_expectation_vec = jax.vmap(toy_expectation, in_axes=(0, None, None)) +expectations = toy_expectation_vec(keys, model, hists) + + +# just sample observations with poisson +poisson_obs = evm.pdf.Poisson(observation) +sampled_observation = poisson_obs.sample(key) + +# vectorized sampling +sampled_observations = jax.vmap(poisson_obs.sample)(keys) diff --git a/pyproject.toml b/pyproject.toml index e4ac42b..34223e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,12 +27,11 @@ classifiers = [ dynamic = ["version"] # version is set in src/evermore/__init__.py dependencies = [ "equinox>=0.10.6", # eqx.field - "jaxopt >=0.6", # jaxopt.LGBFGS ] [project.optional-dependencies] test = ["pytest >=6", "pytest-cov >=3"] -dev = ["pytest >=6", "pytest-cov >=3"] +dev = ["pytest >=6", "pytest-cov >=3", "optax", "jaxopt >=0.6"] docs = [ "sphinx>=7.0", "myst_parser>=0.13", diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index eb11bd6..f9a52d6 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -18,18 +18,15 @@ __all__ = [ "__version__", "effect", - "ipy_util", - "likelihood", - "optimizer", + "loss", "pdf", "util", + "sample", # explicitely expose some classes - "Model", - "Result", "Parameter", "modifier", - "staterror", - "autostaterrors", + # "staterror", + # "autostaterrors", "compose", ] @@ -40,17 +37,18 @@ def __dir__(): from evermore import ( # noqa: E402 effect, - ipy_util, - likelihood, - optimizer, + loss, pdf, + sample, util, ) -from evermore.model import Model, Result # noqa: E402 + +# from evermore.model import Model, Result from evermore.modifier import ( # noqa: E402 - autostaterrors, + # autostaterrors, compose, modifier, - staterror, ) + +# staterror, from evermore.parameter import Parameter # noqa: E402 diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 28bd629..b901250 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -1,10 +1,9 @@ from collections.abc import Callable from typing import Any -import jax +import jaxtyping -ArrayLike = jax.typing.ArrayLike -AddOrMul = Callable[[ArrayLike, ArrayLike], jax.Array] +AddOrMul = Callable[[jaxtyping.ArrayLike, jaxtyping.ArrayLike], jaxtyping.Array] class Sentinel: diff --git a/src/evermore/effect.py b/src/evermore/effect.py index c300bd4..9fb4642 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -3,10 +3,10 @@ from typing import TYPE_CHECKING, ClassVar import equinox as eqx -import jax import jax.numpy as jnp +from jaxtyping import Array, Float -from evermore.custom_types import AddOrMul, ArrayLike +from evermore.custom_types import AddOrMul from evermore.parameter import Parameter from evermore.pdf import Flat, Gauss, HashablePDF, Poisson from evermore.util import as1darray @@ -40,7 +40,7 @@ def constraint(self) -> HashablePDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: ... @@ -51,7 +51,7 @@ class unconstrained(Effect): def constraint(self) -> HashablePDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return parameter.value @@ -59,18 +59,18 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class gauss(Effect): - width: ArrayLike = eqx.field(static=True, converter=as1darray) + width: Array = eqx.field(static=True, converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, width: ArrayLike) -> None: + def __init__(self, width: Array) -> None: self.width = width @property def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ Implementation with (inverse) CDFs is defined as follows: @@ -92,20 +92,20 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class shape(Effect): - up: jax.Array = eqx.field(converter=as1darray) - down: jax.Array = eqx.field(converter=as1darray) + up: Array = eqx.field(converter=as1darray) + down: Array = eqx.field(converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.add def __init__( self, - up: jax.Array, - down: jax.Array, + up: Array, + down: Array, ) -> None: self.up = up # +1 sigma self.down = down # -1 sigma - def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: + def vshift(self, sf: Array, sumw: Array) -> Array: factor = sf dx_sum = self.up + self.down - 2 * sumw dx_diff = self.up - self.down @@ -128,7 +128,7 @@ def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: sf = parameter.value return self.vshift(sf=sf, sumw=sumw) # shift = self.vshift(sf=sf, sumw=sumw) @@ -138,20 +138,18 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class lnN(Effect): - width: tuple[ArrayLike, ArrayLike] = eqx.field(static=True) + width: Float[Array, "2"] = eqx.field(static=True) apply_op: ClassVar[AddOrMul] = operator.mul def __init__( self, - width: tuple[ArrayLike, ArrayLike], + width: Float[Array, "2"], # given as (down, up) ) -> None: - # given as (down, up) - assert isinstance(width, tuple) - assert len(width) == 2 + assert width.shape == (2,) self.width = width - def interpolate(self, parameter: Parameter) -> jax.Array: + def interpolate(self, parameter: Parameter) -> Array: # https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129 x = parameter.value lo, hi = self.width @@ -171,7 +169,7 @@ def interpolate(self, parameter: Parameter) -> jax.Array: def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ Implementation with (inverse) CDFs is defined as follows: @@ -193,16 +191,16 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class poisson(Effect): - lamb: jax.Array = eqx.field(static=True, converter=as1darray) + lamb: Array = eqx.field(static=True, converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, lamb: jax.Array) -> None: + def __init__(self, lamb: Array) -> None: self.lamb = lamb @property def constraint(self) -> HashablePDF: return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return parameter.value + 1 diff --git a/src/evermore/ipy_util.py b/src/evermore/ipy_util.py deleted file mode 100644 index 6268f8a..0000000 --- a/src/evermore/ipy_util.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any - -import jax.numpy as jnp - -from evermore.custom_types import ArrayLike -from evermore.model import Model - -__all__ = ["interactive"] - - -def __dir__(): - return __all__ - - -def interactive(model: Model) -> None: - import ipywidgets as widgets - import matplotlib.pyplot as plt - - def slider(v: ArrayLike) -> widgets.FloatSlider: - return widgets.FloatSlider(min=v - 2, max=v + 2, step=0.01, value=v) - - fig, ax = plt.subplots() - - expectation = model.evaluate().expectation() - bins = jnp.arange(expectation.size) - - art = ax.bar(bins, expectation, color="gray") - - @widgets.interact( - **{name: slider(param.value) for name, param in model.parameters.items()} - ) - def update(**kwargs: Any) -> None: - m = model.update(values=kwargs) - res = m.evaluate() - - expectation = res.expectation() - print("Expectation:", expectation) - print("Constraint (logpdf):", m.parameter_constraints()) - - nonlocal art - art.remove() - - art = ax.bar(bins, expectation, color="gray") - - ax.set_xticks(bins) - ax.set_xticklabels(list(map(str, bins))) - ax.set_xlabel(r"Bin #") - ax.set_ylabel(r"S+B model") - plt.tight_layout() - plt.show() diff --git a/src/evermore/likelihood.py b/src/evermore/likelihood.py deleted file mode 100644 index f4c160e..0000000 --- a/src/evermore/likelihood.py +++ /dev/null @@ -1,128 +0,0 @@ -from typing import TYPE_CHECKING, cast - -import equinox as eqx -import jax -import jax.numpy as jnp - -from evermore.custom_types import Sentinel, _NoValue -from evermore.model import Model - -__all__ = [ - "NLL", - "Hessian", - "CovMatrix", - "SampleToy", -] - - -def __dir__(): - return __all__ - - -class BaseModule(eqx.Module): - """ - Base module to hold the `model` and the `observation`. - """ - - model: Model - observation: jax.Array = eqx.field(converter=jnp.asarray) - - def __init__(self, model: Model, observation: jax.Array) -> None: - self.model = model - self.observation = observation - - -class NLL(BaseModule): - """ - Negative log-likelihood (NLL). - """ - - def logpdf(self, *args, **kwargs) -> jax.Array: - return jax.scipy.stats.poisson.logpmf(*args, **kwargs) - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - model = self.model.update(values=values) - res = model.evaluate() - nll = jnp.sum( - self.logpdf(self.observation, res.expectation()) - - self.logpdf(self.observation, self.observation), - axis=-1, - ) - # add constraints - constraints = jax.tree_util.tree_leaves(model.parameter_constraints()) - nll += sum(constraints) - nll += model.nll_boundary_penalty() - return -jnp.sum(nll) - - -class Hessian(BaseModule): - """ - Hessian matrix. - """ - - NLL: NLL - - def __init__(self, model: Model, observation: jax.Array) -> None: - super().__init__(model=model, observation=observation) - self.NLL = NLL(model=model, observation=observation) - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - if TYPE_CHECKING: - values = cast(dict, values) - hessian = jax.hessian(self.NLL, argnums=0)(values) - hessian, _ = jax.tree_util.tree_flatten(hessian) - hessian = jnp.array(hessian) - new_shape = len(values) - return jnp.reshape(hessian, (new_shape, new_shape)) - - -class CovMatrix(Hessian): - """ - Covariance matrix. - """ - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - hessian = super().__call__(values=values) - return jnp.linalg.inv(hessian) - - -class SampleToy(BaseModule): - """ - Sample a toy from the model. - """ - - CovMatrix: CovMatrix - - def __init__(self, model: Model, observation: jax.Array) -> None: - super().__init__(model=model, observation=observation) - self.CovMatrix = CovMatrix(model=model, observation=observation) - - def __call__( - self, - values: dict | Sentinel = _NoValue, - key: jax.Array | Sentinel = _NoValue, - ) -> dict[str, jax.Array]: - if values is _NoValue: - values = self.model.parameter_values - if key is _NoValue: - key = jax.random.PRNGKey(1234) - if TYPE_CHECKING: - key = cast(jax.Array, key) - cov = self.CovMatrix(values=values) - _values, tree_def = jax.tree_util.tree_flatten( - self.model.update(values=values).parameter_values - ) - sampled_values = jax.random.multivariate_normal( - key=key, - mean=jnp.concatenate(_values), - cov=cov, - ) - new_values = jax.tree_util.tree_unflatten(tree_def, sampled_values) - model = self.model.update(values=new_values) - return model.evaluate().expectations diff --git a/src/evermore/loss.py b/src/evermore/loss.py new file mode 100644 index 0000000..2cc487c --- /dev/null +++ b/src/evermore/loss.py @@ -0,0 +1,74 @@ +from collections.abc import Callable + +import equinox as eqx +import jax +import jax.numpy as jnp +from jaxtyping import Array + +from evermore.parameter import Parameter +from evermore.util import _params_map + +__all__ = [ + "get_param_constraints", + "PoissonNLL", +] + + +def __dir__(): + return __all__ + + +def get_param_constraints(module: eqx.Module) -> dict: + constraints = {} + + def _constraint(param: Parameter) -> Array: + if param.constraints: + if len(param.constraints) > 1: + msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" + raise ValueError(msg) + return next(iter(param.constraints)).logpdf(param.value) + return jnp.array([0.0]) + + # constraints from pdfs + constraints["pdfs"] = _params_map(module, _constraint) + # constraints from boundaries + constraints["boundaries"] = _params_map(module, lambda p: p.boundary_penalty) + return constraints + + +class PoissonNLL(eqx.Module): + """ + Poisson negative log-likelihood (NLL). + + Usage: + + .. code-block:: python + + import evermore as evm + + nll = evm.loss.PoissonNLL() + + def loss(model, x, y): + expectation = model(x) + constraints = evm.loss.get_param_constraints(model) + loss = nll(expectation, y, evm.util.sum_leaves(constraints)) + return loss + """ + + @property + def logpdf(self) -> Callable: + return jax.scipy.stats.poisson.logpmf + + @jax.named_scope("evm.loss.PoissonNLL") + def __call__( + self, expectation: Array, observation: Array, constraint: Array + ) -> Array: + # poisson log-likelihood + nll = jnp.sum( + self.logpdf(observation, expectation) + - self.logpdf(observation, observation), + axis=-1, + ) + # add constraint + nll += constraint + return -jnp.sum(nll) diff --git a/src/evermore/model.py b/src/evermore/model.py deleted file mode 100644 index 24e31cf..0000000 --- a/src/evermore/model.py +++ /dev/null @@ -1,205 +0,0 @@ -from __future__ import annotations - -import abc -from typing import TYPE_CHECKING, Any, cast - -import equinox as eqx -import jax -import jax.numpy as jnp -import jax.tree_util as jtu - -from evermore.custom_types import Sentinel, _NoValue -from evermore.parameter import Parameter -from evermore.util import deep_update - -__all__ = [ - "Result", - "Model", -] - - -def __dir__(): - return __all__ - - -class Result(eqx.Module): - expectations: dict[str, jax.Array] - - def __init__(self) -> None: - self.expectations = {} - - def add(self, process: str, expectation: jax.Array) -> Result: - self.expectations[process] = expectation - return self - - def expectation(self) -> jax.Array: - return cast(jax.Array, sum(jtu.tree_leaves(self.expectations))) - - -def _is_parameter(leaf: Any) -> bool: - return isinstance(leaf, Parameter) - - -def _is_none_or_is_parameter(leaf: Any) -> bool: - return leaf is None or _is_parameter(leaf) - - -class Model(eqx.Module): - """ - A model describing nuisance parameters, templates (histograms), and how they interact. - It is requires to implement the `evaluate` method, which returns an `Result` object. - - Example: - - .. code-block:: python - - import equinox as eqx - import jax - import jax.numpy as jnp - - import evermore as evm - - - # Define a simple model with two processes and two parameters - class MyModel(evm.Model): - def __call__(self, processes: dict, parameters: dict) -> evm.Result: - res = evm.Result() - - # signal - mu_mod = evm.modifier(name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()) - res.add(process="signal", expectation=mu_mod(processes["signal"])) - - # background - bkg_mod = evm.modifier(name="sigma", parameter=parameters["sigma"], effect=evm.effect.lnN(0.2)) - res.add(process="background", expectation=bkg_mod(processes["background"])) - return res - - - # Setup model - processes = {"signal": jnp.array([10]), "background": jnp.array([50])} - parameters = { - "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "sigma": evm.Parameter(value=jnp.array([0.0])), - } - - model = MyModel(processes=processes, parameters=parameters) - - # evaluate the expectation - model.evaluate().expectation() - # -> Array([60.], dtype=float32) - - %timeit model.evaluate().expectation() - # -> 485 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) - - # evaluate the expectation *fast* - @eqx.filter_jit - def eval(model) -> jax.Array: - res = model.evaluate() - return res.expectation() - - eqx.filter_jit(eval)(model) - # -> Array([60.], dtype=float32) - - %timeit eqx.filter_jit(eval)(model).block_until_ready() - # -> 202 µs ± 4.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) - """ - - processes: dict - parameters: dict[str, Parameter] - auxiliary: Any - - def __init__( - self, - processes: dict, - parameters: dict, - auxiliary: Any | Sentinel = _NoValue, - ) -> None: - self.processes = processes - self.parameters = parameters - if auxiliary is _NoValue: - auxiliary = {} - self.auxiliary = auxiliary - - @property - def parameter_values(self) -> dict: - return jtu.tree_map( - lambda l: l.value, # noqa: E741 - self.parameters, - is_leaf=_is_parameter, - ) - - def parameter_constraints(self) -> dict: - def _constraint(param: Parameter) -> jax.Array: - if param.constraints: - if len(param.constraints) > 1: - msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" - raise ValueError(msg) - return next(iter(param.constraints)).logpdf(param.value) - return jnp.array([0.0]) - - return jtu.tree_map( - _constraint, - self.parameters, - is_leaf=_is_parameter, - ) - - def update( - self, - processes: dict | Sentinel = _NoValue, - values: dict | Sentinel = _NoValue, - ) -> Model: - if values is _NoValue: - values = {} - if processes is _NoValue: - processes = {} - - if TYPE_CHECKING: - values = cast(dict, values) - processes = cast(dict, processes) - - # patch original processes with new ones - new_processes = deep_update(self.processes, processes) - - # patch original parameters with new ones - _updates = deep_update( - jtu.tree_map(lambda _: None, self.parameters, is_leaf=_is_parameter), - values, - ) - - def _update_params(update: jax.Array | None, param: Parameter) -> Parameter: - if update is None: - return param - return param.update(value=update) - - new_parameters = jtu.tree_map( - _update_params, - _updates, - self.parameters, - is_leaf=_is_none_or_is_parameter, - ) - - return eqx.tree_at( - lambda t: (t.processes, t.parameters), self, (new_processes, new_parameters) - ) - - def nll_boundary_penalty(self) -> jax.Array: - return cast( - jax.Array, - sum( - jtu.tree_leaves( - jtu.tree_map( - lambda p: p.boundary_penalty, - self.parameters, - is_leaf=_is_parameter, - ) - ) - ), - ) - - @abc.abstractmethod - def __call__(self, processes: dict, parameters: dict) -> Result: - ... - - def evaluate(self) -> Result: - # evaluate the model with its current state - return self(self.processes, self.parameters) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index d837c0b..422effa 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -8,12 +8,11 @@ import equinox as eqx import jax import jax.numpy as jnp +from jaxtyping import Array from evermore.custom_types import AddOrMul from evermore.effect import ( DEFAULT_EFFECT, - gauss, - poisson, ) from evermore.parameter import Parameter @@ -23,8 +22,8 @@ __all__ = [ "modifier", "compose", - "staterror", - "autostaterrors", + # "staterror", + # "autostaterrors", ] @@ -34,7 +33,7 @@ def __dir__(): class ModifierBase(eqx.Module): @abc.abstractmethod - def __call__(self, sumw: jax.Array) -> jax.Array: + def __call__(self, sumw: Array) -> Array: ... @@ -75,27 +74,27 @@ class modifier(ModifierBase): modify(jnp.array([10, 20, 30])) """ - name: str parameter: Parameter effect: Effect - def __init__( - self, name: str, parameter: Parameter, effect: Effect = DEFAULT_EFFECT - ) -> None: - self.name = name + def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> None: self.parameter = parameter self.effect = effect self.parameter.constraints.add(self.effect.constraint) - def scale_factor(self, sumw: jax.Array) -> jax.Array: + def scale_factor(self, sumw: Array) -> Array: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) - def __call__(self, sumw: jax.Array) -> jax.Array: + @jax.named_scope("evm.modifier") + def __call__(self, sumw: Array) -> Array: op = self.effect.apply_op shift = jnp.atleast_1d(self.scale_factor(sumw=sumw)) shift = jnp.broadcast_to(shift, sumw.shape) return op(shift, sumw) # type: ignore[call-arg] + def __matmul__(self, other: modifier) -> compose: + return compose(self, other) + class compose(ModifierBase): """ @@ -133,7 +132,7 @@ class compose(ModifierBase): eqx.filter_jit(composition)(jnp.array([10, 20, 30])) """ - modifiers: list[ModifierBase] + modifiers: list[modifier] def __init__(self, *modifiers: modifier) -> None: self.modifiers = list(modifiers) @@ -147,19 +146,12 @@ def __init__(self, *modifiers: modifier) -> None: _modifiers.append(mod) self.modifiers = _modifiers - def __check_init__(self): - # check for duplicate names - names = [m.name for m in self.modifiers] - duplicates = {name for name in names if names.count(name) > 1} - if duplicates: - msg = f"Modifiers need to have unique names, got: {duplicates}" - raise ValueError(msg) - def __len__(self) -> int: return len(self.modifiers) - def __call__(self, sumw: jax.Array) -> jax.Array: - def _prep_shift(modifier: ModifierBase, sumw: jax.Array) -> jax.Array: + @jax.named_scope("evm.compose") + def __call__(self, sumw: Array) -> Array: + def _prep_shift(modifier: modifier, sumw: Array) -> Array: shift = modifier.scale_factor(sumw=sumw) shift = jnp.atleast_1d(shift) return jnp.broadcast_to(shift, sumw.shape) @@ -181,257 +173,256 @@ def _prep_shift(modifier: ModifierBase, sumw: jax.Array) -> jax.Array: return _mult_fact * (sumw + _add_shift) -class staterror(ModifierBase): - """ - Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. - - *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! - - Example: - - .. code-block:: python - - import jax.numpy as jnp - import evermore as evm - - hist = jnp.array([10, 20, 30]) - - p1 = evm.Parameter(value=1.0) - p2 = evm.Parameter(value=0.0) - p3 = evm.Parameter(value=0.0) - - # all bins with bin content below 10 (threshold) are treated as poisson, else gauss - modify = evm.staterror( - parameters={1: p1, 2: p2, 3: p3}, - sumw=hist, - sumw2=hist, - threshold=10.0, - ) - modify(hist) - # -> Array([13.162277, 20. , 30. ], dtype=float32) - - # jit - import equinox as eqx - - fast_modify = eqx.filter_jit(modify) - """ - - name: str = "staterror" - parameters: dict[str, Parameter] - sumw: jax.Array - sumw2: jax.Array - sumw2sqrt: jax.Array - widths: jax.Array - mask: jax.Array - threshold: float - - def __init__( - self, - parameters: dict[str, Parameter], - sumw: jax.Array, - sumw2: jax.Array, - threshold: float, - ) -> None: - self.parameters = parameters - self.sumw = sumw - self.sumw2 = sumw2 - self.sumw2sqrt = jnp.sqrt(sumw2) - self.threshold = threshold - - # calculate width - self.widths = self.sumw2sqrt / self.sumw - - # store if sumw is below threshold - self.mask = self.sumw < self.threshold - - for i, name in enumerate(self.parameters): - param = self.parameters[name] - effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) - param.constraints.add(effect.constraint) - - def __check_init__(self): - if not len(self.parameters) == len(self.sumw2) == len(self.sumw): - msg = ( - f"Length of parameters ({len(self.parameters)}), " - f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " - "must be the same." - ) - raise ValueError(msg) - if not self.threshold > 0.0: - msg = f"Threshold must be >= 0.0, got: {self.threshold}" - raise ValueError(msg) - - def scale_factor(self, sumw: jax.Array) -> jax.Array: - from functools import partial - - assert len(sumw) == len(self.parameters) == len(self.sumw2) - - values = jnp.concatenate([param.value for param in self.parameters.values()]) - idxs = jnp.arange(len(sumw)) - - # sumw where mask (poisson) else widths (gauss) - _widths = jnp.where(self.mask, self.sumw, self.widths) - - def _mod( - value: jax.Array, - width: jax.Array, - idx: jax.Array, - effect: Effect, - ) -> jax.Array: - return effect(width).scale_factor( - parameter=Parameter(value=value), - sumw=sumw[idx], - )[0] - - _poisson_mod = partial(_mod, effect=poisson) - _gauss_mod = partial(_mod, effect=gauss) - - # apply - return jnp.where( - self.mask, - jax.vmap(_poisson_mod)(values, _widths, idxs), - jax.vmap(_gauss_mod)(values, _widths, idxs), - ) - - def __call__(self, sumw: jax.Array) -> jax.Array: - # both gauss and poisson behave multiplicative - op = operator.mul - sf = self.scale_factor(sumw=sumw) - return op(jnp.atleast_1d(sf), sumw) - - -class autostaterrors(eqx.Module): - class Mode(eqx.Enumeration): - barlow_beeston_full = ( - "Barlow-Beeston (full) approach: Poisson per process and bin" - ) - poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" - barlow_beeston_lite = "Barlow-Beeston (lite) approach" - - sumw: dict[str, jax.Array] - sumw2: dict[str, jax.Array] - masks: dict[str, jax.Array] - threshold: float - mode: str - key_template: str = eqx.field(static=True) - - def __init__( - self, - sumw: dict[str, jax.Array], - sumw2: dict[str, jax.Array], - threshold: float = 10.0, - mode: str = Mode.barlow_beeston_lite, - key_template: str = "__staterror_{process}__", - ) -> None: - self.sumw = sumw - self.sumw2 = sumw2 - self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()} - self.threshold = threshold - self.mode = mode - self.key_template = key_template - - def __check_init__(self): - if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure( - self.sumw2 - ): # type: ignore[operator] - msg = ( - "The structure of `sumw` and `sumw2` needs to be identical, got " - f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and " - f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})" - ) - raise ValueError(msg) - if not self.threshold > 0.0: - msg = f"Threshold must be >= 0.0, got: {self.threshold}" - raise ValueError(msg) - if not isinstance(self.mode, self.Mode): - msg = f"Mode must be of type {self.Mode}, got: {self.mode}" - raise ValueError(msg) - - def prepare( - self, - ) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]: - """ - Helper to automatically create parameters used by `staterror` - for the initialisation of a `evm.Model`. - - *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! - - Example: - - .. code-block:: python - - import jax.numpy as jnp - import evermore as evm - - sumw = { - "signal": jnp.array([5, 20, 30]), - "background": jnp.array([5, 20, 30]), - } - - sumw2 = { - "signal": jnp.array([5, 20, 30]), - "background": jnp.array([5, 20, 30]), - } - - - auto = evm.autostaterrors( - sumw=sumw, - sumw2=sumw2, - threshold=10.0, - mode=evm.autostaterrors.Mode.barlow_beeston_full, - ) - parameters, staterrors = auto.prepare() - - # barlow-beeston-lite - auto2 = evm.autostaterrors( - sumw=sumw, - sumw2=sumw2, - threshold=10.0, - mode=evm.autostaterrors.Mode.barlow_beeston_lite, - ) - parameters2, staterrors2 = auto2.prepare() - - # materialize: - process = "signal" - pkey = auto.key_template.format(process=process) - modify = staterrors[pkey](parameters[pkey]) - modified_process = modify(sumw[process]) - """ - import equinox as eqx - - parameters: dict[str, dict[str, Parameter]] = {} - staterrors: dict[str, dict[str, eqx.Partial]] = {} - - for process, _sumw in self.sumw.items(): - key = self.key_template.format(process=process) - process_parameters = parameters[key] = {} - mask = self.masks[process] - for i in range(len(_sumw)): - pkey = f"{process}_{i}" - if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: - # we merge all processes into one parameter - # for the barlow-beeston-lite approach where - # the bin content is above a certain threshold - pkey = f"{i}" - process_parameters[pkey] = Parameter(value=jnp.array(0.0)) - # prepare staterror - kwargs = { - "sumw": _sumw, - "sumw2": self.sumw2[process], - "threshold": self.threshold, - } - if self.mode == self.Mode.barlow_beeston_full: - kwargs["threshold"] = jnp.inf # inf -> always poisson - elif self.mode == self.Mode.barlow_beeston_lite: - kwargs["sumw"] = jnp.where( - mask, - _sumw, - sum(jax.tree_util.tree_leaves(self.sumw)), - ) - kwargs["sumw2"] = jnp.where( - mask, - self.sumw2[process], - sum(jax.tree_util.tree_leaves(self.sumw2)), - ) - staterrors[key] = eqx.Partial(staterror, **kwargs) - return parameters, staterrors +# class staterror(ModifierBase): +# """ +# Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. + +# *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! + +# Example: + +# .. code-block:: python + +# import jax.numpy as jnp +# import evermore as evm + +# hist = jnp.array([10, 20, 30]) + +# p1 = evm.Parameter(value=1.0) +# p2 = evm.Parameter(value=0.0) +# p3 = evm.Parameter(value=0.0) + +# # all bins with bin content below 10 (threshold) are treated as poisson, else gauss +# modify = evm.staterror( +# parameters={1: p1, 2: p2, 3: p3}, +# sumw=hist, +# sumw2=hist, +# threshold=10.0, +# ) +# modify(hist) +# # -> Array([13.162277, 20. , 30. ], dtype=float32) + +# # jit +# import equinox as eqx + +# fast_modify = eqx.filter_jit(modify) +# """ + +# parameters: dict[str, Parameter] +# sumw: Array +# sumw2: Array +# sumw2sqrt: Array +# widths: Array +# mask: Array +# threshold: float + +# def __init__( +# self, +# parameters: dict[str, Parameter], +# sumw: Array, +# sumw2: Array, +# threshold: float, +# ) -> None: +# self.parameters = parameters +# self.sumw = sumw +# self.sumw2 = sumw2 +# self.sumw2sqrt = jnp.sqrt(sumw2) +# self.threshold = threshold + +# # calculate width +# self.widths = self.sumw2sqrt / self.sumw + +# # store if sumw is below threshold +# self.mask = self.sumw < self.threshold + +# for i, name in enumerate(self.parameters): +# param = self.parameters[name] +# effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) +# param.constraints.add(effect.constraint) + +# def __check_init__(self): +# if not len(self.parameters) == len(self.sumw2) == len(self.sumw): +# msg = ( +# f"Length of parameters ({len(self.parameters)}), " +# f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " +# "must be the same." +# ) +# raise ValueError(msg) +# if not self.threshold > 0.0: +# msg = f"Threshold must be >= 0.0, got: {self.threshold}" +# raise ValueError(msg) + +# def scale_factor(self, sumw: Array) -> Array: +# from functools import partial + +# assert len(sumw) == len(self.parameters) == len(self.sumw2) + +# values = jnp.concatenate([param.value for param in self.parameters.values()]) +# idxs = jnp.arange(len(sumw)) + +# # sumw where mask (poisson) else widths (gauss) +# _widths = jnp.where(self.mask, self.sumw, self.widths) + +# def _mod( +# value: Array, +# width: Array, +# idx: Array, +# effect: type[poisson] | type[gauss], +# ) -> Array: +# return effect(width).scale_factor( +# parameter=Parameter(value=value), +# sumw=sumw[idx], +# )[0] + +# _poisson_mod = partial(_mod, effect=poisson) +# _gauss_mod = partial(_mod, effect=gauss) + +# # apply +# return jnp.where( +# self.mask, +# jax.vmap(_poisson_mod)(values, _widths, idxs), +# jax.vmap(_gauss_mod)(values, _widths, idxs), +# ) + +# def __call__(self, sumw: Array) -> Array: +# # both gauss and poisson behave multiplicative +# op = operator.mul +# sf = self.scale_factor(sumw=sumw) +# return op(jnp.atleast_1d(sf), sumw) + + +# class autostaterrors(eqx.Module): +# class Mode(eqx.Enumeration): +# barlow_beeston_full = ( +# "Barlow-Beeston (full) approach: Poisson per process and bin" +# ) +# poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" +# barlow_beeston_lite = "Barlow-Beeston (lite) approach" + +# sumw: dict[str, Array] +# sumw2: dict[str, Array] +# masks: dict[str, Array] +# threshold: float +# mode: str +# key_template: str = eqx.field(static=True) + +# def __init__( +# self, +# sumw: dict[str, Array], +# sumw2: dict[str, Array], +# threshold: float = 10.0, +# mode: str = Mode.barlow_beeston_lite, +# key_template: str = "__staterror_{process}__", +# ) -> None: +# self.sumw = sumw +# self.sumw2 = sumw2 +# self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()} +# self.threshold = threshold +# self.mode = mode +# self.key_template = key_template + +# def __check_init__(self): +# if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure( +# self.sumw2 +# ): # type: ignore[operator] +# msg = ( +# "The structure of `sumw` and `sumw2` needs to be identical, got " +# f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and " +# f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})" +# ) +# raise ValueError(msg) +# if not self.threshold > 0.0: +# msg = f"Threshold must be >= 0.0, got: {self.threshold}" +# raise ValueError(msg) +# if not isinstance(self.mode, self.Mode): +# msg = f"Mode must be of type {self.Mode}, got: {self.mode}" +# raise ValueError(msg) + +# def prepare( +# self, +# ) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]: +# """ +# Helper to automatically create parameters used by `staterror` +# for the initialisation of a `evm.Model`. + +# *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! + +# Example: + +# .. code-block:: python + +# import jax.numpy as jnp +# import evermore as evm + +# sumw = { +# "signal": jnp.array([5, 20, 30]), +# "background": jnp.array([5, 20, 30]), +# } + +# sumw2 = { +# "signal": jnp.array([5, 20, 30]), +# "background": jnp.array([5, 20, 30]), +# } + + +# auto = evm.autostaterrors( +# sumw=sumw, +# sumw2=sumw2, +# threshold=10.0, +# mode=evm.autostaterrors.Mode.barlow_beeston_full, +# ) +# parameters, staterrors = auto.prepare() + +# # barlow-beeston-lite +# auto2 = evm.autostaterrors( +# sumw=sumw, +# sumw2=sumw2, +# threshold=10.0, +# mode=evm.autostaterrors.Mode.barlow_beeston_lite, +# ) +# parameters2, staterrors2 = auto2.prepare() + +# # materialize: +# process = "signal" +# pkey = auto.key_template.format(process=process) +# modify = staterrors[pkey](parameters[pkey]) +# modified_process = modify(sumw[process]) +# """ +# import equinox as eqx + +# parameters: dict[str, dict[str, Parameter]] = {} +# staterrors: dict[str, dict[str, eqx.Partial]] = {} + +# for process, _sumw in self.sumw.items(): +# key = self.key_template.format(process=process) +# process_parameters = parameters[key] = {} +# mask = self.masks[process] +# for i in range(len(_sumw)): +# pkey = f"{process}_{i}" +# if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: +# # we merge all processes into one parameter +# # for the barlow-beeston-lite approach where +# # the bin content is above a certain threshold +# pkey = f"{i}" +# process_parameters[pkey] = Parameter(value=jnp.array(0.0)) +# # prepare staterror +# kwargs = { +# "sumw": _sumw, +# "sumw2": self.sumw2[process], +# "threshold": self.threshold, +# } +# if self.mode == self.Mode.barlow_beeston_full: +# kwargs["threshold"] = jnp.inf # inf -> always poisson +# elif self.mode == self.Mode.barlow_beeston_lite: +# kwargs["sumw"] = jnp.where( +# mask, +# _sumw, +# sum(jax.tree_util.tree_leaves(self.sumw)), +# ) +# kwargs["sumw2"] = jnp.where( +# mask, +# self.sumw2[process], +# sum(jax.tree_util.tree_leaves(self.sumw2)), +# ) +# staterrors[key] = eqx.Partial(staterror, **kwargs) +# return parameters, staterrors diff --git a/src/evermore/optimizer.py b/src/evermore/optimizer.py deleted file mode 100644 index 7fcfbd2..0000000 --- a/src/evermore/optimizer.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Hashable -from typing import TYPE_CHECKING, Any, cast - -import equinox as eqx -import jax -import jaxopt - -from evermore.custom_types import Sentinel, _NoValue - -__all__ = [ - "JaxOptimizer", - "Chain", -] - - -def __dir__(): - return __all__ - - -class JaxOptimizer(eqx.Module): - """ - Wrapper around `jaxopt` optimizers to make them hashable. - This allows to pass the optimizer as a parameter to a `jax.jit` function, and setup the optimizer therein. - - Example: - - .. code-block:: python - - optimizer = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - # or, e.g.: optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10}) - - optimizer.fit(fun=nll, init_values=init_values) - """ - - name: str - _settings: tuple[tuple[str, Hashable], ...] - - def __init__(self, name: str, _settings: tuple[tuple[str, Hashable], ...]) -> None: - self.name = name - self._settings = _settings - - @classmethod - def make( - cls: type[JaxOptimizer], - name: str, - settings: dict[str, Hashable] | Sentinel = _NoValue, - ) -> JaxOptimizer: - if settings is _NoValue: - settings = {} - if TYPE_CHECKING: - settings = cast(dict[str, Hashable], settings) - return cls(name=name, _settings=tuple(settings.items())) - - @property - def settings(self) -> dict[str, Hashable]: - return dict(self._settings) - - def solver_instance(self, fun: Callable) -> jaxopt._src.base.Solver: - return getattr(jaxopt, self.name)(fun=fun, **self.settings) - - def fit( - self, fun: Callable, init_values: dict[str, jax.Array] - ) -> tuple[dict[str, jax.Array], Any]: - values, state = self.solver_instance(fun=fun).run(init_values) - return values, state - - -class Chain(eqx.Module): - """ - Chain multiple optimizers together. - They probably should have the `maxiter` setting set to a value, - in order to have a deterministic runtime behaviour. - - Example: - - .. code-block:: python - - opt1 = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - opt2 = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10}) - - chain = Chain(opt1, opt2) - # first 5 steps are minimized with GradientDescent, then 10 steps with LBFGS - chain.fit(fun=nll, init_values=init_values) - """ - - optimizers: tuple[JaxOptimizer, ...] - - def __init__(self, *optimizers: JaxOptimizer) -> None: - self.optimizers = optimizers - - def fit( - self, fun: Callable, init_values: dict[str, jax.Array] - ) -> tuple[dict[str, jax.Array], Any]: - values = init_values - for optimizer in self.optimizers: - values, state = optimizer.fit(fun=fun, init_values=values) - return values, state diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index a18adb8..195f239 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -1,12 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import equinox as eqx -import jax import jax.numpy as jnp +from jaxtyping import Array, ArrayLike, Float from evermore.pdf import HashablePDF from evermore.util import as1darray +if TYPE_CHECKING: + from evermore.modifier import modifier + __all__ = [ "Parameter", ] @@ -17,28 +22,57 @@ def __dir__(): class Parameter(eqx.Module): - value: jax.Array = eqx.field(converter=as1darray) - bounds: tuple[jax.Array, jax.Array] = eqx.field( - static=True, converter=lambda x: tuple(map(as1darray, x)) - ) + value: Array = eqx.field(converter=as1darray) + lower: Array = eqx.field(static=True, converter=as1darray) + upper: Array = eqx.field(static=True, converter=as1darray) constraints: set[HashablePDF] = eqx.field(static=True) def __init__( self, - value: jax.Array, - bounds: tuple[jax.Array, jax.Array] = (as1darray(-jnp.inf), as1darray(jnp.inf)), + value: ArrayLike = 0.0, + lower: ArrayLike = -jnp.inf, + upper: ArrayLike = jnp.inf, ) -> None: - self.value = value - self.bounds = bounds + self.value = as1darray(value) + self.lower = as1darray(lower) + self.upper = as1darray(upper) self.constraints: set[HashablePDF] = set() - def update(self, value: jax.Array) -> Parameter: + def update(self, value: Array | Parameter) -> Parameter: + if isinstance(value, Parameter): + value = value.value return eqx.tree_at(lambda t: t.value, self, value) @property - def boundary_penalty(self) -> jax.Array: + def boundary_penalty(self) -> Array: return jnp.where( - (self.value < self.bounds[0]) | (self.value > self.bounds[1]), + (self.value < self.lower) | (self.value > self.upper), jnp.inf, 0, ) + + # shorthands + def unconstrained(self) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.unconstrained()) + + def gauss(self, width: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.gauss(width=width)) + + def lnN(self, width: Float[Array, 2]) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.lnN(width=width)) + + def poisson(self, lamb: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.poisson(lamb=lamb)) + + def shape(self, up: Array, down: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.shape(up=up, down=down)) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 3cb804b..f2d4542 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -1,10 +1,15 @@ from __future__ import annotations from abc import abstractmethod +from typing import TYPE_CHECKING, Any import equinox as eqx import jax import jax.numpy as jnp +from jaxtyping import Array, PRNGKeyArray + +if TYPE_CHECKING: + from evermore import Parameter __all__ = [ "HashablePDF", @@ -24,19 +29,23 @@ def __hash__(self) -> int: ... @abstractmethod - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: + ... + + @abstractmethod + def pdf(self, x: Array) -> Array: ... @abstractmethod - def pdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: ... @abstractmethod - def cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: ... @abstractmethod - def inv_cdf(self, x: jax.Array) -> jax.Array: + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: ... @@ -44,19 +53,29 @@ class Flat(HashablePDF): def __hash__(self): return hash(self.__class__) - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: return jnp.array([0.0]) - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jnp.array([1.0]) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jnp.array([1.0]) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: msg = "Flat distribution has no inverse CDF." raise ValueError(msg) + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: + return jax.random.uniform( + key, + parameter.value.shape, + # what should be the ranges? + # +/-jnp.inf leads to nans... + # minval=parameter.lower, + # maxval=parameter.upper, + ) + class Gauss(HashablePDF): mean: float = eqx.field(static=True) @@ -69,44 +88,59 @@ def __init__(self, mean: float, width: float) -> None: def __hash__(self): return hash(self.__class__) ^ hash((self.mean, self.width)) - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.norm.logpdf( self.mean, loc=self.mean, scale=self.width ) unnormalized = jax.scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.width) return unnormalized - logpdf_max - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jax.scipy.stats.norm.pdf(x, loc=self.mean, scale=self.width) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jax.scipy.stats.norm.cdf(x, loc=self.mean, scale=self.width) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: return jax.scipy.stats.norm.ppf(x, loc=self.mean, scale=self.width) + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: + return self.mean + self.width * jax.random.normal( + key, + shape=parameter.value.shape, + dtype=parameter.value.dtype, + ) + class Poisson(HashablePDF): - lamb: jax.Array = eqx.field(static=True) + lamb: Array = eqx.field(static=True) - def __init__(self, lamb: jax.Array) -> None: + def __init__(self, lamb: Array) -> None: self.lamb = lamb def __hash__(self): - return hash(self.__class__) ^ hash(str(self.lamb)) # is this a safe hash?? + return hash(self.__class__) - def logpdf(self, x: jax.Array) -> jax.Array: + def __eq__(self, other: Any): # type: ignore[override] + if not isinstance(other, Poisson): + return ValueError(f"Cannot compare Poisson with {type(other)}") + # We need to implement __eq__ explicitely because we have a non-hashable field (lamb). + # Implementing __eq__ is necessary for the `==` operator to work and to ensure that the + # Poisson distribution is correctly added to a python set. + return jnp.all(self.lamb == other.lamb) + + def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb) unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb) return unnormalized - logpdf_max - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jax.scipy.stats.poisson.pmf((x + 1) * self.lamb, mu=self.lamb) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jax.scipy.stats.poisson.cdf((x + 1) * self.lamb, mu=self.lamb) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson def cond_fn(val): n, cdf = val @@ -121,3 +155,11 @@ def body_fn(val): cdf_start = self.cdf(start) n, _ = jax.lax.while_loop(cond_fn, body_fn, (start, cdf_start)) return n.astype(jnp.result_type(int)) + + def sample(self, key: PRNGKeyArray) -> Array: # type: ignore[override] + return jax.random.poisson( + key, + self.lamb, + shape=self.lamb.shape, + dtype=self.lamb.dtype, + ) diff --git a/src/evermore/sample.py b/src/evermore/sample.py new file mode 100644 index 0000000..13d0ec9 --- /dev/null +++ b/src/evermore/sample.py @@ -0,0 +1,36 @@ +from collections.abc import Callable + +import equinox as eqx +import jax +from jaxtyping import Array, PRNGKeyArray, PyTree + +from evermore.util import is_parameter + + +# get the PDFs from the parameters of the model +def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: + from evermore import Parameter + + params_tree = eqx.filter(module, is_parameter, is_leaf=is_parameter) + params_structure = jax.tree_util.tree_structure(params_tree) + n_params = params_structure.num_leaves # type: ignore[attr-defined] + + keys = jax.random.split(key, n_params) + keys_tree = jax.tree_util.tree_unflatten(params_structure, keys) + + def _sample(param: Parameter, key: Parameter) -> Array: + if not param.constraints: + msg = f"Parameter {param} has no constraint pdf, can't sample from it." + raise RuntimeError(msg) + if len(param.constraints) > 1: + msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" + raise ValueError(msg) + pdf = next(iter(param.constraints)) + + # sample new value from the constraint pdf + sampled_param_value = pdf.sample(key.value, param) + + # replace the sampled parameter value and return new parameter + return eqx.tree_at(lambda p: p.value, param, sampled_param_value) + + return jax.tree_util.tree_map(_sample, params_tree, keys_tree, is_leaf=is_parameter) diff --git a/src/evermore/util.py b/src/evermore/util.py index 8ce79b2..db70722 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -1,22 +1,24 @@ from __future__ import annotations -import collections -import pprint -from collections.abc import Callable, Hashable, Iterable, Mapping -from typing import TYPE_CHECKING, Any, TypeVar, cast - +import operator +from collections.abc import Callable +from functools import partial +from typing import ( + Any, +) + +import equinox as eqx import jax import jax.numpy as jnp - -from evermore.custom_types import ArrayLike, Sentinel, _NoValue +import jax.tree_util as jtu +from jaxtyping import Array, ArrayLike, PyTree __all__ = [ - "HistDB", - "FrozenDB", + "is_parameter", + "sum_leaves", "as1darray", "dump_hlo_graph", "dump_jaxpr", - "deep_update", ] @@ -24,226 +26,37 @@ def __dir__(): return __all__ -class FrozenKeysView(collections.abc.KeysView): - """FrozenKeysView that does not print values when repr'ing.""" - - def __init__(self, mapping): - super().__init__(mapping) - self._mapping = mapping - - def __repr__(self): - return f"{type(self).__name__}({list(map(_pretty_key, self._mapping.keys()))})" - - __str__ = __repr__ - - -def _pretty_key(key): - if not isinstance(key, frozenset): - key = FrozenDB.keyify(key) - if len(key) == 1: - return next(iter(key)) - return tuple([_pretty_key(k) for k in key]) - - -def _indent(amount: int, s: str) -> str: - """Indents `s` with `amount` spaces.""" - prefix = amount * " " - return "\n".join(prefix + line for line in s.splitlines()) - - -def _pretty_dict(x): - if not isinstance(x, Mapping): - return pprint.pformat(x) - rep = "" - for key, val in x.items(): - rep += f"{_pretty_key(key)!r}: {_pretty_dict(val)},\n" - if rep: - return "{\n" + _indent(2, rep) + "\n}" - return "{}" - - -K = TypeVar("K") -V = TypeVar("V") - - -def _prepare_freeze(xs: Any) -> Any: - """Deep copy unfrozen dicts to make the dictionary FrozenDict safe.""" - if isinstance(xs, FrozenDB): - # we can safely ref share the internal state of a FrozenDict - # because it is immutable. - return xs._dict - if not isinstance(xs, dict): - # return a leaf as is. - return xs - # recursively copy dictionary to avoid ref sharing - return {FrozenDB.keyify(key): _prepare_freeze(val) for key, val in xs.items()} - - -def _check_no_duplicate_keys(keys: Iterable[Hashable]) -> None: - keys = list(keys) - if any(keys.count(x) > 1 for x in keys): - msg = f"Duplicate keys: {tuple(keys)}, this is not allowed!" - raise ValueError(msg) - - -class FrozenDB(Mapping[K, V]): - """An immutable database-like custom dict. - - Example: - - .. code-block:: python - - hists = HistDB( - { - # QCD - ("QCD", "nominal"): jnp.array([1, 1, 1, 1, 1]), - ("QCD", "JES", "Up"): jnp.array([1.5, 1.5, 1.5, 1.5, 1.5]), - ("QCD", "JES", "Down"): jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]), - # DY - ("DY", "nominal"): jnp.array([2, 2, 2, 2, 2]), - ("DY", "JES", "Up"): jnp.array([2.5, 2.5, 2.5, 2.5, 2.5]), - ("DY", "JES", "Down"): jnp.array([0.7, 0.7, 0.7, 0.7, 0.7]), - } - ) - - print(hists) - # -> HistDB({ - # ('QCD', 'nominal'): Array([1, 1, 1, 1, 1], dtype=int32), - # ('QCD', 'Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('QCD', 'Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # ('DY', 'nominal'): Array([2, 2, 2, 2, 2], dtype=int32), - # ('DY', 'Up', 'JES'): Array([2.5, 2.5, 2.5, 2.5, 2.5], dtype=float32), - # ('DY', 'Down', 'JES'): Array([0.7, 0.7, 0.7, 0.7, 0.7], dtype=float32), - # }) - - print(hists["QCD"]) - # -> HistDB({ - # 'nominal': Array([1, 1, 1, 1, 1], dtype=int32), - # ('Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # }) - - print(hists["JES"]) - # -> HistDB({ - # ('QCD', 'Up'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('QCD', 'Down'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # ('DY', 'Up'): Array([2.5, 2.5, 2.5, 2.5, 2.5], dtype=float32), - # ('DY', 'Down'): Array([0.7, 0.7, 0.7, 0.7, 0.7], dtype=float32), - # }) - - # It's jit-compatible: - def foo(hists): - return (hists["QCD", "nominal"] + 1.2) ** 2 - - print(jax.jit(foo)(hists)) - # -> Array([4.84, 4.84, 4.84, 4.84, 4.84], dtype=float32, weak_type=True) - """ - - __slots__ = ("_dict",) - - if TYPE_CHECKING: - _dict: dict[frozenset, Any] - - @staticmethod - def keyify(keyish: Any) -> frozenset: - if not isinstance(keyish, tuple | list | set | frozenset): - keyish = (keyish,) - _check_no_duplicate_keys(keyish) - keyish = frozenset(keyish) - assert not any(isinstance(key, set) for key in keyish) - return keyish - - def __init__( - self, - xs: Mapping | Sentinel = _NoValue, - __unsafe_skip_copy__: bool = False, - ) -> None: - # make sure the dict is as - if xs is _NoValue: - xs = {} - data = dict(cast(Mapping, xs)) - if __unsafe_skip_copy__: - self._dict = data - else: - self._dict = _prepare_freeze(data) - - def __getitem__(self, key) -> Any: - key = self.keyify(key) - if key in self._dict: - return self._dict[key] - ret = self.__class__({k - key: v for k, v in self.items() if key <= k}) - if not ret: - raise KeyError(key) - return ret - - def __setitem__(self, key, value) -> None: - msg = f"{type(self).__name__} is immutable." - raise ValueError(msg) - - def __contains__(self, key) -> bool: - key = self.keyify(key) - return key in self._dict - - def __len__(self) -> int: - return len(self._dict) - - def __iter__(self): - return iter(self._dict) - - def keys(self) -> FrozenKeysView: - return FrozenKeysView(self._dict) +def is_parameter(leaf: Any) -> bool: + from evermore import Parameter - def values(self): - return self._dict.values() + return isinstance(leaf, Parameter) - def items(self): - for key in self._dict: - yield (key, self[key]) - def only(self, *keys) -> FrozenDB: - return self.__class__({key: self[key] for key in keys}) +K = str +V = Any - def subset(self, *keys) -> FrozenDB: - new = {} - for key in keys: - new.update({k: v for k, v in self.items() if self.keyify(key) <= k}) - return self.__class__(new) - def copy(self) -> FrozenDB: - return self.__class__(self) - - def __repr__(self) -> str: - return f"{type(self).__name__}({_pretty_dict(self._dict)})" - - def as_compact_dict(self): - return {"/".join(sorted(map(str, k))): v for k, v in self.items()} - - -def _flatten(tree): - return (tuple(tree.values()), tuple(tree.keys())) - - -def _make_unflatten(cls: type[FrozenDB]) -> Callable: - def _unflatten(keys, values): - return cls(dict(zip(keys, values, strict=True)), __unsafe_skip_copy__=True) - - return _unflatten +def _filtered_module_map( + module: eqx.Module, + fun: Callable, + filter: Callable, +) -> eqx.Module: + params = eqx.filter(module, filter, is_leaf=filter) + return jtu.tree_map( + fun, + params, + is_leaf=filter, + ) -class HistDB(FrozenDB): - ... +_params_map = partial(_filtered_module_map, filter=is_parameter) -# then we register them with jax as a PyTree -for cls in HistDB, FrozenDB: - jax.tree_util.register_pytree_node( - cls, - _flatten, - _make_unflatten(cls), - ) +def sum_leaves(tree: PyTree) -> Array: + return jtu.tree_reduce(operator.add, tree) -def as1darray(x: ArrayLike) -> jax.Array: +def as1darray(x: ArrayLike) -> Array: """ Converts `x` to a 1d array. @@ -316,20 +129,3 @@ def f(x: jax.Array) -> jax.Array: filepath.write_text(dump_hlo_graph(f, x), encoding='ascii') """ return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph() - - -def deep_update( - mapping: dict[K, Any], - new_mapping: dict[K, Any], -) -> dict[K, Any]: - updated_mapping = mapping.copy() - for k, v in new_mapping.items(): - if ( - k in updated_mapping - and isinstance(updated_mapping[k], dict) - and isinstance(v, dict) - ): - updated_mapping[k] = deep_update(updated_mapping[k], v) - else: - updated_mapping[k] = v - return updated_mapping diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py deleted file mode 100644 index f328e8e..0000000 --- a/tests/test_optimizer.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import jax -import jaxopt -import pytest - -from evermore.optimizer import JaxOptimizer - - -def test_jaxoptimizer(): - opt = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - - assert opt.name == "GradientDescent" - assert opt.settings == {"maxiter": 5} - - assert isinstance(opt.solver_instance(fun=lambda x: x), jaxopt.GradientDescent) - - # jit compatibility - @partial(jax.jit, static_argnums=0) - def f(optimizer): - @jax.jit - def fun(x): - return (x - 2.0) ** 2 - - init_values = 1.0 - values, _ = optimizer.fit(fun=fun, init_values=init_values) - return values - - assert f(opt) == pytest.approx(2.0) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 4b6e3f0..3147288 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -8,11 +8,12 @@ def test_parameter(): - p = evm.Parameter(value=jnp.array(1.0), bounds=(jnp.array(0.0), jnp.array(2.0))) + p = evm.Parameter(value=jnp.array(1.0), lower=jnp.array(0.0), upper=jnp.array(2.0)) assert p.value == 1.0 assert p.update(jnp.array(2.0)).value == 2.0 - assert p.bounds == (0.0, 2.0) + assert p.lower == 0.0 + assert p.upper == 2.0 assert p.boundary_penalty == 0.0 assert p.update(jnp.array(3.0)).boundary_penalty == jnp.inf @@ -42,7 +43,7 @@ def test_gauss(): def test_lnN(): p = evm.Parameter(value=jnp.array(0.0)) - ln = evm.effect.lnN(width=(0.9, 1.1)) + ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) assert ln.constraint == Gauss(mean=0.0, width=1.0) assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -67,21 +68,15 @@ def test_modifier(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier( - name="mu", parameter=mu, effect=evm.effect.unconstrained() - ) + m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) assert m_unconstrained(jnp.array([10])) == pytest.approx(11) # gauss effect - m_gauss = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.gauss(jnp.array(0.1)) - ) + m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) assert m_gauss(jnp.array([10])) == pytest.approx(10) # lnN effect - m_lnN = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.lnN(width=(0.9, 1.1)) - ) + m_lnN = evm.modifier(parameter=norm, effect=evm.effect.lnN(width=(0.9, 1.1))) assert m_lnN(jnp.array([10])) == pytest.approx(10) # poisson effect # FIXME @@ -99,13 +94,9 @@ def test_compose(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier( - name="mu", parameter=mu, effect=evm.effect.unconstrained() - ) + m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) # gauss effect - m_gauss = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.gauss(jnp.array(0.1)) - ) + m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) # compose m = evm.compose(m_unconstrained, m_gauss) diff --git a/tests/test_util.py b/tests/test_util.py index 75afe94..9e50171 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,46 +2,7 @@ import jax -from evermore.util import FrozenDB, as1darray - - -def get_frozendb(): - return FrozenDB( - { - # QCD - ("a", "b"): 1, - ("a", "d", "e"): 2, - ("a", "d", "f"): 3, - # DY - ("g", "b"): 4, - ("g", "d", "e"): 5, - ("g", "d", "f"): 6, - } - ) - - -def test_frozendb_len(): - db = get_frozendb() - - assert len(db) == 6 - - -def test_frozendb_getitem(): - db = get_frozendb() - - assert db["a"]["b"] == 1 - assert db["a", "b"] == 1 - assert db["b"] == FrozenDB({"a": 1, "g": 4}) - - -def test_frozendb_jitcompatible(): - db = get_frozendb() - - @jax.jit - def fun(db): - return (db["a", "b"] + 1) ** 2 - - assert fun(db) == 4 +from evermore.util import as1darray def test_as1darray():