Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to use full polynomial in ML linesearch #487

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 311 additions & 12 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class ML(PositionCorrectionEngine):
lowlim = 0
help = Number of iterations before probe update starts

[poly_line_coeffs]
default = quadratic
type = str
help = How many coefficients to be used in the the linesearch
doc = choose between the 'quadratic' approximation (default) or 'all'

"""

SUPPORTED_MODELS = [Full, Vanilla, Bragg3dModel, BlockVanilla, BlockFull, GradFull, BlockGradFull]
Expand Down Expand Up @@ -287,16 +293,25 @@ def engine_iterate(self, num=1):
# In principle, the way things are now programmed this part
# could be iterated over in a real Newton-Raphson style.
t2 = time.time()
B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h)
if self.p.poly_line_coeffs == "all":
B = self.ML_model.poly_line_all_coeffs(self.ob_h, self.pr_h)
diffB = np.arange(1,len(B))*B[1:] # coefficients of poly derivative
roots = np.roots(np.flip(diffB.astype(np.double))) # roots only supports double
real_roots = np.real(roots[np.isreal(roots)]) # not interested in complex roots
if real_roots.size == 1: # single real root
self.tmin = dt(real_roots[0])
else: # find real root with smallest poly objective
evalp = lambda root: np.polyval(np.flip(B),root)
self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective
elif self.p.poly_line_coeffs == "quadratic":
B = self.ML_model.poly_line_coeffs(self.ob_h, self.pr_h)
# same as above but quicker when poly quadratic
self.tmin = dt(-0.5 * B[1] / B[2])
else:
raise NotImplementedError("poly_line_coeffs should be 'quadratic' or 'all'")

tc += time.time() - t2

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.tmin = dt(-.5 * B[1] / B[2])

self.ob_h *= self.tmin
self.pr_h *= self.tmin
self.ob += self.ob_h
Expand Down Expand Up @@ -427,6 +442,13 @@ def poly_line_coeffs(self, ob_h, pr_h):
"""
raise NotImplementedError

def poly_line_all_coeffs(self, ob_h, pr_h):
"""
Compute all the coefficients of the polynomial for line minimization
in direction h
"""
raise NotImplementedError


class GaussianModel(BaseModel):
"""
Expand Down Expand Up @@ -578,7 +600,7 @@ def poly_line_coeffs(self, ob_h, pr_h):
w = pod.upsample(w)

B[0] += np.dot(w.flat, (A0**2).flat) * Brenorm
B[1] += np.dot(w.flat, (2 * A0 * A1).flat) * Brenorm
B[1] += np.dot(w.flat, (2*A0*A1).flat) * Brenorm
B[2] += np.dot(w.flat, (A1**2 + 2*A0*A2).flat) * Brenorm

parallel.allreduce(B)
Expand All @@ -589,10 +611,101 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B

def poly_line_all_coeffs(self, ob_h, pr_h):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to have a single method with an internal switch based on 'quadratic' or 'all'? It feels like there is a lot of repetition between the two...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there is a lot of code duplication, but I believe @daurer preferred it this way as there's no risk of breakage...

"""
Compute all the coefficients of the polynomial for line minimization
in direction h
"""

B = np.zeros((9,), dtype=np.longdouble)
Brenorm = 1. / self.LL[0]**2

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Weights and intensities for this view
w = self.weights[diff_view]
I = diff_view.data

A0 = None
A1 = None
A2 = None
A3 = None
A4 = None

for name, pod in diff_view.pods.items():
if not pod.active:
continue
f = pod.fw(pod.probe * pod.object)
a = pod.fw(pod.probe * ob_h[pod.ob_view]
+ pr_h[pod.pr_view] * pod.object)
b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view])

if A0 is None:
A0 = u.abs2(f).astype(np.longdouble)
A1 = 2 * np.real(f * a.conj()).astype(np.longdouble)
A2 = (2 * np.real(f * b.conj()).astype(np.longdouble)
+ u.abs2(a).astype(np.longdouble))
A3 = 2 * np.real(a * b.conj()).astype(np.longdouble)
A4 = u.abs2(b).astype(np.longdouble)
else:
A0 += u.abs2(f)
A1 += 2 * np.real(f * a.conj())
A2 += 2 * np.real(f * b.conj()) + u.abs2(a)
A3 += 2 * np.real(a * b.conj())
A4 += u.abs2(b)

if self.p.floating_intensities:
A0 *= self.float_intens_coeff[dname]
A1 *= self.float_intens_coeff[dname]
A2 *= self.float_intens_coeff[dname]
A3 *= self.float_intens_coeff[dname]
A4 *= self.float_intens_coeff[dname]

A0 = np.double(A0) - pod.upsample(I)
#A0 -= pod.upsample(I)
w = pod.upsample(w)

B[0] += np.dot(w.flat, (A0**2).flat) * Brenorm
B[1] += np.dot(w.flat, (2*A0*A1).flat) * Brenorm
B[2] += np.dot(w.flat, (A1**2 + 2*A0*A2).flat) * Brenorm
B[3] += np.dot(w.flat, (2*A0*A3 + 2*A1*A2).flat) * Brenorm
B[4] += np.dot(w.flat, (A2**2 + 2*A1*A3 + 2*A0*A4).flat) * Brenorm
B[5] += np.dot(w.flat, (2*A1*A4 + 2*A2*A3).flat) * Brenorm
B[6] += np.dot(w.flat, (A3**2 + 2*A2*A4).flat) * Brenorm
B[7] += np.dot(w.flat, (2*A3*A4).flat) * Brenorm
B[8] += np.dot(w.flat, (A4**2).flat) * Brenorm

parallel.allreduce(B)

# Object regularizer
if self.regularizer:
for name, s in self.ob.storages.items():
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B


class PoissonModel(BaseModel):
"""
Expand Down Expand Up @@ -730,7 +843,7 @@ def poly_line_coeffs(self, ob_h, pr_h):

