Skip to content

Commit

Permalink
small name touchup
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Jan 16, 2025
1 parent 4a77ab8 commit 30f6272
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 24 deletions.
4 changes: 2 additions & 2 deletions docs/api/simulator/transfer_theory.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ This documentation describes the elements of . More features are included than d
::: cryojax.simulator.AbstractTransferFunction
options:
members:
- compute_phase_shifts_from_instrument
- compute_aberration_phase_shifts
- __call__

::: cryojax.simulator.ContrastTransferFunction
options:
members:
- __init__
- compute_phase_shifts_from_instrument
- compute_aberration_phase_shifts
- __call__

## Transfer Theories
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class AbstractTransferFunction(Module, strict=True):
"""An abstract base class for a transfer function."""

@abstractmethod
def compute_phase_shifts_from_instrument(
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
Expand Down
2 changes: 1 addition & 1 deletion src/cryojax/simulator/_transfer_theory/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


# Not currently public API
def compute_phase_shifts_with_astigmatism_and_spherical_aberration(
def compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
defocus_in_angstroms: Float[Array, ""],
astigmatism_in_angstroms: Float[Array, ""],
Expand Down
27 changes: 16 additions & 11 deletions src/cryojax/simulator/_transfer_theory/contrast_transfer_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base_transfer_theory import AbstractTransferFunction
from .common_functions import (
compute_phase_shift_from_amplitude_contrast_ratio,
compute_phase_shifts_with_astigmatism_and_spherical_aberration,
compute_phase_shifts_with_spherical_aberration,
)


Expand Down Expand Up @@ -69,13 +69,13 @@ def __init__(
self.phase_shift = jnp.asarray(phase_shift)

@override
def compute_phase_shifts_from_instrument(
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Float[Array, "y_dim x_dim"]:
"""Compute the frequency-dependent phase shifts due to the instrument.
"""Compute the frequency-dependent phase shifts due to wave aberration.
This is often denoted as $\\chi(\\boldsymbol{q})$ for the in-plane
spatial frequency $\\boldsymbol{q}$.
Expand All @@ -90,8 +90,6 @@ def compute_phase_shifts_from_instrument(
is converted to the wavelength of incident electrons using
the function [`cryojax.constants.convert_keV_to_angstroms`](https://mjo22.github.io/cryojax/api/constants/units/#cryojax.constants.convert_keV_to_angstroms)
"""
# Convert degrees to radians
phase_shift = jnp.deg2rad(self.phase_shift)
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
Expand All @@ -100,15 +98,15 @@ def compute_phase_shifts_from_instrument(
jnp.asarray(voltage_in_kilovolts)
)
# Compute phase shifts for CTF
phase_shifts = compute_phase_shifts_with_astigmatism_and_spherical_aberration(
phase_shifts = compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms,
self.defocus_in_angstroms,
self.astigmatism_in_angstroms,
astigmatism_angle,
wavelength_in_angstroms,
spherical_aberration_in_angstroms,
)
return phase_shifts - phase_shift
return phase_shifts

@override
def __call__(
Expand All @@ -129,14 +127,21 @@ def __call__(
is converted to the wavelength of incident electrons using
the function [`cryojax.constants.convert_keV_to_angstroms`](https://mjo22.github.io/cryojax/api/constants/units/#cryojax.constants.convert_keV_to_angstroms)
""" # noqa: E501
phase_shifts = self.compute_phase_shifts_from_instrument(
# Convert degrees to radians
aberration_phase_shifts = self.compute_aberration_phase_shifts(
frequency_grid_in_angstroms, voltage_in_kilovolts=voltage_in_kilovolts
)
phase_shifts -= compute_phase_shift_from_amplitude_contrast_ratio(
self.amplitude_contrast_ratio
# Additional phase shifts
phase_shift = jnp.deg2rad(self.phase_shift)
amplitude_contrast_phase_shift = (
compute_phase_shift_from_amplitude_contrast_ratio(
self.amplitude_contrast_ratio
)
)
# Compute the CTF
return jnp.sin(phase_shifts).at[0, 0].set(0.0)
return jnp.sin(
aberration_phase_shifts - (phase_shift + amplitude_contrast_phase_shift)
)


class ContrastTransferTheory(eqx.Module, strict=True):
Expand Down
21 changes: 13 additions & 8 deletions src/cryojax/simulator/_transfer_theory/wave_transfer_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .base_transfer_theory import AbstractTransferFunction
from .common_functions import (
compute_phase_shift_from_amplitude_contrast_ratio,
compute_phase_shifts_with_astigmatism_and_spherical_aberration,
compute_phase_shifts_with_spherical_aberration,
)


Expand Down Expand Up @@ -58,14 +58,13 @@ def __init__(
self.phase_shift = jnp.asarray(phase_shift)

@override
def compute_phase_shifts_from_instrument(
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Float[Array, "y_dim x_dim"]:
# Convert degrees to radians
phase_shift = jnp.deg2rad(self.phase_shift)
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
Expand All @@ -74,35 +73,41 @@ def compute_phase_shifts_from_instrument(
jnp.asarray(voltage_in_kilovolts)
)
# Compute phase shifts for CTF
phase_shifts = compute_phase_shifts_with_astigmatism_and_spherical_aberration(
phase_shifts = compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms,
self.defocus_in_angstroms,
self.astigmatism_in_angstroms,
astigmatism_angle,
wavelength_in_angstroms,
spherical_aberration_in_angstroms,
)
return phase_shifts.at[0, 0].add(phase_shift)
return phase_shifts

def __call__(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Complex[Array, "y_dim x_dim"]:
# Compute phase shifts due to the instrument
phase_shifts = self.compute_phase_shifts_from_instrument(
# Compute aberration phase shifts
aberration_phase_shifts = self.compute_aberration_phase_shifts(
frequency_grid_in_angstroms, voltage_in_kilovolts=voltage_in_kilovolts
)
# Additional phase shifts only impact zero mode
phase_shift = jnp.deg2rad(self.phase_shift)
amplitude_contrast_phase_shift = (
compute_phase_shift_from_amplitude_contrast_ratio(
self.amplitude_contrast_ratio
)
)
# Compute the WTF, correcting for the amplitude contrast and additional phase
# shift in the zero mode
return jnp.exp(-1.0j * phase_shifts.at[0, 0].add(amplitude_contrast_phase_shift))
return jnp.exp(
-1.0j
* aberration_phase_shifts.at[0, 0].add(
phase_shift + amplitude_contrast_phase_shift
)
)


class WaveTransferTheory(eqx.Module, strict=True):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agree_with_cistem.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_ctf_with_cistem(defocus1, defocus2, asti_angle, kV, cs, ac, pixel_size)
cisTEM_ctf = np.vectorize(
lambda k_sqr, theta: cisTEM_optics.Evaluate(k_sqr, theta)
)(k_sqr.ravel() * pixel_size**2, theta.ravel()).reshape(freqs.shape[0:2])
cisTEM_ctf[0, 0] = 0.0
# cisTEM_ctf[0, 0] = 0.0

# Compute cryojax and cisTEM power spectrum
radial_freqs = jnp.linalg.norm(freqs, axis=-1)
Expand Down

0 comments on commit 30f6272

Please sign in to comment.