Skip to content

Commit

Permalink
more careful normalizations
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed May 30, 2024
1 parent 7ed3d2e commit 936df0e
Showing 1 changed file with 77 additions and 37 deletions.
114 changes: 77 additions & 37 deletions py4DSTEM/process/phase/xray_magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,49 +887,22 @@ def _gradient_descent_adjoint(
+ (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2
)

probe_magnetic_abs = xp.abs(shifted_probes * xp.exp(1.0j * object_patches[1]))
probe_magnetic_normalization = self._sum_overlapping_patches_bincounts(
probe_magnetic_abs**2,
positions_px,
)
probe_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ ((1 - normalization_min) * probe_magnetic_normalization) ** 2
+ (normalization_min * xp.max(probe_magnetic_normalization)) ** 2
)
match (self._recon_mode, self._active_measurement_index):
case (0, 0) | (1, 0): # reverse

if not fix_probe:
electrostatic_magnetic_abs = xp.abs(
object_patches[0] * xp.exp(1.0j * object_patches[1])
)
electrostatic_magnetic_normalization = xp.sum(
electrostatic_magnetic_abs**2,
axis=0,
)
electrostatic_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2
+ (normalization_min * xp.max(electrostatic_magnetic_normalization))
** 2
)
magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1]))

if self._recon_mode > 0:
electrostatic_abs = xp.abs(object_patches[0])
electrostatic_normalization = xp.sum(
electrostatic_abs**2,
axis=0,
probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp)
probe_magnetic_normalization = self._sum_overlapping_patches_bincounts(
probe_magnetic_abs**2,
positions_px,
)
electrostatic_normalization = 1 / xp.sqrt(
probe_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ ((1 - normalization_min) * electrostatic_normalization) ** 2
+ (normalization_min * xp.max(electrostatic_normalization)) ** 2
+ ((1 - normalization_min) * probe_magnetic_normalization) ** 2
+ (normalization_min * xp.max(probe_magnetic_normalization)) ** 2
)

match (self._recon_mode, self._active_measurement_index):
case (0, 0) | (1, 0): # reverse

magnetic_exp = xp.exp(1.0j * xp.conj(object_patches[1]))

# P* exp(i M*)
electrostatic_update = self._sum_overlapping_patches_bincounts(
probe_conj * magnetic_exp * exit_waves,
Expand All @@ -950,6 +923,28 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_magnetic_abs = xp.abs(
object_patches[0] * magnetic_exp
)
electrostatic_magnetic_normalization = xp.sum(
electrostatic_magnetic_abs**2,
axis=0,
)
electrostatic_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ (
(1 - normalization_min)
* electrostatic_magnetic_normalization
)
** 2
+ (
normalization_min
* xp.max(electrostatic_magnetic_normalization)
)
** 2
)

# exp(i M*) C*
current_probe += step_size * (
xp.sum(
Expand All @@ -963,6 +958,17 @@ def _gradient_descent_adjoint(

magnetic_exp = xp.exp(-1.0j * xp.conj(object_patches[1]))

probe_magnetic_abs = xp.abs(shifted_probes * magnetic_exp)
probe_magnetic_normalization = self._sum_overlapping_patches_bincounts(
probe_magnetic_abs**2,
positions_px,
)
probe_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ ((1 - normalization_min) * probe_magnetic_normalization) ** 2
+ (normalization_min * xp.max(probe_magnetic_normalization)) ** 2
)

# P* exp(-i M*)
electrostatic_update = self._sum_overlapping_patches_bincounts(
probe_conj * magnetic_exp * exit_waves,
Expand All @@ -983,6 +989,28 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_magnetic_abs = xp.abs(
object_patches[0] * magnetic_exp
)
electrostatic_magnetic_normalization = xp.sum(
electrostatic_magnetic_abs**2,
axis=0,
)
electrostatic_magnetic_normalization = 1 / xp.sqrt(
1e-16
+ (
(1 - normalization_min)
* electrostatic_magnetic_normalization
)
** 2
+ (
normalization_min
* xp.max(electrostatic_magnetic_normalization)
)
** 2
)

# exp(-i M*) C*
current_probe += step_size * (
xp.sum(
Expand Down Expand Up @@ -1015,6 +1043,18 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_abs = xp.abs(object_patches[0])
electrostatic_normalization = xp.sum(
electrostatic_abs**2,
axis=0,
)
electrostatic_normalization = 1 / xp.sqrt(
1e-16
+ ((1 - normalization_min) * electrostatic_normalization) ** 2
+ (normalization_min * xp.max(electrostatic_normalization)) ** 2
)

# V*
current_probe += step_size * (
xp.sum(
Expand Down

0 comments on commit 936df0e

Please sign in to comment.