Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Apr 26, 2024
1 parent 86af5f3 commit b328651
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/building_blocks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
(building-blocks)=
# Building Blocks

## Parameter
Expand Down
40 changes: 40 additions & 0 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,43 @@ These building blocks include:
varied.
- **evm.Modifier**: Modifiers combine **evm.Effects** and **evm.Parameters** to
modify data.

The negative log-likelihood (NLL) function of Eq.{eq}`likelihood` can be implemented with evermore as follows:

```{code-block} python
from jaxtyping import PyTree, Array
import equinox as eqx
import evermore as evm
# -- parameter definition --
# params: PyTree[evm.Parameter] = ...
# dynamic_params, static_params = evm.parameter.partition(params)
# -- model definition --
# def model(params: PyTree[evm.Parameter], hists: PyTree[Array]) -> PyTree[Array]:
# ...
# -- NLL definition --
@eqx.filter_jit
def NLL(dynamic_params, static_params, hists, observation):
params = eqx.combine(dynamic_params, static_params)
expectations = model(params, hists)
log_likelihood = evm.loss.PoissonLogLikelihood()
# first product of Eq. 1 (Poisson term)
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
)
# second product of Eq. 1 (constraint)
constraints = evm.loss.get_log_probs(model)
loss_val += evm.util.sum_over_leaves(constraints)
return -jnp.sum(loss_val)
```

Building the parameters and the model is key here. The relevant parts to build parameters and a model are described in <project:#building-blocks>.
106 changes: 106 additions & 0 deletions docs/tips_and_tricks.md
Original file line number Diff line number Diff line change
@@ -1 +1,107 @@
# Tips and tricks

Here are some advanced tips and tricks.


## Parameter Partitioning

For optimization it is necessary to differentiate only against meaningful leaves of the PyTree of `evm.Parameters`.
By default JAX would differentiate w.r.t. every non-static leaf of a `evm.Parameter`, including the prior PDF and its bounds.
Gradients are typically only wanted w.r.t. the `.value` attribute of the `evm.Parameters`. This is solved by splitting
the PyTree of `evm.Parameters` into a differentiable and a non-differentiable part, and then defining the loss function
w.r.t. both parts. Gradient calculation is performed only w.r.t. the differentiable arguement, see:

```{code-block} python
from jaxtyping import Array, PyTree
import evermore as evm
# define a PyTree of parameters
params = {
"a": evm.Parameter(),
"b": evm.Parameter(),
}
# split the PyTree into diffable and the static parts
filter_spec = evm.parameter.value_filter_spec(params)
diffable, static = eqx.partition(params, filter_spec)
# or
# diffable, static = evm.parameter.partition(params)
# loss's first argument is only the diffable part of the parameter Pytree!
def loss(diffable: PyTree[evm.Parameter], static: PyTree[evm.Parameter], hists: PyTree[Array]) -> Array:
# combine the diffable and static parts of the parameter PyTree
parameters = eqx.combine(diffable, static)
assert parameters == params
# use the parameters to calculate the loss as usual
...
grad_loss = eqx.filter_grad(loss)(diffable, static, ...)
```

If you need to further exclude parameter from being optimized you can either set `frozen=True` or set the corresponding leaf in `filter_spec` from `True` to `False`.


## JAX Transformations

Evert component of evermore is compatible with JAX transformations. That means you can `jax.jit`, `jax.vmap`, ... _everything_.
You can e.g. sample the parameter values multiple times vectorized from its prior PDF:

```{code-block} python
import jax
import evermore as evm
params = {"a": evm.NormalParameter(), "b": evm.NormalParameter()}
rng_key = jax.random.key(0)
rng_keys = jax.random.split(rng_key, 100)
vec_sample = jax.vmap(evm.parameter.sample, in_axes=(None, 0))
print(vec_sample(params, rng_keys))
# {'a': NormalParameter(
# value=f32[100,1],
# name=None,
# lower=f32[100,1],
# upper=f32[100,1],
# prior=Normal(mean=f32[100,1], width=f32[100,1]),
# frozen=False,
# ),
# 'b': NormalParameter(
# value=f32[100,1],
# name=None,
# lower=f32[100,1],
# upper=f32[100,1],
# prior=Normal(mean=f32[100,1], width=f32[100,1]),
# frozen=False,
# )}
```

Many minimizers from the JAX ecosystem are e.g. batchable (`optax`, `optimistix`), which allows you vectorize _full fits_, e.g., for embarrassingly parallel likleihood profiles.

## Visualize the Computational Graph

You can visualize the computational graph of a JAX computation by:

```{code-block} python
import pathlib
import jax.numpy as jnp
import equinox as eqx
import evermore as evm
param = evm.Parameter(value=1.1)
# create the modifier and JIT it
modify = eqx.filter_jit(param.scale())
# apply the modifier
hist = jnp.array([10, 20, 30])
modify(hist)
# -> Array([11., 22., 33.], dtype=float32, weak_type=True),
# visualize the graph:
filepath = pathlib.Path('graph.gv')
filepath.write_text(evm.util.dump_hlo_graph(modify, hist), encoding='ascii')
```

0 comments on commit b328651

Please sign in to comment.