Skip to content

Commit

Permalink
use for every kind of PyTree
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 15, 2024
1 parent 3299378 commit d7c4583
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 112 deletions.
104 changes: 8 additions & 96 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from evermore.custom_types import SF, ModifierLike
from evermore.effect import DEFAULT_EFFECT
from evermore.parameter import Parameter
from evermore.util import tree_stack

if TYPE_CHECKING:
from evermore.effect import Effect
Expand Down Expand Up @@ -173,7 +174,7 @@ class where(ModifierBase):
hist = jnp.array([5, 20, 30])
syst = evm.Parameter(value=0.1)
norm = syst.lnN(jnp.array([0.9, 1.1]))
norm = syst.lnN(up=jnp.array([1.1]), down=jnp.array([0.9]))
shape = syst.shape(up=jnp.array([7, 22, 31]), down=jnp.array([4, 16, 27]))
# apply norm if hist < 10, else apply shape
Expand All @@ -190,7 +191,7 @@ class where(ModifierBase):
# -> Array([ 5.1593127, 20.281374 , 30.181376 ], dtype=float32)
"""

condition: Array = eqx.field(static=True)
condition: Array
modifier_true: Modifier
modifier_false: Modifier

Expand Down Expand Up @@ -231,7 +232,7 @@ class mask(ModifierBase):
# -> Array([ 5.049494, 20. , 30.296963], dtype=float32)
"""

where: Array = eqx.field(static=True)
where: Array
modifier: Modifier

def scale_factor(self, hist: Array) -> SF:
Expand Down Expand Up @@ -321,7 +322,7 @@ class compose(ModifierBase):
# nest compositions
composition = evm.modifier.compose(
composition,
evm.Modifier(parameter=sigma, effect=evm.effect.lnN(jnp.array([0.8, 1.2]))),
evm.Modifier(parameter=sigma, effect=evm.effect.lnN(up=jnp.array([1.2]), down=jnp.array([0.8]))),
)
# jit
Expand Down Expand Up @@ -355,31 +356,14 @@ def scale_factor(self, hist: Array) -> SF:
additive_sf = jnp.zeros_like(hist)

groups = defaultdict(list)
_sentinel = object()
# first group modifiers into same tree structures
for mod in self.modifiers:
key = jtu.tree_structure(mod) if isinstance(mod, Modifier) else _sentinel
# We have to handle the case where we have a `Modifier` instance, and the case where we have a
# other types of `ModifierBase` instances, e.g. `evm.modifier.where`, `evm.modifier.transform`, etc.
# basically: anything that is not a `Modifier` instance, but inherits from `ModifierBase`.
# It's unclear how to use them in `jax.lax.scan` constructs for now,
# so we just loop over them and calculate the scale factors with python for-loops.
# This is not ideal for compiletime, but it's a start. It is not expected that there
# are many of these in a typical composition.
groups[key].append(mod)

# first do the python for-loops
mods = groups.pop(_sentinel, [])
for mod in mods:
sf = mod.scale_factor(hist)
multiplicative_sf *= sf.multiplicative
additive_sf += sf.additive

groups[jtu.tree_structure(mod)].append(mod)
# then do the `jax.lax.scan` loops
for _, group_mods in groups.items():
# Filter stack for modifiers with same effect cls, and stack them in order to reduce compile time.
# Essentially we are turning an array of modifiers into a single modifier with a stack of scale factors and effect leaves (e.g. `width`).
# Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors without having to compile the fully unrolled loop.
stack = modifier_stack(group_mods, broadcast_effect_leaves=True) # type: ignore[arg-type]
stack = tree_stack(group_mods, broadcast_leaves=True) # type: ignore[arg-type]
# scan over first axis of stack
dynamic_stack, static_stack = eqx.partition(stack, eqx.is_array)

Expand All @@ -397,75 +381,3 @@ def calc_sf(_hist, _dynamic_stack, _static_stack):
additive_sf += jnp.sum(sf.additive, axis=0)

return SF(multiplicative=multiplicative_sf, additive=additive_sf)


