Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP GPU implementation of FFT-based smoothing in ML #504

Merged
merged 16 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion ptypy/accelerate/base/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def norm2(input):
return np.sum(abs2(input))


def gaussian_kernel_2d(shape, sigmau, sigmav):
"""
2D Gaussian kernel using the last 2 dimension of given shape
Requires sigma for both dimensions (sigmau and sigmav)
"""
u, v = np.fft.fftfreq(shape[-2]), np.fft.fftfreq(shape[-1])
uu, vv = np.meshgrid(u, v, sparse=True, indexing='ij')
kernel = np.exp(-2* ( (np.pi*sigmau)**2 * uu**2 + (np.pi*sigmav)**2 * vv**2 ) )
return kernel

def complex_gaussian_filter(input, mfs):
'''
takes 2D and 3D arrays. Complex input, complex output. mfs has len 0<x<=2
Expand All @@ -70,6 +80,44 @@ def complex_gaussian_filter(input, mfs):
input.dtype)


def complex_gaussian_filter_fft(input, mfs):
'''
takes 2D and 3D arrays. Complex input, complex output. mfs has len 0<x<=2
'''
if len(mfs) > 2:
raise NotImplementedError("Only batches of 2D arrays allowed!")
elif len(mfs) == 1:
mfs = np.array([mfs,mfs])
else:
mfs = np.array(mfs)

k = gaussian_kernel_2d(input.shape, mfs[0], mfs[1]).astype(input.dtype)
return fft_filter(input, k)


def fft_filter(input, kernel, prefactor=None, postfactor=None, forward=True):
"""
Compute
output = ifft(fft( prefactor * input ) * kernel) * postfactor
"""
# Make a copy (and cast if necessary)
x = np.array(input)


if prefactor is not None:
x *= prefactor

if forward:
x = np.fft.ifftn(np.fft.fftn(x, norm="ortho") * kernel, norm="ortho")
else:
x = np.fft.fftn(np.fft.ifftn(x, norm="ortho") * kernel, norm="ortho")

if postfactor is not None:
x *= postfactor

return x


def mass_center(A):
'''
Input will always be real, and 2d or 3d, single precision here
Expand All @@ -81,7 +129,7 @@ def interpolated_shift(c, shift, do_linear=False):
'''
complex bicubic interpolated shift.
complex output. This shift should be applied to 2D arrays. shift should have len=c.ndims

'''
if not do_linear:
return ndi.shift(np.real(c), shift, order=3, prefilter=True) + 1j * ndi.shift(
Expand Down
25 changes: 20 additions & 5 deletions ptypy/accelerate/base/engines/ML_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@
from ptypy.engines import register
from ptypy.accelerate.base.kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel
from ptypy.accelerate.base import address_manglers
from ptypy.accelerate.base.array_utils import complex_gaussian_filter, complex_gaussian_filter_fft


__all__ = ['ML_serial']

@register()
class ML_serial(ML):

"""
Defaults:

[smooth_gradient_method]
default = convolution
type = str
help = Method to be used for smoothing the gradient, choose between ```convolution``` or ```fft```.
"""

def __init__(self, ptycho_parent, pars=None):
"""
Maximum likelihood reconstruction engine.
Expand Down Expand Up @@ -143,7 +153,12 @@ def engine_prepare(self):
self.ML_model.prepare()

def _get_smooth_gradient(self, data, sigma):
return self.smooth_gradient(data)
if self.p.smooth_gradient_method == "convolution":
return complex_gaussian_filter(data, sigma)
elif self.p.smooth_gradient_method == "fft":
return complex_gaussian_filter_fft(data, sigma)
else:
raise NotImplementedError("smooth_gradient_method should be ```convolution``` or ```fft```.")

def _replace_ob_grad(self):
new_ob_grad = self.ob_grad_new
Expand Down Expand Up @@ -272,7 +287,7 @@ def engine_iterate(self, num=1):
return error_dct # np.array([[self.ML_model.LL[0]] * 3])

def position_update(self):
"""
"""
Position refinement
"""
if not self.do_position_refinement:
Expand All @@ -283,7 +298,7 @@ def position_update(self):
# Update positions
if do_update_pos:
"""
Iterates through all positions and refines them by a given algorithm.
Iterates through all positions and refines them by a given algorithm.
"""
log(4, "----------- START POS REF -------------")
for dID in self.di.S.keys():
Expand All @@ -308,7 +323,7 @@ def position_update(self):
max_oby = ob.shape[-2] - aux.shape[-2] - 1
max_obx = ob.shape[-1] - aux.shape[-1] - 1

# We need to re-calculate the current error
# We need to re-calculate the current error
PCK.build_aux(aux, addr, ob, pr)
aux[:] = FW(aux)
PCK.log_likelihood_ml(aux, addr, I, w, err_phot)
Expand Down Expand Up @@ -338,7 +353,7 @@ def engine_finalize(self):
for i,view in enumerate(d.views):
for j,(pname, pod) in enumerate(view.pods.items()):
delta = (prep.addr[i][j][1][1:] - prep.original_addr[i][j][1][1:]) * res
pod.ob_view.coord += delta
pod.ob_view.coord += delta
pod.ob_view.storage.update_views(pod.ob_view)
self.ptycho.record_positions = True

Expand Down
3 changes: 2 additions & 1 deletion ptypy/accelerate/cuda_common/batched_multiply.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ extern "C" __global__ void batched_multiply(const complex<IN_TYPE>* input,
int gy = threadIdx.y + blockIdx.y * blockDim.y;
int gz = threadIdx.z + blockIdx.z * blockDim.z;

if (gx > rows || gy > columns || gz > nBatches)
if (gx > rows - 1 || gy > columns - 1 || gz > nBatches)
return;

auto val = input[gz * rows * columns + gy * rows + gx];

if (MPY_DO_FILT) // set at compile-time
{
val *= filter[gy * rows + gx];
Expand Down
88 changes: 88 additions & 0 deletions ptypy/accelerate/cuda_cupy/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from ptypy.accelerate.cuda_common.utils import map2ctype
from ptypy.accelerate.base.array_utils import gaussian_kernel_2d
from ptypy.utils.math_utils import gaussian
from . import load_kernel

Expand Down Expand Up @@ -62,6 +63,37 @@ def dot(self, A: cp.ndarray, B: cp.ndarray, out: cp.ndarray = None) -> cp.ndarra
def norm2(self, A, out=None):
return self.dot(A, A, out)

class BatchedMultiplyKernel:
def __init__(self, array, queue=None, math_type=np.complex64):
self.queue = queue
self.array_shape = array.shape[-2:]
self.batches = int(np.prod(array.shape[0:array.ndim-2]) if array.ndim > 2 else 1)
self.batched_multiply_cuda = load_kernel("batched_multiply", {
'MPY_DO_SCALE': 'true',
'MPY_DO_FILT': 'true',
'IN_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'OUT_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'MATH_TYPE': 'float' if math_type==np.complex64 else 'double'
})
self.block = (32,32,1)
self.grid = (
int((self.array_shape[0] + 31) // 32),
int((self.array_shape[1] + 31) // 32),
int(self.batches)
)

def multiply(self, x,y, scale=1.):
assert x.dtype == y.dtype, "Input arrays must be of same data type"
assert x.shape[-2:] == y.shape[-2:], "Input arrays must be of the same size in last 2 dims"
if self.queue is not None:
self.queue.use()
self.batched_multiply_cuda(self.grid,
self.block,
args=(x,x,y,
np.float32(scale),
np.int32(self.batches),
np.int32(self.array_shape[0]),
np.int32(self.array_shape[1])))

class TransposeKernel:

Expand Down Expand Up @@ -322,6 +354,62 @@ def delxb(self, input, out, axis=-1):
out, lower_dim, higher_dim, np.int32(input.shape[axis])))


class FFTGaussianSmoothingKernel:
def __init__(self, queue=None, kernel_type='float'):
if kernel_type not in ['float', 'double']:
raise ValueError('Invalid data type for kernel')
self.kernel_type = kernel_type
self.dtype = np.complex64
self.stype = "complex<float>"
self.queue = queue
self.sigma = None

from .kernels import FFTFilterKernel

# Create general FFT filter object
self.fft_filter = FFTFilterKernel(queue_thread=queue)

def allocate(self, shape, sigma=1.):

# Create kernel
self.sigma = sigma
kernel = self._compute_kernel(shape, sigma)

# Allocate filter
kernel_dev = cp.asarray(kernel)
self.fft_filter.allocate(kernel=kernel_dev)

def _compute_kernel(self, shape, sigma):
# Create kernel
sigma = np.array(sigma)
if sigma.size == 1:
sigma = np.array([sigma, sigma])
if sigma.size > 2:
raise NotImplementedError("Only batches of 2D arrays allowed!")
self.sigma = sigma

kernel = gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype)

# Extend kernel if needed
if len(shape) == 3:
kernel = np.tile(kernel, (shape[0],1,1))

return kernel

def filter(self, data, sigma=None):
"""
Apply filter in-place

If sigma is not None: reallocate a new fft filter first.
"""
if self.sigma is None:
self.allocate(shape=data.shape, sigma=sigma)
else:
self.fft_filter.set_kernel(self._compute_kernel(data.shape, sigma))

self.fft_filter.apply_filter(data)


class GaussianSmoothingKernel:
def __init__(self, queue=None, num_stdevs=4, kernel_type='float'):
if kernel_type not in ['float', 'double']:
Expand Down
21 changes: 15 additions & 6 deletions ptypy/accelerate/cuda_cupy/engines/ML_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .. import get_context, log_device_memory_stats
from ..kernels import PropagationKernel, RealSupportKernel, FourierSupportKernel
from ..kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel
from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, TransposeKernel
from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, FFTGaussianSmoothingKernel, TransposeKernel
from ..mem_utils import GpuDataManager

#from ..mem_utils import GpuDataManager
Expand Down Expand Up @@ -79,8 +79,12 @@ def engine_initialize(self):
self.qu_htod = cp.cuda.Stream()
self.qu_dtoh = cp.cuda.Stream()

self.GSK = GaussianSmoothingKernel(queue=self.queue)
self.GSK.tmp = None
if self.p.smooth_gradient_method == "convolution":
self.GSK = GaussianSmoothingKernel(queue=self.queue)
self.GSK.tmp = None

if self.p.smooth_gradient_method == "fft":
self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue)

# Real/Fourier Support Kernel
self.RSK = {}
Expand Down Expand Up @@ -260,9 +264,14 @@ def _set_pr_ob_ref_for_data(self, dev='gpu', container=None, sync_copy=False):
dev=dev, container=container, sync_copy=sync_copy)

