diff --git a/examples/grad_nll.py b/examples/grad_nll.py index e69de29..0850829 100644 --- a/examples/grad_nll.py +++ b/examples/grad_nll.py @@ -0,0 +1,24 @@ +import equinox as eqx +import jax.numpy as jnp +from model import hists, model, observation + +import evermore as evm + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + loss_val = nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -jnp.sum(loss_val) + + +loss_val = loss(model, hists, observation) +grads = eqx.filter_grad(loss)(model, hists, observation) diff --git a/examples/model.py b/examples/model.py index be416a0..25cebf8 100644 --- a/examples/model.py +++ b/examples/model.py @@ -60,20 +60,4 @@ def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: } observation = jnp.array([37]) - -nll = evm.loss.PoissonNLL() - - -@eqx.filter_jit -def loss(model, hists, observation): - expectations = model(hists) - constraints = evm.loss.get_param_constraints(model) - return nll( - expectation=evm.util.sum_leaves(expectations), - observation=observation, - constraint=evm.util.sum_leaves(constraints), - ) - - -loss_val = loss(model, hists, observation) -grads = eqx.filter_grad(loss)(model, hists, observation) +expectations = model(hists) diff --git a/examples/nll_fit.py b/examples/nll_fit.py index e69de29..3d42579 100644 --- a/examples/nll_fit.py +++ b/examples/nll_fit.py @@ -0,0 +1,39 @@ +import equinox as eqx +import jax.numpy as jnp +import optax +from model import hists, model, observation + +import evermore as evm + +optim = optax.sgd(learning_rate=1e-2) +opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + loss_val = nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -jnp.sum(loss_val) + + +@eqx.filter_jit +def make_step(model, opt_state, events, observation): + # differentiate full analysis + grads = eqx.filter_grad(loss)(model, events, observation) + updates, opt_state = optim.update(grads, opt_state) + # apply nuisance parameter and DNN weight updates + model = eqx.apply_updates(model, updates) + return model, opt_state + + +# minimize model with 1000 steps +for _ in range(1000): + model, opt_state = make_step(model, opt_state, hists, observation) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index ea37ff5..01271e9 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -32,11 +32,13 @@ def loss(diff_model, static_model, hists, observation): model = eqx.combine(diff_model, static_model) expectations = model(hists) constraints = evm.loss.get_param_constraints(model) - return nll( + loss_val = nll( expectation=evm.util.sum_leaves(expectations), observation=observation, - constraint=evm.util.sum_leaves(constraints), ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -2 * jnp.sum(loss_val) @eqx.filter_jit def make_step(model, opt_state, events, observation): @@ -60,7 +62,8 @@ def make_step(model, opt_state, events, observation): for mu in mus: print(f"[for-loop] mu={mu:.2f} - NLL={fixed_mu_fit(jnp.array(mu)):.6f}") + # or vectorized!!! likelihood_scan = jax.vmap(fixed_mu_fit)(mus) for mu, nll in zip(mus, likelihood_scan, strict=False): - print(f"[vectorized] mu={mu:.2f} - NLL={nll:.6f}") + print(f"[jax.vmap] mu={mu:.2f} - NLL={nll:.6f}") diff --git a/pyproject.toml b/pyproject.toml index 34223e0..168df55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ classifiers = [ ] dynamic = ["version"] # version is set in src/evermore/__init__.py dependencies = [ + "jax", + "jaxtyping", "equinox>=0.10.6", # eqx.field ] diff --git a/src/evermore/loss.py b/src/evermore/loss.py index 2cc487c..e638aa9 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -60,15 +60,10 @@ def logpdf(self) -> Callable: return jax.scipy.stats.poisson.logpmf @jax.named_scope("evm.loss.PoissonNLL") - def __call__( - self, expectation: Array, observation: Array, constraint: Array - ) -> Array: + def __call__(self, expectation: Array, observation: Array) -> Array: # poisson log-likelihood - nll = jnp.sum( + return jnp.sum( self.logpdf(observation, expectation) - self.logpdf(observation, observation), axis=-1, ) - # add constraint - nll += constraint - return -jnp.sum(nll) diff --git a/src/evermore/sample.py b/src/evermore/sample.py index 43fc87d..5ea6632 100644 --- a/src/evermore/sample.py +++ b/src/evermore/sample.py @@ -20,7 +20,7 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: def _sample(param: Parameter, key: Parameter) -> Array: if not param.constraints: - msg = f"Parameter {param} has no constraint pdf, can't sample from it." + msg = f"Parameter {param} has no constraint pdf, can't sample from it. Maybe you need to call the model once to populate all constraints?" raise RuntimeError(msg) if len(param.constraints) > 1: msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}"