Skip to content

Commit

Permalink
Merge pull request #269 from mjo22/252-add-envelope-function-to-relio…
Browse files Browse the repository at this point in the history
…ndataset

Fix bugs in `RelionDataset`
  • Loading branch information
mjo22 authored Sep 11, 2024
2 parents 58490b9 + 8aaaa9c commit c681390
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions src/cryojax/data/_relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def __len__(self) -> int:
def _get_starfile_params(
self, particle_blocks, optics_group, device
) -> tuple[InstrumentConfig, ContrastTransferTheory, EulerAnglePose]:
defocus_in_angstroms = jnp.asarray(particle_blocks["rlnDefocusU"], device=device)

defocus_in_angstroms = (
jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
+ jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
) / 2
astigmatism_in_angstroms = jnp.asarray(
particle_blocks["rlnDefocusV"], device=device
) - jnp.asarray(particle_blocks["rlnDefocusU"], device=device)
Expand All @@ -244,29 +246,51 @@ def _get_starfile_params(
jnp.asarray(voltage_in_kilovolts, device=device),
)
# ... now the ContrastTransferTheory
ctf = ContrastTransferFunction(
defocus_in_angstroms=defocus_in_angstroms,
astigmatism_in_angstroms=astigmatism_in_angstroms,
astigmatism_angle=astigmatism_angle,
voltage_in_kilovolts=voltage_in_kilovolts,
spherical_aberration_in_mm=spherical_aberration_in_mm,
amplitude_contrast_ratio=amplitude_contrast_ratio,
phase_shift=phase_shift,
make_ctf = (
lambda defocus, astig, angle, voltage, sph, ac, ps: ContrastTransferFunction(
defocus_in_angstroms=defocus,
astigmatism_in_angstroms=astig,
astigmatism_angle=angle,
voltage_in_kilovolts=voltage,
spherical_aberration_in_mm=sph,
amplitude_contrast_ratio=ac,
phase_shift=ps,
)
)
ctf_params = (
defocus_in_angstroms,
astigmatism_in_angstroms,
astigmatism_angle,
voltage_in_kilovolts,
spherical_aberration_in_mm,
amplitude_contrast_ratio,
phase_shift,
)
ctf = (
eqx.filter_vmap(make_ctf, in_axes=(0, 0, 0, None, None, None, 0))(*ctf_params)
if defocus_in_angstroms.ndim == 1
else make_ctf(*ctf_params)
)
if self.get_envelope_function:
b_factor, scale_factor = (
(
jnp.asarray(particle_blocks["rlnCtfBfactor"], device=device)
if "rlnCtfBfactor" in particle_blocks.keys()
else 0.0
else jnp.asarray(0.0)
),
(
jnp.asarray(particle_blocks["rlnCtfScalefactor"], device=device)
if "rlnCtfScalefactor" in particle_blocks.keys()
else 1.0
else jnp.asarray(1.0)
),
)
envelope = FourierGaussian(b_factor=b_factor, amplitude=scale_factor)
make_envelope = lambda b, amp: FourierGaussian(b_factor=b, amplitude=amp)
envelope_params = (b_factor, scale_factor)
envelope = (
eqx.filter_vmap(make_envelope, in_axes=(0, 0))(*envelope_params)
if b_factor.ndim == 1
else make_envelope(*envelope_params)
)
else:
envelope = None
transfer_theory = ContrastTransferTheory(ctf, envelope)
Expand Down

0 comments on commit c681390

Please sign in to comment.