Skip to content

Commit

Permalink
Tidy and add full polynomial for Euclid model
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Feb 14, 2024
1 parent 3cc7b86 commit 8c4ca44
Showing 1 changed file with 96 additions and 12 deletions.
108 changes: 96 additions & 12 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def engine_iterate(self, num=1):
evalp = lambda root: np.polyval(np.flip(B),root)
self.tmin = dt(min(real_roots, key=evalp)) # root with smallest poly objective
else: # same as above but quicker when poly quadratic
self.tmin = dt(-.5 * B[1] / B[2])
self.tmin = dt(-0.5 * B[1] / B[2])
self.ob_h *= self.tmin
self.pr_h *= self.tmin
self.ob += self.ob_h
Expand Down Expand Up @@ -603,7 +603,7 @@ def poly_line_coeffs(self, ob_h, pr_h):
w = pod.upsample(w)

B[0] += np.dot(w.flat, (A0**2).flat) * Brenorm
B[1] += np.dot(w.flat, (2 * A0 * A1).flat) * Brenorm
B[1] += np.dot(w.flat, (2*A0*A1).flat) * Brenorm
B[2] += np.dot(w.flat, (A1**2 + 2*A0*A2).flat) * Brenorm

parallel.allreduce(B)
Expand Down Expand Up @@ -691,7 +691,7 @@ def poly_line_all_coeffs(self, ob_h, pr_h):
if self.regularizer:
for name, s in self.ob.storages.items():
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)
ob_h.storages[name].data, s.data)

self.B = B

Expand Down Expand Up @@ -834,7 +834,7 @@ def poly_line_coeffs(self, ob_h, pr_h):

B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm
B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm

parallel.allreduce(B)

Expand Down Expand Up @@ -906,13 +906,13 @@ def poly_line_all_coeffs(self, ob_h, pr_h):

B[0] += (self.LLbase[dname] + (m * (A0 - I * np.log(A0))).sum().astype(np.float64)) * Brenorm
B[1] += np.dot(m.flat, (A1*DI).flat) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + .5*np.dot(m.flat, (I*(A1/A0)**2.).flat)) * Brenorm
B[3] += (np.dot(m.flat, (A3*DI).flat) + np.dot(m.flat, (I*((A1*A2)/(A0**2.))).flat)) * Brenorm
B[4] += (np.dot(m.flat, (A4*DI).flat) + .5*np.dot(m.flat, (I*(A2/A0)**2.).flat) + np.dot(m.flat, (I*((A1*A3)/(A0**2.))).flat)) * Brenorm
B[5] += (np.dot(m.flat, (I*((A1*A4)/(A0**2.))).flat) + np.dot(m.flat, (I*((A2*A3)/(A0**2.))).flat)) * Brenorm
B[6] += (.5*np.dot(m.flat, (I*(A3/A0)**2.).flat) + np.dot(m.flat, (I*((A2*A4)/(A0**2.))).flat)) * Brenorm
B[7] += np.dot(m.flat, (I*((A3*A4)/(A0**2.))).flat) * Brenorm
B[8] += (.5*np.dot(m.flat, (I*(A4/A0)**2.).flat)) * Brenorm
B[2] += (np.dot(m.flat, (A2*DI).flat) + 0.5*np.dot(m.flat, (I*(A1/A0)**2).flat)) * Brenorm
B[3] += (np.dot(m.flat, (A3*DI).flat) + 0.5*np.dot(m.flat, (I*((2*A1*A2)/A0**2)).flat)) * Brenorm
B[4] += (np.dot(m.flat, (A4*DI).flat) + 0.5*np.dot(m.flat, (I*((A2**2 + 2*A1*A3)/A0**2)).flat)) * Brenorm
B[5] += 0.5*np.dot(m.flat, (I*((2*A1*A4 + 2*A2*A3)/A0**2)).flat) * Brenorm
B[6] += 0.5*np.dot(m.flat, (I*((A3**2 + 2*A2*A4)/A0**2)).flat) * Brenorm
B[7] += 0.5*np.dot(m.flat, (I*((2*A3*A4)/A0**2)).flat) * Brenorm
B[8] += 0.5*np.dot(m.flat, (I*(A4/A0)**2).flat) * Brenorm

parallel.allreduce(B)

Expand Down Expand Up @@ -1078,7 +1078,7 @@ def poly_line_coeffs(self, ob_h, pr_h):

B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm
B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + .25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * A/A0**(3/2)).flat)) * Brenorm

parallel.allreduce(B)

