diff --git a/README.md b/README.md index ca50535..4ff1702 100644 --- a/README.md +++ b/README.md @@ -38,21 +38,21 @@ jax.config.update("jax_enable_x64", True) # define a simple model with two processes and two parameters -class MyModel(dlx.model.Model): +class MyModel(dlx.Model): def __call__( - self, processes: dict, parameters: dict[str, dlx.parameter.Parameter] - ) -> dlx.model.Result: - res = dlx.model.Result() + self, processes: dict, parameters: dict[str, dlx.Parameter] + ) -> dlx.Result: + res = dlx.Result() # signal - mu_mod = dlx.parameter.modifier( - name="mu", parameter=parameters["mu"], effect=dlx.parameter.unconstrained() + mu_mod = dlx.modifier( + name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained() ) res.add(process="signal", expectation=mu_mod(processes["signal"])) # background - bkg_mod = dlx.parameter.modifier( - name="sigma", parameter=parameters["sigma"], effect=dlx.parameter.gauss(0.2) + bkg_mod = dlx.modifier( + name="sigma", parameter=parameters["sigma"], effect=dlx.effect.gauss(0.2) ) res.add(process="background", expectation=bkg_mod(processes["background"])) return res @@ -61,8 +61,8 @@ class MyModel(dlx.model.Model): # setup model processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])} parameters = { - "mu": dlx.parameter.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "sigma": dlx.parameter.Parameter(value=jnp.array([0.0])), + "mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), + "sigma": dlx.Parameter(value=jnp.array([0.0])), } model = MyModel(processes=processes, parameters=parameters) diff --git a/src/dilax/__init__.py b/src/dilax/__init__.py index 5ba6744..bc77d56 100644 --- a/src/dilax/__init__.py +++ b/src/dilax/__init__.py @@ -16,14 +16,20 @@ # expose public API __all__ = [ + "effect", "ipy_util", "likelihood", - "model", "optimizer", - "parameter", # "pdf", # this should not be needed in public API "util", "__version__", + # explicitely expose some classes + "Model", + "Result", + "Parameter", + "modifier", + "staterror", + "compose", ] @@ -32,11 +38,12 @@ def __dir__(): from dilax import ( # noqa: E402 + effect, ipy_util, likelihood, - model, optimizer, - parameter, # pdf, # this should not be needed in public API util, ) +from dilax.model import Model, Result # noqa: E402 +from dilax.parameter import Parameter, compose, modifier, staterror # noqa: E402 diff --git a/src/dilax/effect.py b/src/dilax/effect.py new file mode 100644 index 0000000..076d78f --- /dev/null +++ b/src/dilax/effect.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING + +import equinox as eqx +import jax +import jax.numpy as jnp + +from dilax.pdf import Flat, Gauss, HashablePDF, Poisson +from dilax.util import as1darray + +if TYPE_CHECKING: + from dilax.parameter import Parameter + +__all__ = [ + "Effect", + "unconstrained", + "gauss", + "lnN", + "poisson", + "shape", +] + + +def __dir__(): + return __all__ + + +class Effect(eqx.Module): + @property + @abc.abstractmethod + def constraint(self) -> HashablePDF: + ... + + @abc.abstractmethod + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + ... + + +class unconstrained(Effect): + @property + def constraint(self) -> HashablePDF: + return Flat() + + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + return parameter.value + + +DEFAULT_EFFECT = unconstrained() + + +class gauss(Effect): + width: jax.Array = eqx.field(static=True, converter=as1darray) + + def __init__(self, width: jax.Array) -> None: + self.width = width + + @property + def constraint(self) -> HashablePDF: + return Gauss(mean=0.0, width=1.0) + + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + """ + Implementation with (inverse) CDFs is defined as follows: + + .. code-block:: python + + gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type] + g1 = Gauss(mean=1.0, width=1.0) + + return gx.inv_cdf(g1.cdf(parameter.value + 1)) + + But we can use the fast analytical solution instead: + + .. code-block:: python + + return (parameter.value * self.width) + 1 + + """ + return (parameter.value * self.width) + 1 + + +class shape(Effect): + up: jax.Array = eqx.field(converter=as1darray) + down: jax.Array = eqx.field(converter=as1darray) + + def __init__( + self, + up: jax.Array, + down: jax.Array, + ) -> None: + self.up = up # +1 sigma + self.down = down # -1 sigma + + @eqx.filter_jit + def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: + factor = sf + dx_sum = self.up + self.down - 2 * sumw + dx_diff = self.up - self.down + + # taken from https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L173C6-L192 + _asym_poly = jnp.array([3.0, -10.0, 15.0, 0.0]) / 8.0 + + abs_value = jnp.abs(factor) + return 0.5 * ( + dx_diff * factor + + dx_sum + * jnp.where( + abs_value > 1.0, + abs_value, + jnp.polyval(_asym_poly, factor * factor), + ) + ) + + @property + def constraint(self) -> HashablePDF: + return Gauss(mean=0.0, width=1.0) + + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + sf = parameter.value + # clip, no negative values are allowed + return jnp.maximum((sumw + self.vshift(sf=sf, sumw=sumw)) / sumw, 0.0) + + +class lnN(Effect): + width: jax.Array | tuple[jax.Array, jax.Array] = eqx.field(static=True) + + def __init__( + self, + width: jax.Array | tuple[jax.Array, jax.Array], + ) -> None: + self.width = width + + def scale(self, parameter: Parameter) -> jax.Array: + if isinstance(self.width, tuple): + down, up = self.width + scale = jnp.where(parameter.value > 0, up, down) + else: + scale = self.width + return scale + + @property + def constraint(self) -> HashablePDF: + return Gauss(mean=0.0, width=1.0) + + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + """ + Implementation with (inverse) CDFs is defined as follows: + + .. code-block:: python + + gx = Gauss(mean=jnp.exp(parameter.value), width=width) # type: ignore[arg-type] + g1 = Gauss(mean=1.0, width=1.0) + + return gx.inv_cdf(g1.cdf(parameter.value + 1)) + + But we can use the fast analytical solution instead: + + .. code-block:: python + + return jnp.exp(parameter.value * self.scale(parameter=parameter)) + + """ + return jnp.exp(parameter.value * self.scale(parameter=parameter)) + + +class poisson(Effect): + lamb: jax.Array = eqx.field(static=True, converter=as1darray) + + def __init__(self, lamb: jax.Array) -> None: + self.lamb = lamb + + @property + def constraint(self) -> HashablePDF: + return Poisson(lamb=self.lamb) + + def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + return parameter.value + 1 diff --git a/src/dilax/model.py b/src/dilax/model.py index c268cbe..70a6a6e 100644 --- a/src/dilax/model.py +++ b/src/dilax/model.py @@ -34,31 +34,30 @@ def expectation(self) -> jax.Array: class Model(eqx.Module): """ A model describing nuisance parameters, templates (histograms), and how they interact. - It is requires to implement the `evaluate` method, which returns an `EvaluationResult` object. + It is requires to implement the `evaluate` method, which returns an `Result` object. Example: .. code-block:: python + import equinox as eqx import jax import jax.numpy as jnp - import equinox as eqx - from dilax.model import Model, Result - from dilax.parameter import Parameter, lnN, modifier, unconstrained + import dilax as dlx # Define a simple model with two processes and two parameters - class MyModel(Model): - def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result: - res = Result() + class MyModel(dlx.Model): + def __call__(self, processes: dict, parameters: dict[str, dlx.Parameter]) -> dlx.Result: + res = dlx.Result() # signal - mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained()) + mu_mod = dlx.modifier(name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained()) res.add(process="signal", expectation=mu_mod(processes["signal"])) # background - bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=lnN(0.2)) + bkg_mod = dlx.modifier(name="sigma", parameter=parameters["sigma"], effect=dlx.effect.lnN(0.2)) res.add(process="background", expectation=bkg_mod(processes["background"])) return res @@ -66,8 +65,8 @@ def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result: # Setup model processes = {"signal": jnp.array([10]), "background": jnp.array([50])} parameters = { - "mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "sigma": Parameter(value=jnp.array([0.0])), + "mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), + "sigma": dlx.Parameter(value=jnp.array([0.0])), } model = MyModel(processes=processes, parameters=parameters) @@ -77,7 +76,7 @@ def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result: # -> Array([60.], dtype=float32) %timeit model.evaluate().expectation() - # -> 3.05 ms ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) + # -> 485 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) # evaluate the expectation *fast* @eqx.filter_jit @@ -89,7 +88,7 @@ def eval(model) -> jax.Array: # -> Array([60.], dtype=float32) %timeit eqx.filter_jit(eval)(model).block_until_ready() - # -> 114 µs ± 327 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) + # -> 202 µs ± 4.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) """ processes: dict diff --git a/src/dilax/parameter.py b/src/dilax/parameter.py index e755aa3..2f0731c 100644 --- a/src/dilax/parameter.py +++ b/src/dilax/parameter.py @@ -1,22 +1,25 @@ from __future__ import annotations import abc +from typing import TYPE_CHECKING import equinox as eqx import jax import jax.numpy as jnp -from dilax.pdf import Flat, Gauss, HashablePDF, Poisson +from dilax.effect import ( + DEFAULT_EFFECT, + gauss, + poisson, +) +from dilax.pdf import HashablePDF from dilax.util import as1darray +if TYPE_CHECKING: + from dilax.effect import Effect + __all__ = [ "Parameter", - "Effect", - "unconstrained", - "gauss", - "lnN", - "poisson", - "shape", "modifier", "staterror", "compose", @@ -55,133 +58,6 @@ def boundary_penalty(self) -> jax.Array: ) -class Effect(eqx.Module): - @property - @abc.abstractmethod - def constraint(self) -> HashablePDF: - ... - - @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - ... - - -class unconstrained(Effect): - @property - def constraint(self) -> HashablePDF: - return Flat() - - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - return parameter.value - - -DEFAULT_EFFECT = unconstrained() - - -class gauss(Effect): - width: jax.Array = eqx.field(static=True, converter=as1darray) - - def __init__(self, width: jax.Array) -> None: - self.width = width - - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) - - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - # gx = Gauss(mean=1.0, width=self.width) # type: ignore[arg-type] - # g1 = Gauss(mean=1.0, width=1.0) - # return gx.inv_cdf(g1.cdf(parameter.value + 1)) - return (parameter.value * self.width) + 1 # fast analytical solution - - -class shape(Effect): - up: jax.Array = eqx.field(converter=as1darray) - down: jax.Array = eqx.field(converter=as1darray) - - def __init__( - self, - up: jax.Array, - down: jax.Array, - ) -> None: - self.up = up # +1 sigma - self.down = down # -1 sigma - - @eqx.filter_jit - def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: - factor = sf - dx_sum = self.up + self.down - 2 * sumw - dx_diff = self.up - self.down - - # taken from https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L173C6-L192 - _asym_poly = jnp.array([3.0, -10.0, 15.0, 0.0]) / 8.0 - - abs_value = jnp.abs(factor) - return 0.5 * ( - dx_diff * factor - + dx_sum - * jnp.where( - abs_value > 1.0, - abs_value, - jnp.polyval(_asym_poly, factor * factor), - ) - ) - - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) - - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - sf = parameter.value - # clip, no negative values are allowed - return jnp.maximum((sumw + self.vshift(sf=sf, sumw=sumw)) / sumw, 0.0) - - -class lnN(Effect): - width: jax.Array | tuple[jax.Array, jax.Array] = eqx.field(static=True) - - def __init__( - self, - width: jax.Array | tuple[jax.Array, jax.Array], - ) -> None: - self.width = width - - def scale(self, parameter: Parameter) -> jax.Array: - if isinstance(self.width, tuple): - down, up = self.width - scale = jnp.where(parameter.value > 0, up, down) - else: - scale = self.width - return scale - - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) - - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - # width = self.scale(parameter=parameter) - # g1 = Gauss(mean=1.0, width=1.0) - # gx = Gauss(mean=jnp.exp(parameter.value), width=width) # type: ignore[arg-type] - # return gx.inv_cdf(g1.cdf(parameter.value + 1)) - return jnp.exp( - parameter.value * self.scale(parameter=parameter) - ) # fast analytical solution - - -class poisson(Effect): - lamb: jax.Array = eqx.field(static=True, converter=as1darray) - - def __init__(self, lamb: jax.Array) -> None: - self.lamb = lamb - - @property - def constraint(self) -> HashablePDF: - return Poisson(lamb=self.lamb) - - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: - return parameter.value + 1 - - class ModifierBase(eqx.Module): @abc.abstractmethod def scale_factor(self, sumw: jax.Array) -> jax.Array: @@ -200,31 +76,31 @@ class modifier(ModifierBase): .. code-block:: python import jax.numpy as jnp - from dilax.parameter import modifier, Parameter, unconstrained, lnN, poisson, shape + import dilax as dlx - mu = Parameter(value=1.1, bounds=(0, 100)) - norm = Parameter(value=0.0, bounds=(-jnp.inf, jnp.inf)) + mu = dlx.Parameter(value=1.1, bounds=(0, 100)) + norm = dlx.Parameter(value=0.0, bounds=(-jnp.inf, jnp.inf)) # create a new parameter and a penalty - modify = modifier(name="mu", parameter=mu, effect=unconstrained()) + modify = dlx.modifier(name="mu", parameter=mu, effect=dlx.effect.unconstrained()) # apply the modifier modify(jnp.array([10, 20, 30])) # -> Array([11., 22., 33.], dtype=float32, weak_type=True), # lnN effect - modify = modifier(name="norm", parameter=norm, effect=lnN(0.2)) + modify = dlx.modifier(name="norm", parameter=norm, effect=dlx.effect.lnN(0.2)) modify(jnp.array([10, 20, 30])) # poisson effect hist = jnp.array([10, 20, 30]) - modify = modifier(name="norm", parameter=norm, effect=poisson(hist)) + modify = dlx.modifier(name="norm", parameter=norm, effect=dlx.effect.poisson(hist)) modify(jnp.array([10, 20, 30])) # shape effect up = jnp.array([12, 23, 35]) down = jnp.array([8, 19, 26]) - modify = modifier(name="norm", parameter=norm, effect=shape(up, down)) + modify = dlx.modifier(name="norm", parameter=norm, effect=dlx.effect.shape(up, down)) modify(jnp.array([10, 20, 30])) """ @@ -253,16 +129,16 @@ class staterror(ModifierBase): .. code-block:: python import jax.numpy as jnp - from dilax.parameter import modifier, Parameter, unconstrained, lnN, poisson, shape + import dilax as dlx hist = jnp.array([10, 20, 30]) - p1 = Parameter(value=1.0) - p2 = Parameter(value=0.0) - p3 = Parameter(value=0.0) + p1 = dlx.Parameter(value=1.0) + p2 = dlx.Parameter(value=0.0) + p3 = dlx.Parameter(value=0.0) # all bins with bin content below 10 (threshold) are treated as poisson, else gauss - modify = staterror( + modify = dlx.staterror( parameters=[p1, p2, p3], sumw=hist, sumw2=hist, @@ -271,6 +147,9 @@ class staterror(ModifierBase): modify(hist) # -> Array([13.162277, 20. , 30. ], dtype=float32) + # jit + import equinox as eqx + fast_modify = eqx.filter_jit(modify) """ @@ -290,8 +169,6 @@ def __init__( sumw2: jax.Array, threshold: float, ) -> None: - assert len(parameters) == len(sumw2) == len(sumw) - self.parameters = parameters self.sumw = sumw self.sumw2 = sumw2 @@ -308,6 +185,15 @@ def __init__( effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) param.constraints.add(effect.constraint) + def __check_init__(self): + if not len(self.parameters) == len(self.sumw2) == len(self.sumw): + msg = ( + f"Length of parameters ({len(self.parameters)}), " + f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " + "must be the same." + ) + raise ValueError(msg) + def scale_factor(self, sumw: jax.Array) -> jax.Array: from functools import partial @@ -350,27 +236,30 @@ class compose(ModifierBase): .. code-block:: python - from dilax.parameter import modifier, compose, Parameter, unconstrained, lnN + import jax.numpy as jnp + import dilax as dlx - mu = Parameter(value=1.1, bounds=(0, 100)) - sigma = Parameter(value=0.1, bounds=(-100, 100)) + mu = dlx.Parameter(value=1.1, bounds=(0, 100)) + sigma = dlx.Parameter(value=0.1, bounds=(-100, 100)) # create a new parameter and a composition of modifiers - composition = compose( - modifier(name="mu", parameter=mu), - modifier(name="sigma1", parameter=sigma, effect=lnN(0.1)), + composition = dlx.compose( + dlx.modifier(name="mu", parameter=mu), + dlx.modifier(name="sigma1", parameter=sigma, effect=dlx.effect.lnN(0.1)), ) # apply the composition composition(jnp.array([10, 20, 30])) # nest compositions - composition = compose( + composition = dlx.compose( composition, - modifier(name="sigma2", parameter=sigma, effect=lnN(0.2)), + dlx.modifier(name="sigma2", parameter=sigma, effect=dlx.effect.lnN(0.2)), ) # jit + import equinox as eqx + eqx.filter_jit(composition)(jnp.array([10, 20, 30])) """ @@ -389,10 +278,11 @@ def __init__(self, *modifiers: modifier) -> None: else: self.names.append(modifier.name) + def __check_init__(self): # check for duplicate names duplicates = [name for name in self.names if self.names.count(name) > 1] if duplicates: - msg = f"Modifier need to have unique names, got: {duplicates}" + msg = f"Modifiers need to have unique names, got: {duplicates}" raise ValueError(msg) def __len__(self) -> int: diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 4ce72cc..276e874 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -3,20 +3,12 @@ import jax.numpy as jnp import pytest -from dilax.parameter import ( - Parameter, - compose, - gauss, - lnN, - modifier, - poisson, - unconstrained, -) +import dilax as dlx from dilax.pdf import Flat, Gauss, Poisson def test_parameter(): - p = Parameter(value=jnp.array(1.0), bounds=(jnp.array(0.0), jnp.array(2.0))) + p = dlx.Parameter(value=jnp.array(1.0), bounds=(jnp.array(0.0), jnp.array(2.0))) assert p.value == 1.0 assert p.update(jnp.array(2.0)).value == 2.0 @@ -27,8 +19,8 @@ def test_parameter(): def test_unconstrained(): - p = Parameter(value=jnp.array(1.0)) - u = unconstrained() + p = dlx.Parameter(value=jnp.array(1.0)) + u = dlx.effect.unconstrained() assert u.constraint == Flat() assert u.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -38,8 +30,8 @@ def test_unconstrained(): def test_gauss(): - p = Parameter(value=jnp.array(0.0)) - g = gauss(width=jnp.array(1.0)) + p = dlx.Parameter(value=jnp.array(0.0)) + g = dlx.effect.gauss(width=jnp.array(1.0)) assert g.constraint == Gauss(mean=0.0, width=1.0) assert g.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -49,8 +41,8 @@ def test_gauss(): def test_lnN(): - p = Parameter(value=jnp.array(0.0)) - ln = lnN(width=jnp.array(0.1)) + p = dlx.Parameter(value=jnp.array(0.0)) + ln = dlx.effect.lnN(width=jnp.array(0.1)) assert ln.constraint == Gauss(mean=0.0, width=1.0) assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -58,8 +50,8 @@ def test_lnN(): def test_poisson(): - # p = Parameter(value=jnp.array(0.0)) - po = poisson(lamb=jnp.array(10)) + # p = dlx.Parameter(value=jnp.array(0.0)) + po = dlx.effect.poisson(lamb=jnp.array(10)) assert po.constraint == Poisson(lamb=jnp.array(10)) # assert po.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) # FIXME @@ -71,19 +63,25 @@ def test_shape(): def test_modifier(): - mu = Parameter(value=jnp.array(1.1)) - norm = Parameter(value=jnp.array(0.0)) + mu = dlx.Parameter(value=jnp.array(1.1)) + norm = dlx.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = modifier(name="mu", parameter=mu, effect=unconstrained()) + m_unconstrained = dlx.modifier( + name="mu", parameter=mu, effect=dlx.effect.unconstrained() + ) assert m_unconstrained(jnp.array(10)) == pytest.approx(11) # gauss effect - m_gauss = modifier(name="norm", parameter=norm, effect=gauss(jnp.array(0.1))) + m_gauss = dlx.modifier( + name="norm", parameter=norm, effect=dlx.effect.gauss(jnp.array(0.1)) + ) assert m_gauss(jnp.array(10)) == pytest.approx(10) # lnN effect - m_lnN = modifier(name="norm", parameter=norm, effect=lnN(jnp.array(0.1))) + m_lnN = dlx.modifier( + name="norm", parameter=norm, effect=dlx.effect.lnN(jnp.array(0.1)) + ) assert m_lnN(jnp.array(10)) == pytest.approx(10) # poisson effect # FIXME @@ -97,16 +95,20 @@ def test_modifier(): def test_compose(): - mu = Parameter(value=jnp.array(1.1)) - norm = Parameter(value=jnp.array(0.0)) + mu = dlx.Parameter(value=jnp.array(1.1)) + norm = dlx.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = modifier(name="mu", parameter=mu, effect=unconstrained()) + m_unconstrained = dlx.modifier( + name="mu", parameter=mu, effect=dlx.effect.unconstrained() + ) # gauss effect - m_gauss = modifier(name="norm", parameter=norm, effect=gauss(jnp.array(0.1))) + m_gauss = dlx.modifier( + name="norm", parameter=norm, effect=dlx.effect.gauss(jnp.array(0.1)) + ) # compose - m = compose(m_unconstrained, m_gauss) + m = dlx.compose(m_unconstrained, m_gauss) assert m.names == ["mu", "norm"] assert len(m) == 2