def modifier_stack(
modifiers: list[Modifier], broadcast_effect_leaves: bool = False
) -> Modifier:
"""
Turn an array of `evm.Modifier`(s) into a `evm.Modifier` of arrays.
Caution:
It is important that the `jax.Array`(s) of the underlying `evm.Parameter` have the same shape.
Same applies for the effect leaves (e.g. `width`). However, the effect leaves can be
broadcasted to the same shape if `broadcast_effect_leaves` is set to `True`.
Example:
.. code-block:: python
import evermore as evm
import jax
import jax.numpy as jnp
modifier = [
evm.Parameter().lnN(up=jnp.array([0.9, 0.95]), down=jnp.array([1.1, 1.14])),
evm.Parameter().lnN(up=jnp.array([0.8]), down=jnp.array([1.2])),
]
print(modifier_stack(modifier))
# -> Modifier(
# parameter=Parameter(
# value=f32[2,1], # <- stacked dimension (2, 1)
# lower=f32[1],
# upper=f32[1],
# constraint=Gauss(mean=f32[1], width=f32[1])
# ),
# effect=lnN(up=f32[2,1], down=f32[2,1]) # <- stacked dimension (2, 1)
# )
"""
# If there is only one modifier, we can return it directly
if len(modifiers) == 1:
return modifiers[0]
param_leaves_list = []
effect_leaves_list = []
for modifier in modifiers:
# parameter
param_leaves_list.extend(jtu.tree_leaves(modifier.parameter))
# effect
effect_leaves_list.extend(jtu.tree_leaves(modifier.effect))

stacked_leaves = []

# Parameter:
# Here we are dealing with the evm.Parameter value
# We can not broadcast its leaves as that would change
# the meaning/correlation of the parameter
stacked_leaves.append(jnp.stack(param_leaves_list))

# Effect:
# Here, we are dealing with the effect of the modifier.
# We can broadcast the leaves as they are independent if `broadcast_effect_leaves=True`.
# Effects may have multiple leaves, e.g. `lnN` has `up` and `down`,
# thus we need to stack them separately, so we loop over the leaf groups
grouped_effect_leaves = zip(*effect_leaves_list, strict=False)
for leaves in grouped_effect_leaves:
if broadcast_effect_leaves:
shape = jnp.broadcast_shapes(*[leaf.shape for leaf in leaves])
stacked_leaves.append(
jnp.stack(jtu.tree_map(partial(jnp.broadcast_to, shape=shape), leaves))
)
else:
stacked_leaves.append(jnp.stack(leaves))

# reconstruct the modifier
modifier_structure = jtu.tree_structure(modifiers[0])
return jtu.tree_unflatten(modifier_structure, stacked_leaves)
4 changes: 2 additions & 2 deletions src/evermore/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __dir__():

class Parameter(eqx.Module):
value: Array = eqx.field(converter=as1darray)
lower: Array = eqx.field(static=True, converter=as1darray)
upper: Array = eqx.field(static=True, converter=as1darray)
lower: Array = eqx.field(converter=as1darray)
upper: Array = eqx.field(converter=as1darray)
constraint: PDF | Sentinel = eqx.field(static=True)

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def sample(self, key: PRNGKeyArray) -> Array:


class Gauss(PDF):
mean: Array = eqx.field(static=True)
width: Array = eqx.field(static=True)
mean: Array
width: Array

