diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index 7577a63..f52ab92 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -53,6 +53,7 @@ class Parameter(eqx.Module): frozen_parameter = evm.Parameter(value=1.0, frozen=True) """ + name: str | None = eqx.field(static=True, default=None) value: Array = eqx.field(converter=jnp.atleast_1d, default=0.0) lower: Array = eqx.field(converter=jnp.atleast_1d, default=-jnp.inf) upper: Array = eqx.field(converter=jnp.atleast_1d, default=jnp.inf) @@ -190,11 +191,12 @@ def _sample(param: Parameter, key: Parameter) -> Array: # TODO: make this compatible with externally provided Poisson PDFs if isinstance(pdf, Poisson): sampled_value = (sampled_value / pdf.lamb) - 1 - elif param.prior is None: - if not jnp.isfinite(param.lower) and jnp.isfinite(param.upper): + else: + assert param.prior is None, f"Unknown prior type: {param.prior}." + if not jnp.isfinite(param.lower) and not jnp.isfinite(param.upper): msg = f"Can't sample uniform from {param} (no given prior), because of non-finite bounds. " - msg += "Please provide a prior or make the bounds finite." - raise RuntimeError(msg) + msg += "Please provide finite bounds." + raise ValueError(msg) sampled_value = jax.random.uniform( key.value, shape=param.value.shape, diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index c771d33..595b94f 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -27,8 +27,8 @@ def sample(self, key: PRNGKeyArray) -> Array: ... class Normal(PDF): - mean: Array = eqx.field(converter=jnp.asarray) - width: Array = eqx.field(converter=jnp.asarray) + mean: Array = eqx.field(converter=jnp.atleast_1d) + width: Array = eqx.field(converter=jnp.atleast_1d) def log_prob(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.norm.logpdf( @@ -43,7 +43,7 @@ def sample(self, key: PRNGKeyArray) -> Array: class Poisson(PDF): - lamb: Array + lamb: Array = eqx.field(converter=jnp.atleast_1d) def log_prob(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb) diff --git a/src/evermore/util.py b/src/evermore/util.py index b8efaf4..e413f70 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -44,12 +44,14 @@ def sum_over_leaves(tree: PyTree) -> Array: def tree_stack(trees: list[PyTree], broadcast_leaves: bool = False) -> PyTree: """ - Turn an array of evm.Modifier(s) into a evm.Modifier of arrays. + Turns e.g. an array of evm.Modifier(s) into a evm.Modifier of arrays. It is important that the jax.Array(s) of the underlying Arrays have the same shape. Same applies for the effect leaves (e.g. width). However, the effect leaves can be broadcasted to the same shape if broadcast_effect_leaves is set to True. + The stacked PyTree will have the static nodes of the first PyTree in the list. + Example: .. code-block:: python @@ -65,10 +67,11 @@ def tree_stack(trees: list[PyTree], broadcast_leaves: bool = False) -> PyTree: print(evm.util.tree_stack(modifiers)) # -> Modifier( # parameter=NormalParameter( + # name=None, # value=f32[2,1], # <- stacked dimension (2, 1) # lower=f32[2,1], # <- stacked dimension (2, 1) # upper=f32[2,1], # <- stacked dimension (2, 1) - # prior=Normal(mean=f32[2], width=f32[2]), # <- stacked dimension (2,) + # prior=Normal(mean=f32[2,1], width=f32[2,1]), # <- stacked dimension (2,1) # frozen=False # ), # effect=AsymmetricExponential(up=f32[2,1], down=f32[2,1]) # <- stacked dimension (2, 1)