Skip to content

Commit

Permalink
start of restructure: be more close to equinox' philosophy
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 6, 2024
1 parent 1855676 commit 5aff066
Show file tree
Hide file tree
Showing 24 changed files with 690 additions and 1,365 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"use_edit_page_button": True,
}
html_context = {"default_mode": "light"}
html_logo = "../assets/favicon.png"
html_favicon = "../assets/favicon.png"

extensions = [
"sphinx.ext.autodoc",
Expand Down
36 changes: 36 additions & 0 deletions examples/dnn_weights_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import equinox as eqx
import jax
import jax.numpy as jnp

import evermore as evm


class LinearConstrained(eqx.Module):
weights: evm.Parameter
biases: jax.Array

def __init__(self, in_size, out_size, key):
self.biases = jax.random.normal(key, (out_size,))
self.weights = evm.Parameter(value=jax.random.normal(key, (out_size, in_size)))
self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5))

def __call__(self, x: jax.Array):
return self.weights.value @ x + self.biases


@eqx.filter_jit
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
mse = jax.numpy.mean((y - pred_y) ** 2)
constraints = evm.loss.get_param_constraints(model)
# sum them all up for each weight
constraints = jax.tree_util.tree_map(jnp.sum, constraints)
return mse + evm.util.sum_leaves(constraints)


batch_size, in_size, out_size = 32, 2, 3
model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
loss_val = loss_fn(model, x, y)
grads = eqx.filter_grad(loss_fn)(model, x, y)
31 changes: 0 additions & 31 deletions examples/grad_nll.py
Original file line number Diff line number Diff line change
@@ -1,31 +0,0 @@
from __future__ import annotations

import equinox as eqx
from jax import config
from model import init_values, model, observation, optimizer

import evermore as evm

config.update("jax_enable_x64", True)

# create negative log likelihood
nll = evm.likelihood.NLL(model=model, observation=observation)

# fit
params, state = optimizer.fit(fun=nll, init_values=init_values)

# gradients of nll of fitted model
fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll))
grads = fast_grad_nll(params)
# gradients of nll of fitted model only wrt to `mu`
# basically: pass the parameters dict of which you want the gradients
params_ = {k: v for k, v in params.items() if k == "mu"}
grad_mu = fast_grad_nll(params_)

# hessian + cov_matrix of nll of fitted model
hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))()

# covariance matrix of fitted model
covmatrix = eqx.filter_jit(
evm.likelihood.CovMatrix(model=model, observation=observation)
)()
131 changes: 58 additions & 73 deletions examples/model.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,79 @@
from __future__ import annotations

from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp

import evermore as evm


class SPlusBModel(evm.Model):
def __call__(self, processes: dict, parameters: dict) -> evm.Result:
res = evm.Result()
class SPlusBModel(eqx.Module):
mu: evm.Parameter
norm1: evm.Parameter
norm2: evm.Parameter
shape1: evm.Parameter

mu_modifier = evm.modifier(
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
)
res.add(
process="signal",
expectation=mu_modifier(processes[("signal", "nominal")]),
)
def __init__(self) -> None:
self.mu = evm.Parameter(value=jnp.array([1.0]))
self.norm1 = evm.Parameter()
self.norm2 = evm.Parameter()
self.shape1 = evm.Parameter()

bkg1_modifier = evm.compose(
evm.modifier(
name="lnN1",
parameter=parameters["norm1"],
effect=evm.effect.lnN((0.9, 1.1)),
),
evm.modifier(
name="shape1_bkg1",
parameter=parameters["shape1"],
effect=evm.effect.shape(
up=processes[("background1", "shape_up")],
down=processes[("background1", "shape_down")],
),
),
)
res.add(
process="background1",
expectation=bkg1_modifier(processes[("background1", "nominal")]),
)
def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]:
expectations = {}

bkg2_modifier = evm.compose(
evm.modifier(
name="lnN2",
parameter=parameters["norm2"],
effect=evm.effect.lnN((0.95, 1.05)),
),
evm.modifier(
name="shape1_bkg2",
parameter=parameters["shape1"],
effect=evm.effect.shape(
up=processes[("background2", "shape_up")],
down=processes[("background2", "shape_down")],
),
),
# signal process
sig_mod = self.mu.unconstrained()
expectations["signal"] = sig_mod(hists[("signal", "nominal")])

# bkg1 process
bkg1_mod = self.norm1.lnN(width=jnp.array([0.9, 1.1])) @ self.shape1.shape(
up=hists[("bkg1", "shape_up")],
down=hists[("bkg1", "shape_down")],
)
res.add(
process="background2",
expectation=bkg2_modifier(processes[("background2", "nominal")]),
expectations["bkg1"] = bkg1_mod(hists[("bkg1", "nominal")])

# bkg2 process
bkg2_mod = self.norm2.lnN(width=jnp.array([0.95, 1.05])) @ self.shape1.shape(
up=hists[("bkg2", "shape_up")],
down=hists[("bkg2", "shape_down")],
)
return res
expectations["bkg2"] = bkg2_mod(hists[("bkg2", "nominal")])

# return the modified expectations
return expectations

def create_model():
processes = {
("signal", "nominal"): jnp.array([3]),
("background1", "nominal"): jnp.array([10]),
("background2", "nominal"): jnp.array([20]),
("background1", "shape_up"): jnp.array([12]),
("background1", "shape_down"): jnp.array([8]),
("background2", "shape_up"): jnp.array([23]),
("background2", "shape_down"): jnp.array([19]),
}
parameters = {
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"norm1": evm.Parameter(value=jnp.array([0.0])),
"norm2": evm.Parameter(value=jnp.array([0.0])),
"shape1": evm.Parameter(value=jnp.array([0.0])),
}

# return model
return SPlusBModel(processes=processes, parameters=parameters)
model = SPlusBModel()


model = create_model()
hists = {
("signal", "nominal"): jnp.array([3]),
("bkg1", "nominal"): jnp.array([10]),
("bkg2", "nominal"): jnp.array([20]),
("bkg1", "shape_up"): jnp.array([12]),
("bkg1", "shape_down"): jnp.array([8]),
("bkg2", "shape_up"): jnp.array([23]),
("bkg2", "shape_down"): jnp.array([19]),
}

init_values = model.parameter_values
observation = jnp.array([37])
asimov = model.evaluate().expectation()

nll = evm.loss.PoissonNLL()


@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
return nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
constraint=evm.util.sum_leaves(constraints),
)


# create optimizer (from `jaxopt`)
optimizer = evm.optimizer.JaxOptimizer.make(
name="LBFGS",
settings={"maxiter": 5, "jit": True, "unroll": True},
)
loss_val = loss(model, hists, observation)
grads = eqx.filter_grad(loss)(model, hists, observation)
18 changes: 0 additions & 18 deletions examples/nll_fit.py
Original file line number Diff line number Diff line change
@@ -1,18 +0,0 @@
from __future__ import annotations

from jax import config
from model import init_values, model, observation, optimizer

from evermore.likelihood import NLL

config.update("jax_enable_x64", True)


# create negative log likelihood
nll = NLL(model=model, observation=observation)

# fit
values, state = optimizer.fit(fun=nll, init_values=init_values)

# update model with fitted values
fitted_model = model.update(values=values)
67 changes: 0 additions & 67 deletions examples/nll_profiling.py
Original file line number Diff line number Diff line change
@@ -1,67 +0,0 @@
from __future__ import annotations

from functools import partial

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import config
from model import asimov, model, optimizer

from evermore import Model
from evermore.likelihood import NLL
from evermore.optimizer import JaxOptimizer

config.update("jax_enable_x64", True)


def nll_profiling(
value_name: str,
scan_points: jax.Array,
model: Model,
observation: jax.Array,
optimizer: JaxOptimizer,
fit: bool,
) -> jax.Array:
# define single fit for a fixed parameter of interest (poi)
@partial(jax.jit, static_argnames=("value_name", "optimizer", "fit"))
def fixed_poi_fit(
value_name: str,
scan_point: jax.Array,
model: Model,
observation: jax.Array,
optimizer: JaxOptimizer,
fit: bool,
) -> jax.Array:
# fix theta into the model
model = model.update(values={value_name: scan_point})
init_values = model.parameter_values
init_values.pop(value_name, 1)
# minimize
nll = eqx.filter_jit(NLL(model=model, observation=observation))
if fit:
values, _ = optimizer.fit(fun=nll, init_values=init_values)
else:
values = model.parameter_values
return nll(values=values)

# vectorise for multiple fixed values (scan points)
fixed_poi_fit_vec = jax.vmap(
fixed_poi_fit, in_axes=(None, 0, None, None, None, None)
)
return fixed_poi_fit_vec(
value_name, scan_points, model, observation, optimizer, fit
)


# profile the NLL around starting point of `0`
scan_points = jnp.r_[-1.9:2.0:0.1]

profile_postfit = nll_profiling(
value_name="norm1",
scan_points=scan_points,
model=model,
observation=asimov,
optimizer=optimizer,
fit=True,
)
52 changes: 32 additions & 20 deletions examples/toy_generation.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
from __future__ import annotations
from typing import Any

import equinox as eqx
import jax
from jax import config
from model import init_values, model, observation, optimizer
from jaxtyping import Array, PRNGKeyArray
from model import hists, model, observation

from evermore.likelihood import NLL, SampleToy
import evermore as evm

config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(0)

# generate a new model with sampled parameters according to their constraint pdfs
toymodel = evm.sample.toy_module(model, key)

# create negative log likelihood
nll = NLL(model=model, observation=observation)

# fit
values, state = optimizer.fit(fun=nll, init_values=init_values)
# generate new expectation based on the toy model
def toy_expectation(
key: PRNGKeyArray,
module: eqx.Module,
hists: dict[Any, Array],
) -> Array:
toymodel = evm.sample.toy_module(model, key)
expectations = toymodel(hists)
return evm.util.sum_leaves(expectations)

# create sampling method
sample_toy = SampleToy(model=model, observation=observation)
# vectorise and jit
sample_toys = eqx.filter_vmap(in_axes=(None, 0))(eqx.filter_jit(sample_toy))

sample_toy(values, jax.random.PRNGKey(1234))
expectation = toy_expectation(key, model, hists)

# sample 10 toys based on fitted parameters
keys = jax.random.split(jax.random.PRNGKey(1234), num=10)
# postfit toys
toys_postfit = sample_toys(values, keys)
# prefit toys
toys_prefit = sample_toys(init_values, keys)

# generate a new expectations vectorized over many keys
keys = jax.random.split(key, 1000)

# vectorized toy expectation
toy_expectation_vec = jax.vmap(toy_expectation, in_axes=(0, None, None))
expectations = toy_expectation_vec(keys, model, hists)


# just sample observations with poisson
poisson_obs = evm.pdf.Poisson(observation)
sampled_observation = poisson_obs.sample(key)

# vectorized sampling
sampled_observations = jax.vmap(poisson_obs.sample)(keys)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ classifiers = [
dynamic = ["version"] # version is set in src/evermore/__init__.py
dependencies = [
"equinox>=0.10.6", # eqx.field
"jaxopt >=0.6", # jaxopt.LGBFGS
]

[project.optional-dependencies]
test = ["pytest >=6", "pytest-cov >=3"]
dev = ["pytest >=6", "pytest-cov >=3"]
dev = ["pytest >=6", "pytest-cov >=3", "optax", "jaxopt >=0.6"]
docs = [
"sphinx>=7.0",
"myst_parser>=0.13",
Expand Down
Loading

0 comments on commit 5aff066

Please sign in to comment.