Skip to content

Commit

Permalink
CUDA kernels for upsampling, implemented for DM, still need to do ML
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Apr 9, 2021
1 parent 9b458e7 commit 556d111
Show file tree
Hide file tree
Showing 28 changed files with 1,485 additions and 269 deletions.
24 changes: 24 additions & 0 deletions ptypy/accelerate/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import numpy as np
from scipy import ndimage as ndi
from ptypy import utils as u


def dot(A, B, acc_dtype=np.float64):
Expand Down Expand Up @@ -148,3 +149,26 @@ def crop_pad_2d_simple(A, B):
b1, b2 = B.shape[-2:]
offset = [0, a1 // 2 - b1 // 2, a2 // 2 - b2 // 2]
fill3D(A, B, offset)

def resample(A, B):
"""
Resamples the last two dimensions of B onto shape of A and places it in A.
The ratio between shapes needs to be a power of 2 along the last two dimension.
upsampling (A larger than B): nearest neighbour interpolation
downsampling (B larger than A): integrate over neighbouring regions
"""
assert A.ndim > 2, "Arrays must have at least 2 dimensions"
assert B.ndim > 2, "Arrays must have at least 2 dimensions"
assert A.shape[:-2] == B.shape[:-2], "Arrays must have same shape expect along the last 2 dimensions"
assert A.shape[-2] == A.shape[-1], "Last two dimensions must be of equal length"
assert B.shape[-2] == B.shape[-2], "Last two dimensions must be of equal length"
# same sampling, no need to call this function
assert A.shape != B.shape, "A and B have the same shape, no need to do any resampling"
# upsampling
if A.shape[-1] > B.shape[-1]:
resample = A.shape[-1] // B.shape[-1]
A[:] = u.repeat_2d(B, resample) / (resample**2)
# downsampling
elif A.shape[-1] < B.shape[-1]:
resample = B.shape[-1] // A.shape[-1]
A[:] = u.rebin_2d(B, resample) * (resample**2)
47 changes: 41 additions & 6 deletions ptypy/accelerate/base/engines/DM_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,18 @@ def _setup_kernels(self):
aux = np.zeros(ash, dtype=np.complex64)
kern.aux = aux

# create extra array for resampling (if needed)
kern.resample = scan.resample > 1
if kern.resample:
ish = (ash[0],) + tuple(geo.shape//scan.resample)
kern.aux_tmp1 = np.zeros(ash, dtype=np.float32)
kern.aux_tmp2 = np.zeros(ish, dtype=np.float32)
aux_f = kern.aux_tmp2 # making sure to pass the correct shapes to FUK
else:
aux_f = aux

# setup kernels, one for each SCAN.
kern.FUK = FourierUpdateKernel(aux, nmodes)
kern.FUK = FourierUpdateKernel(aux_f, nmodes)
kern.FUK.allocate()

kern.POK = PoUpdateKernel()
Expand Down Expand Up @@ -279,6 +289,12 @@ def engine_iterate(self, num=1):
pbound = self.pbound_scan[prep.label]
aux = kern.aux

# resampling
resample = kern.resample
if resample:
aux_tmp1 = kern.aux_tmp1
aux_tmp2 = kern.aux_tmp2

# local references
ma = prep.ma
ob = self.ob.S[oID].data
Expand All @@ -290,7 +306,12 @@ def engine_iterate(self, num=1):
t1 = time.time()
AWK.build_aux_no_ex(aux, addr, ob, pr)
aux[:] = FW(aux)
FUK.log_likelihood(aux, addr, mag, ma, err_phot)
if resample:
aux_tmp1 = np.abs(aux)**2
au.resample(aux_tmp2, aux_tmp1)
FUK.log_likelihood(aux_tmp2, addr, mag, ma, err_phot, aux_is_intensity=True)
else:
FUK.log_likelihood(aux, addr, mag, ma, err_phot)
self.benchmark.F_LLerror += time.time() - t1

## build auxilliary wave
Expand All @@ -305,9 +326,18 @@ def engine_iterate(self, num=1):

## Deviation from measured data
t1 = time.time()
FUK.fourier_error(aux, addr, mag, ma, ma_sum)
FUK.error_reduce(addr, err_fourier)
FUK.fmag_all_update(aux, addr, mag, ma, err_fourier, pbound)
if resample:
aux_tmp1 = np.abs(aux)**2
au.resample(aux_tmp2, aux_tmp1)
FUK.fourier_error(aux_tmp2, addr, mag, ma, ma_sum, aux_is_intensity=True)
FUK.error_reduce(addr, err_fourier)
FUK.fmag_all_update(aux_tmp2, addr, mag, ma, err_fourier, pbound, mult=False)
au.resample(aux_tmp1, aux_tmp2)
aux *= aux_tmp1
else:
FUK.fourier_error(aux, addr, mag, ma, ma_sum)
FUK.error_reduce(addr, err_fourier)
FUK.fmag_all_update(aux, addr, mag, ma, err_fourier, pbound)
self.benchmark.C_Fourier_update += time.time() - t1

## backward FFT
Expand All @@ -318,7 +348,12 @@ def engine_iterate(self, num=1):
## build exit wave
t1 = time.time()
AWK.build_exit(aux, addr, ob, pr, ex, alpha=self.p.alpha)
FUK.exit_error(aux,addr)
if resample:
aux_tmp1 = np.abs(aux)**2
au.resample(aux_tmp2, aux_tmp1)
FUK.exit_error(aux_tmp2, addr, aux_is_intensity=True)
else:
FUK.exit_error(aux,addr)
FUK.error_reduce(addr, err_exit)
self.benchmark.E_Build_exit += time.time() - t1

Expand Down
42 changes: 37 additions & 5 deletions ptypy/accelerate/base/engines/ML_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from ptypy.utils import parallel
from ptypy.engines.utils import Cnorm2, Cdot
from ptypy.engines import register
from ptypy.accelerate.base.kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel
from ptypy.accelerate.base.kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel
from ptypy.accelerate.base import address_manglers
from ptypy.accelerate.base import array_utils as au


__all__ = ['ML_serial']
Expand Down Expand Up @@ -90,6 +91,13 @@ def _setup_kernels(self):
kern.a = np.zeros(ash, dtype=np.complex64)
kern.b = np.zeros(ash, dtype=np.complex64)

# create extra array for resampling (if needed)
kern.resample = scan.resample > 1
if kern.resample:
ish = (ash[0],) + tuple(geo.shape//scan.resample)
kern.aux_tmp1 = np.zeros(ash, dtype=np.float32)
kern.aux_tmp2 = np.zeros(ish, dtype=np.float32)

# setup kernels, one for each SCAN.
kern.GDK = GradientDescentKernel(aux, nmodes)
kern.GDK.allocate()
Expand Down Expand Up @@ -284,6 +292,13 @@ def prepare(self):
prep = self.engine.diff_info[d.ID]
prep.weights = (self.Irenorm * self.engine.ma.S[d.ID].data
/ (1. / self.Irenorm + d.data)).astype(d.data.dtype)
prep.intensity = self.engine.di.S[d.ID].data
kern = self.engine.kernels[prep.label]
if kern.resample:
prep.weights2 = np.zeros((prep.weights.shape[0],) + kern.aux_tmp1.shape[-2:], dtype=prep.weights.dtype)
au.resample(prep.weights2, prep.weights)
prep.intensity2 = np.zeros((prep.intensity.shape[0],) + kern.aux_tmp1.shape[-2:], dtype=prep.weights.dtype)
au.resample(prep.intensity2, prep.intensity)

def __del__(self):
"""
Expand Down Expand Up @@ -325,7 +340,8 @@ def new_grad(self):

# get addresses and auxilliary array
addr = prep.addr
w = prep.weights
w = prep.weights2
I = prep.intensity2
err_phot = prep.err_phot
fic = prep.float_intens_coeff

Expand All @@ -334,15 +350,26 @@ def new_grad(self):
obg = ob_grad.S[oID].data
pr = self.engine.pr.S[pID].data
prg = pr_grad.S[pID].data
I = self.engine.di.S[dID].data

# resampling
if kern.resample:
aux_tmp1 = kern.aux_tmp1
aux_tmp2 = kern.aux_tmp2

# make propagated exit (to buffer)
AWK.build_aux_no_ex(aux, addr, ob, pr, add=False)

# forward prop
aux[:] = FW(aux)

GDK.make_model(aux, addr)
if kern.resample:
aux_tmp1 = np.abs(aux)**2
au.resample(aux_tmp2, aux_tmp1)
au.resample(aux_tmp1, aux_tmp2)
GDK.make_model(aux_tmp1, addr, aux_is_intensity=True)
else:
GDK.make_model(aux, addr)

if self.p.floating_intensities:
GDK.floating_intensity(addr, w, I, fic)
GDK.main(aux, addr, w, I)
Expand Down Expand Up @@ -406,14 +433,19 @@ def poly_line_coeffs(self, c_ob_h, c_pr_h):
# get addresses and auxilliary array
addr = prep.addr
w = prep.weights
I = prep.intensity
fic = prep.float_intens_coeff

# local references
ob = self.ob.S[oID].data
ob_h = c_ob_h.S[oID].data
pr = self.pr.S[pID].data
pr_h = c_pr_h.S[pID].data
I = self.di.S[dID].data

# resampling
if kern.resample:
w = prep.weights2
I = prep.intensity2

# make propagated exit (to buffer)
AWK.build_aux_no_ex(f, addr, ob, pr, add=False)
Expand Down
66 changes: 43 additions & 23 deletions ptypy/accelerate/base/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def allocate(self):
self.npy.fdev = np.zeros(self.fshape, dtype=np.float32)
self.npy.ferr = np.zeros(self.fshape, dtype=np.float32)

def fourier_error(self, b_aux, addr, mag, mask, mask_sum):
def fourier_error(self, b_aux, addr, mag, mask, mask_sum, aux_is_intensity=False):
# reference shape (write-to shape)
sh = self.fshape
sh = b_aux.shape
# stopper
maxz = mag.shape[0]

Expand All @@ -65,7 +65,10 @@ def fourier_error(self, b_aux, addr, mag, mask, mask_sum):
# build model from complex fourier magnitudes, summing up
# all modes incoherently
tf = aux.reshape(maxz, self.nmodes, sh[1], sh[2])
af = np.sqrt((np.abs(tf) ** 2).sum(1))
if aux_is_intensity:
af = np.sqrt(tf.sum(1))
else:
af = np.sqrt((np.abs(tf) ** 2).sum(1))

# calculate difference to real data (g_mag)
fdev[:] = af - mag
Expand All @@ -74,9 +77,9 @@ def fourier_error(self, b_aux, addr, mag, mask, mask_sum):
ferr[:] = mask * np.abs(fdev) ** 2 / mask_sum.reshape((maxz, 1, 1))
return

def fourier_deviation(self, b_aux, addr, mag):
def fourier_deviation(self, b_aux, addr, mag, aux_is_intensity=False):
# reference shape (write-to shape)
sh = self.fshape
sh = b_aux.shape
# stopper
maxz = mag.shape[0]

Expand All @@ -89,16 +92,17 @@ def fourier_deviation(self, b_aux, addr, mag):
# build model from complex fourier magnitudes, summing up
# all modes incoherently
tf = aux.reshape(maxz, self.nmodes, sh[1], sh[2])
af = np.sqrt((np.abs(tf) ** 2).sum(1))
if aux_is_intensity:
af = np.sqrt(tf.sum(1))
else:
af = np.sqrt((np.abs(tf) ** 2).sum(1))

# calculate difference to real data (g_mag)
fdev[:] = af - mag

return

def error_reduce(self, addr, err_sum):
# reference shape (write-to shape)
sh = self.fshape

# stopper
maxz = err_sum.shape[0]
Expand All @@ -113,9 +117,9 @@ def error_reduce(self, addr, err_sum):
err_sum[:] = ferr.sum(-1).sum(-1)
return

def fmag_all_update(self, b_aux, addr, mag, mask, err_sum, pbound=0.0):
def fmag_all_update(self, b_aux, addr, mag, mask, err_sum, pbound=0.0, mult=True):

sh = self.fshape
sh = b_aux.shape
nmodes = self.nmodes

# stopper
Expand Down Expand Up @@ -153,12 +157,15 @@ def fmag_all_update(self, b_aux, addr, mag, mask, err_sum, pbound=0.0):

#fm[:] = mag / (af + 1e-6)
# upcasting
aux[:] = (aux.reshape(ish[0] // nmodes, nmodes, ish[1], ish[2]) * fm[:, np.newaxis, :, :]).reshape(ish)
if mult:
aux[:] = (aux.reshape(ish[0] // nmodes, nmodes, ish[1], ish[2]) * fm[:, np.newaxis, :, :]).reshape(ish)
else:
aux[:] = (np.ones((ish[0] // nmodes, nmodes, ish[1], ish[2])) * fm[:, np.newaxis, :, :]).reshape(ish)
return

def fmag_update_nopbound(self, b_aux, addr, mag, mask):
def fmag_update_nopbound(self, b_aux, addr, mag, mask, mult=True):

sh = self.fshape
sh = b_aux.shape
nmodes = self.nmodes

# stopper
Expand All @@ -180,12 +187,15 @@ def fmag_update_nopbound(self, b_aux, addr, mag, mask):
fm[:] = (1 - mask) + mask * mag / (af + self.denom)

# upcasting
aux[:] = (aux.reshape(ish[0] // nmodes, nmodes, ish[1], ish[2]) * fm[:, np.newaxis, :, :]).reshape(ish)
if mult:
aux[:] = (aux.reshape(ish[0] // nmodes, nmodes, ish[1], ish[2]) * fm[:, np.newaxis, :, :]).reshape(ish)
else:
aux[:] = (np.ones((ish[0] // nmodes, nmodes, ish[1], ish[2])) * fm[:, np.newaxis, :, :]).reshape(ish)
return

def log_likelihood(self, b_aux, addr, mag, mask, err_phot):
def log_likelihood(self, b_aux, addr, mag, mask, err_phot, aux_is_intensity=False):
# reference shape (write-to shape)
sh = self.fshape
sh = b_aux.shape
# stopper
maxz = mag.shape[0]

Expand All @@ -195,7 +205,10 @@ def log_likelihood(self, b_aux, addr, mag, mask, err_phot):
# build model from complex fourier magnitudes, summing up
# all modes incoherently
tf = aux.reshape(maxz, self.nmodes, sh[1], sh[2])
LL = (np.abs(tf) ** 2).sum(1)
if aux_is_intensity:
LL = tf.sum(1)
else:
LL = (np.abs(tf) ** 2).sum(1)

# Intensity data
I = mag**2
Expand All @@ -204,15 +217,19 @@ def log_likelihood(self, b_aux, addr, mag, mask, err_phot):
err_phot[:] = ((mask * (LL - I)**2 / (I + 1.)).sum(-1).sum(-1) / np.prod(LL.shape[-2:]))
return

def exit_error(self, aux, addr):
def exit_error(self, aux, addr, aux_is_intensity=False):
sh = addr.shape
maxz = sh[0]

# batch buffers
ferr = self.npy.ferr[:maxz]
dex = aux[:maxz * self.nmodes]
fsh = dex.shape[-2:]
ferr[:] = (np.abs(dex.reshape((maxz,self.nmodes,fsh[0], fsh[1])))**2).sum(axis=1) / np.prod(fsh)
tf = dex.reshape((maxz,self.nmodes,fsh[0], fsh[1]))
if aux_is_intensity:
ferr[:] = (tf).sum(axis=1) / np.prod(fsh)
else:
ferr[:] = (np.abs(tf)**2).sum(axis=1) / np.prod(fsh)


class GradientDescentKernel(BaseKernel):
Expand Down Expand Up @@ -256,18 +273,21 @@ def allocate(self):

self.npy.fic_tmp = np.ones((self.fshape[0],), dtype=self.ftype)

def make_model(self, b_aux, addr):
def make_model(self, b_aux, addr, aux_is_intensity=False):

# reference shape (= GPU global dims)
sh = self.fshape
sh = b_aux.shape

# batch buffers
Imodel = self.npy.Imodel
aux = b_aux

## Actual math ## (subset of FUK.fourier_error)
tf = aux.reshape(sh[0], self.nmodes, sh[1], sh[2])
Imodel[:] = (np.abs(tf) ** 2).sum(1)
if aux_is_intensity:
Imodel[:] = tf.sum(1)
else:
Imodel[:] = (np.abs(tf) ** 2).sum(1)

def make_a012(self, b_f, b_a, b_b, addr, I, fic):

Expand Down Expand Up @@ -377,8 +397,8 @@ def main(self, b_aux, addr, w, I):
## math ##
DI = Imodel - I
tmp = w * DI
err[:] = tmp * DI

err[:] = tmp * DI
aux[:] = (aux.reshape(ish[0] // nmodes, nmodes, ish[1], ish[2]) * tmp[:, np.newaxis, :, :]).reshape(ish)
return

Expand Down
Loading

0 comments on commit 556d111

Please sign in to comment.