Skip to content

Commit

Permalink
no more 6D object !!!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Jun 8, 2024
1 parent d0c6e3d commit 0f4ba9d
Showing 1 changed file with 128 additions and 24 deletions.
152 changes: 128 additions & 24 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,17 @@ def preprocess(
else:
diffraction_shape = self._initial_datacube_shape[-1]
self._object = xp_storage.zeros(
self._object_shape_x_y_z
+ (
diffraction_shape,
diffraction_shape,
diffraction_shape,
(
self._object_shape_x_y_z[0],
self._object_shape_x_y_z[1] * self._object_shape_x_y_z[2],
diffraction_shape * diffraction_shape * diffraction_shape,
),
)
self._object_shape_6D = self._object_shape_x_y_z + (
diffraction_shape,
diffraction_shape,
diffraction_shape,
)

# ellpitical fitting?!

Expand Down Expand Up @@ -240,20 +244,20 @@ def preprocess(

def _forward(
self,
datacube_number: float,
tilt_deg: int,
slice_number: int,
tilt_deg: float,
num_points: int,
):
"""
Forward projection of object for simulation of diffraction data
Parameters
----------
datacube_number: int
index of datacube
slice_number: int
x slice for forward projection
tilt_deg: float
tilt of object in degrees
num_points: float
num_points: int
number of points for bilinear interpolation
Returns
Expand All @@ -264,26 +268,126 @@ def _forward(
datacube with diffraction data reshapped in 2D arrays
"""
xp = self._xp
s = self._object.shape
current_object = xp.asarray(self._object)
current_object_sliced = xp.zeros((s[0], s[1], s[-1], s[-1]))
s = self._object_shape_6D
obj = self._object[slice_number]
tilt = xp.deg2rad(tilt_deg)

for a0 in range(s[0]):
current_object_projected = self._real_space_radon(
current_object=current_object,
tilt_deg=tilt_deg,
x_index=a0,
num_points=num_points,
###solve for real space coordinates
line_z = xp.arange(0, 1, 1 / num_points) * (s[2] - 1)
line_y = line_z * xp.tan(tilt)
offset = xp.arange(s[1], dtype="int")

yF = xp.floor(line_y).astype("int")
zF = xp.floor(line_z).astype("int")
dy = line_y - yF
dz = line_z - zF

ind0 = np.hstack(
(
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
)
)

current_object_sliced[a0] = self._diffraction_space_slice(
current_object_projected=current_object_projected,
tilt_deg=tilt_deg,
ind1 = np.hstack(
(
xp.tile(zF, (s[1], 1)),
xp.tile(zF, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
)
)

current_object_sliced_2D = self._reshape_4D_array_to_2D(current_object_sliced)
weights_real = np.hstack(
(
xp.tile(((1 - dy) * (1 - dz)), (s[1], 1)),
xp.tile(((dy) * (1 - dz)), (s[1], 1)),
xp.tile(((1 - dy) * (dz)), (s[1], 1)),
xp.tile(((dy) * (dz)), (s[1], 1)),
)
)

###solve for diffraction space coordinates
xp = np
tilt_deg = 5
tilt = xp.deg2rad(tilt_deg)

l = s[-1] * xp.cos(tilt)
line_y_diff = xp.arange(-1 * (l) / 2, l / 2, l / s[-1])
line_z_diff = line_y_diff * xp.tan(tilt) + (s[-1] - 1) / 2
line_y_diff += s[-1] / 2

yF_diff = xp.floor(line_y_diff).astype("int")
zF_diff = xp.floor(line_z_diff).astype("int")
dy_diff = line_y_diff - yF_diff
dz_diff = line_z_diff - zF_diff

qx = xp.arange(11)
qy = xp.arange(11)
qxx, qyy = xp.meshgrid(qx, qy)

ind0_diff = np.hstack(
(
xp.tile(yF_diff, (s[-1], 1)),
xp.tile(yF_diff + 1, (s[-1], 1)),
xp.tile(yF_diff, (s[-1], 1)),
xp.tile(yF_diff + 1, (s[-1], 1)),
)
)

ind1_diff = np.hstack(
(
xp.tile(zF_diff, (s[-1], 1)),
xp.tile(zF_diff, (s[-1], 1)),
xp.tile(zF_diff + 1, (s[-1], 1)),
xp.tile(zF_diff + 1, (s[-1], 1)),
)
)

weights_diff = np.hstack(
(
xp.tile(((1 - dy_diff) * (1 - dz_diff)), (s[-1], 1)),
xp.tile(((dy_diff) * (1 - dz_diff)), (s[-1], 1)),
xp.tile(((1 - dy_diff) * (dz_diff)), (s[-1], 1)),
xp.tile(((dy_diff) * (dz_diff)), (s[-1], 1)),
)
)

ind_diff = xp.ravel_multi_index(
(
xp.tile(qxx.ravel(), (1, 4)),
ind0_diff.ravel(),
ind1_diff.ravel(),
),
(s[-1], s[-1], s[-1]),
"clip",
)

ind_max = self._ind_diffraction_ravel.max()
bincount_x = xp.tile(
weights_diff.ravel() * (xp.tile(self._ind_diffraction_ravel, (1, 4))),
(1, s[1]),
) + xp.repeat(xp.arange(s[1]), ind_diff.shape[1])
bincount_x = xp.asarray(bincount_x[0], dtype="int")

obj_projected = (
xp.bincount(
bincount_x,
(
obj[xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip"),]
* weights_real[:, :, None]
)
.sum(1)[:, ind_diff]
.ravel(),
minlength=self._q_length * 4 * s[1],
)
.reshape(s[1], self._q_length, 4) ## check this reshape
.sum(2)[:, self._circular_mask_bincount]
)

return current_object_sliced_2D
return obj_projected

def _prepare_datacube(
self,
Expand Down

0 comments on commit 0f4ba9d

Please sign in to comment.