Skip to content

Commit

Permalink
for now, get rid of attempt to pass a voxel grid to the multislice al…
Browse files Browse the repository at this point in the history
…gorithm
  • Loading branch information
mjo22 committed Dec 6, 2024
1 parent 7a73330 commit 0fd73c2
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 104 deletions.
79 changes: 34 additions & 45 deletions docs/examples/multislice.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,34 @@
from equinox import error_if
from jaxtyping import Array, Complex

from cryojax.coordinates import make_frequency_grid
from cryojax.image import fftn, ifftn, map_coordinates
# from cryojax.coordinates import make_frequency_grid
from cryojax.image import fftn, ifftn

from .._instrument_config import InstrumentConfig
from .._potential_representation import AbstractAtomicPotential, RealVoxelGridPotential
from .._potential_representation import (
AbstractAtomicPotential,
)

# , RealVoxelGridPotential
from .._scattering_theory import compute_phase_shifts_from_integrated_potential
from .base_multislice_integrator import AbstractMultisliceIntegrator


class FFTMultisliceIntegrator(
AbstractMultisliceIntegrator[AbstractAtomicPotential | RealVoxelGridPotential],
AbstractMultisliceIntegrator[AbstractAtomicPotential], # | RealVoxelGridPotential],
strict=True,
):
"""Multislice integrator that steps using successive FFT-based convolutions."""

slice_thickness_in_voxels: int
interpolation_order: int
# interpolation_order: int
options_for_rasterization: dict[str, Any]

