diff --git a/pdequinox/_utils.py b/pdequinox/_utils.py index 421167b..d5c43a0 100644 --- a/pdequinox/_utils.py +++ b/pdequinox/_utils.py @@ -186,7 +186,30 @@ def cycling_dataloader( epoch_id += 1 -def extract_from_ensemble(ensemble, i): +def extract_from_ensemble(ensemble: eqx.Module, i: int): + """ + Given an ensemble of equinox Modules, extract its i-th element. + + If you create an ensemble, e.g., with + + ```python + + import equinox as eqx + + ensemble = eqx.filter_vmap( + lambda k: eqx.nn.Conv1d(1, 1, 3) + )(jax.random.split(jax.random.PRNGKey(0), 5) ``` + + its weight arrays have an additional batch/ensemble axis. It cannot be used + natively on its corresponding data. This function extracts the i-th element + of the ensemble. + + **Arguments:** + + - `ensemble`: eqx.Module. The ensemble of networks. + - `i`: int. The index of the network to be extracted. This can also be a + slice! + """ params, static = eqx.partition(ensemble, eqx.is_array) params_extracted = jtu.tree_map(lambda x: x[i], params) network_extracted = eqx.combine(params_extracted, static)