Skip to content

Commit

Permalink
minor improvements, added name attribute to parameters, updated doc s…
Browse files Browse the repository at this point in the history
…trings
  • Loading branch information
pfackeldey committed Apr 25, 2024
1 parent f891203 commit a19f077
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
10 changes: 6 additions & 4 deletions src/evermore/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Parameter(eqx.Module):
frozen_parameter = evm.Parameter(value=1.0, frozen=True)
"""

name: str | None = eqx.field(static=True, default=None)
value: Array = eqx.field(converter=jnp.atleast_1d, default=0.0)
lower: Array = eqx.field(converter=jnp.atleast_1d, default=-jnp.inf)
upper: Array = eqx.field(converter=jnp.atleast_1d, default=jnp.inf)
Expand Down Expand Up @@ -190,11 +191,12 @@ def _sample(param: Parameter, key: Parameter) -> Array:
# TODO: make this compatible with externally provided Poisson PDFs
if isinstance(pdf, Poisson):
sampled_value = (sampled_value / pdf.lamb) - 1
elif param.prior is None:
if not jnp.isfinite(param.lower) and jnp.isfinite(param.upper):
else:
assert param.prior is None, f"Unknown prior type: {param.prior}."
if not jnp.isfinite(param.lower) and not jnp.isfinite(param.upper):
msg = f"Can't sample uniform from {param} (no given prior), because of non-finite bounds. "
msg += "Please provide a prior or make the bounds finite."
raise RuntimeError(msg)
msg += "Please provide finite bounds."
raise ValueError(msg)
sampled_value = jax.random.uniform(
key.value,
shape=param.value.shape,
Expand Down
6 changes: 3 additions & 3 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def sample(self, key: PRNGKeyArray) -> Array: ...


class Normal(PDF):
mean: Array = eqx.field(converter=jnp.asarray)
width: Array = eqx.field(converter=jnp.asarray)
mean: Array = eqx.field(converter=jnp.atleast_1d)
width: Array = eqx.field(converter=jnp.atleast_1d)

def log_prob(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.norm.logpdf(
Expand All @@ -43,7 +43,7 @@ def sample(self, key: PRNGKeyArray) -> Array:


class Poisson(PDF):
lamb: Array
lamb: Array = eqx.field(converter=jnp.atleast_1d)

def log_prob(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb)
Expand Down
7 changes: 5 additions & 2 deletions src/evermore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def sum_over_leaves(tree: PyTree) -> Array:

def tree_stack(trees: list[PyTree], broadcast_leaves: bool = False) -> PyTree:
"""
Turn an array of evm.Modifier(s) into a evm.Modifier of arrays.
Turns e.g. an array of evm.Modifier(s) into a evm.Modifier of arrays.
It is important that the jax.Array(s) of the underlying Arrays have the same shape.
Same applies for the effect leaves (e.g. width). However, the effect leaves can be
broadcasted to the same shape if broadcast_effect_leaves is set to True.
The stacked PyTree will have the static nodes of the first PyTree in the list.
Example:
.. code-block:: python
Expand All @@ -65,10 +67,11 @@ def tree_stack(trees: list[PyTree], broadcast_leaves: bool = False) -> PyTree:
print(evm.util.tree_stack(modifiers))
# -> Modifier(
# parameter=NormalParameter(
# name=None,
# value=f32[2,1], # <- stacked dimension (2, 1)
# lower=f32[2,1], # <- stacked dimension (2, 1)
# upper=f32[2,1], # <- stacked dimension (2, 1)
# prior=Normal(mean=f32[2], width=f32[2]), # <- stacked dimension (2,)
# prior=Normal(mean=f32[2,1], width=f32[2,1]), # <- stacked dimension (2,1)
# frozen=False
# ),
# effect=AsymmetricExponential(up=f32[2,1], down=f32[2,1]) # <- stacked dimension (2, 1)
Expand Down

0 comments on commit a19f077

Please sign in to comment.