Skip to content

Commit

Permalink
Bug fix for non-square images
Browse files Browse the repository at this point in the history
  • Loading branch information
mjo22 authored Aug 10, 2024
1 parent f80179f commit a53abcb
Showing 1 changed file with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@ class FFTMultisliceIntegrator(
):
"""Multislice integrator that steps using successive FFT-based convolutions."""

z_depth_in_voxels: int
slice_thickness_in_voxels: int
options_for_rasterization: dict[str, Any]

def __init__(
self,
z_depth_in_voxels: int,
slice_thickness_in_voxels: int = 1,
*,
options_for_rasterization: dict[str, Any],
):
"""**Arguments:**
- `z_depth_in_voxels`:
The z-dimension of the rasterized voxel grid.
- `slice_thickness_in_voxels`:
The number of slices to step through per iteration of the
rasterized voxel grid.
Expand All @@ -42,6 +46,7 @@ def __init__(
"FFTMultisliceIntegrator.slice_thickness_in_voxels must be an "
"integer greater than or equal to 1."
)
self.z_depth_in_voxels = z_depth_in_voxels
self.slice_thickness_in_voxels = slice_thickness_in_voxels
self.options_for_rasterization = options_for_rasterization

Expand All @@ -66,23 +71,23 @@ def compute_wavefunction_at_exit_plane(
The wavefunction in the exit plane of the specimen.
""" # noqa: E501
# Rasterize 3D potential
dim = min(instrument_config.padded_shape)
z_dim = self.z_depth_in_voxels
y_dim, x_dim = instrument_config.padded_shape
pixel_size = instrument_config.pixel_size
potential_voxel_grid = potential.as_real_voxel_grid(
(dim, dim, dim), pixel_size, **self.options_for_rasterization
(z_dim, y_dim, x_dim), pixel_size, **self.options_for_rasterization
)
# Initialize multislice geometry
shape_xy = (dim, dim)
n_slices = dim // self.slice_thickness_in_voxels
n_slices = z_dim // self.slice_thickness_in_voxels
slice_thickness = pixel_size * self.slice_thickness_in_voxels
# Locally average the potential to be at the given slice thickness.
# Thow away some slices equal to the remainder
# `dim % self.slice_thickness_in_voxels`
if self.slice_thickness_in_voxels > 1:
potential_voxel_grid = jnp.mean(
potential_voxel_grid[
: dim - dim % self.slice_thickness_in_voxels, ...
].reshape((self.slice_thickness_in_voxels, n_slices, dim, dim)),
: z_dim - z_dim % self.slice_thickness_in_voxels, ...
].reshape((self.slice_thickness_in_voxels, n_slices, y_dim, x_dim)),
axis=0,
)
# Compute the integrated potential in a given slice interval, multiplying by
Expand All @@ -105,12 +110,12 @@ def compute_wavefunction_at_exit_plane(
* slice_thickness
)
# Prepare for iteration. First, initialize plane wave
plane_wave = jnp.ones(shape_xy, dtype=complex)
plane_wave = jnp.ones((y_dim, x_dim), dtype=complex)
# ... stepping function
make_step = lambda n, last_exit_wave: ifftn(
fftn(transmission[n, :, :] * last_exit_wave) * fresnel_propagator
)
# Compute exit wave
exit_wave = jax.lax.fori_loop(0, n_slices, make_step, plane_wave)

return self._postprocess_exit_wave(exit_wave, instrument_config)
return exit_wave

0 comments on commit a53abcb

Please sign in to comment.