Skip to content

Commit

Permalink
polish examples; add jax as dependency; better error in sample function
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 6, 2024
1 parent 892540c commit 0a3a4cf
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 28 deletions.
24 changes: 24 additions & 0 deletions examples/grad_nll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import equinox as eqx
import jax.numpy as jnp
from model import hists, model, observation

import evermore as evm

nll = evm.loss.PoissonNLL()


@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
loss_val = nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
)
# add constraint
loss_val += evm.util.sum_leaves(constraints)
return -jnp.sum(loss_val)


loss_val = loss(model, hists, observation)
grads = eqx.filter_grad(loss)(model, hists, observation)
18 changes: 1 addition & 17 deletions examples/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,4 @@ def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]:
}

observation = jnp.array([37])

nll = evm.loss.PoissonNLL()


@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
return nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
constraint=evm.util.sum_leaves(constraints),
)


loss_val = loss(model, hists, observation)
grads = eqx.filter_grad(loss)(model, hists, observation)
expectations = model(hists)
39 changes: 39 additions & 0 deletions examples/nll_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import equinox as eqx
import jax.numpy as jnp
import optax
from model import hists, model, observation

import evermore as evm

optim = optax.sgd(learning_rate=1e-2)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

nll = evm.loss.PoissonNLL()


@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
loss_val = nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
)
# add constraint
loss_val += evm.util.sum_leaves(constraints)
return -jnp.sum(loss_val)


@eqx.filter_jit
def make_step(model, opt_state, events, observation):
# differentiate full analysis
grads = eqx.filter_grad(loss)(model, events, observation)
updates, opt_state = optim.update(grads, opt_state)
# apply nuisance parameter and DNN weight updates
model = eqx.apply_updates(model, updates)
return model, opt_state


# minimize model with 1000 steps
for _ in range(1000):
model, opt_state = make_step(model, opt_state, hists, observation)
9 changes: 6 additions & 3 deletions examples/nll_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def loss(diff_model, static_model, hists, observation):
model = eqx.combine(diff_model, static_model)
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
return nll(
loss_val = nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
constraint=evm.util.sum_leaves(constraints),
)
# add constraint
loss_val += evm.util.sum_leaves(constraints)
return -2 * jnp.sum(loss_val)

@eqx.filter_jit
def make_step(model, opt_state, events, observation):
Expand All @@ -60,7 +62,8 @@ def make_step(model, opt_state, events, observation):
for mu in mus:
print(f"[for-loop] mu={mu:.2f} - NLL={fixed_mu_fit(jnp.array(mu)):.6f}")


# or vectorized!!!
likelihood_scan = jax.vmap(fixed_mu_fit)(mus)
for mu, nll in zip(mus, likelihood_scan, strict=False):
print(f"[vectorized] mu={mu:.2f} - NLL={nll:.6f}")
print(f"[jax.vmap] mu={mu:.2f} - NLL={nll:.6f}")
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ classifiers = [
]
dynamic = ["version"] # version is set in src/evermore/__init__.py
dependencies = [
"jax",
"jaxtyping",
"equinox>=0.10.6", # eqx.field
]

Expand Down
9 changes: 2 additions & 7 deletions src/evermore/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,10 @@ def logpdf(self) -> Callable:
return jax.scipy.stats.poisson.logpmf

@jax.named_scope("evm.loss.PoissonNLL")
def __call__(
self, expectation: Array, observation: Array, constraint: Array
) -> Array:
def __call__(self, expectation: Array, observation: Array) -> Array:
# poisson log-likelihood
nll = jnp.sum(
return jnp.sum(
self.logpdf(observation, expectation)
- self.logpdf(observation, observation),
axis=-1,
)
# add constraint
nll += constraint
return -jnp.sum(nll)
2 changes: 1 addition & 1 deletion src/evermore/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]:

def _sample(param: Parameter, key: Parameter) -> Array:
if not param.constraints:
msg = f"Parameter {param} has no constraint pdf, can't sample from it."
msg = f"Parameter {param} has no constraint pdf, can't sample from it. Maybe you need to call the model once to populate all constraints?"
raise RuntimeError(msg)
if len(param.constraints) > 1:
msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}"
Expand Down

0 comments on commit 0a3a4cf

Please sign in to comment.