def _get_smooth_gradient(self, data, sigma):
if self.GSK.tmp is None:
self.GSK.tmp = cp.empty(data.shape, dtype=np.complex64)
self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp)
if self.p.smooth_gradient_method == "convolution":
if self.GSK.tmp is None:
self.GSK.tmp = cp.empty(data.shape, dtype=np.complex64)
self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp)
elif self.p.smooth_gradient_method == "fft":
self.FGSK.filter(data, sigma)
else:
raise NotImplementedError("smooth_gradient_method should be ```convolution``` or ```fft```.")
return data

def _replace_ob_grad(self):
Expand Down
28 changes: 28 additions & 0 deletions ptypy/accelerate/cuda_cupy/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,34 @@ def apply_real_support(self, x):
x *= self.support


class FFTFilterKernel:
def __init__(self, queue_thread=None, fft='cuda'):
# Current implementation recompiles every time there is a change in input shape.
self.queue = queue_thread
self._fft_type = fft
self._fft1 = None
self._fft2 = None
def allocate(self, kernel, prefactor=None, postfactor=None, forward=True):
FFT = choose_fft(kernel.shape[-2:], fft_type=self._fft_type)

self._fft1 = FFT(kernel, self.queue,
pre_fft=prefactor,
post_fft=kernel,
symmetric=True,
forward=forward)
self._fft2 = FFT(kernel, self.queue,
post_fft=postfactor,
symmetric=True,
forward=not forward)

def set_kernel(self, kernel):
self._fft1.post_fft.set(kernel)

def apply_filter(self, x):
self._fft1.ft(x, x)
self._fft2.ift(x, x)


class FourierUpdateKernel(ab.FourierUpdateKernel):

def __init__(self, aux, nmodes=1, queue_thread=None, accumulate_type='float', math_type='float'):
Expand Down
Loading