Skip to content

Commit

Permalink
Add required kernel function
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Jan 6, 2025
1 parent 17ba4ad commit fe0df7d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
31 changes: 31 additions & 0 deletions ptypy/accelerate/base/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,37 @@ def make_model(self, b_aux, addr):
tf = aux.reshape(sh[0], self.nmodes, sh[1], sh[2])
Imodel[:] = ((tf * tf.conj()).real).sum(1)

def make_a012_notb(self, b_f, b_a, addr, I, fic):

# reference shape (= GPU global dims)
sh = I.shape

# stopper
maxz = I.shape[0]

A0 = self.npy.Imodel
A1 = self.npy.LLerr
A2 = self.npy.LLden

# batch buffers
f = b_f[:maxz * self.nmodes]
a = b_a[:maxz * self.nmodes]

## Actual math ## (subset of FUK.fourier_error)
fc = fic.reshape((maxz,1,1))
A0.fill(0.)
tf = np.real(f * f.conj()).astype(self.ftype)
A0[:maxz] = np.double(tf.reshape(maxz, self.nmodes, sh[1], sh[2]).sum(1) * fc) - I

A1.fill(0.)
tf = 2. * np.real(f * a.conj())
A1[:maxz] = tf.reshape(maxz, self.nmodes, sh[1], sh[2]).sum(1) * fc

A2.fill(0.)
tf = np.real(a * a.conj())
A2[:maxz] = tf.reshape(maxz, self.nmodes, sh[1], sh[2]).sum(1) * fc
return

def make_a012(self, b_f, b_a, b_b, addr, I, fic):

# reference shape (= GPU global dims)
Expand Down
4 changes: 2 additions & 2 deletions ptypy/custom/MLSeparateGrads_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def poly_line_coeffs_ob(self, c_ob_h):
f[:] = FW(f)
a[:] = FW(a)

GDK.make_a012(f, a, 0, addr, I, fic) # FIXME: need new kernel
GDK.make_a012_notb(f, a, addr, I, fic)
GDK.fill_b(addr, Brenorm, w, B)

parallel.allreduce(B)
Expand Down Expand Up @@ -596,7 +596,7 @@ def poly_line_coeffs_pr(self, c_pr_h):
f[:] = FW(f)
a[:] = FW(a)

GDK.make_a012(f, a, 0, addr, I, fic) # FIXME: need new kernel
GDK.make_a012_notb(f, a, addr, I, fic)
GDK.fill_b(addr, Brenorm, w, B)

parallel.allreduce(B)
Expand Down

0 comments on commit fe0df7d

Please sign in to comment.