Skip to content

Commit

Permalink
update docs and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Apr 29, 2024
1 parent d8d5b2c commit fa6f340
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
7 changes: 4 additions & 3 deletions docs/binned_likelihood.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ These building blocks include:
- **evm.Modifier**: Modifiers combine **evm.Effects** and **evm.Parameters** to
modify data.

The negative log-likelihood (NLL) function of Eq.{eq}`likelihood` can be implemented with evermore as follows:
The negative log-likelihood (NLL) function of Eq.{eq}`likelihood` can be implemented with evermore as follows (copy & paste the following snippet to start write a _new_ statistical model):

```{code-block} python
from jaxtyping import PyTree, Array
Expand All @@ -45,22 +45,23 @@ import evermore as evm
# ...
# -- NLL definition --
@eqx.filter_jit
def NLL(dynamic_params, static_params, hists, observation):
params = eqx.combine(dynamic_params, static_params)
expectations = model(params, hists)
log_likelihood = evm.loss.PoissonLogLikelihood()
# first product of Eq. 1 (Poisson term)
log_likelihood = evm.loss.PoissonLogLikelihood()
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
)
# second product of Eq. 1 (constraint)
constraints = evm.loss.get_log_probs(model)
# for parameters with `.value.size > 1` (jnp.sum the constraints)
constraints = jtu.tree_map(jnp.sum, constraints)
loss_val += evm.util.sum_over_leaves(constraints)
return -jnp.sum(loss_val)
```
Expand Down
24 changes: 21 additions & 3 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from __future__ import annotations

import jax.numpy as jnp

import evermore as evm


def test_PoissonLogLikelihood():
pass
f = evm.loss.PoissonLogLikelihood()
assert f(jnp.array([1.0]), jnp.array([1.0])) == 0.0


def test_get_log_probs():
pass
params = {
"a": evm.NormalParameter(value=0.5),
"b": evm.NormalParameter(),
"c": evm.Parameter(),
}

log_probs = evm.loss.get_log_probs(params)
assert log_probs["a"] == -0.125
assert log_probs["b"] == 0.0
assert log_probs["c"] == 0.0


def test_get_boundary_constraints():
pass
in_bounds_param = evm.Parameter(value=1.0, lower=0.0, upper=2.0)
oob_param = evm.Parameter(value=3.0, lower=0.0, upper=2.0)

assert evm.loss.get_boundary_constraints(in_bounds_param) == 0.0
assert evm.loss.get_boundary_constraints(oob_param) == jnp.inf

0 comments on commit fa6f340

Please sign in to comment.