diff --git a/src/cryojax/data/_relion.py b/src/cryojax/data/_relion.py index 35da8a88..d220cab0 100644 --- a/src/cryojax/data/_relion.py +++ b/src/cryojax/data/_relion.py @@ -246,14 +246,16 @@ def _get_starfile_params( jnp.asarray(voltage_in_kilovolts, device=device), ) # ... now the ContrastTransferTheory - make_ctf = lambda params: ContrastTransferFunction( - defocus_in_angstroms=params[0], - astigmatism_in_angstroms=params[1], - astigmatism_angle=params[2], - voltage_in_kilovolts=params[3], - spherical_aberration_in_mm=params[4], - amplitude_contrast_ratio=params[5], - phase_shift=params[6], + make_ctf = ( + lambda defocus, astig, angle, voltage, sph, ac, ps: ContrastTransferFunction( + defocus_in_angstroms=defocus, + astigmatism_in_angstroms=astig, + astigmatism_angle=angle, + voltage_in_kilovolts=voltage, + spherical_aberration_in_mm=sph, + amplitude_contrast_ratio=ac, + phase_shift=ps, + ) ) ctf_params = ( defocus_in_angstroms, @@ -265,9 +267,9 @@ def _get_starfile_params( phase_shift, ) ctf = ( - eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(ctf_params) + eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(*ctf_params) if defocus_in_angstroms.ndim == 1 - else make_ctf(ctf_params) + else make_ctf(*ctf_params) ) if self.get_envelope_function: b_factor, scale_factor = ( @@ -282,14 +284,12 @@ def _get_starfile_params( else jnp.asarray(1.0) ), ) - make_envelope = lambda params: FourierGaussian( - b_factor=params[0], amplitude=params[1] - ) + make_envelope = lambda b, amp: FourierGaussian(b_factor=b, amplitude=amp) envelope_params = (b_factor, scale_factor) envelope = ( - eqx.filter_vmap(make_envelope, in_axes=(0, 0))(envelope_params) + eqx.filter_vmap(make_envelope, in_axes=(0, 0))(*envelope_params) if b_factor.ndim == 1 - else make_envelope(envelope_params) + else make_envelope(*envelope_params) ) else: envelope = None