Skip to content

pfackeldey/evermore

Repository files navigation

logo

evermore

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Example - Model and Loss Definition

See more in examples/

evermore in a nutshell:

from typing import NamedTuple

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, PyTree

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
def model(params: PyTree, hists: dict[str, Array]) -> Array:
  mu_modifier = params.mu.scale()
  syst_modifier = params.syst.scale_log(up=1.1, down=0.9)
  return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


def loss(
  diffable: PyTree,
  static: PyTree,
  hists: dict[str, Array],
  observation: Array,
) -> Array:
    params = eqx.combine(diffable, static)
    expectation = model(params, hists)
    # Poisson NLL of the expectation and observation
    log_likelihood = evm.pdf.Poisson(lamb=expectation).log_prob(observation)
    # Add parameter constraints from logpdfs
    constraints = evm.loss.get_log_probs(params)
    log_likelihood += evm.util.sum_over_leaves(constraints)
    return -jnp.sum(log_likelihood)


# setup data
hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])}
observation = jnp.array([15])


# define parameters, can be any PyTree of evm.Parameters
class Params(NamedTuple):
  mu: evm.Parameter
  syst: evm.NormalParameter


params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))
diffable, static = evm.parameter.partition(params)

# Calculate negative log-likelihood/loss
loss_val = loss(diffable, static, hists, observation)
# gradients of negative log-likelihood w.r.t. diffable parameters
grads = eqx.filter_grad(loss)(diffable, static, hists, observation)
print(f"{grads.mu.value=}, {grads.syst.value=}")
# -> grads.mu.value=Array([-0.46153846]), grads.syst.value=Array([-0.15436207])

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

About

differentiable (binned) likelihoods with JAX

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages