Skip to content

Commit

Permalink
error and vis
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jun 14, 2024
1 parent 1d0354a commit e65ec1f
Showing 1 changed file with 90 additions and 7 deletions.
97 changes: 90 additions & 7 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Sequence, Tuple, Union

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
from py4DSTEM import show
import numpy as np
from py4DSTEM.datacube import DataCube
from py4DSTEM.preprocess.utils import bin2D
Expand Down Expand Up @@ -219,35 +223,50 @@ def reconstruct(
num_iter: int = 1,
step_size: float = 0.5,
num_points: int = 60,
store_iterations: bool = False,
):
""" """
device = self._device
self.error_iterations = []

if store_iterations:
self.object_iterations = []

for a0 in range(num_iter):
error_iteration = 0
for a1 in range(self._num_datacubes):
diffraction_patterns_projected = copy_to_device(
self._diffraction_patterns_projected[a1], device
)

for a2 in range(self._object_shape_6D[0]):
object_sliced = self._forward(
x_index=a2,
tilt_deg=self._tilt_deg[a1],
num_points=num_points,
)

update = self._calculate_update(
update, error = self._calculate_update(
object_sliced=object_sliced,
diffraction_patterns_projected=diffraction_patterns_projected,
x_index=a2,
datacube_number=a1,
)

error_iteration += error

update *= step_size
self._back(
num_points=num_points,
x_index=a2,
update=update,
)

self.error_iterations.append(error_iteration)
self.error = error_iteration
if store_iterations:
self.object_iterations.append(self._object)

return self

def _prepare_datacube(
Expand Down Expand Up @@ -1099,10 +1118,10 @@ def _calculate_update(

ind = self._positions_vox_F[0][0] == x_index

update = xp.zeros(
diffraction_patterns_resampled = xp.zeros(
(self._positions_vox_dF[0][0].shape[0], object_sliced.shape[-1])
)
update[
diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind],
Expand All @@ -1119,7 +1138,7 @@ def _calculate_update(
)[:, None]
)

update[
diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind] + 1,
Expand All @@ -1136,7 +1155,7 @@ def _calculate_update(
)[:, None]
)

update[
diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind],
Expand All @@ -1153,7 +1172,7 @@ def _calculate_update(
)[:, None]
)

update[
diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind] + 1,
Expand All @@ -1169,7 +1188,19 @@ def _calculate_update(
* (self._positions_vox_dF[datacube_number][1][ind])
)[:, None]
)
return update[ind]
diffraction_patterns_resampled = diffraction_patterns_resampled[ind]
update = diffraction_patterns_resampled - object_sliced

# error = xp.mean((update.ravel()) ** 2) / xp.mean(
# (diffraction_patterns_projected.ravel()) ** 2
# )

error = xp.mean(update.ravel() ** 2) / xp.mean(
diffraction_patterns_projected.ravel() ** 2
)
error = copy_to_device(error, "cpu")

return update, error

def _back(
self,
Expand Down Expand Up @@ -1528,3 +1559,55 @@ def set_storage(self, storage):
self._storage = storage

return self

def visualize(self, plot_convergence=True, figsize=(10, 10)):
"""
vis
"""

if plot_convergence:
spec = GridSpec(
ncols=2,
nrows=2,
height_ratios=[4, 1],
hspace=0.15,
# width_ratios=[
# (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
# 1,
# ],
wspace=0.35,
)

else:
spec = GridSpec(ncols=2, nrows=1, wspace=0.35)

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(spec[0, 0])
show(
self.object_6D.mean((2, 3, 4, 5)),
figax=(fig, ax),
cmap="magma",
title="real space object",
)

ax = fig.add_subplot(spec[0, 1])
show(
self.object_6D.mean((0, 1, 2, 5)),
figax=(fig, ax),
cmap="magma",
title="diffraction space object",
)

if plot_convergence:
ax = fig.add_subplot(spec[1, :])
ax.plot(self.error_iterations, color="b")
ax.set_xlabel("iterations")
ax.set_ylabel("error")

return self

@property
def object_6D(self):
""" 6D object"""

return self._object.reshape(self._object_shape_6D)

0 comments on commit e65ec1f

Please sign in to comment.