Skip to content

Commit

Permalink
some other changes for making iterations work
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jun 14, 2024
1 parent e65ec1f commit de38241
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
from emdfile import tqdmnd

from py4DSTEM import show
import numpy as np
from py4DSTEM.datacube import DataCube
Expand Down Expand Up @@ -160,10 +161,13 @@ def preprocess(
# initialize object
if a0 == 0:
if self._initial_object_guess:
self._object = copy_to_device(self._initial_object_guess, storage)
self._object_initial = copy_to_device(
self._initial_object_guess, storage
)
del self._initial_object_guess
else:
diffraction_shape = self._initial_datacube_shape[-1]
self._object = xp_storage.zeros(
self._object_initial = xp_storage.zeros(
(
self._object_shape_x_y_z[0],
self._object_shape_x_y_z[1] * self._object_shape_x_y_z[2],
Expand Down Expand Up @@ -221,18 +225,29 @@ def preprocess(
def reconstruct(
self,
num_iter: int = 1,
store_iterations: bool = False,
reset: bool = True,
step_size: float = 0.5,
num_points: int = 60,
store_iterations: bool = False,
progress_bar: bool = True,
):
""" """
device = self._device
self.error_iterations = []

if store_iterations:
self.object_iterations = []
if reset is True:
self.error_iterations = []

for a0 in range(num_iter):
if store_iterations:
self.object_iterations = []

self._object = self._object_initial.copy()

for a0 in tqdmnd(
num_iter,
desc="Reconstructing object",
unit=" iter",
disable=not progress_bar,
):
error_iteration = 0
for a1 in range(self._num_datacubes):
diffraction_patterns_projected = copy_to_device(
Expand Down Expand Up @@ -265,7 +280,7 @@ def reconstruct(
self.error_iterations.append(error_iteration)
self.error = error_iteration
if store_iterations:
self.object_iterations.append(self._object)
self.object_iterations.append(self._object.copy())

return self

Expand Down Expand Up @@ -427,7 +442,7 @@ def _calculate_scan_positions(
device = self._device

# calculate shape
field_of_view_px = self._object.shape[0:2]
field_of_view_px = self._object_initial.shape[0:2]
self._field_of_view_A = (
self._voxel_size_A * field_of_view_px[0],
self._voxel_size_A * field_of_view_px[1],
Expand Down Expand Up @@ -1191,10 +1206,6 @@ def _calculate_update(
diffraction_patterns_resampled = diffraction_patterns_resampled[ind]
update = diffraction_patterns_resampled - object_sliced

# error = xp.mean((update.ravel()) ** 2) / xp.mean(
# (diffraction_patterns_projected.ravel()) ** 2
# )

error = xp.mean(update.ravel() ** 2) / xp.mean(
diffraction_patterns_projected.ravel() ** 2
)
Expand Down Expand Up @@ -1575,11 +1586,11 @@ def visualize(self, plot_convergence=True, figsize=(10, 10)):
# (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
# 1,
# ],
wspace=0.35,
wspace=0.15,
)

else:
spec = GridSpec(ncols=2, nrows=1, wspace=0.35)
spec = GridSpec(ncols=2, nrows=1)

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(spec[0, 0])
Expand Down Expand Up @@ -1608,6 +1619,6 @@ def visualize(self, plot_convergence=True, figsize=(10, 10)):

@property
def object_6D(self):
""" 6D object"""
"""6D object"""

return self._object.reshape(self._object_shape_6D)
return self._object.reshape(self._object_shape_6D)

0 comments on commit de38241

Please sign in to comment.