Skip to content

Commit

Permalink
source/kernel.cu: use async data copies
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Oct 18, 2023
1 parent 8d71855 commit e735eb7
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions source/kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <iterator>

#include <cuda_pipeline_primitives.h>

#define BLOCK_X 16
#define BLOCK_Y 8

Expand Down Expand Up @@ -31,7 +33,11 @@ static void bilateral(
int sy = min(max(cy - static_cast<int>(threadIdx.y) - radius + y, 0), height - 1);
for (int cx = threadIdx.x; cx < 2 * radius + BLOCK_X; cx += BLOCK_X) {
int sx = min(max(cx - static_cast<int>(threadIdx.x) - radius + x, 0), width - 1);
buffer[cy * (2 * radius + BLOCK_X) + cx] = src[sy * stride + sx];
__pipeline_memcpy_async(
&buffer[cy * (2 * radius + BLOCK_X) + cx],
&src[sy * stride + sx],
4
);
}
}

Expand All @@ -40,18 +46,24 @@ static void bilateral(
int sy = min(max(cy - static_cast<int>(threadIdx.y) - radius + y, 0), height - 1);
for (int cx = threadIdx.x; cx < 2 * radius + BLOCK_X; cx += BLOCK_X) {
int sx = min(max(cx - static_cast<int>(threadIdx.x) - radius + x, 0), width - 1);
buffer[(2 * radius + BLOCK_Y + cy) * (2 * radius + BLOCK_X) + cx] = src[(height + sy) * stride + sx];
__pipeline_memcpy_async(
&buffer[(2 * radius + BLOCK_Y + cy) * (2 * radius + BLOCK_X) + cx],
&src[(height + sy) * stride + sx],
4
);
}
}
}

__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();

if (x >= width || y >= height)
return;

const float center = buffer[
(has_ref * (2 * radius + BLOCK_Y) + radius + threadIdx.y) * (2 * radius + BLOCK_X) +
(has_ref * (2 * radius + BLOCK_Y) + radius + threadIdx.y) * (2 * radius + BLOCK_X) +
radius + threadIdx.x
]; // src[(has_ref * height + y) * stride + x];

Expand Down Expand Up @@ -148,12 +160,12 @@ cudaGraphExec_t get_graphexec(

kernel_params.func = (
useSharedMem ?
(has_ref ?
reinterpret_cast<void *>(bilateral<true, true>) :
(has_ref ?
reinterpret_cast<void *>(bilateral<true, true>) :
reinterpret_cast<void *>(bilateral<true, false>)
) :
(has_ref ?
reinterpret_cast<void *>(bilateral<false, true>) :
(has_ref ?
reinterpret_cast<void *>(bilateral<false, true>) :
reinterpret_cast<void *>(bilateral<false, false>)
)
);
Expand Down

0 comments on commit e735eb7

Please sign in to comment.