B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm
B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm

parallel.allreduce(B)

Expand All @@ -740,6 +853,96 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B

def poly_line_all_coeffs(self, ob_h, pr_h):
"""
Compute all the coefficients of the polynomial for line minimization
in direction h
"""
B = np.zeros((9,), dtype=np.longdouble)
Brenorm = 1/(self.tot_measpts * self.LL[0])**2

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Weights and intensities for this view
I = diff_view.data
m = diff_view.pod.ma_view.data

A0 = None
A1 = None
A2 = None
A3 = None
A4 = None

for name, pod in diff_view.pods.items():
if not pod.active:
continue
f = pod.fw(pod.probe * pod.object)
a = pod.fw(pod.probe * ob_h[pod.ob_view]
+ pr_h[pod.pr_view] * pod.object)
b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view])

if A0 is None:
A0 = u.abs2(f).astype(np.longdouble)
A1 = 2 * np.real(f * a.conj()).astype(np.longdouble)
A2 = (2 * np.real(f * b.conj()).astype(np.longdouble)
+ u.abs2(a).astype(np.longdouble))
A3 = 2 * np.real(a * b.conj()).astype(np.longdouble)
A4 = u.abs2(b).astype(np.longdouble)
else:
A0 += u.abs2(f)
A1 += 2 * np.real(f * a.conj())
A2 += 2 * np.real(f * b.conj()) + u.abs2(a)
A3 += 2 * np.real(a * b.conj())
A4 += u.abs2(b)


if self.p.floating_intensities:
A0 *= self.float_intens_coeff[dname]
A1 *= self.float_intens_coeff[dname]
A2 *= self.float_intens_coeff[dname]
A3 *= self.float_intens_coeff[dname]
A4 *= self.float_intens_coeff[dname]

A0 += 1e-6
DI = 1. - I/A0

B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm
B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm
B[3] += (np.dot(m.flat, (A3*DI).flat) + 0.5*np.dot(m.flat, (I*((2*A1*A2)/A0**2)).flat)) * Brenorm
B[4] += (np.dot(m.flat, (A4*DI).flat) + 0.5*np.dot(m.flat, (I*((A2**2 + 2*A1*A3)/A0**2)).flat)) * Brenorm
B[5] += 0.5*np.dot(m.flat, (I*((2*A1*A4 + 2*A2*A3)/A0**2)).flat) * Brenorm
B[6] += 0.5*np.dot(m.flat, (I*((A3**2 + 2*A2*A4)/A0**2)).flat) * Brenorm
B[7] += 0.5*np.dot(m.flat, (I*((2*A3*A4)/A0**2)).flat) * Brenorm
B[8] += 0.5*np.dot(m.flat, (I*(A4/A0)**2).flat) * Brenorm

