diff --git a/src/cryojax/simulator/_multislice_integrator/fft_multislice_integrator.py b/src/cryojax/simulator/_multislice_integrator/fft_multislice_integrator.py index 7198a9b3..fbcff92b 100644 --- a/src/cryojax/simulator/_multislice_integrator/fft_multislice_integrator.py +++ b/src/cryojax/simulator/_multislice_integrator/fft_multislice_integrator.py @@ -19,17 +19,21 @@ class FFTMultisliceIntegrator( ): """Multislice integrator that steps using successive FFT-based convolutions.""" + z_depth_in_voxels: int slice_thickness_in_voxels: int options_for_rasterization: dict[str, Any] def __init__( self, + z_depth_in_voxels: int, slice_thickness_in_voxels: int = 1, *, options_for_rasterization: dict[str, Any], ): """**Arguments:** + - `z_depth_in_voxels`: + The z-dimension of the rasterized voxel grid. - `slice_thickness_in_voxels`: The number of slices to step through per iteration of the rasterized voxel grid. @@ -42,6 +46,7 @@ def __init__( "FFTMultisliceIntegrator.slice_thickness_in_voxels must be an " "integer greater than or equal to 1." ) + self.z_depth_in_voxels = z_depth_in_voxels self.slice_thickness_in_voxels = slice_thickness_in_voxels self.options_for_rasterization = options_for_rasterization @@ -66,14 +71,14 @@ def compute_wavefunction_at_exit_plane( The wavefunction in the exit plane of the specimen. """ # noqa: E501 # Rasterize 3D potential - dim = min(instrument_config.padded_shape) + z_dim = self.z_depth_in_voxels + y_dim, x_dim = instrument_config.padded_shape pixel_size = instrument_config.pixel_size potential_voxel_grid = potential.as_real_voxel_grid( - (dim, dim, dim), pixel_size, **self.options_for_rasterization + (z_dim, y_dim, x_dim), pixel_size, **self.options_for_rasterization ) # Initialize multislice geometry - shape_xy = (dim, dim) - n_slices = dim // self.slice_thickness_in_voxels + n_slices = z_dim // self.slice_thickness_in_voxels slice_thickness = pixel_size * self.slice_thickness_in_voxels # Locally average the potential to be at the given slice thickness. # Thow away some slices equal to the remainder @@ -81,8 +86,8 @@ def compute_wavefunction_at_exit_plane( if self.slice_thickness_in_voxels > 1: potential_voxel_grid = jnp.mean( potential_voxel_grid[ - : dim - dim % self.slice_thickness_in_voxels, ... - ].reshape((self.slice_thickness_in_voxels, n_slices, dim, dim)), + : z_dim - z_dim % self.slice_thickness_in_voxels, ... + ].reshape((self.slice_thickness_in_voxels, n_slices, y_dim, x_dim)), axis=0, ) # Compute the integrated potential in a given slice interval, multiplying by @@ -105,7 +110,7 @@ def compute_wavefunction_at_exit_plane( * slice_thickness ) # Prepare for iteration. First, initialize plane wave - plane_wave = jnp.ones(shape_xy, dtype=complex) + plane_wave = jnp.ones((y_dim, x_dim), dtype=complex) # ... stepping function make_step = lambda n, last_exit_wave: ifftn( fftn(transmission[n, :, :] * last_exit_wave) * fresnel_propagator @@ -113,4 +118,4 @@ def compute_wavefunction_at_exit_plane( # Compute exit wave exit_wave = jax.lax.fori_loop(0, n_slices, make_step, plane_wave) - return self._postprocess_exit_wave(exit_wave, instrument_config) + return exit_wave