From e735eb7d8efc845e4ef22bf3df143a3054fd65b7 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Wed, 18 Oct 2023 10:24:28 +0800 Subject: [PATCH] source/kernel.cu: use async data copies --- source/kernel.cu | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/source/kernel.cu b/source/kernel.cu index 6b7b2c3..e6b34a8 100644 --- a/source/kernel.cu +++ b/source/kernel.cu @@ -1,5 +1,7 @@ #include +#include + #define BLOCK_X 16 #define BLOCK_Y 8 @@ -31,7 +33,11 @@ static void bilateral( int sy = min(max(cy - static_cast(threadIdx.y) - radius + y, 0), height - 1); for (int cx = threadIdx.x; cx < 2 * radius + BLOCK_X; cx += BLOCK_X) { int sx = min(max(cx - static_cast(threadIdx.x) - radius + x, 0), width - 1); - buffer[cy * (2 * radius + BLOCK_X) + cx] = src[sy * stride + sx]; + __pipeline_memcpy_async( + &buffer[cy * (2 * radius + BLOCK_X) + cx], + &src[sy * stride + sx], + 4 + ); } } @@ -40,18 +46,24 @@ static void bilateral( int sy = min(max(cy - static_cast(threadIdx.y) - radius + y, 0), height - 1); for (int cx = threadIdx.x; cx < 2 * radius + BLOCK_X; cx += BLOCK_X) { int sx = min(max(cx - static_cast(threadIdx.x) - radius + x, 0), width - 1); - buffer[(2 * radius + BLOCK_Y + cy) * (2 * radius + BLOCK_X) + cx] = src[(height + sy) * stride + sx]; + __pipeline_memcpy_async( + &buffer[(2 * radius + BLOCK_Y + cy) * (2 * radius + BLOCK_X) + cx], + &src[(height + sy) * stride + sx], + 4 + ); } } } + __pipeline_commit(); + __pipeline_wait_prior(0); __syncthreads(); - + if (x >= width || y >= height) return; const float center = buffer[ - (has_ref * (2 * radius + BLOCK_Y) + radius + threadIdx.y) * (2 * radius + BLOCK_X) + + (has_ref * (2 * radius + BLOCK_Y) + radius + threadIdx.y) * (2 * radius + BLOCK_X) + radius + threadIdx.x ]; // src[(has_ref * height + y) * stride + x]; @@ -148,12 +160,12 @@ cudaGraphExec_t get_graphexec( kernel_params.func = ( useSharedMem ? - (has_ref ? - reinterpret_cast(bilateral) : + (has_ref ? + reinterpret_cast(bilateral) : reinterpret_cast(bilateral) ) : - (has_ref ? - reinterpret_cast(bilateral) : + (has_ref ? + reinterpret_cast(bilateral) : reinterpret_cast(bilateral) ) );