Skip to content

Commit

Permalink
Add Lipschitz preconditioner
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Sep 12, 2024
1 parent 0c44010 commit 1884eaa
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ class ML(PositionCorrectionEngine):
help = How many coefficients to be used in the the linesearch
doc = choose between the 'quadratic' approximation (default) or 'all'
[lipschitz_precond]
default = False
type = bool
help = Whether to use the Lipschitz preconditioner
doc = This parameter can give faster convergence.
[lipschitz_delta_object]
default = 0.1
type = float
help = Lipschitz preconditioner damping constant for the object.
[lipschitz_delta_probe]
default = 0.1
type = float
help = Lipschitz preconditioner damping constant for the probe.
"""

SUPPORTED_MODELS = [Full, Vanilla, Bragg3dModel, BlockVanilla, BlockFull, GradFull, BlockGradFull]
Expand Down Expand Up @@ -137,6 +153,10 @@ def __init__(self, ptycho_parent, pars=None):
# Probe gradient
self.pr_grad_new = None

# Object and probe normalisation
if self.p.lipschitz_precond:
self.ob_nrm = None
self.pr_nrm = None

# Other
self.tmin = None
Expand Down Expand Up @@ -172,6 +192,11 @@ def engine_initialize(self):
self.pr_grad_new = self.pr.copy(self.pr.ID + '_grad_new', fill=0.)
self.pr_h = self.pr.copy(self.pr.ID + '_h', fill=0.)

# Object and probe normalisation
if self.p.lipschitz_precond:
self.ob_nrm = self.ob.copy(self.ob.ID + '_nrm', fill=0., dtype='real')
self.pr_nrm = self.pr.copy(self.pr.ID + '_nrm', fill=0., dtype='real')

self.tmin = 1.

# Other options
Expand Down Expand Up @@ -290,6 +315,13 @@ def engine_iterate(self, num=1):
self.pr_grad *= self.scale_p_o
self.pr_h -= self.pr_grad

# Lipschitz preconditioner
if self.p.lipschitz_precond:
self.ob_nrm += self.p.lipschitz_delta_object
self.ob_h /= self.ob_nrm
self.pr_nrm += self.p.lipschitz_delta_probe
self.pr_h /= self.pr_nrm

# In principle, the way things are now programmed this part
# could be iterated over in a real Newton-Raphson style.
t2 = time.time()
Expand Down Expand Up @@ -353,6 +385,11 @@ def engine_finalize(self):
del self.pr_grad_new
del self.ptycho.containers[self.pr_h.ID]
del self.pr_h
if self.p.lipschitz_precond:
del self.ptycho.containers[self.ob_nrm.ID]
del self.ob_nrm
del self.ptycho.containers[self.pr_nrm.ID]
del self.pr_nrm

# Save floating intensities into runtime
self.ptycho.runtime["float_intens"] = parallel.gather_dict(self.ML_model.float_intens_coeff)
Expand All @@ -377,6 +414,9 @@ def __init__(self, MLengine):
self.ob = self.engine.ob
self.ob_grad = self.engine.ob_grad_new
self.pr_grad = self.engine.pr_grad_new
if self.p.lipschitz_precond:
self.ob_nrm = self.engine.ob_nrm
self.pr_nrm = self.engine.pr_nrm
self.pr = self.engine.pr
self.float_intens_coeff = {}

Expand Down Expand Up @@ -490,6 +530,9 @@ def new_grad(self):
"""
self.ob_grad.fill(0.)
self.pr_grad.fill(0.)
if self.p.lipschitz_precond:
self.ob_nrm.fill(0.)
self.pr_nrm.fill(0.)

# We need an array for MPI
LL = np.array([0.])
Expand Down Expand Up @@ -531,13 +574,21 @@ def new_grad(self):
self.ob_grad[pod.ob_view] += 2. * xi * pod.probe.conj()
self.pr_grad[pod.pr_view] += 2. * xi * pod.object.conj()

# Compute normalisations for object and probe
if self.p.lipschitz_precond:
self.ob_nrm[pod.ob_view] += u.abs2(pod.probe)
self.pr_nrm[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DI.shape), 0])
LL += LLL

# MPI reduction of gradients
self.ob_grad.allreduce()
self.pr_grad.allreduce()
if self.p.lipschitz_precond:
self.ob_nrm.allreduce()
self.pr_nrm.allreduce()
parallel.allreduce(LL)

# Object regularizer
Expand Down Expand Up @@ -733,6 +784,9 @@ def new_grad(self):
"""
self.ob_grad.fill(0.)
self.pr_grad.fill(0.)
if self.p.lipschitz_precond:
self.ob_nrm.fill(0.)
self.pr_nrm.fill(0.)

# We need an array for MPI
LL = np.array([0.])
Expand Down Expand Up @@ -774,13 +828,21 @@ def new_grad(self):
self.ob_grad[pod.ob_view] += 2 * xi * pod.probe.conj()
self.pr_grad[pod.pr_view] += 2 * xi * pod.object.conj()

# Compute normalisations for object and probe
if self.p.lipschitz_precond:
self.ob_nrm[pod.ob_view] += u.abs2(pod.probe)
self.pr_nrm[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DI.shape), 0])
LL += LLL

# MPI reduction of gradients
self.ob_grad.allreduce()
self.pr_grad.allreduce()
if self.p.lipschitz_precond:
self.ob_nrm.allreduce()
self.pr_nrm.allreduce()
parallel.allreduce(LL)

# Object regularizer
Expand Down Expand Up @@ -989,6 +1051,9 @@ def new_grad(self):
"""
self.ob_grad.fill(0.)
self.pr_grad.fill(0.)
if self.p.lipschitz_precond:
self.ob_nrm.fill(0.)
self.pr_nrm.fill(0.)

# We need an array for MPI
LL = np.array([0.])
Expand Down Expand Up @@ -1030,13 +1095,21 @@ def new_grad(self):
self.ob_grad[pod.ob_view] += 2. * xi * pod.probe.conj()
self.pr_grad[pod.pr_view] += 2. * xi * pod.object.conj()

# Compute normalisations for object and probe
if self.p.lipschitz_precond:
self.ob_nrm[pod.ob_view] += u.abs2(pod.probe)
self.pr_nrm[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DA.shape), 0])
LL += LLL

# MPI reduction of gradients
self.ob_grad.allreduce()
self.pr_grad.allreduce()
if self.p.lipschitz_precond:
self.ob_nrm.allreduce()
self.pr_nrm.allreduce()
parallel.allreduce(LL)

# Object regularizer
Expand Down

0 comments on commit 1884eaa

Please sign in to comment.