Skip to content

Commit

Permalink
only base class needs to compute phase aberration function
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Jan 16, 2025
1 parent 30f6272 commit 2c5059d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 82 deletions.
5 changes: 2 additions & 3 deletions docs/api/simulator/transfer_theory.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Further, the `ContrastTransferTheory` is a class that takes in a projection imag
projection_image_in_fourier_domain = ...
ctf = ContrastTransferFunction(...)
transfer_theory = cxs.ContrastTransferTheory(ctf)
contrast_in_fourier_domain = transfer_theory.propagate_object_to_detector_plan(projection_image_in_fourier_domain)
contrast_in_fourier_domain = transfer_theory.propagate_object_to_detector_plane(projection_image_in_fourier_domain)
```

This documentation describes the elements of . More features are included than described above, such as the ability to include a user-specified envelope function to the `ContrastTransferTheory`. Much of the code and theory have been adapted from the Grigorieff Lab's CTFFIND4 program.
Expand All @@ -36,7 +36,7 @@ This documentation describes the elements of . More features are included than d

## Transfer Functions

??? abstract "`cryojax.simulator.AbstractTransferFunction`"
???+ abstract "`cryojax.simulator.AbstractTransferFunction`"
::: cryojax.simulator.AbstractTransferFunction
options:
members:
Expand All @@ -47,7 +47,6 @@ This documentation describes the elements of . More features are included than d
options:
members:
- __init__
- compute_aberration_phase_shifts
- __call__

## Transfer Theories
Expand Down
49 changes: 45 additions & 4 deletions src/cryojax/simulator/_transfer_theory/base_transfer_theory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,61 @@
from abc import abstractmethod

from equinox import Module
import jax.numpy as jnp
from equinox import AbstractVar, Module
from jaxtyping import Array, Complex, Float

from ...constants import convert_keV_to_angstroms
from .common_functions import compute_phase_shifts_with_spherical_aberration


class AbstractTransferFunction(Module, strict=True):
"""An abstract base class for a transfer function."""
"""An abstract base class for a transfer function in cryo-EM."""

defocus_in_angstroms: AbstractVar[Float[Array, ""]]
astigmatism_in_angstroms: AbstractVar[Float[Array, ""]]
astigmatism_angle: AbstractVar[Float[Array, ""]]
spherical_aberration_in_mm: AbstractVar[Float[Array, ""]]
amplitude_contrast_ratio: AbstractVar[Float[Array, ""]]
phase_shift: AbstractVar[Float[Array, ""]]

@abstractmethod
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Float[Array, "y_dim x_dim"]:
raise NotImplementedError
"""Compute the frequency-dependent phase shifts due to wave aberration.
This is often denoted as $\\chi(\\boldsymbol{q})$ for the in-plane
spatial frequency $\\boldsymbol{q}$.
**Arguments:**
- `frequency_grid_in_angstroms`:
The grid of frequencies in units of inverse angstroms. This can
be computed with [`cryojax.coordinates.make_frequency_grid`](https://mjo22.github.io/cryojax/api/coordinates/making_coordinates/#cryojax.coordinates.make_frequency_grid)
- `voltage_in_kilovolts`:
The accelerating voltage of the microscope in kilovolts. This
is converted to the wavelength of incident electrons using
the function [`cryojax.constants.convert_keV_to_angstroms`](https://mjo22.github.io/cryojax/api/constants/units/#cryojax.constants.convert_keV_to_angstroms)
"""
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
# Get the wavelength
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(voltage_in_kilovolts)
)
# Compute phase shifts for CTF
phase_shifts = compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms,
self.defocus_in_angstroms,
self.astigmatism_in_angstroms,
astigmatism_angle,
wavelength_in_angstroms,
spherical_aberration_in_angstroms,
)
return phase_shifts

@abstractmethod
def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from ...constants import convert_keV_to_angstroms
from ...image.operators import FourierOperatorLike
from ...internal import error_if_negative, error_if_not_fractional
from .._instrument_config import InstrumentConfig
from .base_transfer_theory import AbstractTransferFunction
from .common_functions import (
compute_phase_shift_from_amplitude_contrast_ratio,
compute_phase_shifts_with_spherical_aberration,
)
from .common_functions import compute_phase_shift_from_amplitude_contrast_ratio


class ContrastTransferFunction(AbstractTransferFunction, strict=True):
Expand Down Expand Up @@ -68,46 +64,6 @@ def __init__(
self.amplitude_contrast_ratio = error_if_not_fractional(amplitude_contrast_ratio)
self.phase_shift = jnp.asarray(phase_shift)

@override
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Float[Array, "y_dim x_dim"]:
"""Compute the frequency-dependent phase shifts due to wave aberration.
This is often denoted as $\\chi(\\boldsymbol{q})$ for the in-plane
spatial frequency $\\boldsymbol{q}$.
**Arguments:**
- `frequency_grid_in_angstroms`:
The grid of frequencies in units of inverse angstroms. This can
be computed with [`cryojax.coordinates.make_frequency_grid`](https://mjo22.github.io/cryojax/api/coordinates/making_coordinates/#cryojax.coordinates.make_frequency_grid)
- `voltage_in_kilovolts`:
The accelerating voltage of the microscope in kilovolts. This
is converted to the wavelength of incident electrons using
the function [`cryojax.constants.convert_keV_to_angstroms`](https://mjo22.github.io/cryojax/api/constants/units/#cryojax.constants.convert_keV_to_angstroms)
"""
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
# Get the wavelength
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(voltage_in_kilovolts)
)
# Compute phase shifts for CTF
phase_shifts = compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms,
self.defocus_in_angstroms,
self.astigmatism_in_angstroms,
astigmatism_angle,
wavelength_in_angstroms,
spherical_aberration_in_angstroms,
)
return phase_shifts

@override
def __call__(
self,
Expand Down
30 changes: 0 additions & 30 deletions src/cryojax/simulator/_transfer_theory/wave_transfer_theory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing_extensions import override

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from ...constants import convert_keV_to_angstroms
from ...internal import error_if_negative, error_if_not_fractional
from .._instrument_config import InstrumentConfig
from .base_transfer_theory import AbstractTransferFunction
from .common_functions import (
compute_phase_shift_from_amplitude_contrast_ratio,
compute_phase_shifts_with_spherical_aberration,
)


Expand Down Expand Up @@ -57,32 +53,6 @@ def __init__(
self.amplitude_contrast_ratio = error_if_not_fractional(amplitude_contrast_ratio)
self.phase_shift = jnp.asarray(phase_shift)

@override
def compute_aberration_phase_shifts(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
*,
voltage_in_kilovolts: Float[Array, ""] | float = 300.0,
) -> Float[Array, "y_dim x_dim"]:
# Convert degrees to radians
astigmatism_angle = jnp.deg2rad(self.astigmatism_angle)
# Convert spherical abberation coefficient to angstroms
spherical_aberration_in_angstroms = self.spherical_aberration_in_mm * 1e7
# Get the wavelength
wavelength_in_angstroms = convert_keV_to_angstroms(
jnp.asarray(voltage_in_kilovolts)
)
# Compute phase shifts for CTF
phase_shifts = compute_phase_shifts_with_spherical_aberration(
frequency_grid_in_angstroms,
self.defocus_in_angstroms,
self.astigmatism_in_angstroms,
astigmatism_angle,
wavelength_in_angstroms,
spherical_aberration_in_angstroms,
)
return phase_shifts

def __call__(
self,
frequency_grid_in_angstroms: Float[Array, "y_dim x_dim 2"],
Expand Down

0 comments on commit 2c5059d

Please sign in to comment.