diff --git a/ptypy/accelerate/cuda_cupy/array_utils.py b/ptypy/accelerate/cuda_cupy/array_utils.py index 9c68d9431..9f14015a1 100644 --- a/ptypy/accelerate/cuda_cupy/array_utils.py +++ b/ptypy/accelerate/cuda_cupy/array_utils.py @@ -2,6 +2,7 @@ import numpy as np from ptypy.accelerate.cuda_common.utils import map2ctype +from ptypy.accelerate.base.array_utils import gaussian_kernel_2d from ptypy.utils.math_utils import gaussian from . import load_kernel @@ -322,6 +323,59 @@ def delxb(self, input, out, axis=-1): out, lower_dim, higher_dim, np.int32(input.shape[axis]))) +class FFTGaussianSmoothingKernel: + def __init__(self, queue=None, kernel_type='float'): + if kernel_type not in ['float', 'double']: + raise ValueError('Invalid data type for kernel') + self.kernel_type = kernel_type + self.dtype = np.complex64 + self.stype = "complex" + self.queue = queue + self.sigma = None + + from .kernels import FFTFilterKernel + + # Create general FFT filter object + self.fft_filter = FFTFilterKernel(queue_thread=queue) + + def allocate(self, shape, sigma=1.): + + # Create kernel + self.sigma = sigma + kernel = self._compute_kernel(shape, sigma) + + # Extend kernel if needed + if len(shape) == 3: + kernel = np.tile(kernel, (shape[0],1,1)) + + # Allocate filter + kernel_dev = cp.asarray(kernel) + self.fft_filter.allocate(kernel=kernel_dev) + + def _compute_kernel(self, shape, sigma): + # Create kernel + sigma = np.array(sigma) + if sigma.size == 1: + sigma = np.array([sigma, sigma]) + if sigma.size > 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") + self.sigma = sigma + return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) + + def filter(self, data, sigma=None): + """ + Apply filter in-place + + If sigma is not None: reallocate a new fft filter first. + """ + if self.sigma is None: + self.allocate(shape=data.shape, sigma=sigma) + else: + self.fft_filter.set_kernel(self._compute_kernel(data.shape, sigma)) + + self.fft_filter.apply_filter(data) + + class GaussianSmoothingKernel: def __init__(self, queue=None, num_stdevs=4, kernel_type='float'): if kernel_type not in ['float', 'double']: diff --git a/ptypy/accelerate/cuda_cupy/kernels.py b/ptypy/accelerate/cuda_cupy/kernels.py index 049108e71..118aadbe6 100644 --- a/ptypy/accelerate/cuda_cupy/kernels.py +++ b/ptypy/accelerate/cuda_cupy/kernels.py @@ -170,6 +170,34 @@ def apply_real_support(self, x): x *= self.support +class FFTFilterKernel: + def __init__(self, queue_thread=None, fft='cuda'): + # Current implementation recompiles every time there is a change in input shape. + self.queue = queue_thread + self._fft_type = fft + self._fft1 = None + self._fft2 = None + def allocate(self, kernel, prefactor=None, postfactor=None, forward=True): + FFT = choose_fft(kernel.shape[-2:], fft_type=self._fft_type) + + self._fft1 = FFT(kernel, self.queue, + pre_fft=prefactor, + post_fft=kernel, + symmetric=True, + forward=forward) + self._fft2 = FFT(kernel, self.queue, + post_fft=postfactor, + symmetric=True, + forward=not forward) + + def set_kernel(self, kernel): + self._fft1.post_fft.set(kernel) + + def apply_filter(self, x): + self._fft1.ft(x, x) + self._fft2.ift(x, x) + + class FourierUpdateKernel(ab.FourierUpdateKernel): def __init__(self, aux, nmodes=1, queue_thread=None, accumulate_type='float', math_type='float'): diff --git a/ptypy/accelerate/cuda_pycuda/array_utils.py b/ptypy/accelerate/cuda_pycuda/array_utils.py index 2845a8b93..2879739e3 100644 --- a/ptypy/accelerate/cuda_pycuda/array_utils.py +++ b/ptypy/accelerate/cuda_pycuda/array_utils.py @@ -357,10 +357,10 @@ def allocate(self, shape, sigma=1.): def _compute_kernel(self, shape, sigma): # Create kernel sigma = np.array(sigma) - if sigma.ndim <= 1: + if sigma.size == 1: sigma = np.array([sigma, sigma]) - if sigma.ndim > 2: - raise NotImplementedError("Only batches of 2D arrays allowed!") + if sigma.size > 2: + raise NotImplementedError("Only batches of 2D arrays allowed!") self.sigma = sigma return gaussian_kernel_2d(shape, sigma[0], sigma[1]).astype(self.dtype) diff --git a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py index 0c018b205..a7a92d7b9 100644 --- a/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py +++ b/test/accelerate_tests/cuda_cupy_tests/array_utils_test.py @@ -6,6 +6,7 @@ if have_cupy(): import cupy as cp import ptypy.accelerate.cuda_cupy.array_utils as gau + from ptypy.accelerate.cuda_cupy.kernels import FFTFilterKernel class ArrayUtilsTest(CupyCudaTest): @@ -534,3 +535,128 @@ def test_interpolate_shift_no_shift_UNITY(self): np.testing.assert_allclose(out, isk, rtol=1e-6, atol=1e-6, err_msg="The shifting of array has not been calculated as expected") + def test_fft_filter_UNITY(self): + sh = (32, 48) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[0, 0] = 1. + kernel[0, 1] = 0.5 + + prefactor = np.zeros_like(data) + prefactor[:,2:] = 1. + postfactor = np.zeros_like(data) + postfactor[2:,:] = 1. + + data_dev = cp.asarray(data) + kernel_dev = cp.asarray(kernel) + pre_dev = cp.asarray(prefactor) + post_dev = cp.asarray(postfactor) + + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + + def test_fft_filter_batched_UNITY(self): + sh = (2,16, 35) + data = np.zeros(sh, dtype=np.complex64) + data.flat[:] = np.arange(np.prod(sh)) + kernel = np.zeros_like(data) + kernel[:,0, 0] = 1. + kernel[:,0, 1] = 0.5 + + prefactor = np.zeros_like(data) + prefactor[:,:,2:] = 1. + postfactor = np.zeros_like(data) + postfactor[:,2:,:] = 1. + + data_dev = cp.asarray(data) + kernel_dev = cp.asarray(kernel) + pre_dev = cp.asarray(prefactor) + post_dev = cp.asarray(postfactor) + + FF = FFTFilterKernel(queue_thread=self.stream) + FF.allocate(kernel=kernel_dev, prefactor=pre_dev, postfactor=post_dev) + FF.apply_filter(data_dev) + + output = au.fft_filter(data, kernel, prefactor, postfactor) + print(data_dev.get()) + + np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6) + + def test_complex_gaussian_filter_fft_little_blurring_UNITY(self): + # Arrange + data = np.zeros((21, 21), dtype=np.complex64) + data[10:12, 10:12] = 2.0+2.0j + mfs = 0.2,0.2 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_more_blurring_UNITY(self): + # Arrange + data = np.zeros((8, 8), dtype=np.complex64) + data[3:5, 3:5] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_nonsquare_UNITY(self): + # Arrange + data = np.zeros((32, 16), dtype=np.complex64) + data[3:4, 11:12] = 2.0+2.0j + data[3:5, 3:5] = 2.0+2.0j + data[20:25,3:5] = 2.0+2.0j + mfs = 1.0,1.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) + + def test_complex_gaussian_filter_fft_batched(self): + # Arrange + batch_number = 2 + A = 5 + B = 5 + data = np.zeros((batch_number, A, B), dtype=np.complex64) + data[:, 2:3, 2:3] = 2.0+2.0j + mfs = 3.0,4.0 + data_dev = cp.asarray(data) + + # Act + FGSK = gau.FFTGaussianSmoothingKernel(queue=self.stream) + FGSK.filter(data_dev, mfs) + + # Assert + out_exp = au.complex_gaussian_filter_fft(data, mfs) + out = data_dev.get() + + np.testing.assert_allclose(out_exp, out, atol=1e-6) \ No newline at end of file