-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start of restructure: be more close to equinox' philosophy
- Loading branch information
1 parent
1855676
commit 5aff066
Showing
24 changed files
with
690 additions
and
1,365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import evermore as evm | ||
|
||
|
||
class LinearConstrained(eqx.Module): | ||
weights: evm.Parameter | ||
biases: jax.Array | ||
|
||
def __init__(self, in_size, out_size, key): | ||
self.biases = jax.random.normal(key, (out_size,)) | ||
self.weights = evm.Parameter(value=jax.random.normal(key, (out_size, in_size))) | ||
self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5)) | ||
|
||
def __call__(self, x: jax.Array): | ||
return self.weights.value @ x + self.biases | ||
|
||
|
||
@eqx.filter_jit | ||
def loss_fn(model, x, y): | ||
pred_y = jax.vmap(model)(x) | ||
mse = jax.numpy.mean((y - pred_y) ** 2) | ||
constraints = evm.loss.get_param_constraints(model) | ||
# sum them all up for each weight | ||
constraints = jax.tree_util.tree_map(jnp.sum, constraints) | ||
return mse + evm.util.sum_leaves(constraints) | ||
|
||
|
||
batch_size, in_size, out_size = 32, 2, 3 | ||
model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0)) | ||
x = jax.numpy.zeros((batch_size, in_size)) | ||
y = jax.numpy.zeros((batch_size, out_size)) | ||
loss_val = loss_fn(model, x, y) | ||
grads = eqx.filter_grad(loss_fn)(model, x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +0,0 @@ | ||
from __future__ import annotations | ||
|
||
import equinox as eqx | ||
from jax import config | ||
from model import init_values, model, observation, optimizer | ||
|
||
import evermore as evm | ||
|
||
config.update("jax_enable_x64", True) | ||
|
||
# create negative log likelihood | ||
nll = evm.likelihood.NLL(model=model, observation=observation) | ||
|
||
# fit | ||
params, state = optimizer.fit(fun=nll, init_values=init_values) | ||
|
||
# gradients of nll of fitted model | ||
fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll)) | ||
grads = fast_grad_nll(params) | ||
# gradients of nll of fitted model only wrt to `mu` | ||
# basically: pass the parameters dict of which you want the gradients | ||
params_ = {k: v for k, v in params.items() if k == "mu"} | ||
grad_mu = fast_grad_nll(params_) | ||
|
||
# hessian + cov_matrix of nll of fitted model | ||
hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))() | ||
|
||
# covariance matrix of fitted model | ||
covmatrix = eqx.filter_jit( | ||
evm.likelihood.CovMatrix(model=model, observation=observation) | ||
)() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,94 +1,79 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import evermore as evm | ||
|
||
|
||
class SPlusBModel(evm.Model): | ||
def __call__(self, processes: dict, parameters: dict) -> evm.Result: | ||
res = evm.Result() | ||
class SPlusBModel(eqx.Module): | ||
mu: evm.Parameter | ||
norm1: evm.Parameter | ||
norm2: evm.Parameter | ||
shape1: evm.Parameter | ||
|
||
mu_modifier = evm.modifier( | ||
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained() | ||
) | ||
res.add( | ||
process="signal", | ||
expectation=mu_modifier(processes[("signal", "nominal")]), | ||
) | ||
def __init__(self) -> None: | ||
self.mu = evm.Parameter(value=jnp.array([1.0])) | ||
self.norm1 = evm.Parameter() | ||
self.norm2 = evm.Parameter() | ||
self.shape1 = evm.Parameter() | ||
|
||
bkg1_modifier = evm.compose( | ||
evm.modifier( | ||
name="lnN1", | ||
parameter=parameters["norm1"], | ||
effect=evm.effect.lnN((0.9, 1.1)), | ||
), | ||
evm.modifier( | ||
name="shape1_bkg1", | ||
parameter=parameters["shape1"], | ||
effect=evm.effect.shape( | ||
up=processes[("background1", "shape_up")], | ||
down=processes[("background1", "shape_down")], | ||
), | ||
), | ||
) | ||
res.add( | ||
process="background1", | ||
expectation=bkg1_modifier(processes[("background1", "nominal")]), | ||
) | ||
def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: | ||
expectations = {} | ||
|
||
bkg2_modifier = evm.compose( | ||
evm.modifier( | ||
name="lnN2", | ||
parameter=parameters["norm2"], | ||
effect=evm.effect.lnN((0.95, 1.05)), | ||
), | ||
evm.modifier( | ||
name="shape1_bkg2", | ||
parameter=parameters["shape1"], | ||
effect=evm.effect.shape( | ||
up=processes[("background2", "shape_up")], | ||
down=processes[("background2", "shape_down")], | ||
), | ||
), | ||
# signal process | ||
sig_mod = self.mu.unconstrained() | ||
expectations["signal"] = sig_mod(hists[("signal", "nominal")]) | ||
|
||
# bkg1 process | ||
bkg1_mod = self.norm1.lnN(width=jnp.array([0.9, 1.1])) @ self.shape1.shape( | ||
up=hists[("bkg1", "shape_up")], | ||
down=hists[("bkg1", "shape_down")], | ||
) | ||
res.add( | ||
process="background2", | ||
expectation=bkg2_modifier(processes[("background2", "nominal")]), | ||
expectations["bkg1"] = bkg1_mod(hists[("bkg1", "nominal")]) | ||
|
||
# bkg2 process | ||
bkg2_mod = self.norm2.lnN(width=jnp.array([0.95, 1.05])) @ self.shape1.shape( | ||
up=hists[("bkg2", "shape_up")], | ||
down=hists[("bkg2", "shape_down")], | ||
) | ||
return res | ||
expectations["bkg2"] = bkg2_mod(hists[("bkg2", "nominal")]) | ||
|
||
# return the modified expectations | ||
return expectations | ||
|
||
def create_model(): | ||
processes = { | ||
("signal", "nominal"): jnp.array([3]), | ||
("background1", "nominal"): jnp.array([10]), | ||
("background2", "nominal"): jnp.array([20]), | ||
("background1", "shape_up"): jnp.array([12]), | ||
("background1", "shape_down"): jnp.array([8]), | ||
("background2", "shape_up"): jnp.array([23]), | ||
("background2", "shape_down"): jnp.array([19]), | ||
} | ||
parameters = { | ||
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), | ||
"norm1": evm.Parameter(value=jnp.array([0.0])), | ||
"norm2": evm.Parameter(value=jnp.array([0.0])), | ||
"shape1": evm.Parameter(value=jnp.array([0.0])), | ||
} | ||
|
||
# return model | ||
return SPlusBModel(processes=processes, parameters=parameters) | ||
model = SPlusBModel() | ||
|
||
|
||
model = create_model() | ||
hists = { | ||
("signal", "nominal"): jnp.array([3]), | ||
("bkg1", "nominal"): jnp.array([10]), | ||
("bkg2", "nominal"): jnp.array([20]), | ||
("bkg1", "shape_up"): jnp.array([12]), | ||
("bkg1", "shape_down"): jnp.array([8]), | ||
("bkg2", "shape_up"): jnp.array([23]), | ||
("bkg2", "shape_down"): jnp.array([19]), | ||
} | ||
|
||
init_values = model.parameter_values | ||
observation = jnp.array([37]) | ||
asimov = model.evaluate().expectation() | ||
|
||
nll = evm.loss.PoissonNLL() | ||
|
||
|
||
@eqx.filter_jit | ||
def loss(model, hists, observation): | ||
expectations = model(hists) | ||
constraints = evm.loss.get_param_constraints(model) | ||
return nll( | ||
expectation=evm.util.sum_leaves(expectations), | ||
observation=observation, | ||
constraint=evm.util.sum_leaves(constraints), | ||
) | ||
|
||
|
||
# create optimizer (from `jaxopt`) | ||
optimizer = evm.optimizer.JaxOptimizer.make( | ||
name="LBFGS", | ||
settings={"maxiter": 5, "jit": True, "unroll": True}, | ||
) | ||
loss_val = loss(model, hists, observation) | ||
grads = eqx.filter_grad(loss)(model, hists, observation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +0,0 @@ | ||
from __future__ import annotations | ||
|
||
from jax import config | ||
from model import init_values, model, observation, optimizer | ||
|
||
from evermore.likelihood import NLL | ||
|
||
config.update("jax_enable_x64", True) | ||
|
||
|
||
# create negative log likelihood | ||
nll = NLL(model=model, observation=observation) | ||
|
||
# fit | ||
values, state = optimizer.fit(fun=nll, init_values=init_values) | ||
|
||
# update model with fitted values | ||
fitted_model = model.update(values=values) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +0,0 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import config | ||
from model import asimov, model, optimizer | ||
|
||
from evermore import Model | ||
from evermore.likelihood import NLL | ||
from evermore.optimizer import JaxOptimizer | ||
|
||
config.update("jax_enable_x64", True) | ||
|
||
|
||
def nll_profiling( | ||
value_name: str, | ||
scan_points: jax.Array, | ||
model: Model, | ||
observation: jax.Array, | ||
optimizer: JaxOptimizer, | ||
fit: bool, | ||
) -> jax.Array: | ||
# define single fit for a fixed parameter of interest (poi) | ||
@partial(jax.jit, static_argnames=("value_name", "optimizer", "fit")) | ||
def fixed_poi_fit( | ||
value_name: str, | ||
scan_point: jax.Array, | ||
model: Model, | ||
observation: jax.Array, | ||
optimizer: JaxOptimizer, | ||
fit: bool, | ||
) -> jax.Array: | ||
# fix theta into the model | ||
model = model.update(values={value_name: scan_point}) | ||
init_values = model.parameter_values | ||
init_values.pop(value_name, 1) | ||
# minimize | ||
nll = eqx.filter_jit(NLL(model=model, observation=observation)) | ||
if fit: | ||
values, _ = optimizer.fit(fun=nll, init_values=init_values) | ||
else: | ||
values = model.parameter_values | ||
return nll(values=values) | ||
|
||
# vectorise for multiple fixed values (scan points) | ||
fixed_poi_fit_vec = jax.vmap( | ||
fixed_poi_fit, in_axes=(None, 0, None, None, None, None) | ||
) | ||
return fixed_poi_fit_vec( | ||
value_name, scan_points, model, observation, optimizer, fit | ||
) | ||
|
||
|
||
# profile the NLL around starting point of `0` | ||
scan_points = jnp.r_[-1.9:2.0:0.1] | ||
|
||
profile_postfit = nll_profiling( | ||
value_name="norm1", | ||
scan_points=scan_points, | ||
model=model, | ||
observation=asimov, | ||
optimizer=optimizer, | ||
fit=True, | ||
) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,43 @@ | ||
from __future__ import annotations | ||
from typing import Any | ||
|
||
import equinox as eqx | ||
import jax | ||
from jax import config | ||
from model import init_values, model, observation, optimizer | ||
from jaxtyping import Array, PRNGKeyArray | ||
from model import hists, model, observation | ||
|
||
from evermore.likelihood import NLL, SampleToy | ||
import evermore as evm | ||
|
||
config.update("jax_enable_x64", True) | ||
key = jax.random.PRNGKey(0) | ||
|
||
# generate a new model with sampled parameters according to their constraint pdfs | ||
toymodel = evm.sample.toy_module(model, key) | ||
|
||
# create negative log likelihood | ||
nll = NLL(model=model, observation=observation) | ||
|
||
# fit | ||
values, state = optimizer.fit(fun=nll, init_values=init_values) | ||
# generate new expectation based on the toy model | ||
def toy_expectation( | ||
key: PRNGKeyArray, | ||
module: eqx.Module, | ||
hists: dict[Any, Array], | ||
) -> Array: | ||
toymodel = evm.sample.toy_module(model, key) | ||
expectations = toymodel(hists) | ||
return evm.util.sum_leaves(expectations) | ||
|
||
# create sampling method | ||
sample_toy = SampleToy(model=model, observation=observation) | ||
# vectorise and jit | ||
sample_toys = eqx.filter_vmap(in_axes=(None, 0))(eqx.filter_jit(sample_toy)) | ||
|
||
sample_toy(values, jax.random.PRNGKey(1234)) | ||
expectation = toy_expectation(key, model, hists) | ||
|
||
# sample 10 toys based on fitted parameters | ||
keys = jax.random.split(jax.random.PRNGKey(1234), num=10) | ||
# postfit toys | ||
toys_postfit = sample_toys(values, keys) | ||
# prefit toys | ||
toys_prefit = sample_toys(init_values, keys) | ||
|
||
# generate a new expectations vectorized over many keys | ||
keys = jax.random.split(key, 1000) | ||
|
||
# vectorized toy expectation | ||
toy_expectation_vec = jax.vmap(toy_expectation, in_axes=(0, None, None)) | ||
expectations = toy_expectation_vec(keys, model, hists) | ||
|
||
|
||
# just sample observations with poisson | ||
poisson_obs = evm.pdf.Poisson(observation) | ||
sampled_observation = poisson_obs.sample(key) | ||
|
||
# vectorized sampling | ||
sampled_observations = jax.vmap(poisson_obs.sample)(keys) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.