Skip to content

Commit

Permalink
Fixed another bug in batch multiply kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Mar 7, 2024
1 parent 6f522a8 commit d739843
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 3 deletions.
8 changes: 7 additions & 1 deletion ptypy/accelerate/cuda_common/batched_multiply.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ extern "C" __global__ void batched_multiply(const complex<IN_TYPE>* input,
int gy = threadIdx.y + blockIdx.y * blockDim.y;
int gz = threadIdx.z + blockIdx.z * blockDim.z;

if (gx > rows || gy > columns || gz > nBatches)
if (gx > rows - 1 || gy > columns - 1 || gz > nBatches)
return;

auto val = input[gz * rows * columns + gy * rows + gx];
//printf("gx = %d, gy = %d, gz = %d, val= %.1f +i%.1f\n", gz,gy,gz, val.real(), val.imag());
//printf("threads: x=%d y=%d z=%d\n", threadIdx.x, threadIdx.y, threadIdx.z);
//printf("blocks: x=%d y=%d z=%d\n", blockIdx.x, blockIdx.y, blockIdx.z);
//printf("grids: x=%d y=%d z=%d\n", blockDim.x, blockDim.y, blockDim.z);


if (MPY_DO_FILT) // set at compile-time
{
val *= filter[gy * rows + gx];
Expand Down
31 changes: 31 additions & 0 deletions ptypy/accelerate/cuda_cupy/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ def dot(self, A: cp.ndarray, B: cp.ndarray, out: cp.ndarray = None) -> cp.ndarra
def norm2(self, A, out=None):
return self.dot(A, A, out)

class BatchedMultiplyKernel:
def __init__(self, array, queue=None, math_type=np.complex64):
self.queue = queue
self.array_shape = array.shape[-2:]
self.batches = int(np.prod(array.shape[0:array.ndim-2]) if array.ndim > 2 else 1)
self.batched_multiply_cuda = load_kernel("batched_multiply", {
'MPY_DO_SCALE': 'true',
'MPY_DO_FILT': 'true',
'IN_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'OUT_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'MATH_TYPE': 'float' if math_type==np.complex64 else 'double'
})
self.block = (32,32,1)
self.grid = (
int((self.array_shape[0] + 31) // 32),
int((self.array_shape[1] + 31) // 32),
int(self.batches)
)

def multiply(self, x,y, scale=1.):
assert x.dtype == y.dtype, "Input arrays must be of same data type"
assert x.shape[-2:] == y.shape[-2:], "Input arrays must be of the same size in last 2 dims"
if self.queue is not None:
self.queue.use()
self.batched_multiply_cuda(self.grid,
self.block,
args=(x,x,y,
np.float32(scale),
np.int32(self.batches),
np.int32(self.array_shape[0]),
np.int32(self.array_shape[1])))

class TransposeKernel:

Expand Down
32 changes: 32 additions & 0 deletions ptypy/accelerate/cuda_pycuda/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,38 @@ def dot(self, A, B, out=None):
def norm2(self, A, out=None):
return self.dot(A, A, out)

class BatchedMultiplyKernel:
def __init__(self, array, queue=None, math_type=np.complex64):
self.queue = queue
self.array_shape = array.shape[-2:]
self.batches = int(np.prod(array.shape[0:array.ndim-2]) if array.ndim > 2 else 1)
self.batched_multiply_cuda = load_kernel("batched_multiply", {
'MPY_DO_SCALE': 'true',
'MPY_DO_FILT': 'true',
'IN_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'OUT_TYPE': 'float' if array.dtype==np.complex64 else 'double',
'MATH_TYPE': 'float' if math_type==np.complex64 else 'double'
})
self.block = (32,32,1)
self.grid = (
int((self.array_shape[0] + 31) // 32),
int((self.array_shape[1] + 31) // 32),
int(self.batches)
)

def multiply(self, x,y, scale=1.):
assert x.dtype == y.dtype, "Input arrays must be of same data type"
assert x.shape[-2:] == y.shape[-2:], "Input arrays must be of the same size in last 2 dims"
self.batched_multiply_cuda(x,x,y,
np.float32(scale),
np.int32(self.batches),
np.int32(self.array_shape[0]),
np.int32(self.array_shape[1]),
block=self.block,
grid=self.grid,
stream=self.queue)


class TransposeKernel:

def __init__(self, queue=None):
Expand Down
19 changes: 17 additions & 2 deletions test/accelerate_tests/cuda_cupy_tests/array_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ def test_dot_performance(self):
AU = gau.ArrayUtilsKernel(acc_dtype=np.float64)
AU.dot(A_dev, A_dev)

def test_batched_multiply(self):
# Arrange
sh = (3,14,24)
ksh = (14,24)
data = (np.random.random(sh) + 1j* np.random.random(sh)).astype(np.complex64)
kernel = (np.random.random(ksh) + 1j* np.random.random(ksh)).astype(np.complex64)
data_dev = cp.asarray(data)
kernel_dev = cp.asarray(kernel)

# Act
BM = gau.BatchedMultiplyKernel(data_dev)
BM.multiply(data_dev, kernel_dev, scale=2.)

# Assert
expected = data * kernel * 2.
np.testing.assert_array_almost_equal(data_dev.get(), expected)

def test_transpose_2D(self):
# Arrange
inp, _ = np.indices((5, 3), dtype=np.int32)
Expand Down Expand Up @@ -584,8 +601,6 @@ def test_fft_filter_batched_UNITY(self):
FF.apply_filter(data_dev)

output = au.fft_filter(data, kernel, prefactor, postfactor)
print(data_dev.get())

np.testing.assert_allclose(output, data_dev.get(), rtol=1e-5, atol=1e-6)

def test_complex_gaussian_filter_fft_little_blurring_UNITY(self):
Expand Down
17 changes: 17 additions & 0 deletions test/accelerate_tests/cuda_pycuda_tests/array_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def test_dot_performance(self):
AU = gau.ArrayUtilsKernel(acc_dtype=np.float64)
out_dev = AU.dot(A_dev, A_dev)

def test_batched_multiply(self):
# Arrange
sh = (3,14,24)
ksh = (14,24)
data = (np.random.random(sh) + 1j* np.random.random(sh)).astype(np.complex64)
kernel = (np.random.random(ksh) + 1j* np.random.random(ksh)).astype(np.complex64)
data_dev = gpuarray.to_gpu(data)
kernel_dev = gpuarray.to_gpu(kernel)

# Act
BM = gau.BatchedMultiplyKernel(data_dev)
BM.multiply(data_dev, kernel_dev, scale=2.)

# Assert
expected = data * kernel * 2.
np.testing.assert_array_almost_equal(data_dev.get(), expected)

def test_transpose_2D(self):
## Arrange
inp,_ = np.indices((5,3), dtype=np.int32)
Expand Down

0 comments on commit d739843

Please sign in to comment.