From 7eb128d31a89562005291cf1a3909d03e462a27b Mon Sep 17 00:00:00 2001 From: Pierre Thibault Date: Thu, 24 Aug 2023 12:04:56 +0200 Subject: [PATCH 01/16] WIP kernel implementation of general FFT filter --- ptypy/accelerate/base/array_utils.py | 25 ++++++++++++++++++++++++- ptypy/accelerate/cuda_pycuda/kernels.py | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/ptypy/accelerate/base/array_utils.py b/ptypy/accelerate/base/array_utils.py index b94a056f3..216b54c63 100644 --- a/ptypy/accelerate/base/array_utils.py +++ b/ptypy/accelerate/base/array_utils.py @@ -70,6 +70,29 @@ def complex_gaussian_filter(input, mfs): input.dtype) +def fft_filter(input, kernel, prefactor=None, postfactor=None, forward=True): + """ + Compute + output = ifft(fft( prefactor * input ) * kernel) * postfactor + """ + # Make a copy (and cast if necessary) + x = np.array(input) + + + if prefactor is not None: + x *= prefactor + + if forward: + x = np.fft.ifftn(np.fft.fftn(x) * kernel) + else: + x = np.fft.fftn(np.fft.ifftn(x) * kernel) + + if postfactor is not None: + x *= postfactor + + return x + + def mass_center(A): ''' Input will always be real, and 2d or 3d, single precision here @@ -81,7 +104,7 @@ def interpolated_shift(c, shift, do_linear=False): ''' complex bicubic interpolated shift. complex output. This shift should be applied to 2D arrays. shift should have len=c.ndims - + ''' if not do_linear: return ndi.shift(np.real(c), shift, order=3, prefilter=True) + 1j * ndi.shift( diff --git a/ptypy/accelerate/cuda_pycuda/kernels.py b/ptypy/accelerate/cuda_pycuda/kernels.py index 8f7378715..166836a9b 100644 --- a/ptypy/accelerate/cuda_pycuda/kernels.py +++ b/ptypy/accelerate/cuda_pycuda/kernels.py @@ -160,6 +160,31 @@ def allocate(self): def apply_real_support(self, x): x *= self.support +class FFTFilterKernel: + def __init__(self, queue_thread=None, fft='reikna'): + # Current implementation recompiles every time there is a change in input shape. + self.queue = queue_thread + self._fft_type = fft + self.shape = None + self._fft1 = None + self._fft2 = None + def allocate(self, kernel, prefactor=None, postfactor=None, forward=True): + FFT = choose_fft(self._fft_type, kernel.shape[-2:]) + + self._fft1 = FFT(kernel, self.queue, + pre_fft=prefactor, + post_fft=kernel, + symmetric=True, + forward=forward) + self._fft2 = FFT(kernel, self.queue, + post_fft=postfactor, + symmetric=True, + forward=not forward) + def apply_filter(self, x): + self._fft1.ft(x,x) + self._fft2.ift(x,x) + + class FourierUpdateKernel(ab.FourierUpdateKernel): def __init__(self, aux, nmodes=1, queue_thread=None, accumulate_type='float', math_type='float'): From 9c72469764c04ae77088cb4b1a7f419b514fb685 Mon Sep 17 00:00:00 2001 From: Pierre Thibault Date: Thu, 24 Aug 2023 12:05:21 +0200 Subject: [PATCH 02/16] Tests for FFT filter --- .../base_tests/array_utils_test.py | 50 +++++++++++++++++++ .../cuda_pycuda_tests/array_utils_test.py | 26 ++++++++++ 2 files changed, 76 insertions(+) diff --git a/test/accelerate_tests/base_tests/array_utils_test.py b/test/accelerate_tests/base_tests/array_utils_test.py index 06646b813..fed73e17d 100644 --- a/test/accelerate_tests/base_tests/array_utils_test.py +++ b/test/accelerate_tests/base_tests/array_utils_test.py @@ -291,6 +291,56 @@ def test_crop_pad_3(self): dtype=np.complex64) np.testing.assert_array_almost_equal(A, exp_A) + def test_fft_filter(self): + data = np.zeros((256, 512), dtype=COMPLEX_TYPE) + data[64:-64,128:-128] = 1 + 1.j + + prefactor = np.zeros_like(data) + prefactor[:,256:] = 1. + postfactor = np.zeros_like(data) + postfactor[128:,:] = 1. + + rk = np.zeros_like(data) + rk[:30, :30] = 1. + kernel = np.fft.fftn(rk) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + known_test_output = np.array([-0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + -0.00000000e+00 + 0.00000000e+00j, 0.00000000e+00 - 0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, -0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + -0.00000000e+00+0.00000000e+00j, -0.00000000e+00+0.00000000e+00j, + -0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 8.66422277e-14+4.86768828e-14j, + 7.23113320e-14+2.82331542e-14j, 9.00000000e+02+9.00000000e+02j, + 9.00000000e+02+9.00000000e+02j, 5.10000000e+02+5.10000000e+02j, + 1.41172830e-14+3.62223425e-14j, 2.61684238e-14-4.13866575e-14j, + 2.16691314e-14-1.95102733e-14j, -1.36536942e-13-9.94589021e-14j, + -1.42905371e-13-5.77964697e-14j, -5.00005072e-14+4.08620637e-14j, + 6.38160272e-14+7.61753583e-14j, 3.90000000e+02+3.90000000e+02j, + 9.00000000e+02+9.00000000e+02j, 9.00000000e+02+9.00000000e+02j, + 3.00000000e+01+3.00000000e+01j, 8.63255773e-14+7.08532924e-14j, + 1.80941313e-14-3.85517154e-14j, 7.84277340e-14-1.32008745e-14j, + -6.57025196e-14-1.72739350e-14j, -6.69570857e-15+6.49622898e-14j, + 6.27436466e-15+7.57162569e-14j, 2.01150157e-15+3.65538558e-14j, + 8.70000000e+01+8.70000000e+01j, -1.13686838e-13-1.70530257e-13j, + 0.00000000e+00-2.27373675e-13j, -1.84492121e-14-9.21502853e-14j, + 2.12418687e-14-8.62209232e-14j, 1.20880692e-13+3.86522371e-14j, + 1.03754734e-13+9.19851759e-14j, 5.50926123e-14+1.17150422e-13j, + -5.47869215e-14+5.87176511e-14j, -3.52652980e-14+8.44455504e-15j]) + + np.testing.assert_array_almost_equal(output.flat[::2000], known_test_output) if __name__ == '__main__': unittest.main() diff --git a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py index 912b68bfe..8e7cfbcf5 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py @@ -538,3 +538,29 @@ def test_interpolate_shift_no_shift_UNITY(self): np.testing.assert_allclose(out, isk, rtol=1e-6, atol=1e-6, err_msg="The shifting of array has not been calculated as expected") + def test_fft_filter_UNITY(self): + data = np.zeros((256, 512), dtype=np.complex64) + data[64:-64,128:-128] = 1 + 1.j + + prefactor = np.zeros_like(data) + prefactor[:,256:] = 1. + postfactor = np.zeros_like(data) + postfactor[128:,:] = 1. + + rk = np.zeros_like(data) + rk[:30, :30] = 1. + kernel = np.fft.fftn(rk) + + data_dev = gpuarray.to_gpu(data) + kernel_dev = gpuarray.to_gpu(kernel) + prefactor_dev = gpuarray.to_gpu(prefactor) + postfactor_dev = gpuarray.to_gpu(postfactor) + + FF = gau.FFTFilterKernel() + FF.allocate(kernel=kernel, prefactor=prefactor, postfactor=prefactor) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5) + From 12e293cf2d2830c4998e65e74d7ff8e1f7321a4a Mon Sep 17 00:00:00 2001 From: Pierre Thibault Date: Fri, 25 Aug 2023 15:48:16 +0200 Subject: [PATCH 03/16] FFT-based gaussian smoothing in ML --- ptypy/accelerate/base/array_utils.py | 4 +- ptypy/accelerate/cuda_pycuda/array_utils.py | 47 +++++++++++++++++++ .../cuda_pycuda/engines/ML_pycuda.py | 11 ++++- ptypy/accelerate/cuda_pycuda/fft.py | 4 +- ptypy/accelerate/cuda_pycuda/kernels.py | 9 ++-- .../cuda_pycuda_tests/array_utils_test.py | 27 ++++++----- 6 files changed, 81 insertions(+), 21 deletions(-) diff --git a/ptypy/accelerate/base/array_utils.py b/ptypy/accelerate/base/array_utils.py index 216b54c63..2f9954faf 100644 --- a/ptypy/accelerate/base/array_utils.py +++ b/ptypy/accelerate/base/array_utils.py @@ -83,9 +83,9 @@ def fft_filter(input, kernel, prefactor=None, postfactor=None, forward=True): x *= prefactor if forward: - x = np.fft.ifftn(np.fft.fftn(x) * kernel) + x = np.fft.ifftn(np.fft.fftn(x, norm="ortho") * kernel, norm="ortho") else: - x = np.fft.fftn(np.fft.ifftn(x) * kernel) + x = np.fft.fftn(np.fft.ifftn(x, norm="ortho") * kernel, norm="ortho") if postfactor is not None: x *= postfactor diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 72eae996f..0334edafe 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -324,6 +324,53 @@ def delxb(self, input, out, axis=-1): ) +class FFTGaussianSmoothingKernel: + def __init__(self, queue=None, kernel_type='float'): + if kernel_type not in ['float', 'double']: + raise ValueError('Invalid data type for kernel') + self.kernel_type = kernel_type + self.dtype = np.complex64 + self.stype = "complex" + self.queue = queue + self.sigma = None + + from .kernels import FFTFilterKernel + + # Create general FFT filter object + self.fft_filter = FFTFilterKernel(queue_thread=queue) + + def allocate(self, shape, sigma=1.): + + # Create kernel + self.sigma = sigma + kernel = self._compute_kernel(shape, sigma) + + # Allocate filter + kernel_dev = gpuarray.to_gpu(kernel) + self.fft_filter.allocate(kernel=kernel_dev) + + def _compute_kernel(self, shape, sigma): + # Create kernel + self.sigma = sigma + u, v = np.fft.fftfreq(shape[-2]), np.fft.fftfreq(shape[-1]) + uu, vv = np.meshgrid(u, v, sparse=True, indexing='ij') + kernel = np.exp(-2*(np.pi*sigma)**2 * (uu**2 + vv**2)) + return kernel.astype(self.dtype) + + def filter(self, data, sigma=None): + """ + Apply filter in-place + + If sigma is not None: reallocate a new fft filter first. + """ + if self.sigma is None: + self.allocate(shape=data.shape, sigma=sigma) + elif sigma is not None: + self.fft_filter.set_kernel(self._compute_kernel(data.shape, sigma)) + + self.fft_filter.apply_filter(data) + + class GaussianSmoothingKernel: def __init__(self, queue=None, num_stdevs=4, kernel_type='float'): if kernel_type not in ['float', 'double']: diff --git a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py index 339102452..afa3f26db 100644 --- a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py +++ b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py @@ -24,7 +24,7 @@ from .. import get_context, get_dev_pool from ..kernels import PropagationKernel, RealSupportKernel, FourierSupportKernel from ..kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel -from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, TransposeKernel +from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, FFTGaussianSmoothingKernel, TransposeKernel from ..mem_utils import GpuDataManager from ptypy.accelerate.base import address_manglers @@ -93,6 +93,8 @@ def engine_initialize(self): self.GSK = GaussianSmoothingKernel(queue=self.queue) self.GSK.tmp = None + self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue) + # Real/Fourier Support Kernel self.RSK = {} self.FSK = {} @@ -260,13 +262,18 @@ def _get_smooth_gradient(self, data, sigma): self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) return data + def _get_smooth_gradient_fft(self, data, sigma): + self.FGSK.filter(data, sigma) + return data + def _replace_ob_grad(self): new_ob_grad = self.ob_grad_new # Smoothing preconditioner if self.smooth_gradient: self.smooth_gradient.sigma *= (1. - self.p.smooth_gradient_decay) for name, s in new_ob_grad.storages.items(): - s.gpu = self._get_smooth_gradient(s.gpu, self.smooth_gradient.sigma) + #s.gpu = self._get_smooth_gradient(s.gpu, self.smooth_gradient.sigma) + s.gpu = self._get_smooth_gradient_fft(s.gpu, self.smooth_gradient.sigma) return self._replace_grad(self.ob_grad, new_ob_grad) diff --git a/ptypy/accelerate/cuda_pycuda/fft.py b/ptypy/accelerate/cuda_pycuda/fft.py index 916ed7e54..318e46cec 100644 --- a/ptypy/accelerate/cuda_pycuda/fft.py +++ b/ptypy/accelerate/cuda_pycuda/fft.py @@ -11,7 +11,9 @@ def __init__(self, array, queue=None, post_fft=None, symmetric=True, forward=True): - + """ + array should be gpuarray already + """ self._queue = queue from pycuda import gpuarray ## reikna diff --git a/ptypy/accelerate/cuda_pycuda/kernels.py b/ptypy/accelerate/cuda_pycuda/kernels.py index 166836a9b..8fcebb6ef 100644 --- a/ptypy/accelerate/cuda_pycuda/kernels.py +++ b/ptypy/accelerate/cuda_pycuda/kernels.py @@ -165,7 +165,6 @@ def __init__(self, queue_thread=None, fft='reikna'): # Current implementation recompiles every time there is a change in input shape. self.queue = queue_thread self._fft_type = fft - self.shape = None self._fft1 = None self._fft2 = None def allocate(self, kernel, prefactor=None, postfactor=None, forward=True): @@ -180,9 +179,13 @@ def allocate(self, kernel, prefactor=None, postfactor=None, forward=True): post_fft=postfactor, symmetric=True, forward=not forward) + + def set_kernel(self, kernel): + self._fft1.post_fft.set(kernel) + def apply_filter(self, x): - self._fft1.ft(x,x) - self._fft2.ift(x,x) + self._fft1.ft(x, x) + self._fft2.ift(x, x) class FourierUpdateKernel(ab.FourierUpdateKernel): diff --git a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py index 8e7cfbcf5..5d8dcf974 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py @@ -11,6 +11,7 @@ if have_pycuda(): from pycuda import gpuarray import ptypy.accelerate.cuda_pycuda.array_utils as gau + from ptypy.accelerate.cuda_pycuda.kernels import FFTFilterKernel class ArrayUtilsTest(PyCudaTest): @@ -539,28 +540,28 @@ def test_interpolate_shift_no_shift_UNITY(self): err_msg="The shifting of array has not been calculated as expected") def test_fft_filter_UNITY(self): - data = np.zeros((256, 512), dtype=np.complex64) - data[64:-64,128:-128] = 1 + 1.j + sh = (16, 35) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[0, 0] = 1. + kernel[0, 1] = 0.5 prefactor = np.zeros_like(data) - prefactor[:,256:] = 1. + prefactor[:,2:] = 1. postfactor = np.zeros_like(data) - postfactor[128:,:] = 1. - - rk = np.zeros_like(data) - rk[:30, :30] = 1. - kernel = np.fft.fftn(rk) + postfactor[2:,:] = 1. data_dev = gpuarray.to_gpu(data) kernel_dev = gpuarray.to_gpu(kernel) - prefactor_dev = gpuarray.to_gpu(prefactor) - postfactor_dev = gpuarray.to_gpu(postfactor) + pre_dev = gpuarray.to_gpu(prefactor) + post_dev = gpuarray.to_gpu(postfactor) - FF = gau.FFTFilterKernel() - FF.allocate(kernel=kernel, prefactor=prefactor, postfactor=prefactor) + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) FF.apply_filter(data_dev) output = au.fft_filter(data, kernel, prefactor, postfactor) - np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5) + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) From 7466026306c85a3f07fc1ed26f4a0aff0b305570 Mon Sep 17 00:00:00 2001 From: Timothy Poon Date: Tue, 5 Dec 2023 18:40:33 +0000 Subject: [PATCH 04/16] Add dummy _get_smooth_gradient_fft in ML_serial --- ptypy/accelerate/base/engines/ML_serial.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ptypy/accelerate/base/engines/ML_serial.py b/ptypy/accelerate/base/engines/ML_serial.py index 38f63f385..f03882636 100644 --- a/ptypy/accelerate/base/engines/ML_serial.py +++ b/ptypy/accelerate/base/engines/ML_serial.py @@ -145,6 +145,9 @@ def engine_prepare(self): def _get_smooth_gradient(self, data, sigma): return self.smooth_gradient(data) + def _get_smooth_gradient_fft(self, data, sigma): + return self.smooth_gradient(data) + def _replace_ob_grad(self): new_ob_grad = self.ob_grad_new # Smoothing preconditioner @@ -231,7 +234,8 @@ def engine_iterate(self, num=1): # Smoothing preconditioner if self.smooth_gradient: for name, s in self.ob_h.storages.items(): - s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) + # s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) + s.data[:] -= self._get_smooth_gradient_fft(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) else: self.ob_h -= self.ob_grad @@ -272,7 +276,7 @@ def engine_iterate(self, num=1): return error_dct # np.array([[self.ML_model.LL[0]] * 3]) def position_update(self): - """ + """ Position refinement """ if not self.do_position_refinement: @@ -283,7 +287,7 @@ def position_update(self): # Update positions if do_update_pos: """ - Iterates through all positions and refines them by a given algorithm. + Iterates through all positions and refines them by a given algorithm. """ log(4, "----------- START POS REF -------------") for dID in self.di.S.keys(): @@ -308,7 +312,7 @@ def position_update(self): max_oby = ob.shape[-2] - aux.shape[-2] - 1 max_obx = ob.shape[-1] - aux.shape[-1] - 1 - # We need to re-calculate the current error + # We need to re-calculate the current error PCK.build_aux(aux, addr, ob, pr) aux[:] = FW(aux) PCK.log_likelihood_ml(aux, addr, I, w, err_phot) @@ -338,7 +342,7 @@ def engine_finalize(self): for i,view in enumerate(d.views): for j,(pname, pod) in enumerate(view.pods.items()): delta = (prep.addr[i][j][1][1:] - prep.original_addr[i][j][1][1:]) * res - pod.ob_view.coord += delta + pod.ob_view.coord += delta pod.ob_view.storage.update_views(pod.ob_view) self.ptycho.record_positions = True From d33032c8dab84b982efe189c76666440f07df1c8 Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Thu, 29 Feb 2024 13:08:25 +0000 Subject: [PATCH 05/16] added batched fft_filter tests --- .../base_tests/array_utils_test.py | 52 +++++++++++++++++++ .../cuda_pycuda_tests/array_utils_test.py | 25 +++++++++ 2 files changed, 77 insertions(+) diff --git a/test/accelerate_tests/base_tests/array_utils_test.py b/test/accelerate_tests/base_tests/array_utils_test.py index fed73e17d..0b8ca9b10 100644 --- a/test/accelerate_tests/base_tests/array_utils_test.py +++ b/test/accelerate_tests/base_tests/array_utils_test.py @@ -342,5 +342,57 @@ def test_fft_filter(self): np.testing.assert_array_almost_equal(output.flat[::2000], known_test_output) + def test_fft_filter_batched(self): + data = np.zeros((2,256, 512), dtype=COMPLEX_TYPE) + data[:,64:-64,128:-128] = 1 + 1.j + + prefactor = np.zeros_like(data) + prefactor[:,:,256:] = 1. + postfactor = np.zeros_like(data) + postfactor[:,128:,:] = 1. + + rk = np.zeros_like(data)[0] + rk[:30, :30] = 1. + kernel = np.fft.fftn(rk) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + known_test_output = np.array([-0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + -0.00000000e+00 + 0.00000000e+00j, 0.00000000e+00 - 0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, -0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + -0.00000000e+00+0.00000000e+00j, -0.00000000e+00+0.00000000e+00j, + -0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00+0.00000000e+00j, + 0.00000000e+00+0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 0.00000000e+00-0.00000000e+00j, + 0.00000000e+00-0.00000000e+00j, 8.66422277e-14+4.86768828e-14j, + 7.23113320e-14+2.82331542e-14j, 9.00000000e+02+9.00000000e+02j, + 9.00000000e+02+9.00000000e+02j, 5.10000000e+02+5.10000000e+02j, + 1.41172830e-14+3.62223425e-14j, 2.61684238e-14-4.13866575e-14j, + 2.16691314e-14-1.95102733e-14j, -1.36536942e-13-9.94589021e-14j, + -1.42905371e-13-5.77964697e-14j, -5.00005072e-14+4.08620637e-14j, + 6.38160272e-14+7.61753583e-14j, 3.90000000e+02+3.90000000e+02j, + 9.00000000e+02+9.00000000e+02j, 9.00000000e+02+9.00000000e+02j, + 3.00000000e+01+3.00000000e+01j, 8.63255773e-14+7.08532924e-14j, + 1.80941313e-14-3.85517154e-14j, 7.84277340e-14-1.32008745e-14j, + -6.57025196e-14-1.72739350e-14j, -6.69570857e-15+6.49622898e-14j, + 6.27436466e-15+7.57162569e-14j, 2.01150157e-15+3.65538558e-14j, + 8.70000000e+01+8.70000000e+01j, -1.13686838e-13-1.70530257e-13j, + 0.00000000e+00-2.27373675e-13j, -1.84492121e-14-9.21502853e-14j, + 2.12418687e-14-8.62209232e-14j, 1.20880692e-13+3.86522371e-14j, + 1.03754734e-13+9.19851759e-14j, 5.50926123e-14+1.17150422e-13j, + -5.47869215e-14+5.87176511e-14j, -3.52652980e-14+8.44455504e-15j]) + + np.testing.assert_array_almost_equal(output[1].flat[::2000], known_test_output) + + if __name__ == '__main__': unittest.main() diff --git a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py index 5d8dcf974..cab1147d4 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py @@ -565,3 +565,28 @@ def test_fft_filter_UNITY(self): np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + def test_fft_filter_batched_UNITY(self): + sh = (2,16, 35) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[:,0, 0] = 1. + kernel[:,0, 1] = 0.5 + + prefactor = np.zeros_like(data) + prefactor[:,:,2:] = 1. + postfactor = np.zeros_like(data) + postfactor[:,2:,:] = 1. + + data_dev = gpuarray.to_gpu(data) + kernel_dev = gpuarray.to_gpu(kernel) + pre_dev = gpuarray.to_gpu(prefactor) + post_dev = gpuarray.to_gpu(postfactor) + + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) \ No newline at end of file From 190e5b23547df0f2cb865e9ea9478db53e10cc7d Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Thu, 29 Feb 2024 15:01:06 +0000 Subject: [PATCH 06/16] implented numpy based gaussian fft filter and added more tests --- ptypy/accelerate/base/array_utils.py | 25 ++++++ ptypy/accelerate/cuda_pycuda/array_utils.py | 18 +++-- .../base_tests/array_utils_test.py | 46 +++++++++++ .../cuda_pycuda_tests/array_utils_test.py | 76 ++++++++++++++++++- 4 files changed, 159 insertions(+), 6 deletions(-) diff --git a/ptypy/accelerate/base/array_utils.py b/ptypy/accelerate/base/array_utils.py index 2f9954faf..bec4cc8e4 100644 --- a/ptypy/accelerate/base/array_utils.py +++ b/ptypy/accelerate/base/array_utils.py @@ -56,6 +56,16 @@ def norm2(input): return np.sum(abs2(input)) +def gaussian_kernel_2d(shape, sigmau, sigmav): + """ + 2D Gaussian kernel using the last 2 dimension of given shape + Requires sigma for both dimensions (sigmau and sigmav) + """ + u, v = np.fft.fftfreq(shape[-2]), np.fft.fftfreq(shape[-1]) + uu, vv = np.meshgrid(u, v, sparse=True, indexing='ij') + kernel = np.exp(-2* ( (np.pi*sigmau)**2 * uu**2 + (np.pi*sigmav)**2 * vv**2 ) ) + return kernel + def complex_gaussian_filter(input, mfs): ''' takes 2D and 3D arrays. Complex input, complex output. mfs has len 0 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") + elif len(mfs) == 1: + mfs = np.array([mfs,mfs]) + else: + mfs = np.array(mfs) + + k = gaussian_kernel_2d(input.shape, mfs[0], mfs[1]).astype(input.dtype) + return fft_filter(input, k) + + def fft_filter(input, kernel, prefactor=None, postfactor=None, forward=True): """ Compute diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 0334edafe..8c4c8efd0 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -1,4 +1,5 @@ from ptypy.accelerate.cuda_common.utils import map2ctype +from ptypy.accelerate.base.array_utils import gaussian_kernel_2d from . import load_kernel from pycuda import gpuarray @@ -345,6 +346,10 @@ def allocate(self, shape, sigma=1.): self.sigma = sigma kernel = self._compute_kernel(shape, sigma) + # Extend kernel if needed + if len(shape) == 3: + kernel = np.tile(kernel, (shape[0],1,1)) + # Allocate filter kernel_dev = gpuarray.to_gpu(kernel) self.fft_filter.allocate(kernel=kernel_dev) @@ -352,10 +357,13 @@ def allocate(self, shape, sigma=1.): def _compute_kernel(self, shape, sigma): # Create kernel self.sigma = sigma - u, v = np.fft.fftfreq(shape[-2]), np.fft.fftfreq(shape[-1]) - uu, vv = np.meshgrid(u, v, sparse=True, indexing='ij') - kernel = np.exp(-2*(np.pi*sigma)**2 * (uu**2 + vv**2)) - return kernel.astype(self.dtype) + if len(sigma) == 1: + sigma = np.array([sigma, sigma]) + elif len(sigma) == 2: + sigma = np.array(sigma) + else: + raise NotImplementedError("Only batches of 2D arrays allowed!") + return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) def filter(self, data, sigma=None): """ @@ -365,7 +373,7 @@ def filter(self, data, sigma=None): """ if self.sigma is None: self.allocate(shape=data.shape, sigma=sigma) - elif sigma is not None: + else: self.fft_filter.set_kernel(self._compute_kernel(data.shape, sigma)) self.fft_filter.apply_filter(data) diff --git a/test/accelerate_tests/base_tests/array_utils_test.py b/test/accelerate_tests/base_tests/array_utils_test.py index 0b8ca9b10..364352a5c 100644 --- a/test/accelerate_tests/base_tests/array_utils_test.py +++ b/test/accelerate_tests/base_tests/array_utils_test.py @@ -394,5 +394,51 @@ def test_fft_filter_batched(self): np.testing.assert_array_almost_equal(output[1].flat[::2000], known_test_output) + def test_complex_gaussian_filter_fft(self): + data = np.zeros((8, 8), dtype=COMPLEX_TYPE) + data[3:5, 3:5] = 2.0 + 2.0j + mfs = 3.0, 4.0 + + out = au.complex_gaussian_filter_fft(data, mfs) + expected_out = np.array([0.11033735 + 0.11033735j, 0.11888228 + 0.11888228j, 0.13116673 + 0.13116673j + , 0.13999543 + 0.13999543j, 0.13999543 + 0.13999543j, 0.13116673 + 0.13116673j + , 0.11888228 + 0.11888228j, 0.11033735 + 0.11033735j], dtype=COMPLEX_TYPE) + np.testing.assert_array_almost_equal(np.diagonal(out), expected_out, decimal=5) + + def test_complex_gaussian_filter_fft_batched(self): + batch_number = 2 + A = 5 + B = 5 + + data = np.zeros((batch_number, A, B), dtype=COMPLEX_TYPE) + data[:, 2:3, 2:3] = 2.0 + 2.0j + mfs = 3.0, 4.0 + out = au.complex_gaussian_filter_fft(data, mfs) + + expected_out = np.array([[[0.07988770 + 0.0798877j, 0.07989411 + 0.07989411j, 0.07989471 + 0.07989471j, + 0.07989411 + 0.07989411j, 0.07988770 + 0.0798877j], + [0.08003781 + 0.08003781j, 0.08004424 + 0.08004424j, 0.08004485 + 0.08004485j, + 0.08004424 + 0.08004424j, 0.08003781 + 0.08003781j], + [0.08012911 + 0.08012911j, 0.08013555 + 0.08013555j, 0.08013615 + 0.08013615j, + 0.08013555 + 0.08013555j, 0.08012911 + 0.08012911j], + [0.08003781 + 0.08003781j, 0.08004424 + 0.08004424j, 0.08004485 + 0.08004485j, + 0.08004424 + 0.08004424j, 0.08003781 + 0.08003781j], + [0.07988770 + 0.0798877j, 0.07989411 + 0.07989411j, 0.07989471 + 0.07989471j, + 0.07989411 + 0.07989411j, 0.07988770 + 0.0798877j]], + + [[0.07988770 + 0.0798877j, 0.07989411 + 0.07989411j, 0.07989471 + 0.07989471j, + 0.07989411 + 0.07989411j, 0.07988770 + 0.0798877j], + [0.08003781 + 0.08003781j, 0.08004424 + 0.08004424j, 0.08004485 + 0.08004485j, + 0.08004424 + 0.08004424j, 0.08003781 + 0.08003781j], + [0.08012911 + 0.08012911j, 0.08013555 + 0.08013555j, 0.08013615 + 0.08013615j, + 0.08013555 + 0.08013555j, 0.08012911 + 0.08012911j], + [0.08003781 + 0.08003781j, 0.08004424 + 0.08004424j, 0.08004485 + 0.08004485j, + 0.08004424 + 0.08004424j, 0.08003781 + 0.08003781j], + [0.07988770 + 0.0798877j, 0.07989411 + 0.07989411j, 0.07989471 + 0.07989471j, + 0.07989411 + 0.07989411j, 0.07988770 + 0.0798877j]]], dtype=COMPLEX_TYPE) + + np.testing.assert_array_almost_equal(out, expected_out, decimal=5) + + if __name__ == '__main__': unittest.main() diff --git a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py index cab1147d4..e46230953 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py @@ -588,5 +588,79 @@ def test_fft_filter_batched_UNITY(self): FF.apply_filter(data_dev) output = au.fft_filter(data, kernel, prefactor, postfactor) + print(data_dev.get()) - np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) \ No newline at end of file + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + + def test_complex_gaussian_filter_fft_little_blurring_UNITY(self): + # Arrange + data = np.zeros((21, 21), dtype=np.complex64) + data[10:12, 10:12] = 2.0+2.0j + mfs = 0.2,0.2 + data_dev = gpuarray.to_gpu(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_more_blurring_UNITY(self): + # Arrange + data = np.zeros((8, 8), dtype=np.complex64) + data[3:5, 3:5] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = gpuarray.to_gpu(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_nonsquare_UNITY(self): + # Arrange + data = np.zeros((32, 16), dtype=np.complex64) + data[3:4, 11:12] = 2.0+2.0j + data[3:5, 3:5] = 2.0+2.0j + data[20:25,3:5] = 2.0+2.0j + mfs = 1.0,1.0 + data_dev = gpuarray.to_gpu(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_batched(self): + # Arrange + batch_number = 2 + A = 5 + B = 5 + data = np.zeros((batch_number, A, B), dtype=np.complex64) + data[:, 2:3, 2:3] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = gpuarray.to_gpu(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) \ No newline at end of file From bec7675eb782dc3afcc7df427211863c9350b1aa Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Thu, 29 Feb 2024 15:25:50 +0000 Subject: [PATCH 07/16] Introduced new parameter for changing method for smoothing kernel --- ptypy/accelerate/base/engines/ML_serial.py | 20 ++++++++++++---- .../cuda_pycuda/engines/ML_pycuda.py | 23 +++++++++++-------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/ptypy/accelerate/base/engines/ML_serial.py b/ptypy/accelerate/base/engines/ML_serial.py index f03882636..c468755b8 100644 --- a/ptypy/accelerate/base/engines/ML_serial.py +++ b/ptypy/accelerate/base/engines/ML_serial.py @@ -23,6 +23,7 @@ from ptypy.engines import register from ptypy.accelerate.base.kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel from ptypy.accelerate.base import address_manglers +from ptypy.accelerate.base.array_utils import complex_gaussian_filter, complex_gaussian_filter_fft __all__ = ['ML_serial'] @@ -30,6 +31,15 @@ @register() class ML_serial(ML): + """ + Defaults: + + [smooth_gradient.method] + default = convolution + type = str + help = Method to be used for smoothing the gradient, choose between ```convolution``` or ```fft```. + """ + def __init__(self, ptycho_parent, pars=None): """ Maximum likelihood reconstruction engine. @@ -143,10 +153,12 @@ def engine_prepare(self): self.ML_model.prepare() def _get_smooth_gradient(self, data, sigma): - return self.smooth_gradient(data) - - def _get_smooth_gradient_fft(self, data, sigma): - return self.smooth_gradient(data) + if self.p.smooth_gradient.method == "convolution": + return complex_gaussian_filter(data, sigma) + elif self.p.smooth_gradient_method == "fft": + return complex_gaussian_filter_fft(data, sigma) + else: + raise NotImplementedError("smooth_gradient.method can only be ```convolution``` or ```fft```.") def _replace_ob_grad(self): new_ob_grad = self.ob_grad_new diff --git a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py index afa3f26db..0f7097357 100644 --- a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py +++ b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py @@ -90,10 +90,12 @@ def engine_initialize(self): self.qu_htod = cuda.Stream() self.qu_dtoh = cuda.Stream() - self.GSK = GaussianSmoothingKernel(queue=self.queue) - self.GSK.tmp = None + if self.p.smooth_gradient.method == "convolution": + self.GSK = GaussianSmoothingKernel(queue=self.queue) + self.GSK.tmp = None - self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue) + if self.p.smooth_gradient.method == "fft": + self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue) # Real/Fourier Support Kernel self.RSK = {} @@ -257,13 +259,14 @@ def _set_pr_ob_ref_for_data(self, dev='gpu', container=None, sync_copy=False): self._set_pr_ob_ref_for_data(dev=dev, container=container, sync_copy=sync_copy) def _get_smooth_gradient(self, data, sigma): - if self.GSK.tmp is None: - self.GSK.tmp = gpuarray.empty(data.shape, dtype=np.complex64) - self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) - return data - - def _get_smooth_gradient_fft(self, data, sigma): - self.FGSK.filter(data, sigma) + if self.p.smooth_gradient.method == "convolution": + if self.GSK.tmp is None: + self.GSK.tmp = gpuarray.empty(data.shape, dtype=np.complex64) + self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + elif self.p.smooth_gradient.method == "fft": + self.FGSK.filter(data, sigma) + else: + raise NotImplementedError("smooth_gradient.method can only be ```convolution``` or ```fft```.") return data def _replace_ob_grad(self): From e2a2f0a0d43064f100744d3ca8bb7d1de5ca15dd Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Mon, 4 Mar 2024 09:52:48 +0000 Subject: [PATCH 08/16] Fixed smooth gradient method parameter --- ptypy/accelerate/base/engines/ML_serial.py | 9 ++++----- ptypy/accelerate/cuda_pycuda/array_utils.py | 11 +++++------ ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py | 13 ++++++------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/ptypy/accelerate/base/engines/ML_serial.py b/ptypy/accelerate/base/engines/ML_serial.py index c468755b8..1e953d26a 100644 --- a/ptypy/accelerate/base/engines/ML_serial.py +++ b/ptypy/accelerate/base/engines/ML_serial.py @@ -34,7 +34,7 @@ class ML_serial(ML): """ Defaults: - [smooth_gradient.method] + [smooth_gradient_method] default = convolution type = str help = Method to be used for smoothing the gradient, choose between ```convolution``` or ```fft```. @@ -153,12 +153,12 @@ def engine_prepare(self): self.ML_model.prepare() def _get_smooth_gradient(self, data, sigma): - if self.p.smooth_gradient.method == "convolution": + if self.p.smooth_gradient_method == "convolution": return complex_gaussian_filter(data, sigma) elif self.p.smooth_gradient_method == "fft": return complex_gaussian_filter_fft(data, sigma) else: - raise NotImplementedError("smooth_gradient.method can only be ```convolution``` or ```fft```.") + raise NotImplementedError("smooth_gradient_method should be ```convolution``` or ```fft```.") def _replace_ob_grad(self): new_ob_grad = self.ob_grad_new @@ -246,8 +246,7 @@ def engine_iterate(self, num=1): # Smoothing preconditioner if self.smooth_gradient: for name, s in self.ob_h.storages.items(): - # s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) - s.data[:] -= self._get_smooth_gradient_fft(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) + s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma) else: self.ob_h -= self.ob_grad diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 8c4c8efd0..2845a8b93 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -356,13 +356,12 @@ def allocate(self, shape, sigma=1.): def _compute_kernel(self, shape, sigma): # Create kernel - self.sigma = sigma - if len(sigma) == 1: + sigma = np.array(sigma) + if sigma.ndim <= 1: sigma = np.array([sigma, sigma]) - elif len(sigma) == 2: - sigma = np.array(sigma) - else: - raise NotImplementedError("Only batches of 2D arrays allowed!") + if sigma.ndim > 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") + self.sigma = sigma return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) def filter(self, data, sigma=None): diff --git a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py index 0f7097357..b712f8974 100644 --- a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py +++ b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py @@ -90,11 +90,11 @@ def engine_initialize(self): self.qu_htod = cuda.Stream() self.qu_dtoh = cuda.Stream() - if self.p.smooth_gradient.method == "convolution": + if self.p.smooth_gradient_method == "convolution": self.GSK = GaussianSmoothingKernel(queue=self.queue) self.GSK.tmp = None - if self.p.smooth_gradient.method == "fft": + if self.p.smooth_gradient_method == "fft": self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue) # Real/Fourier Support Kernel @@ -259,14 +259,14 @@ def _set_pr_ob_ref_for_data(self, dev='gpu', container=None, sync_copy=False): self._set_pr_ob_ref_for_data(dev=dev, container=container, sync_copy=sync_copy) def _get_smooth_gradient(self, data, sigma): - if self.p.smooth_gradient.method == "convolution": + if self.p.smooth_gradient_method == "convolution": if self.GSK.tmp is None: self.GSK.tmp = gpuarray.empty(data.shape, dtype=np.complex64) self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) - elif self.p.smooth_gradient.method == "fft": + elif self.p.smooth_gradient_method == "fft": self.FGSK.filter(data, sigma) else: - raise NotImplementedError("smooth_gradient.method can only be ```convolution``` or ```fft```.") + raise NotImplementedError("smooth_gradient_method should be ```convolution``` or ```fft```.") return data def _replace_ob_grad(self): @@ -275,8 +275,7 @@ def _replace_ob_grad(self): if self.smooth_gradient: self.smooth_gradient.sigma *= (1. - self.p.smooth_gradient_decay) for name, s in new_ob_grad.storages.items(): - #s.gpu = self._get_smooth_gradient(s.gpu, self.smooth_gradient.sigma) - s.gpu = self._get_smooth_gradient_fft(s.gpu, self.smooth_gradient.sigma) + s.gpu = self._get_smooth_gradient(s.gpu, self.smooth_gradient.sigma) return self._replace_grad(self.ob_grad, new_ob_grad) From cffdd848bd0bd2ccbcdf6e7b88380e2c7b65944b Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Tue, 5 Mar 2024 15:34:02 +0000 Subject: [PATCH 09/16] add example for FFT smoothing in ML pycuda --- .../moonflower_ML_pycuda_fft_smoothing.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 templates/misc/moonflower_ML_pycuda_fft_smoothing.py diff --git a/templates/misc/moonflower_ML_pycuda_fft_smoothing.py b/templates/misc/moonflower_ML_pycuda_fft_smoothing.py new file mode 100644 index 000000000..d27a8a504 --- /dev/null +++ b/templates/misc/moonflower_ML_pycuda_fft_smoothing.py @@ -0,0 +1,64 @@ +""" +This script is a test for ptychographic reconstruction in the absence +of actual data. It uses the test Scan class +`ptypy.core.data.MoonFlowerScan` to provide "data". +""" + +from ptypy.core import Ptycho +from ptypy import utils as u +import ptypy +ptypy.load_gpu_engines(arch="pycuda") + +import tempfile +tmpdir = tempfile.gettempdir() + +p = u.Param() + +# for verbose output +p.verbose_level = "info" +p.frames_per_block = 400 +# set home path +p.io = u.Param() +p.io.home = "/".join([tmpdir, "ptypy"]) +p.io.autosave = u.Param(active=False) +p.io.autoplot = u.Param(active=False) +p.io.interaction = u.Param(active=False) + +# max 200 frames (128x128px) of diffraction data +p.scans = u.Param() +p.scans.MF = u.Param() +# now you have to specify which ScanModel to use with scans.XX.name, +# just as you have to give 'name' for engines and PtyScan subclasses. +p.scans.MF.name = 'BlockFull' +p.scans.MF.data= u.Param() +p.scans.MF.data.name = 'MoonFlowerScan' +p.scans.MF.data.shape = 128 +p.scans.MF.data.num_frames = 100 +p.scans.MF.data.save = None + +p.scans.MF.illumination = u.Param(diversity=None) +p.scans.MF.coherence = u.Param(num_probe_modes=1) +# position distance in fraction of illumination frame +p.scans.MF.data.density = 0.2 +# total number of photon in empty beam +p.scans.MF.data.photons = 1e8 +# Gaussian FWHM of possible detector blurring +p.scans.MF.data.psf = 0. + +# attach a reconstrucion engine +p.engines = u.Param() +p.engines.engine00 = u.Param() +p.engines.engine00.name = 'ML_pycuda' +p.engines.engine00.numiter = 300 +p.engines.engine00.numiter_contiguous = 5 +p.engines.engine00.reg_del2 = True # Whether to use a Gaussian prior (smoothing) regularizer +p.engines.engine00.reg_del2_amplitude = 1. # Amplitude of the Gaussian prior if used +p.engines.engine00.scale_precond = True +p.engines.engine00.smooth_gradient = 50. +p.engines.engine00.smooth_gradient_decay = 1/50. +p.engines.engine00.smooth_gradient_metehod = "fft" # with method "convolution" there can be shared memory issues +p.engines.engine00.floating_intensities = False + +# prepare and run +if __name__ == "__main__": + P = Ptycho(p,level=5) From 4f534a0db1e133fa67da9ea76ba8d1420ad8be13 Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Wed, 6 Mar 2024 13:53:36 +0000 Subject: [PATCH 10/16] Added new template --- .../misc/moonflower_ML_cupy_fft_smoothing.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 templates/misc/moonflower_ML_cupy_fft_smoothing.py diff --git a/templates/misc/moonflower_ML_cupy_fft_smoothing.py b/templates/misc/moonflower_ML_cupy_fft_smoothing.py new file mode 100644 index 000000000..2c1007f32 --- /dev/null +++ b/templates/misc/moonflower_ML_cupy_fft_smoothing.py @@ -0,0 +1,64 @@ +""" +This script is a test for ptychographic reconstruction in the absence +of actual data. It uses the test Scan class +`ptypy.core.data.MoonFlowerScan` to provide "data". +""" + +from ptypy.core import Ptycho +from ptypy import utils as u +import ptypy +ptypy.load_gpu_engines(arch="cupy") + +import tempfile +tmpdir = tempfile.gettempdir() + +p = u.Param() + +# for verbose output +p.verbose_level = "info" +p.frames_per_block = 400 +# set home path +p.io = u.Param() +p.io.home = "/".join([tmpdir, "ptypy"]) +p.io.autosave = u.Param(active=False) +p.io.autoplot = u.Param(active=False) +p.io.interaction = u.Param(active=False) + +# max 200 frames (128x128px) of diffraction data +p.scans = u.Param() +p.scans.MF = u.Param() +# now you have to specify which ScanModel to use with scans.XX.name, +# just as you have to give 'name' for engines and PtyScan subclasses. +p.scans.MF.name = 'BlockFull' +p.scans.MF.data= u.Param() +p.scans.MF.data.name = 'MoonFlowerScan' +p.scans.MF.data.shape = 128 +p.scans.MF.data.num_frames = 100 +p.scans.MF.data.save = None + +p.scans.MF.illumination = u.Param(diversity=None) +p.scans.MF.coherence = u.Param(num_probe_modes=1) +# position distance in fraction of illumination frame +p.scans.MF.data.density = 0.2 +# total number of photon in empty beam +p.scans.MF.data.photons = 1e8 +# Gaussian FWHM of possible detector blurring +p.scans.MF.data.psf = 0. + +# attach a reconstrucion engine +p.engines = u.Param() +p.engines.engine00 = u.Param() +p.engines.engine00.name = 'ML_cupy' +p.engines.engine00.numiter = 300 +p.engines.engine00.numiter_contiguous = 5 +p.engines.engine00.reg_del2 = True # Whether to use a Gaussian prior (smoothing) regularizer +p.engines.engine00.reg_del2_amplitude = 1. # Amplitude of the Gaussian prior if used +p.engines.engine00.scale_precond = True +p.engines.engine00.smooth_gradient = 50. +p.engines.engine00.smooth_gradient_decay = 1/50. +p.engines.engine00.smooth_gradient_metehod = "fft" # with method "convolution" there can be shared memory issues +p.engines.engine00.floating_intensities = False + +# prepare and run +if __name__ == "__main__": + P = Ptycho(p,level=5) From 6f522a8d198bca6d111bded941e77ae3ac8b20cf Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Wed, 6 Mar 2024 15:52:42 +0000 Subject: [PATCH 11/16] Working on FFT based Gaussian filter for cupy engines --- ptypy/accelerate/cuda_cupy/array_utils.py | 54 ++++++++ ptypy/accelerate/cuda_cupy/kernels.py | 28 ++++ ptypy/accelerate/cuda_pycuda/array_utils.py | 6 +- .../cuda_cupy_tests/array_utils_test.py | 126 ++++++++++++++++++ 4 files changed, 211 insertions(+), 3 deletions(-) diff --git a/ptypy/accelerate/cuda_cupy/array_utils.py b/ptypy/accelerate/cuda_cupy/array_utils.py index 9c68d9431..9f14015a1 100644 --- a/ptypy/accelerate/cuda_cupy/array_utils.py +++ b/ptypy/accelerate/cuda_cupy/array_utils.py @@ -2,6 +2,7 @@ import numpy as np from ptypy.accelerate.cuda_common.utils import map2ctype +from ptypy.accelerate.base.array_utils import gaussian_kernel_2d from ptypy.utils.math_utils import gaussian from . import load_kernel @@ -322,6 +323,59 @@ def delxb(self, input, out, axis=-1): out, lower_dim, higher_dim, np.int32(input.shape[axis]))) +class FFTGaussianSmoothingKernel: + def __init__(self, queue=None, kernel_type='float'): + if kernel_type not in ['float', 'double']: + raise ValueError('Invalid data type for kernel') + self.kernel_type = kernel_type + self.dtype = np.complex64 + self.stype = "complex" + self.queue = queue + self.sigma = None + + from .kernels import FFTFilterKernel + + # Create general FFT filter object + self.fft_filter = FFTFilterKernel(queue_thread=queue) + + def allocate(self, shape, sigma=1.): + + # Create kernel + self.sigma = sigma + kernel = self._compute_kernel(shape, sigma) + + # Extend kernel if needed + if len(shape) == 3: + kernel = np.tile(kernel, (shape[0],1,1)) + + # Allocate filter + kernel_dev = cp.asarray(kernel) + self.fft_filter.allocate(kernel=kernel_dev) + + def _compute_kernel(self, shape, sigma): + # Create kernel + sigma = np.array(sigma) + if sigma.size == 1: + sigma = np.array([sigma, sigma]) + if sigma.size > 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") + self.sigma = sigma + return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + def filter(self, data, sigma=None): + """ + Apply filter in-place + + If sigma is not None: reallocate a new fft filter first. + """ + if self.sigma is None: + self.allocate(shape=data.shape, sigma=sigma) + else: + self.fft_filter.set_kernel(self._compute_kernel(data.shape, sigma)) + + self.fft_filter.apply_filter(data) + + class GaussianSmoothingKernel: def __init__(self, queue=None, num_stdevs=4, kernel_type='float'): if kernel_type not in ['float', 'double']: diff --git a/ptypy/accelerate/cuda_cupy/kernels.py b/ptypy/accelerate/cuda_cupy/kernels.py index 049108e71..118aadbe6 100644 --- a/ptypy/accelerate/cuda_cupy/kernels.py +++ b/ptypy/accelerate/cuda_cupy/kernels.py @@ -170,6 +170,34 @@ def apply_real_support(self, x): x *= self.support +class FFTFilterKernel: + def __init__(self, queue_thread=None, fft='cuda'): + # Current implementation recompiles every time there is a change in input shape. + self.queue = queue_thread + self._fft_type = fft + self._fft1 = None + self._fft2 = None + def allocate(self, kernel, prefactor=None, postfactor=None, forward=True): + FFT = choose_fft(kernel.shape[-2:], fft_type=self._fft_type) + + self._fft1 = FFT(kernel, self.queue, + pre_fft=prefactor, + post_fft=kernel, + symmetric=True, + forward=forward) + self._fft2 = FFT(kernel, self.queue, + post_fft=postfactor, + symmetric=True, + forward=not forward) + + def set_kernel(self, kernel): + self._fft1.post_fft.set(kernel) + + def apply_filter(self, x): + self._fft1.ft(x, x) + self._fft2.ift(x, x) + + class FourierUpdateKernel(ab.FourierUpdateKernel): def __init__(self, aux, nmodes=1, queue_thread=None, accumulate_type='float', math_type='float'): diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 2845a8b93..2879739e3 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -357,10 +357,10 @@ def allocate(self, shape, sigma=1.): def _compute_kernel(self, shape, sigma): # Create kernel sigma = np.array(sigma) - if sigma.ndim <= 1: + if sigma.size == 1: sigma = np.array([sigma, sigma]) - if sigma.ndim > 2: - raise NotImplementedError("Only batches of 2D arrays allowed!") + if sigma.size > 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") self.sigma = sigma return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) diff --git a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py index 0c018b205..a7a92d7b9 100644 --- a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py @@ -6,6 +6,7 @@ if have_cupy(): import cupy as cp import ptypy.accelerate.cuda_cupy.array_utils as gau + from ptypy.accelerate.cuda_cupy.kernels import FFTFilterKernel class ArrayUtilsTest(CupyCudaTest): @@ -534,3 +535,128 @@ def test_interpolate_shift_no_shift_UNITY(self): np.testing.assert_allclose(out, isk, rtol=1e-6, atol=1e-6, err_msg="The shifting of array has not been calculated as expected") + def test_fft_filter_UNITY(self): + sh = (32, 48) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[0, 0] = 1. + kernel[0, 1] = 0.5 + + prefactor = np.zeros_like(data) + prefactor[:,2:] = 1. + postfactor = np.zeros_like(data) + postfactor[2:,:] = 1. + + data_dev = cp.asarray(data) + kernel_dev = cp.asarray(kernel) + pre_dev = cp.asarray(prefactor) + post_dev = cp.asarray(postfactor) + + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + + def test_fft_filter_batched_UNITY(self): + sh = (2,16, 35) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[:,0, 0] = 1. + kernel[:,0, 1] = 0.5 + + prefactor = np.zeros_like(data) + prefactor[:,:,2:] = 1. + postfactor = np.zeros_like(data) + postfactor[:,2:,:] = 1. + + data_dev = cp.asarray(data) + kernel_dev = cp.asarray(kernel) + pre_dev = cp.asarray(prefactor) + post_dev = cp.asarray(postfactor) + + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + print(data_dev.get()) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + + def test_complex_gaussian_filter_fft_little_blurring_UNITY(self): + # Arrange + data = np.zeros((21, 21), dtype=np.complex64) + data[10:12, 10:12] = 2.0+2.0j + mfs = 0.2,0.2 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_more_blurring_UNITY(self): + # Arrange + data = np.zeros((8, 8), dtype=np.complex64) + data[3:5, 3:5] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_nonsquare_UNITY(self): + # Arrange + data = np.zeros((32, 16), dtype=np.complex64) + data[3:4, 11:12] = 2.0+2.0j + data[3:5, 3:5] = 2.0+2.0j + data[20:25,3:5] = 2.0+2.0j + mfs = 1.0,1.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_batched(self): + # Arrange + batch_number = 2 + A = 5 + B = 5 + data = np.zeros((batch_number, A, B), dtype=np.complex64) + data[:, 2:3, 2:3] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) \ No newline at end of file From d7398431f2aa08dde6b31f33d44a8e7f69990fcd Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Thu, 7 Mar 2024 16:40:13 +0000 Subject: [PATCH 12/16] Fixed another bug in batch multiply kernel --- .../cuda_common/batched_multiply.cu | 8 ++++- ptypy/accelerate/cuda_cupy/array_utils.py | 31 ++++++++++++++++++ ptypy/accelerate/cuda_pycuda/array_utils.py | 32 +++++++++++++++++++ .../cuda_cupy_tests/array_utils_test.py | 19 +++++++++-- .../cuda_pycuda_tests/array_utils_test.py | 17 ++++++++++ 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/ptypy/accelerate/cuda_common/batched_multiply.cu b/ptypy/accelerate/cuda_common/batched_multiply.cu index 11394f68c..7976ef765 100644 --- a/ptypy/accelerate/cuda_common/batched_multiply.cu +++ b/ptypy/accelerate/cuda_common/batched_multiply.cu @@ -22,10 +22,16 @@ extern "C" __global__ void batched_multiply(const complex* input, int gy = threadIdx.y + blockIdx.y * blockDim.y; int gz = threadIdx.z + blockIdx.z * blockDim.z; - if (gx > rows || gy > columns || gz > nBatches) + if (gx > rows - 1 || gy > columns - 1 || gz > nBatches) return; auto val = input[gz * rows * columns + gy * rows + gx]; + //printf("gx = %d, gy = %d, gz = %d, val= %.1f +i%.1f\n", gz,gy,gz, val.real(), val.imag()); + //printf("threads: x=%d y=%d z=%d\n", threadIdx.x, threadIdx.y, threadIdx.z); + //printf("blocks: x=%d y=%d z=%d\n", blockIdx.x, blockIdx.y, blockIdx.z); + //printf("grids: x=%d y=%d z=%d\n", blockDim.x, blockDim.y, blockDim.z); + + if (MPY_DO_FILT) // set at compile-time { val *= filter[gy * rows + gx]; diff --git a/ptypy/accelerate/cuda_cupy/array_utils.py b/ptypy/accelerate/cuda_cupy/array_utils.py index 9f14015a1..84c729b56 100644 --- a/ptypy/accelerate/cuda_cupy/array_utils.py +++ b/ptypy/accelerate/cuda_cupy/array_utils.py @@ -63,6 +63,37 @@ def dot(self, A: cp.ndarray, B: cp.ndarray, out: cp.ndarray = None) -> cp.ndarra def norm2(self, A, out=None): return self.dot(A, A, out) +class BatchedMultiplyKernel: + def __init__(self, array, queue=None, math_type=np.complex64): + self.queue = queue + self.array_shape = array.shape[-2:] + self.batches = int(np.prod(array.shape[0:array.ndim-2]) if array.ndim > 2 else 1) + self.batched_multiply_cuda = load_kernel("batched_multiply", { + 'MPY_DO_SCALE': 'true', + 'MPY_DO_FILT': 'true', + 'IN_TYPE': 'float' if array.dtype==np.complex64 else 'double', + 'OUT_TYPE': 'float' if array.dtype==np.complex64 else 'double', + 'MATH_TYPE': 'float' if math_type==np.complex64 else 'double' + }) + self.block = (32,32,1) + self.grid = ( + int((self.array_shape[0] + 31) // 32), + int((self.array_shape[1] + 31) // 32), + int(self.batches) + ) + + def multiply(self, x,y, scale=1.): + assert x.dtype == y.dtype, "Input arrays must be of same data type" + assert x.shape[-2:] == y.shape[-2:], "Input arrays must be of the same size in last 2 dims" + if self.queue is not None: + self.queue.use() + self.batched_multiply_cuda(self.grid, + self.block, + args=(x,x,y, + np.float32(scale), + np.int32(self.batches), + np.int32(self.array_shape[0]), + np.int32(self.array_shape[1]))) class TransposeKernel: diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 2879739e3..7fb730516 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -66,6 +66,38 @@ def dot(self, A, B, out=None): def norm2(self, A, out=None): return self.dot(A, A, out) +class BatchedMultiplyKernel: + def __init__(self, array, queue=None, math_type=np.complex64): + self.queue = queue + self.array_shape = array.shape[-2:] + self.batches = int(np.prod(array.shape[0:array.ndim-2]) if array.ndim > 2 else 1) + self.batched_multiply_cuda = load_kernel("batched_multiply", { + 'MPY_DO_SCALE': 'true', + 'MPY_DO_FILT': 'true', + 'IN_TYPE': 'float' if array.dtype==np.complex64 else 'double', + 'OUT_TYPE': 'float' if array.dtype==np.complex64 else 'double', + 'MATH_TYPE': 'float' if math_type==np.complex64 else 'double' + }) + self.block = (32,32,1) + self.grid = ( + int((self.array_shape[0] + 31) // 32), + int((self.array_shape[1] + 31) // 32), + int(self.batches) + ) + + def multiply(self, x,y, scale=1.): + assert x.dtype == y.dtype, "Input arrays must be of same data type" + assert x.shape[-2:] == y.shape[-2:], "Input arrays must be of the same size in last 2 dims" + self.batched_multiply_cuda(x,x,y, + np.float32(scale), + np.int32(self.batches), + np.int32(self.array_shape[0]), + np.int32(self.array_shape[1]), + block=self.block, + grid=self.grid, + stream=self.queue) + + class TransposeKernel: def __init__(self, queue=None): diff --git a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py index a7a92d7b9..23cfc0156 100644 --- a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py @@ -79,6 +79,23 @@ def test_dot_performance(self): AU = gau.ArrayUtilsKernel(acc_dtype=np.float64) AU.dot(A_dev, A_dev) + def test_batched_multiply(self): + # Arrange + sh = (3,14,24) + ksh = (14,24) + data = (np.random.random(sh) + 1j* np.random.random(sh)).astype(np.complex64) + kernel = (np.random.random(ksh) + 1j* np.random.random(ksh)).astype(np.complex64) + data_dev = cp.asarray(data) + kernel_dev = cp.asarray(kernel) + + # Act + BM = gau.BatchedMultiplyKernel(data_dev) + BM.multiply(data_dev, kernel_dev, scale=2.) + + # Assert + expected = data * kernel * 2. + np.testing.assert_array_almost_equal(data_dev.get(), expected) + def test_transpose_2D(self): # Arrange inp, _ = np.indices((5, 3), dtype=np.int32) @@ -584,8 +601,6 @@ def test_fft_filter_batched_UNITY(self): FF.apply_filter(data_dev) output = au.fft_filter(data, kernel, prefactor, postfactor) - print(data_dev.get()) - np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) def test_complex_gaussian_filter_fft_little_blurring_UNITY(self): diff --git a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py index e46230953..be82cc292 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py @@ -82,6 +82,23 @@ def test_dot_performance(self): AU = gau.ArrayUtilsKernel(acc_dtype=np.float64) out_dev = AU.dot(A_dev, A_dev) + def test_batched_multiply(self): + # Arrange + sh = (3,14,24) + ksh = (14,24) + data = (np.random.random(sh) + 1j* np.random.random(sh)).astype(np.complex64) + kernel = (np.random.random(ksh) + 1j* np.random.random(ksh)).astype(np.complex64) + data_dev = gpuarray.to_gpu(data) + kernel_dev = gpuarray.to_gpu(kernel) + + # Act + BM = gau.BatchedMultiplyKernel(data_dev) + BM.multiply(data_dev, kernel_dev, scale=2.) + + # Assert + expected = data * kernel * 2. + np.testing.assert_array_almost_equal(data_dev.get(), expected) + def test_transpose_2D(self): ## Arrange inp,_ = np.indices((5,3), dtype=np.int32) From 859335e63fa9486d04857a20024e05e6b226f3fe Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Fri, 8 Mar 2024 11:46:49 +0000 Subject: [PATCH 13/16] fft based Gaussian smoothing works with both Ml_pycuda and ML_cupy --- ptypy/accelerate/cuda_cupy/array_utils.py | 13 +++++++----- ptypy/accelerate/cuda_cupy/engines/ML_cupy.py | 21 +++++++++++++------ ptypy/accelerate/cuda_pycuda/array_utils.py | 13 +++++++----- .../misc/moonflower_ML_cupy_fft_smoothing.py | 2 +- .../moonflower_ML_pycuda_fft_smoothing.py | 2 +- 5 files changed, 33 insertions(+), 18 deletions(-) diff --git a/ptypy/accelerate/cuda_cupy/array_utils.py b/ptypy/accelerate/cuda_cupy/array_utils.py index 84c729b56..626aa985a 100644 --- a/ptypy/accelerate/cuda_cupy/array_utils.py +++ b/ptypy/accelerate/cuda_cupy/array_utils.py @@ -375,10 +375,6 @@ def allocate(self, shape, sigma=1.): self.sigma = sigma kernel = self._compute_kernel(shape, sigma) - # Extend kernel if needed - if len(shape) == 3: - kernel = np.tile(kernel, (shape[0],1,1)) - # Allocate filter kernel_dev = cp.asarray(kernel) self.fft_filter.allocate(kernel=kernel_dev) @@ -391,7 +387,14 @@ def _compute_kernel(self, shape, sigma): if sigma.size > 2: raise NotImplementedError("Only batches of 2D arrays allowed!") self.sigma = sigma - return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + kernel = gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + # Extend kernel if needed + if len(shape) == 3: + kernel = np.tile(kernel, (shape[0],1,1)) + + return kernel def filter(self, data, sigma=None): """ diff --git a/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py b/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py index efcc42338..cd68701a2 100644 --- a/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py +++ b/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py @@ -23,7 +23,7 @@ from .. import get_context, log_device_memory_stats from ..kernels import PropagationKernel, RealSupportKernel, FourierSupportKernel from ..kernels import GradientDescentKernel, AuxiliaryWaveKernel, PoUpdateKernel, PositionCorrectionKernel -from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, TransposeKernel +from ..array_utils import ArrayUtilsKernel, DerivativesKernel, GaussianSmoothingKernel, FFTGaussianSmoothingKernel, TransposeKernel from ..mem_utils import GpuDataManager #from ..mem_utils import GpuDataManager @@ -79,8 +79,12 @@ def engine_initialize(self): self.qu_htod = cp.cuda.Stream() self.qu_dtoh = cp.cuda.Stream() - self.GSK = GaussianSmoothingKernel(queue=self.queue) - self.GSK.tmp = None + if self.p.smooth_gradient_method == "convolution": + self.GSK = GaussianSmoothingKernel(queue=self.queue) + self.GSK.tmp = None + + if self.p.smooth_gradient_method == "fft": + self.FGSK = FFTGaussianSmoothingKernel(queue=self.queue) # Real/Fourier Support Kernel self.RSK = {} @@ -260,9 +264,14 @@ def _set_pr_ob_ref_for_data(self, dev='gpu', container=None, sync_copy=False): dev=dev, container=container, sync_copy=sync_copy) def _get_smooth_gradient(self, data, sigma): - if self.GSK.tmp is None: - self.GSK.tmp = cp.empty(data.shape, dtype=np.complex64) - self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + if self.p.smooth_gradient_method == "convolution": + if self.GSK.tmp is None: + self.GSK.tmp = cp.empty(data.shape, dtype=np.complex64) + self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + elif self.p.smooth_gradient_method == "fft": + self.FGSK.filter(data, sigma) + else: + raise NotImplementedError("smooth_gradient_method should be ```convolution``` or ```fft```.") return data def _replace_ob_grad(self): diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 7fb730516..7e592228f 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -378,10 +378,6 @@ def allocate(self, shape, sigma=1.): self.sigma = sigma kernel = self._compute_kernel(shape, sigma) - # Extend kernel if needed - if len(shape) == 3: - kernel = np.tile(kernel, (shape[0],1,1)) - # Allocate filter kernel_dev = gpuarray.to_gpu(kernel) self.fft_filter.allocate(kernel=kernel_dev) @@ -394,7 +390,14 @@ def _compute_kernel(self, shape, sigma): if sigma.size > 2: raise NotImplementedError("Only batches of 2D arrays allowed!") self.sigma = sigma - return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + kernel = gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + # Extend kernel if needed + if len(shape) == 3: + kernel = np.tile(kernel, (shape[0],1,1)) + + return kernel def filter(self, data, sigma=None): """ diff --git a/templates/misc/moonflower_ML_cupy_fft_smoothing.py b/templates/misc/moonflower_ML_cupy_fft_smoothing.py index 2c1007f32..f127d2c46 100644 --- a/templates/misc/moonflower_ML_cupy_fft_smoothing.py +++ b/templates/misc/moonflower_ML_cupy_fft_smoothing.py @@ -56,7 +56,7 @@ p.engines.engine00.scale_precond = True p.engines.engine00.smooth_gradient = 50. p.engines.engine00.smooth_gradient_decay = 1/50. -p.engines.engine00.smooth_gradient_metehod = "fft" # with method "convolution" there can be shared memory issues +p.engines.engine00.smooth_gradient_method = "fft" # with method "convolution" there can be shared memory issues p.engines.engine00.floating_intensities = False # prepare and run diff --git a/templates/misc/moonflower_ML_pycuda_fft_smoothing.py b/templates/misc/moonflower_ML_pycuda_fft_smoothing.py index d27a8a504..d7ea338a3 100644 --- a/templates/misc/moonflower_ML_pycuda_fft_smoothing.py +++ b/templates/misc/moonflower_ML_pycuda_fft_smoothing.py @@ -56,7 +56,7 @@ p.engines.engine00.scale_precond = True p.engines.engine00.smooth_gradient = 50. p.engines.engine00.smooth_gradient_decay = 1/50. -p.engines.engine00.smooth_gradient_metehod = "fft" # with method "convolution" there can be shared memory issues +p.engines.engine00.smooth_gradient_method = "fft" # with method "convolution" there can be shared memory issues p.engines.engine00.floating_intensities = False # prepare and run From b3e0ded7e93c9affb15c7b10c691169236e21bde Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Fri, 8 Mar 2024 12:31:39 +0000 Subject: [PATCH 14/16] small changes to accelerate tests --- test/accelerate_tests/cuda_cupy_tests/array_utils_test.py | 2 +- test/accelerate_tests/cuda_pycuda_tests/fft_setstream_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py index 23cfc0156..bc245ca71 100644 --- a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py @@ -656,7 +656,7 @@ def test_complex_gaussian_filter_fft_nonsquare_UNITY(self): np.testing.assert_allclose(out_exp, out, atol=1e-6) - def test_complex_gaussian_filter_fft_batched(self): + def test_complex_gaussian_filter_fft_batched_UNITY(self): # Arrange batch_number = 2 A = 5 diff --git a/test/accelerate_tests/cuda_pycuda_tests/fft_setstream_test.py b/test/accelerate_tests/cuda_pycuda_tests/fft_setstream_test.py index 5816e3bf3..c30c08e92 100644 --- a/test/accelerate_tests/cuda_pycuda_tests/fft_setstream_test.py +++ b/test/accelerate_tests/cuda_pycuda_tests/fft_setstream_test.py @@ -95,5 +95,6 @@ def test_set_stream_a_reikna(self): def test_set_stream_b_cufft(self): self.helper(cuFFT) + @unittest.skip("Skcuda is currently broken") def test_set_stream_c_skcuda_cufft(self): self.helper(SkcudaCuFFT) From aeec23485b5e4f2a1df4f3a67e0e30bd473a8486 Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Fri, 8 Mar 2024 14:21:12 +0000 Subject: [PATCH 15/16] removed debugging traces --- ptypy/accelerate/cuda_common/batched_multiply.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ptypy/accelerate/cuda_common/batched_multiply.cu b/ptypy/accelerate/cuda_common/batched_multiply.cu index 7976ef765..e54940277 100644 --- a/ptypy/accelerate/cuda_common/batched_multiply.cu +++ b/ptypy/accelerate/cuda_common/batched_multiply.cu @@ -26,11 +26,6 @@ extern "C" __global__ void batched_multiply(const complex* input, return; auto val = input[gz * rows * columns + gy * rows + gx]; - //printf("gx = %d, gy = %d, gz = %d, val= %.1f +i%.1f\n", gz,gy,gz, val.real(), val.imag()); - //printf("threads: x=%d y=%d z=%d\n", threadIdx.x, threadIdx.y, threadIdx.z); - //printf("blocks: x=%d y=%d z=%d\n", blockIdx.x, blockIdx.y, blockIdx.z); - //printf("grids: x=%d y=%d z=%d\n", blockDim.x, blockDim.y, blockDim.z); - if (MPY_DO_FILT) // set at compile-time { From f2c2cb65402f583d52955b633ef95677e3e3a372 Mon Sep 17 00:00:00 2001 From: Benedikt Daurer Date: Fri, 8 Mar 2024 17:29:42 +0000 Subject: [PATCH 16/16] Improve error message when convolution kernel too big --- ptypy/accelerate/cuda_cupy/engines/ML_cupy.py | 6 +++++- ptypy/accelerate/cuda_cupy/kernels.py | 2 +- ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py | 6 +++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py b/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py index cd68701a2..caa0a192d 100644 --- a/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py +++ b/ptypy/accelerate/cuda_cupy/engines/ML_cupy.py @@ -267,7 +267,11 @@ def _get_smooth_gradient(self, data, sigma): if self.p.smooth_gradient_method == "convolution": if self.GSK.tmp is None: self.GSK.tmp = cp.empty(data.shape, dtype=np.complex64) - self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + try: + self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + except MemoryError: + raise RuntimeError("Convolution kernel too large for direct convolution on GPU", + "Please reduce parameter smooth_gradient or set smooth_gradient_method='fft'.") elif self.p.smooth_gradient_method == "fft": self.FGSK.filter(data, sigma) else: diff --git a/ptypy/accelerate/cuda_cupy/kernels.py b/ptypy/accelerate/cuda_cupy/kernels.py index 118aadbe6..b743e9b08 100644 --- a/ptypy/accelerate/cuda_cupy/kernels.py +++ b/ptypy/accelerate/cuda_cupy/kernels.py @@ -171,7 +171,7 @@ def apply_real_support(self, x): class FFTFilterKernel: - def __init__(self, queue_thread=None, fft='cuda'): + def __init__(self, queue_thread=None, fft='cupy'): # Current implementation recompiles every time there is a change in input shape. self.queue = queue_thread self._fft_type = fft diff --git a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py index b712f8974..d527d9f15 100644 --- a/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py +++ b/ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py @@ -262,7 +262,11 @@ def _get_smooth_gradient(self, data, sigma): if self.p.smooth_gradient_method == "convolution": if self.GSK.tmp is None: self.GSK.tmp = gpuarray.empty(data.shape, dtype=np.complex64) - self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + try: + self.GSK.convolution(data, [sigma, sigma], tmp=self.GSK.tmp) + except MemoryError: + raise RuntimeError("Convolution kernel too large for direct convolution on GPU", + "Please reduce parameter smooth_gradient or set smooth_gradient_method='fft'.") elif self.p.smooth_gradient_method == "fft": self.FGSK.filter(data, sigma) else: