Skip to content

Commit

Permalink
fix nll bug; properly implement asymm lnN interpolation like combine
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 5, 2023
1 parent 280645f commit 5a4462f
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 21 deletions.
8 changes: 6 additions & 2 deletions examples/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __call__(self, processes: dict, parameters: dict) -> dlx.Result:

bkg1_modifier = dlx.compose(
dlx.modifier(
name="lnN1", parameter=parameters["norm1"], effect=dlx.effect.lnN(0.1)
name="lnN1",
parameter=parameters["norm1"],
effect=dlx.effect.lnN((0.9, 1.1)),
),
dlx.modifier(
name="shape1_bkg1",
Expand All @@ -37,7 +39,9 @@ def __call__(self, processes: dict, parameters: dict) -> dlx.Result:

bkg2_modifier = dlx.compose(
dlx.modifier(
name="lnN2", parameter=parameters["norm2"], effect=dlx.effect.lnN(0.05)
name="lnN2",
parameter=parameters["norm2"],
effect=dlx.effect.lnN((0.95, 1.05)),
),
dlx.modifier(
name="shape1_bkg2",
Expand Down
36 changes: 23 additions & 13 deletions src/dilax/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dilax.pdf import Flat, Gauss, HashablePDF, Poisson
from dilax.util import as1darray

ArrayLike = jax.typing.ArrayLike

if TYPE_CHECKING:
from dilax.parameter import Parameter

Expand Down Expand Up @@ -51,9 +53,9 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:


class gauss(Effect):
width: jax.Array = eqx.field(static=True, converter=as1darray)
width: ArrayLike = eqx.field(static=True, converter=as1darray)

def __init__(self, width: jax.Array) -> None:
def __init__(self, width: ArrayLike) -> None:
self.width = width

@property
Expand Down Expand Up @@ -124,21 +126,29 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:


class lnN(Effect):
width: jax.Array | tuple[jax.Array, jax.Array] = eqx.field(static=True)
width: tuple[ArrayLike, ArrayLike] = eqx.field(static=True)

def __init__(
self,
width: jax.Array | tuple[jax.Array, jax.Array],
width: tuple[ArrayLike, ArrayLike],
) -> None:
self.width = width

def scale(self, parameter: Parameter) -> jax.Array:
if isinstance(self.width, tuple):
down, up = self.width
scale = jnp.where(parameter.value > 0, up, down)
else:
scale = self.width
return scale
def interpolate(self, parameter: Parameter) -> jax.Array:
# https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129
x = parameter.value
lo, hi = self.width
hi = jnp.log(hi)
lo = jnp.log(lo)
lo = -lo
avg = 0.5 * (hi + lo)
halfdiff = 0.5 * (hi - lo)
twox = x + x
twox2 = twox * twox
alpha = 0.125 * twox * (twox2 * (3 * twox2 - 10.0) + 15.0)
return jnp.where(
jnp.abs(x) >= 0.5, jnp.where(x >= 0, hi, lo), avg + alpha * halfdiff
)

@property
def constraint(self) -> HashablePDF:
Expand All @@ -159,10 +169,10 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array:
.. code-block:: python
return jnp.exp(parameter.value * self.scale(parameter=parameter))
return jnp.exp(parameter.value * self.interpolate(parameter=parameter))
"""
return jnp.exp(parameter.value * self.scale(parameter=parameter))
return jnp.exp(parameter.value * self.interpolate(parameter=parameter))


class poisson(Effect):
Expand Down
8 changes: 5 additions & 3 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array:
values = self.model.parameter_values
model = self.model.update(values=values)
res = model.evaluate()
nll = self.logpdf(self.observation, res.expectation()) - self.logpdf(
self.observation, self.observation
nll = jnp.sum(
self.logpdf(self.observation, res.expectation())
- self.logpdf(self.observation, self.observation),
axis=-1,
)
# add constraints
constraints = jax.tree_util.tree_leaves(model.parameter_constraints())
nll += sum(constraints)
nll += model.nll_boundary_penalty()
return -jnp.sum(nll, axis=-1)
return -jnp.sum(nll)


class Hessian(BaseModule):
Expand Down
4 changes: 3 additions & 1 deletion src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import jax
import jax.numpy as jnp

ArrayLike = jax.typing.ArrayLike

__all__ = [
"HistDB",
"FrozenDB",
Expand Down Expand Up @@ -256,7 +258,7 @@ class HistDB(FrozenDB):
)


def as1darray(x: float | jax.Array) -> jax.Array:
def as1darray(x: ArrayLike) -> jax.Array:
"""
Converts `x` to a 1d array.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_gauss():

def test_lnN():
p = dlx.Parameter(value=jnp.array(0.0))
ln = dlx.effect.lnN(width=jnp.array(0.1))
ln = dlx.effect.lnN(width=(0.9, 1.1))

assert ln.constraint == Gauss(mean=0.0, width=1.0)
assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_modifier():

# lnN effect
m_lnN = dlx.modifier(
name="norm", parameter=norm, effect=dlx.effect.lnN(jnp.array(0.1))
name="norm", parameter=norm, effect=dlx.effect.lnN(width=(0.9, 1.1))
)
assert m_lnN(jnp.array(10)) == pytest.approx(10)

Expand Down

0 comments on commit 5a4462f

Please sign in to comment.