def __init__(
self,
slice_thickness_in_voxels: int = 1,
*,
interpolation_order: int = 1,
# interpolation_order: int = 1,
options_for_rasterization: dict[str, Any] = {},
):
"""**Arguments:**
Expand All @@ -51,13 +55,13 @@ def __init__(
"integer greater than or equal to 1."
)
self.slice_thickness_in_voxels = slice_thickness_in_voxels
self.interpolation_order = interpolation_order
# self.interpolation_order = interpolation_order
self.options_for_rasterization = options_for_rasterization

@override
def compute_wavefunction_at_exit_plane(
self,
potential: AbstractAtomicPotential | RealVoxelGridPotential,
potential: AbstractAtomicPotential, # | RealVoxelGridPotential,
instrument_config: InstrumentConfig,
) -> Complex[
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim}"
Expand All @@ -75,24 +79,32 @@ def compute_wavefunction_at_exit_plane(
The wavefunction in the exit plane of the specimen.
""" # noqa: E501
# Rasterize a voxel grid at the given settings
if isinstance(potential, AbstractAtomicPotential):
z_dim, y_dim, x_dim = (
min(instrument_config.padded_shape),
*instrument_config.padded_shape,
)
voxel_size = instrument_config.pixel_size
potential_voxel_grid = potential.as_real_voxel_grid(
(z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization
)
else:
# Interpolate volume to new pose at given coordinate system
z_dim, y_dim, x_dim = potential.real_voxel_grid.shape
voxel_size = potential.voxel_size
potential_voxel_grid = _interpolate_voxel_grid_to_rotated_coordinates(
potential.real_voxel_grid,
potential.coordinate_grid_in_pixels,
self.interpolation_order,
)
z_dim, y_dim, x_dim = (
min(instrument_config.padded_shape),
*instrument_config.padded_shape,
)
voxel_size = instrument_config.pixel_size
potential_voxel_grid = potential.as_real_voxel_grid(
(z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization
)
# if isinstance(potential, AbstractAtomicPotential):
# z_dim, y_dim, x_dim = (
# min(instrument_config.padded_shape),
# *instrument_config.padded_shape,
# )
# voxel_size = instrument_config.pixel_size
# potential_voxel_grid = potential.as_real_voxel_grid(
# (z_dim, y_dim, x_dim), voxel_size, **self.options_for_rasterization
# )
# else:
# # Interpolate volume to new pose at given coordinate system
# z_dim, y_dim, x_dim = potential.real_voxel_grid.shape
# voxel_size = potential.voxel_size
# potential_voxel_grid = _interpolate_voxel_grid_to_rotated_coordinates(
# potential.real_voxel_grid,
# potential.coordinate_grid_in_pixels,
# self.interpolation_order,
# )
# Initialize multislice geometry
n_slices = z_dim // self.slice_thickness_in_voxels
slice_thickness = voxel_size * self.slice_thickness_in_voxels
Expand Down Expand Up @@ -127,16 +139,20 @@ def compute_wavefunction_at_exit_plane(
# Compute the transmission function
transmission = jnp.exp(1.0j * phase_shifts_per_slice)
# Compute the fresnel propagator (TODO: check numerical factors)
if isinstance(potential, AbstractAtomicPotential):
radial_frequency_grid = jnp.sum(
instrument_config.padded_full_frequency_grid_in_angstroms**2,
axis=-1,
)
else:
radial_frequency_grid = jnp.sum(
make_frequency_grid((y_dim, x_dim), voxel_size, half_space=False) ** 2,
axis=-1,
)
radial_frequency_grid = jnp.sum(
instrument_config.padded_full_frequency_grid_in_angstroms**2,
axis=-1,
)
# if isinstance(potential, AbstractAtomicPotential):
# radial_frequency_grid = jnp.sum(
# instrument_config.padded_full_frequency_grid_in_angstroms**2,
# axis=-1,
# )
# else:
# radial_frequency_grid = jnp.sum(
# make_frequency_grid((y_dim, x_dim), voxel_size, half_space=False) ** 2,
# axis=-1,
# )
fresnel_propagator = jnp.exp(
1.0j
* jnp.pi
Expand All @@ -153,13 +169,15 @@ def compute_wavefunction_at_exit_plane(
# Compute exit wave
exit_wave = jax.lax.fori_loop(0, n_slices, make_step, plane_wave)

return (
exit_wave
if isinstance(potential, AbstractAtomicPotential)
else self._postprocess_exit_wave_for_voxel_potential(
exit_wave, potential, instrument_config
)
)
# return (
# exit_wave
# if isinstance(potential, AbstractAtomicPotential)
# else self._postprocess_exit_wave_for_voxel_potential(
# exit_wave, potential, instrument_config
# )
# )

return exit_wave

def _postprocess_exit_wave_for_voxel_potential(
self,
Expand Down Expand Up @@ -188,19 +206,19 @@ def _postprocess_exit_wave_for_voxel_potential(
return exit_wave


def _interpolate_voxel_grid_to_rotated_coordinates(
real_voxel_grid,
coordinate_grid_in_pixels,
interpolation_order,
):
# Convert to logical coordinates
z_dim, y_dim, x_dim = real_voxel_grid.shape
logical_coordinate_grid = (
coordinate_grid_in_pixels
+ jnp.asarray((x_dim // 2, y_dim // 2, z_dim // 2))[None, None, None, :]
)
# Convert arguments to map_coordinates convention and compute
x, y, z = jnp.transpose(logical_coordinate_grid, axes=[3, 0, 1, 2])
return map_coordinates(
real_voxel_grid, (z, y, x), order=interpolation_order, mode="fill", cval=0.0
)
# def _interpolate_voxel_grid_to_rotated_coordinates(
# real_voxel_grid,
# coordinate_grid_in_pixels,
# interpolation_order,
# ):
# # Convert to logical coordinates
# z_dim, y_dim, x_dim = real_voxel_grid.shape
# logical_coordinate_grid = (
# coordinate_grid_in_pixels
# + jnp.asarray((x_dim // 2, y_dim // 2, z_dim // 2))[None, None, None, :]
# )
# # Convert arguments to map_coordinates convention and compute
# x, y, z = jnp.transpose(logical_coordinate_grid, axes=[3, 0, 1, 2])
# return map_coordinates(
# real_voxel_grid, (z, y, x), order=interpolation_order, mode="fill", cval=0.0
# )

0 comments on commit 0fd73c2

Please sign in to comment.