def logpdf(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.norm.logpdf(
Expand All @@ -83,7 +83,7 @@ def sample(self, key: PRNGKeyArray) -> Array:


class Poisson(PDF):
lamb: Array = eqx.field(static=True)
lamb: Array

def logpdf(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb)
Expand Down
55 changes: 55 additions & 0 deletions src/evermore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
__all__ = [
"is_parameter",
"sum_leaves",
"tree_stack",
"as1darray",
"dump_hlo_graph",
"dump_jaxpr",
Expand Down Expand Up @@ -50,6 +51,60 @@ def sum_leaves(tree: PyTree) -> Array:
return jtu.tree_reduce(operator.add, tree)


def tree_stack(trees: list[PyTree], broadcast_leaves: bool = False) -> PyTree:
"""
Turn an array of `evm.Modifier`(s) into a `evm.Modifier` of arrays.
Caution:
It is important that the `jax.Array`(s) of the underlying `evm.Parameter` have the same shape.
Same applies for the effect leaves (e.g. `width`). However, the effect leaves can be
broadcasted to the same shape if `broadcast_effect_leaves` is set to `True`.
Example:
.. code-block:: python
import evermore as evm
import jax
import jax.numpy as jnp
modifier = [
evm.Parameter().lnN(up=jnp.array([0.9, 0.95]), down=jnp.array([1.1, 1.14])),
evm.Parameter().lnN(up=jnp.array([0.8]), down=jnp.array([1.2])),
]
print(modifier_stack2(modifier))
# -> Modifier(
# parameter=Parameter(
# value=f32[2,1], # <- stacked dimension (2, 1)
# lower=f32[1],
# upper=f32[1],
# constraint=Gauss(mean=f32[1], width=f32[1])
# ),
# effect=lnN(up=f32[2,1], down=f32[2,1]) # <- stacked dimension (2, 1)
# )
"""
# If there is only one modifier, we can return it directly
if len(trees) == 1:
return trees[0]
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = jtu.tree_flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)

grouped_leaves = zip(*leaves_list, strict=False)
stacked_leaves = []
for leaves in grouped_leaves: # type: ignore[assignment]
if broadcast_leaves:
shape = jnp.broadcast_shapes(*[leaf.shape for leaf in leaves])
stacked_leaves.append(
jnp.stack(jtu.tree_map(partial(jnp.broadcast_to, shape=shape), leaves))
)
else:
stacked_leaves.append(jnp.stack(leaves))
return jtu.tree_unflatten(treedef_list[0], stacked_leaves)


def as1darray(x: ArrayLike) -> Array:
"""
Converts `x` to a 1d array.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_gauss():
p = evm.Parameter(value=jnp.array(0.0))
g = evm.effect.gauss(width=jnp.array(1.0))

assert g.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0))
assert isinstance(g.constraint(p), Gauss)
assert g.scale_factor(p, jnp.array([1.0])) == SF(
multiplicative=jnp.array([1.0]), additive=jnp.array([0.0])
)
Expand All @@ -41,7 +41,7 @@ def test_lnN():
p = evm.Parameter(value=jnp.array(0.0))
ln = evm.effect.lnN(up=jnp.array([1.1]), down=jnp.array([0.9]))

assert ln.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0))
assert isinstance(ln.constraint(p), Gauss)
assert ln.scale_factor(p, jnp.array([1.0])) == SF(
multiplicative=jnp.array([1.0]), additive=jnp.array([0.0])
)
Expand All @@ -51,7 +51,7 @@ def test_poisson():
p = evm.Parameter(value=jnp.array(0.0))
po = evm.effect.poisson(lamb=jnp.array(10))

assert po.constraint(p) == Poisson(lamb=jnp.array(10))
assert isinstance(po.constraint(p), Poisson)
# assert po.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) # FIXME
# assert po.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx(1.1) # FIXME

Expand Down
8 changes: 0 additions & 8 deletions tests/test_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,3 @@ def test_poisson():
assert pdf.pdf(jnp.array(0)) == pytest.approx(0.12510978)
assert pdf.logpdf(jnp.array(-0.5)) == pytest.approx(-1.196003)
assert pdf.cdf(jnp.array(0)) == pytest.approx(0.5830412)


def test_hashable():
assert hash(Flat()) == hash(Flat())
assert hash(Gauss(mean=jnp.array(0.0), width=jnp.array(1.0))) == hash(
Gauss(mean=jnp.array(0.0), width=jnp.array(1.0))
)
assert hash(Poisson(lamb=jnp.array(10))) == hash(Poisson(lamb=jnp.array(10)))

0 comments on commit d7c4583

Please sign in to comment.