parallel.allreduce(B)

# Object regularizer
if self.regularizer:
for name, s in self.ob.storages.items():
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down Expand Up @@ -896,7 +1099,7 @@ def poly_line_coeffs(self, ob_h, pr_h):

B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm
B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + .25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm

parallel.allreduce(B)

Expand All @@ -906,6 +1109,102 @@ def poly_line_coeffs(self, ob_h, pr_h):
B += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B

def poly_line_all_coeffs(self, ob_h, pr_h):
"""
Compute all the coefficients of the polynomial for line minimization
in direction h
"""

B = np.zeros((9,), dtype=np.longdouble)
Brenorm = 1. / self.LL[0]**2

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Weights and amplitudes for this view
w = self.weights[diff_view]
A = np.sqrt(diff_view.data)

A0 = None
A1 = None
A2 = None
A3 = None
A4 = None

for name, pod in diff_view.pods.items():
if not pod.active:
continue
f = pod.fw(pod.probe * pod.object)
a = pod.fw(pod.probe * ob_h[pod.ob_view]
+ pr_h[pod.pr_view] * pod.object)
b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view])

if A0 is None:
A0 = u.abs2(f).astype(np.longdouble)
A1 = 2 * np.real(f * a.conj()).astype(np.longdouble)
A2 = (2 * np.real(f * b.conj()).astype(np.longdouble)
+ u.abs2(a).astype(np.longdouble))
A3 = 2 * np.real(a * b.conj()).astype(np.longdouble)
A4 = u.abs2(b).astype(np.longdouble)
else:
A0 += u.abs2(f)
A1 += 2 * np.real(f * a.conj())
A2 += 2 * np.real(f * b.conj()) + u.abs2(a)
A3 += 2 * np.real(a * b.conj())
A4 += u.abs2(b)

if self.p.floating_intensities:
A0 *= self.float_intens_coeff[dname]
A1 *= self.float_intens_coeff[dname]
A2 *= self.float_intens_coeff[dname]
A3 *= self.float_intens_coeff[dname]
A4 *= self.float_intens_coeff[dname]

A0 += 1e-12 # cf Poisson model sqrt(1e-12) = 1e-6
DA = 1. - A/np.sqrt(A0)
DA32 = A/A0**(3/2)

B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm
B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * DA32).flat)) * Brenorm
B[3] += (np.dot(w.flat, (A3*DA).flat) + 0.25*np.dot(w.flat, (2*A1*A2 * DA32).flat) - 0.125*np.dot(w.flat, (A1**3/A0**2).flat)) * Brenorm
B[4] += (np.dot(w.flat, (A4*DA).flat) + 0.25*np.dot(w.flat, ((A2**2 + 2*A1*A3) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1**2*A2)/A0**2).flat)
+ 0.015625*np.dot(w.flat, (A1**4/A0**3).flat)) * Brenorm
B[5] += (0.25*np.dot(w.flat, ((2*A2*A3 + 2*A1*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1*A2**2 + 3*A1**2*A3)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((4*A1**3*A2)/A0**3).flat)) * Brenorm
B[6] += (0.25*np.dot(w.flat, ((A3**2 + 2*A2*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((A2**3 + 3*A1**2*A4 + 6*A1*A2*A3)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((6*A1**2*A2**2 + 4*A1**3*A3)/A0**3).flat)) * Brenorm
B[7] += (0.25*np.dot(w.flat, (2*A3*A4 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2**2*A3 + 3*A1*A3**2 + 6*A1*A2*A4)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((4*A1*A2**3 + 12*A1**2*A2*A3 + 4*A1**3*A4)/A0**3).flat)) * Brenorm
B[8] += (0.25*np.dot(w.flat, (A4**2 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2*A3**2 + 3*A2**2*A4 + 6*A1*A3*A4)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((A2**4 + 12*A1*A2**2*A3 + 6*A1**2*A3**2 + 12*A1**2*A2*A4)/A0**3).flat)) * Brenorm

parallel.allreduce(B)

# Object regularizer
if self.regularizer:
for name, s in self.ob.storages.items():
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

if np.isinf(B).any() or np.isnan(B).any():
logger.warning(
'Warning! inf or nan found! Trying to continue...')
B[np.isinf(B)] = 0.
B[np.isnan(B)] = 0.

self.B = B

return B
Expand Down