Skip to content

Commit

Permalink
Make requested changes to ML.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes authored Dec 6, 2024
1 parent 4a74199 commit 501eb4d
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,8 @@ def __init__(self, ptycho_parent, pars=None):
self.pr_grad_new = None

# Object and probe fluence maps
if self.p.wavefield_precond:
self.ob_fln = None
self.pr_fln = None
self.ob_fln = None
self.pr_fln = None

# Other
self.tmin = None
Expand Down Expand Up @@ -257,11 +256,9 @@ def engine_iterate(self, num=1):

# Wavefield preconditioner
if self.p.wavefield_precond:
self.ob_fln += self.p.wavefield_delta_object
self.pr_fln += self.p.wavefield_delta_probe
for name, s in new_ob_grad.storages.items():
new_ob_grad.storages[name].data /= np.sqrt(self.ob_fln.storages[name].data)
new_pr_grad.storages[name].data /= np.sqrt(self.pr_fln.storages[name].data)
new_ob_grad.storages[name].data /= np.sqrt(self.ob_fln.storages[name].data + self.p.wavefield_delta_object)
new_pr_grad.storages[name].data /= np.sqrt(self.pr_fln.storages[name].data + self.p.wavefield_delta_probe)

# Smoothing preconditioner
if self.smooth_gradient:
Expand Down Expand Up @@ -312,14 +309,17 @@ def engine_iterate(self, num=1):

# 3. Next conjugate
self.ob_h *= bt / self.tmin
# Smoothing and wavefield preconditioners for the object
if self.smooth_gradient and self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data))
elif self.p.wavefield_precond:
# Wavefield preconditioner for the object (with and without smoothing preconditioner)
if self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data)
elif self.smooth_gradient:
if self.smooth_gradient:
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data
/ np.sqrt(self.ob_fln.storages[name].data + self.p.wavefield_delta_object))
else:
s.data[:] -= (self.ob_grad.storages[name].data
/ np.sqrt(self.ob_fln.storages[name].data + self.p.wavefield_delta_object))
# Smoothing preconditioner for the object
if self.smooth_gradient:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data)
else:
Expand All @@ -330,7 +330,8 @@ def engine_iterate(self, num=1):
# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in self.pr_h.storages.items():
s.data[:] -= self.pr_grad.storages[name].data / np.sqrt(self.pr_fln.storages[name].data)
s.data[:] -= (self.pr_grad.storages[name].data
/ np.sqrt(self.pr_fln.storages[name].data + self.p.wavefield_delta_probe))
else:
self.pr_h -= self.pr_grad

Expand Down Expand Up @@ -426,9 +427,8 @@ def __init__(self, MLengine):
self.ob = self.engine.ob
self.ob_grad = self.engine.ob_grad_new
self.pr_grad = self.engine.pr_grad_new
if self.p.wavefield_precond:
self.ob_fln = self.engine.ob_fln
self.pr_fln = self.engine.pr_fln
self.ob_fln = self.engine.ob_fln
self.pr_fln = self.engine.pr_fln
self.pr = self.engine.pr
self.float_intens_coeff = {}

Expand Down

0 comments on commit 501eb4d

Please sign in to comment.