Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 15, 2024
1 parent 9296859 commit 5f108e6
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pdequinox/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,30 @@ def cycling_dataloader(
epoch_id += 1


def extract_from_ensemble(ensemble, i):
def extract_from_ensemble(ensemble: eqx.Module, i: int):
"""
Given an ensemble of equinox Modules, extract its i-th element.
If you create an ensemble, e.g., with
```python
import equinox as eqx
ensemble = eqx.filter_vmap(
lambda k: eqx.nn.Conv1d(1, 1, 3)
)(jax.random.split(jax.random.PRNGKey(0), 5) ```
its weight arrays have an additional batch/ensemble axis. It cannot be used
natively on its corresponding data. This function extracts the i-th element
of the ensemble.
**Arguments:**
- `ensemble`: eqx.Module. The ensemble of networks.
- `i`: int. The index of the network to be extracted. This can also be a
slice!
"""
params, static = eqx.partition(ensemble, eqx.is_array)
params_extracted = jtu.tree_map(lambda x: x[i], params)
network_extracted = eqx.combine(params_extracted, static)
Expand Down

0 comments on commit 5f108e6

Please sign in to comment.