Skip to content

Commit

Permalink
first attempt at correct ewald sphere extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 committed Dec 8, 2024
1 parent c95d7bb commit 2e8ed1f
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 111 deletions.
215 changes: 215 additions & 0 deletions docs/examples/ewald-sphere.ipynb

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions src/cryojax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
AbstractMultisliceIntegrator as AbstractMultisliceIntegrator,
FFTMultisliceIntegrator as FFTMultisliceIntegrator,
)

# from ..simulator._potential_integrator import (
# EwaldSphereExtraction as EwaldSphereExtraction,
# )
from ..simulator._potential_integrator import (
EwaldSphereExtraction as EwaldSphereExtraction,
)
from ..simulator._scattering_theory import (
AbstractWaveScatteringTheory as AbstractWaveScatteringTheory,
HighEnergyScatteringTheory as HighEnergyScatteringTheory,
Expand Down
2 changes: 1 addition & 1 deletion src/cryojax/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._scattering_theory import (
AbstractScatteringTheory as AbstractScatteringTheory,
AbstractWeakPhaseScatteringTheory as AbstractWeakPhaseScatteringTheory,
compute_phase_shifts_from_integrated_potential as compute_phase_shifts_from_integrated_potential, # noqa: E501
convert_units_of_integrated_potential as convert_units_of_integrated_potential, # noqa: E501
LinearSuperpositionScatteringTheory as LinearSuperpositionScatteringTheory,
WeakPhaseScatteringTheory as WeakPhaseScatteringTheory,
)
Expand Down
8 changes: 4 additions & 4 deletions src/cryojax/simulator/_imaging_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def render(
):
# Compute the squared wavefunction
fourier_contrast_at_detector_plane = (
self.scattering_theory.compute_fourier_contrast_at_detector_plane(
self.scattering_theory.compute_contrast_spectrum_at_detector_plane(
self.instrument_config, rng_key
)
)
Expand Down Expand Up @@ -278,7 +278,7 @@ def render(
):
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
theory.compute_intensity_spectrum_at_detector_plane(
self.instrument_config, rng_key
)
)
Expand Down Expand Up @@ -353,7 +353,7 @@ def render(
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
theory.compute_intensity_spectrum_at_detector_plane(
self.instrument_config
)
)
Expand All @@ -374,7 +374,7 @@ def render(
# Compute the squared wavefunction
theory = self.scattering_theory
fourier_squared_wavefunction_at_detector_plane = (
theory.compute_fourier_squared_wavefunction_at_detector_plane(
theory.compute_intensity_spectrum_at_detector_plane(
self.instrument_config, keys[0]
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from equinox import error_if
from jaxtyping import Array, Complex

# from cryojax.coordinates import make_frequency_grid
from cryojax.image import fftn, ifftn

from .._instrument_config import InstrumentConfig
from .._potential_representation import (
AbstractAtomicPotential,
)

# , RealVoxelGridPotential
from .._scattering_theory import compute_phase_shifts_from_integrated_potential
from .._scattering_theory import convert_units_of_integrated_potential
from .base_multislice_integrator import AbstractMultisliceIntegrator


Expand Down Expand Up @@ -134,7 +131,7 @@ def compute_wavefunction_at_exit_plane(
# the slice thickness (TODO: interpolate for different slice thicknesses?)
integrated_potential_per_slice = potential_per_slice * voxel_size
phase_shifts_per_slice = jax.vmap(
compute_phase_shifts_from_integrated_potential, in_axes=[0, None]
convert_units_of_integrated_potential, in_axes=[0, None]
)(integrated_potential_per_slice, instrument_config.wavelength_in_angstroms)
# Compute the transmission function
transmission = jnp.exp(1.0j * phase_shifts_per_slice)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import ClassVar, Optional
from typing_extensions import override

import jax
Expand All @@ -21,6 +21,8 @@ class GaussianMixtureProjection(
):
upsampling_factor: Optional[int]

is_integration_complex: ClassVar[bool] = False

def __init__(self, *, upsampling_factor: Optional[int] = None):
"""**Arguments:**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Generic, Optional, TypeVar

import jax.numpy as jnp
from equinox import AbstractVar, error_if, Module
from equinox import AbstractClassVar, AbstractVar, error_if, Module
from jaxtyping import Array, Complex

from ...image import maybe_rescale_pixel_size
Expand All @@ -23,6 +23,8 @@ class AbstractPotentialIntegrator(Module, Generic[PotentialT], strict=True):
the exit plane.
"""

is_integration_complex: AbstractClassVar[bool]

@abstractmethod
def compute_fourier_integrated_potential(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Using the fourier slice theorem for computing volume projections.
"""

from typing import Optional
from typing import ClassVar, Optional
from typing_extensions import override

import jax.numpy as jnp
Expand Down Expand Up @@ -36,6 +36,8 @@ class FourierSliceExtraction(AbstractVoxelPotentialIntegrator, strict=True):
interpolation_mode: str
interpolation_cval: complex

is_integration_complex: ClassVar[bool] = False

def __init__(
self,
*,
Expand Down Expand Up @@ -207,6 +209,8 @@ class EwaldSphereExtraction(AbstractVoxelPotentialIntegrator, strict=True):
interpolation_mode: str
interpolation_cval: complex

is_integration_complex: ClassVar[bool] = True

def __init__(
self,
*,
Expand Down
4 changes: 3 additions & 1 deletion src/cryojax/simulator/_potential_integrator/nufft_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import math
from typing import Optional
from typing import ClassVar, Optional
from typing_extensions import override

import jax.numpy as jnp
Expand All @@ -23,6 +23,8 @@ class NufftProjection(
pixel_rescaling_method: Optional[str]
eps: float

is_integration_complex: ClassVar[bool] = False

def __init__(
self, *, pixel_rescaling_method: Optional[str] = None, eps: float = 1e-6
):
Expand Down
2 changes: 1 addition & 1 deletion src/cryojax/simulator/_scattering_theory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
AbstractWaveScatteringTheory as AbstractWaveScatteringTheory,
)
from .common_functions import (
compute_phase_shifts_from_integrated_potential as compute_phase_shifts_from_integrated_potential, # noqa: E501
convert_units_of_integrated_potential as convert_units_of_integrated_potential, # noqa: E501
)
from .high_energy_scattering_theory import (
HighEnergyScatteringTheory as HighEnergyScatteringTheory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AbstractScatteringTheory(eqx.Module, strict=True):
"""Base class for a scattering theory."""

@abstractmethod
def compute_fourier_contrast_at_detector_plane(
def compute_contrast_spectrum_at_detector_plane(
self,
instrument_config: InstrumentConfig,
rng_key: Optional[PRNGKeyArray] = None,
Expand All @@ -26,7 +26,7 @@ def compute_fourier_contrast_at_detector_plane(
raise NotImplementedError

@abstractmethod
def compute_fourier_squared_wavefunction_at_detector_plane(
def compute_intensity_spectrum_at_detector_plane(
self,
instrument_config: InstrumentConfig,
rng_key: Optional[PRNGKeyArray] = None,
Expand All @@ -53,7 +53,7 @@ def compute_wavefunction_at_exit_plane(
raise NotImplementedError

@override
def compute_fourier_squared_wavefunction_at_detector_plane(
def compute_intensity_spectrum_at_detector_plane(
self,
instrument_config: InstrumentConfig,
rng_key: Optional[PRNGKeyArray] = None,
Expand All @@ -72,7 +72,7 @@ def compute_fourier_squared_wavefunction_at_detector_plane(
)
wavefunction_at_detector_plane = ifftn(fourier_wavefunction_at_detector_plane)
# ... get the squared wavefunction and return to fourier space
fourier_squared_wavefunction_at_detector_plane = rfftn(
intensity_spectrum_at_detector_plane = rfftn(
(
wavefunction_at_detector_plane * jnp.conj(wavefunction_at_detector_plane)
).real
Expand All @@ -82,10 +82,10 @@ def compute_fourier_squared_wavefunction_at_detector_plane(
instrument_config.padded_frequency_grid_in_angstroms
)

return translational_phase_shifts * fourier_squared_wavefunction_at_detector_plane
return translational_phase_shifts * intensity_spectrum_at_detector_plane

@override
def compute_fourier_contrast_at_detector_plane(
def compute_contrast_spectrum_at_detector_plane(
self,
instrument_config: InstrumentConfig,
rng_key: Optional[PRNGKeyArray] = None,
Expand All @@ -110,7 +110,7 @@ def compute_fourier_contrast_at_detector_plane(
).real
# ... compute the contrast directly from the squared wavefunction
# as C = -1 + psi^2 / 1 + psi^2
fourier_contrast_at_detector_plane = rfftn(
contrast_spectrum_at_detector_plane = rfftn(
(-1 + squared_wavefunction_at_detector_plane)
/ (1 + squared_wavefunction_at_detector_plane)
)
Expand All @@ -119,4 +119,4 @@ def compute_fourier_contrast_at_detector_plane(
instrument_config.padded_frequency_grid_in_angstroms
)

return translational_phase_shifts * fourier_contrast_at_detector_plane
return translational_phase_shifts * contrast_spectrum_at_detector_plane
6 changes: 4 additions & 2 deletions src/cryojax/simulator/_scattering_theory/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from jaxtyping import Array, Float, Inexact


def compute_phase_shifts_from_integrated_potential(
def convert_units_of_integrated_potential(
integrated_potential: Inexact[Array, "y_dim x_dim"],
wavelength_in_angstroms: Float[Array, ""] | float,
) -> Inexact[Array, "y_dim x_dim"]:
"""Given an integrated potential, compute a phase shift distribution.
"""Given an integrated potential, convert units using the interaction
constant. For example, the case of the projection approximation,
compute the phase shift distribution.
!!! info
In the projection approximation in cryo-EM, the phase shifts in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, PRNGKeyArray

from ...image import irfftn
from ...image import ifftn, irfftn
from .._instrument_config import InstrumentConfig
from .._potential_integrator import AbstractPotentialIntegrator
from .._solvent import AbstractIce
from .._structural_ensemble import AbstractStructuralEnsemble
from .._transfer_theory import WaveTransferTheory
from .base_scattering_theory import AbstractWaveScatteringTheory
from .common_functions import compute_phase_shifts_from_integrated_potential
from .common_functions import convert_units_of_integrated_potential


class HighEnergyScatteringTheory(AbstractWaveScatteringTheory, strict=True):
Expand Down Expand Up @@ -58,27 +58,55 @@ def compute_wavefunction_at_exit_plane(
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}"
]:
# Compute the phase shifts in the exit plane
# Compute the object spectrum in the exit plane
potential = self.structural_ensemble.get_potential_in_lab_frame()
fourier_phase_shifts_at_exit_plane = (
compute_phase_shifts_from_integrated_potential(
if not self.potential_integrator.is_integration_complex:
phase_shift_spectrum_at_exit_plane = convert_units_of_integrated_potential(
self.potential_integrator.compute_fourier_integrated_potential(
potential, instrument_config
),
instrument_config.wavelength_in_angstroms,
)
)

if rng_key is not None:
# Get the potential of the specimen plus the ice
if self.solvent is not None:
fourier_phase_shifts_at_exit_plane = (
self.solvent.compute_fourier_phase_shifts_with_ice(
rng_key, fourier_phase_shifts_at_exit_plane, instrument_config
if rng_key is not None:
# Get the potential of the specimen plus the ice
if self.solvent is not None:
phase_shift_spectrum_at_exit_plane = (
self.solvent.compute_object_spectrum_with_ice(
rng_key,
phase_shift_spectrum_at_exit_plane,
instrument_config,
is_hermitian_symmetric=True,
)
)

return jnp.exp(
1.0j
* irfftn(
phase_shift_spectrum_at_exit_plane, s=instrument_config.padded_shape
)
)
else:
object_spectrum_at_exit_plane = convert_units_of_integrated_potential(
self.potential_integrator.compute_fourier_integrated_potential(
potential, instrument_config
),
instrument_config.wavelength_in_angstroms,
)

return jnp.exp(
1.0j
* irfftn(fourier_phase_shifts_at_exit_plane, s=instrument_config.padded_shape)
)
if rng_key is not None:
# Get the potential of the specimen plus the ice
if self.solvent is not None:
object_spectrum_at_exit_plane = (
self.solvent.compute_object_spectrum_with_ice(
rng_key,
object_spectrum_at_exit_plane,
instrument_config,
is_hermitian_symmetric=False,
)
)

return jnp.exp(
1.0j
* ifftn(object_spectrum_at_exit_plane, s=instrument_config.padded_shape)
)
Loading

0 comments on commit 2e8ed1f

Please sign in to comment.