From 0a3113addb4fbbd76950005e3bb2316eed58faeb Mon Sep 17 00:00:00 2001 From: Michael O'Brien Date: Mon, 16 Sep 2024 10:42:25 -0400 Subject: [PATCH] bug in more things getting batch dimension --- src/cryojax/data/_relion/_starfile_reading.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/cryojax/data/_relion/_starfile_reading.py b/src/cryojax/data/_relion/_starfile_reading.py index 5223c7c3..2b61e8a1 100644 --- a/src/cryojax/data/_relion/_starfile_reading.py +++ b/src/cryojax/data/_relion/_starfile_reading.py @@ -5,6 +5,7 @@ from typing import Any, Callable, final, Optional import equinox as eqx +import equinox.internal as eqxi import jax import jax.numpy as jnp import mrcfile @@ -266,7 +267,11 @@ def _get_starfile_params( phase_shift, ) ctf = ( - eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(*ctf_params) + eqx.filter_vmap( + make_ctf, + in_axes=(0, 0, 0, None, None, None, 0), + out_axes=eqxi.if_mapped(0), + )(*ctf_params) if defocus_in_angstroms.ndim == 1 else make_ctf(*ctf_params) )