Skip to content

Commit

Permalink
bug in vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Sep 11, 2024
1 parent 8f686b3 commit 8aaaa9c
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,16 @@ def _get_starfile_params(
jnp.asarray(voltage_in_kilovolts, device=device),
)
# ... now the ContrastTransferTheory
make_ctf = lambda params: ContrastTransferFunction(
defocus_in_angstroms=params[0],
astigmatism_in_angstroms=params[1],
astigmatism_angle=params[2],
voltage_in_kilovolts=params[3],
spherical_aberration_in_mm=params[4],
amplitude_contrast_ratio=params[5],
phase_shift=params[6],
make_ctf = (
lambda defocus, astig, angle, voltage, sph, ac, ps: ContrastTransferFunction(
defocus_in_angstroms=defocus,
astigmatism_in_angstroms=astig,
astigmatism_angle=angle,
voltage_in_kilovolts=voltage,
spherical_aberration_in_mm=sph,
amplitude_contrast_ratio=ac,
phase_shift=ps,
)
)
ctf_params = (
defocus_in_angstroms,
Expand All @@ -265,9 +267,9 @@ def _get_starfile_params(
phase_shift,
)
ctf = (
eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(ctf_params)
eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(*ctf_params)
if defocus_in_angstroms.ndim == 1
else make_ctf(ctf_params)
else make_ctf(*ctf_params)
)
if self.get_envelope_function:
b_factor, scale_factor = (
Expand All @@ -282,14 +284,12 @@ def _get_starfile_params(
else jnp.asarray(1.0)
),
)
make_envelope = lambda params: FourierGaussian(
b_factor=params[0], amplitude=params[1]
)
make_envelope = lambda b, amp: FourierGaussian(b_factor=b, amplitude=amp)
envelope_params = (b_factor, scale_factor)
envelope = (
eqx.filter_vmap(make_envelope, in_axes=(0, 0))(envelope_params)
eqx.filter_vmap(make_envelope, in_axes=(0, 0))(*envelope_params)
if b_factor.ndim == 1
else make_envelope(envelope_params)
else make_envelope(*envelope_params)
)
else:
envelope = None
Expand Down

0 comments on commit 8aaaa9c

Please sign in to comment.