Skip to content

Commit

Permalink
new optimizer features; fixed some effects; notebook for visualizatio…
Browse files Browse the repository at this point in the history
…n of new parameters
  • Loading branch information
pfackeldey committed Sep 27, 2023
1 parent b9d1807 commit 86afe0b
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 93 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,7 @@ Thumbs.db
*.swp

# if examples are run
examples/*.eqx
examples/*.eqx

test/
.vscode/
65 changes: 43 additions & 22 deletions examples/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, lnN, modifier, unconstrained
from dilax.parameter import Parameter, compose, lnN, modifier, shape, unconstrained
from dilax.util import HistDB


Expand All @@ -18,45 +18,65 @@ def __call__(
) -> Result:
res = Result()

mu_modifier = modifier(
name="mu", parameter=parameters["mu"], effect=unconstrained()
)
res.add(
process="signal",
expectation=modifier(
name="mu",
parameter=parameters["mu"],
effect=unconstrained(),
)(processes["signal"]),
expectation=mu_modifier(processes["signal", "nominal"]),
)

bkg1_modifier = compose(
modifier(name="lnN1", parameter=parameters["norm1"], effect=lnN(0.1)),
modifier(
name="shape1_bkg1",
parameter=parameters["shape1"],
effect=shape(
up=processes["background1", "shape_up"],
down=processes["background1", "shape_down"],
),
),
)
res.add(
process="background",
expectation=modifier(
name="lnN1",
parameter=parameters["norm1"],
effect=lnN(0.1),
)(processes["background1"]),
process="background1",
expectation=bkg1_modifier(processes["background1", "nominal"]),
)

bkg2_modifier = compose(
modifier(name="lnN2", parameter=parameters["norm2"], effect=lnN(0.05)),
modifier(
name="shape1_bkg2",
parameter=parameters["shape1"],
effect=shape(
up=processes["background2", "shape_up"],
down=processes["background2", "shape_down"],
),
),
)
res.add(
process="background2",
expectation=modifier(
name="lnN2",
parameter=parameters["norm1"],
effect=lnN(0.05),
)(processes["background2"]),
expectation=bkg2_modifier(processes["background2", "nominal"]),
)
return res


def create_model():
processes = HistDB(
{
"signal": jnp.array([3]),
"background1": jnp.array([10]),
"background2": jnp.array([20]),
("signal", "nominal"): jnp.array([3]),
("background1", "nominal"): jnp.array([10]),
("background2", "nominal"): jnp.array([20]),
("background1", "shape_up"): jnp.array([12]),
("background1", "shape_down"): jnp.array([8]),
("background2", "shape_up"): jnp.array([23]),
("background2", "shape_down"): jnp.array([19]),
}
)
parameters = {
"mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"norm1": Parameter(value=jnp.array([0.0]), bounds=(-jnp.inf, jnp.inf)),
"norm2": Parameter(value=jnp.array([0.0]), bounds=(-jnp.inf, jnp.inf)),
"norm1": Parameter(value=jnp.array([0.0])),
"norm2": Parameter(value=jnp.array([0.0])),
"shape1": Parameter(value=jnp.array([0.0])),
}

# return model
Expand All @@ -69,6 +89,7 @@ def create_model():

init_values = model.parameter_values
observation = jnp.array([37])
asimov = model.evaluate().expectation()


# create optimizer (from `jaxopt`)
Expand Down
30 changes: 21 additions & 9 deletions examples/nll_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from examples.model import asimov, model, optimizer
from jax.config import config

from dilax.likelihood import NLL
from dilax.model import Model
from dilax.optimizer import JaxOptimizer
from model import model, observation, optimizer

config.update("jax_enable_x64", True)

Expand All @@ -21,35 +21,47 @@ def nll_profiling(
model: Model,
observation: jax.Array,
optimizer: JaxOptimizer,
fit: bool,
) -> jax.Array:
# define single fit for a fixed parameter of interest (poi)
@partial(jax.jit, static_argnames=("value_name", "optimizer"))
@partial(jax.jit, static_argnames=("value_name", "optimizer", "fit"))
def fixed_poi_fit(
value_name: str,
scan_point: jax.Array,
model: Model,
observation: jax.Array,
optimizer: JaxOptimizer,
fit: bool,
) -> jax.Array:
# fix theta into the model
model = model.update(values={value_name: scan_point})
init_values = model.parameter_values
init_values.pop(value_name, 1)
# minimize
nll = eqx.filter_jit(NLL(model=model, observation=observation))
values, _ = optimizer.fit(fun=nll, init_values=init_values)
if fit:
values, _ = optimizer.fit(fun=nll, init_values=init_values)
else:
values = model.parameter_values
return nll(values=values)

# vectorise for multiple fixed values (scan points)
fixed_poi_fit_vec = jax.vmap(fixed_poi_fit, in_axes=(None, 0, None, None, None))
return fixed_poi_fit_vec(value_name, scan_points, model, observation, optimizer)
fixed_poi_fit_vec = jax.vmap(
fixed_poi_fit, in_axes=(None, 0, None, None, None, None)
)
return fixed_poi_fit_vec(
value_name, scan_points, model, observation, optimizer, fit
)


# profile the NLL around starting point of `0`
profile = nll_profiling(
value_name="norm2",
scan_points=jnp.array([-0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3]),
scan_points = jnp.r_[-1.9:2.0:0.1]

profile_postfit = nll_profiling(
value_name="norm1",
scan_points=scan_points,
model=model,
observation=observation,
observation=asimov,
optimizer=optimizer,
fit=True,
)
181 changes: 181 additions & 0 deletions examples/nuisance_parameter.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = true
strict = false
ignore_missing_imports = true


[tool.ruff]
select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
Expand Down
56 changes: 38 additions & 18 deletions src/dilax/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Hashable
from typing import Callable
from typing import Any, Callable

import equinox as eqx
import jax
Expand All @@ -15,23 +15,10 @@ class JaxOptimizer(eqx.Module):
Example:
```
optimizer = JaxOptimizer.make(
name="ScipyMinimize",
settings={"method": "trust-constr"},
)
# or
optimizer = JaxOptimizer.make(
name="LBFGS",
settings={
"maxiter": 30,
"tol": 1e-6,
"jit": True,
"unroll": True,
},
)
optimizer = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5})
# or, e.g.: optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10})
optimizer.fit(fun=nll, init_values=init_values)
```
"""

Expand All @@ -55,6 +42,39 @@ def settings(self) -> dict[str, Hashable]:
def solver_instance(self, fun: Callable) -> jaxopt._src.base.Solver:
return getattr(jaxopt, self.name)(fun=fun, **self.settings)

def fit(self, fun: Callable, init_values: dict[str, float]) -> jax.Array:
def fit(
self, fun: Callable, init_values: dict[str, jax.Array]
) -> tuple[dict[str, jax.Array], Any]:
values, state = self.solver_instance(fun=fun).run(init_values)
return values, state


class Chain(eqx.Module):
"""
Chain multiple optimizers together.
They probably should have the `maxiter` setting set to a value,
in order to have a deterministic runtime behaviour.
Example:
```
opt1 = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5})
opt2 = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10})
chain = Chain(opt1, opt2)
# first 5 steps are minimized with GradientDescent, then 10 steps with LBFGS
chain.fit(fun=nll, init_values=init_values)
```
"""

optimizers: tuple[JaxOptimizer, ...]

def __init__(self, *optimizers: JaxOptimizer) -> None:
self.optimizers = optimizers

def fit(
self, fun: Callable, init_values: dict[str, jax.Array]
) -> tuple[dict[str, jax.Array], Any]:
values = init_values
for optimizer in self.optimizers:
values, state = optimizer.fit(fun=fun, init_values=values)
return values, state
Loading

0 comments on commit 86afe0b

Please sign in to comment.