Expand All @@ -1092,6 +1092,90 @@ def poly_line_coeffs(self, ob_h, pr_h):

return B

def poly_line_all_coeffs(self, ob_h, pr_h):
"""
Compute all the coefficients of the polynomial for line minimization
in direction h
"""

B = np.zeros((9,), dtype=np.longdouble)
Brenorm = 1. / self.LL[0]**2

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Weights and amplitudes for this view
w = self.weights[diff_view]
A = np.sqrt(diff_view.data)

A0 = None
A1 = None
A2 = None
A3 = None
A4 = None

for name, pod in diff_view.pods.items():
if not pod.active:
continue
f = pod.fw(pod.probe * pod.object)
a = pod.fw(pod.probe * ob_h[pod.ob_view]
+ pr_h[pod.pr_view] * pod.object)
b = pod.fw(pr_h[pod.pr_view] * ob_h[pod.ob_view])

if A0 is None:
A0 = u.abs2(f).astype(np.longdouble)
A1 = 2 * np.real(f * a.conj()).astype(np.longdouble)
A2 = (2 * np.real(f * b.conj()).astype(np.longdouble)
+ u.abs2(a).astype(np.longdouble))
A3 = 2 * np.real(a * b.conj()).astype(np.longdouble)
A4 = u.abs2(b).astype(np.longdouble)
else:
A0 += u.abs2(f)
A1 += 2 * np.real(f * a.conj())
A2 += 2 * np.real(f * b.conj()) + u.abs2(a)
A3 += 2 * np.real(a * b.conj())
A4 += u.abs2(b)

if self.p.floating_intensities:
A0 *= self.float_intens_coeff[dname]
A1 *= self.float_intens_coeff[dname]
A2 *= self.float_intens_coeff[dname]
A3 *= self.float_intens_coeff[dname]
A4 *= self.float_intens_coeff[dname]

A0 += 1e-12 # cf Poisson model sqrt(1e-12) = 1e-6
DA = 1. - A/np.sqrt(A0)
DA32 = A/A0**(3/2)

B[0] += np.dot(w.flat, ((np.sqrt(A0) - A)**2).flat) * Brenorm
B[1] += np.dot(w.flat, (A1*DA).flat) * Brenorm
B[2] += (np.dot(w.flat, (A2*DA).flat) + 0.25*np.dot(w.flat, (A1**2 * DA32).flat)) * Brenorm
B[3] += (np.dot(w.flat, (A3*DA).flat) + 0.25*np.dot(w.flat, (2*A1*A2 * DA32).flat) - 0.125*np.dot(w.flat, (A1**3/A0**2).flat)) * Brenorm
B[4] += (np.dot(w.flat, (A4*DA).flat) + 0.25*np.dot(w.flat, ((A2**2 + 2*A1*A3) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1**2*A2)/A0**2).flat)
+ 0.015625*np.dot(w.flat, (A1**4/A0**3).flat)) * Brenorm
B[5] += (0.25*np.dot(w.flat, ((2*A2*A3 + 2*A1*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((3*A1*A2**2 + 3*A1**2*A3)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((4*A1**3*A2)/A0**3).flat)) * Brenorm
B[6] += (0.25*np.dot(w.flat, ((A3**2 + 2*A2*A4) * DA32).flat) - 0.125*np.dot(w.flat, ((A2**3 + 3*A1**2*A4 + 6*A1*A2*A3)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((6*A1**2*A2**2 + 4*A1**3*A3)/A0**3).flat)) * Brenorm
B[7] += (0.25*np.dot(w.flat, (2*A3*A4 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2**2*A3 + 3*A1*A3**2 + 6*A1*A2*A4)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((4*A1*A2**3 + 12*A1**2*A2*A3 + 4*A1**3*A4)/A0**3).flat)) * Brenorm
B[8] += (0.25*np.dot(w.flat, (A4**2 * DA32).flat) - 0.125*np.dot(w.flat, ((3*A2*A3**2 + 3*A2**2*A4 + 6*A1*A3*A4)/A0**2).flat)
+ 0.015625*np.dot(w.flat, ((A2**4 + 12*A1*A2**2*A3 + 6*A1**2*A3**2 + 12*A1**2*A2*A4)/A0**3).flat)) * Brenorm

parallel.allreduce(B)

# Object regularizer
if self.regularizer:
for name, s in self.ob.storages.items():
B[:3] += Brenorm * self.regularizer.poly_line_coeffs(
ob_h.storages[name].data, s.data)

self.B = B

return B


class Regul_del2(object):
"""\
Expand Down

0 comments on commit 8c4ca44

Please sign in to comment.