From 8f686b35f993173a10d78d00206612597d97ba71 Mon Sep 17 00:00:00 2001 From: Michael O'Brien Date: Wed, 11 Sep 2024 17:47:45 -0400 Subject: [PATCH 1/2] fix typechecking error --- src/cryojax/data/_relion.py | 50 +++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/src/cryojax/data/_relion.py b/src/cryojax/data/_relion.py index 86b38eaa..35da8a88 100644 --- a/src/cryojax/data/_relion.py +++ b/src/cryojax/data/_relion.py @@ -219,8 +219,10 @@ def __len__(self) -> int: def _get_starfile_params( self, particle_blocks, optics_group, device ) -> tuple[InstrumentConfig, ContrastTransferTheory, EulerAnglePose]: - defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device) - + defocus_in_angstroms = ( + jnp.asarray(particle_blocks["rlnDefocusU"], device=device) + + jnp.asarray(particle_blocks["rlnDefocusU"], device=device) + ) / 2 astigmatism_in_angstroms = jnp.asarray( particle_blocks["rlnDefocusV"], device=device ) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device) @@ -244,29 +246,51 @@ def _get_starfile_params( jnp.asarray(voltage_in_kilovolts, device=device), ) # ... now the ContrastTransferTheory - ctf = ContrastTransferFunction( - defocus_in_angstroms=defocus_in_angstroms, - astigmatism_in_angstroms=astigmatism_in_angstroms, - astigmatism_angle=astigmatism_angle, - voltage_in_kilovolts=voltage_in_kilovolts, - spherical_aberration_in_mm=spherical_aberration_in_mm, - amplitude_contrast_ratio=amplitude_contrast_ratio, - phase_shift=phase_shift, + make_ctf = lambda params: ContrastTransferFunction( + defocus_in_angstroms=params[0], + astigmatism_in_angstroms=params[1], + astigmatism_angle=params[2], + voltage_in_kilovolts=params[3], + spherical_aberration_in_mm=params[4], + amplitude_contrast_ratio=params[5], + phase_shift=params[6], + ) + ctf_params = ( + defocus_in_angstroms, + astigmatism_in_angstroms, + astigmatism_angle, + voltage_in_kilovolts, + spherical_aberration_in_mm, + amplitude_contrast_ratio, + phase_shift, + ) + ctf = ( + eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(ctf_params) + if defocus_in_angstroms.ndim == 1 + else make_ctf(ctf_params) ) if self.get_envelope_function: b_factor, scale_factor = ( ( jnp.asarray(particle_blocks["rlnCtfBfactor"], device=device) if "rlnCtfBfactor" in particle_blocks.keys() - else 0.0 + else jnp.asarray(0.0) ), ( jnp.asarray(particle_blocks["rlnCtfScalefactor"], device=device) if "rlnCtfScalefactor" in particle_blocks.keys() - else 1.0 + else jnp.asarray(1.0) ), ) - envelope = FourierGaussian(b_factor=b_factor, amplitude=scale_factor) + make_envelope = lambda params: FourierGaussian( + b_factor=params[0], amplitude=params[1] + ) + envelope_params = (b_factor, scale_factor) + envelope = ( + eqx.filter_vmap(make_envelope, in_axes=(0, 0))(envelope_params) + if b_factor.ndim == 1 + else make_envelope(envelope_params) + ) else: envelope = None transfer_theory = ContrastTransferTheory(ctf, envelope) From 8aaaa9c9cb7a1eb4411480cc2ec1a3f309f02e6d Mon Sep 17 00:00:00 2001 From: Michael O'Brien Date: Wed, 11 Sep 2024 18:04:39 -0400 Subject: [PATCH 2/2] bug in vmap --- src/cryojax/data/_relion.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/cryojax/data/_relion.py b/src/cryojax/data/_relion.py index 35da8a88..d220cab0 100644 --- a/src/cryojax/data/_relion.py +++ b/src/cryojax/data/_relion.py @@ -246,14 +246,16 @@ def _get_starfile_params( jnp.asarray(voltage_in_kilovolts, device=device), ) # ... now the ContrastTransferTheory - make_ctf = lambda params: ContrastTransferFunction( - defocus_in_angstroms=params[0], - astigmatism_in_angstroms=params[1], - astigmatism_angle=params[2], - voltage_in_kilovolts=params[3], - spherical_aberration_in_mm=params[4], - amplitude_contrast_ratio=params[5], - phase_shift=params[6], + make_ctf = ( + lambda defocus, astig, angle, voltage, sph, ac, ps: ContrastTransferFunction( + defocus_in_angstroms=defocus, + astigmatism_in_angstroms=astig, + astigmatism_angle=angle, + voltage_in_kilovolts=voltage, + spherical_aberration_in_mm=sph, + amplitude_contrast_ratio=ac, + phase_shift=ps, + ) ) ctf_params = ( defocus_in_angstroms, @@ -265,9 +267,9 @@ def _get_starfile_params( phase_shift, ) ctf = ( - eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(ctf_params) + eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(*ctf_params) if defocus_in_angstroms.ndim == 1 - else make_ctf(ctf_params) + else make_ctf(*ctf_params) ) if self.get_envelope_function: b_factor, scale_factor = ( @@ -282,14 +284,12 @@ def _get_starfile_params( else jnp.asarray(1.0) ), ) - make_envelope = lambda params: FourierGaussian( - b_factor=params[0], amplitude=params[1] - ) + make_envelope = lambda b, amp: FourierGaussian(b_factor=b, amplitude=amp) envelope_params = (b_factor, scale_factor) envelope = ( - eqx.filter_vmap(make_envelope, in_axes=(0, 0))(envelope_params) + eqx.filter_vmap(make_envelope, in_axes=(0, 0))(*envelope_params) if b_factor.ndim == 1 - else make_envelope(envelope_params) + else make_envelope(*envelope_params) ) else: envelope = None