diff --git a/examples/model.py b/examples/model.py index aba8d3d..5c32053 100644 --- a/examples/model.py +++ b/examples/model.py @@ -19,7 +19,9 @@ def __call__(self, processes: dict, parameters: dict) -> dlx.Result: bkg1_modifier = dlx.compose( dlx.modifier( - name="lnN1", parameter=parameters["norm1"], effect=dlx.effect.lnN(0.1) + name="lnN1", + parameter=parameters["norm1"], + effect=dlx.effect.lnN((0.9, 1.1)), ), dlx.modifier( name="shape1_bkg1", @@ -37,7 +39,9 @@ def __call__(self, processes: dict, parameters: dict) -> dlx.Result: bkg2_modifier = dlx.compose( dlx.modifier( - name="lnN2", parameter=parameters["norm2"], effect=dlx.effect.lnN(0.05) + name="lnN2", + parameter=parameters["norm2"], + effect=dlx.effect.lnN((0.95, 1.05)), ), dlx.modifier( name="shape1_bkg2", diff --git a/src/dilax/effect.py b/src/dilax/effect.py index 076d78f..2095b3a 100644 --- a/src/dilax/effect.py +++ b/src/dilax/effect.py @@ -10,6 +10,8 @@ from dilax.pdf import Flat, Gauss, HashablePDF, Poisson from dilax.util import as1darray +ArrayLike = jax.typing.ArrayLike + if TYPE_CHECKING: from dilax.parameter import Parameter @@ -51,9 +53,9 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class gauss(Effect): - width: jax.Array = eqx.field(static=True, converter=as1darray) + width: ArrayLike = eqx.field(static=True, converter=as1darray) - def __init__(self, width: jax.Array) -> None: + def __init__(self, width: ArrayLike) -> None: self.width = width @property @@ -124,21 +126,29 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class lnN(Effect): - width: jax.Array | tuple[jax.Array, jax.Array] = eqx.field(static=True) + width: tuple[ArrayLike, ArrayLike] = eqx.field(static=True) def __init__( self, - width: jax.Array | tuple[jax.Array, jax.Array], + width: tuple[ArrayLike, ArrayLike], ) -> None: self.width = width - def scale(self, parameter: Parameter) -> jax.Array: - if isinstance(self.width, tuple): - down, up = self.width - scale = jnp.where(parameter.value > 0, up, down) - else: - scale = self.width - return scale + def interpolate(self, parameter: Parameter) -> jax.Array: + # https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129 + x = parameter.value + lo, hi = self.width + hi = jnp.log(hi) + lo = jnp.log(lo) + lo = -lo + avg = 0.5 * (hi + lo) + halfdiff = 0.5 * (hi - lo) + twox = x + x + twox2 = twox * twox + alpha = 0.125 * twox * (twox2 * (3 * twox2 - 10.0) + 15.0) + return jnp.where( + jnp.abs(x) >= 0.5, jnp.where(x >= 0, hi, lo), avg + alpha * halfdiff + ) @property def constraint(self) -> HashablePDF: @@ -159,10 +169,10 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: .. code-block:: python - return jnp.exp(parameter.value * self.scale(parameter=parameter)) + return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) """ - return jnp.exp(parameter.value * self.scale(parameter=parameter)) + return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) class poisson(Effect): diff --git a/src/dilax/likelihood.py b/src/dilax/likelihood.py index f5bf087..b0aca47 100644 --- a/src/dilax/likelihood.py +++ b/src/dilax/likelihood.py @@ -47,14 +47,16 @@ def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: values = self.model.parameter_values model = self.model.update(values=values) res = model.evaluate() - nll = self.logpdf(self.observation, res.expectation()) - self.logpdf( - self.observation, self.observation + nll = jnp.sum( + self.logpdf(self.observation, res.expectation()) + - self.logpdf(self.observation, self.observation), + axis=-1, ) # add constraints constraints = jax.tree_util.tree_leaves(model.parameter_constraints()) nll += sum(constraints) nll += model.nll_boundary_penalty() - return -jnp.sum(nll, axis=-1) + return -jnp.sum(nll) class Hessian(BaseModule): diff --git a/src/dilax/util.py b/src/dilax/util.py index aa4de84..cc7f667 100644 --- a/src/dilax/util.py +++ b/src/dilax/util.py @@ -8,6 +8,8 @@ import jax import jax.numpy as jnp +ArrayLike = jax.typing.ArrayLike + __all__ = [ "HistDB", "FrozenDB", @@ -256,7 +258,7 @@ class HistDB(FrozenDB): ) -def as1darray(x: float | jax.Array) -> jax.Array: +def as1darray(x: ArrayLike) -> jax.Array: """ Converts `x` to a 1d array. diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 276e874..57df052 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -42,7 +42,7 @@ def test_gauss(): def test_lnN(): p = dlx.Parameter(value=jnp.array(0.0)) - ln = dlx.effect.lnN(width=jnp.array(0.1)) + ln = dlx.effect.lnN(width=(0.9, 1.1)) assert ln.constraint == Gauss(mean=0.0, width=1.0) assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -80,7 +80,7 @@ def test_modifier(): # lnN effect m_lnN = dlx.modifier( - name="norm", parameter=norm, effect=dlx.effect.lnN(jnp.array(0.1)) + name="norm", parameter=norm, effect=dlx.effect.lnN(width=(0.9, 1.1)) ) assert m_lnN(jnp.array(10)) == pytest.approx(10)