Skip to content

Commit

Permalink
Merge pull request #311 from carterbox/transform
Browse files Browse the repository at this point in the history
NEW: Use fixed confidence for position regularization
  • Loading branch information
carterbox authored May 20, 2024
2 parents 1bff99d + fd2bc63 commit 4ae2596
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 53 deletions.
29 changes: 14 additions & 15 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,12 @@ def astuple(self) -> tuple:
self.t1,
)

def __call__(self, x: np.ndarray, gpu=False) -> np.ndarray:
def __call__(self, x: np.ndarray, gpu=False, shift=True) -> np.ndarray:
xp = cp.get_array_module(x)
return (x @ self.asarray(xp)) + xp.array((self.t0, self.t1))
result = x @ self.asarray(xp)
if shift:
result += xp.array((self.t0, self.t1))
return result


def estimate_global_transformation(
Expand Down Expand Up @@ -665,19 +668,13 @@ def _affine_position_helper(
scan,
position_options: PositionOptions,
max_error,
relax=0.1,
relax=0.9,
):
predicted_positions = position_options.transform(
position_options.initial_scan)
err = predicted_positions - position_options.initial_scan
# constrain more the probes in flat regions
W = relax * (1 - (position_options.confidence /
(1 + position_options.confidence)))
# penalize positions that are further than max_error from origin; avoid travel larger than max error
W = cp.minimum(10 * relax,
W + cp.maximum(0, err - max_error)**2 / max_error**2)
# allow free movement in depenence on realibility and max allowed error
new_scan = scan * (1 - W) + W * predicted_positions
position_options.initial_scan,
shift=False,
)
new_scan = scan * (1 - relax) + relax * predicted_positions
return new_scan


Expand Down Expand Up @@ -762,13 +759,15 @@ def gaussian_gradient(
sigma=sigma,
order=1,
axis=-2,
mode='nearest',
mode="nearest",
truncate=6.0,
),
cupyx.scipy.ndimage.gaussian_filter1d(
-x,
sigma=sigma,
order=1,
axis=-1,
mode='nearest',
mode="nearest",
truncate=6.0,
),
)
41 changes: 3 additions & 38 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,46 +633,11 @@ def keep_some_args_constant(

if position_options:
m = 0

# TODO: Try adjusting gradient sigma property
grad_x, grad_y = tike.ptycho.position.gaussian_gradient(
bpatches[blo:bhi])

# start section to compute position certainty metric
crop = probe.shape[-1] // 4
total_illumination = op.diffraction.patch.fwd(
images=object_preconditioner,
positions=scan[lo:hi],
patch_width=probe.shape[-1],
)[:, crop:-crop, crop:-crop].real

power = cp.abs(probe[0, 0, 0, crop:-crop, crop:-crop])**2

dX = cp.mean(
cp.abs(grad_x[:, 0, 0, crop:-crop, crop:-crop]).real *
total_illumination * power,
axis=(-2, -1),
keepdims=False,
)
dY = cp.mean(
cp.abs(grad_y[:, 0, 0, crop:-crop, crop:-crop]).real *
total_illumination * power,
axis=(-2, -1),
keepdims=False,
bpatches[blo:bhi],
sigma=0.333,
)

total_variation = cp.sqrt(cp.stack(
[dX, dY],
axis=1,
))
mean_variation = (cp.mean(
total_variation**4,
axis=0,
) + 1e-6)
position_options.confidence[
lo:hi] = total_variation**4 / mean_variation
# end section to compute position certainty metric

crop = probe.shape[-1] // 4
position_update_numerator[lo:hi, ..., 0] = cp.sum(
cp.real(
cp.conj(grad_x[..., crop:-crop, crop:-crop] *
Expand Down

0 comments on commit 4ae2596

Please sign in to comment.