Skip to content

Commit

Permalink
fix(gpu): fix scalar shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 25, 2024
1 parent d3f2ecd commit 87da4aa
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 59 deletions.
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1755,12 +1755,12 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 2) *
tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 3) *
big_lwe_size_bytes,
streams[0], gpu_indexes[0]);

cuda_memset_async(tmp_rotated, 0,
(num_radix_blocks + 2) * big_lwe_size_bytes, streams[0],
(num_radix_blocks + 3) * big_lwe_size_bytes, streams[0],
gpu_indexes[0]);

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);
Expand Down
117 changes: 60 additions & 57 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
Torus *full_rotated_buffer = mem->tmp_rotated;
Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size];

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
// 1024 threads are used in every block
Expand All @@ -76,6 +74,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
return;
}

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];
auto partial_current_blocks = &lwe_array[rotations * big_lwe_size];
auto partial_previous_blocks =
&full_rotated_buffer[rotations * big_lwe_size];
Expand Down Expand Up @@ -109,6 +108,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(

auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

size_t partial_block_count = num_blocks - rotations;

Expand Down Expand Up @@ -139,8 +139,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
int_arithmetic_scalar_shift_buffer<Torus> *mem, void **bsks, Torus **ksks,
uint32_t num_blocks) {

cudaSetDevice(gpu_indexes[0]);

auto params = mem->params;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
Expand All @@ -160,15 +158,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
size_t shift_within_block = shift % num_bits_in_block;

Torus *rotated_buffer = mem->tmp_rotated;
Torus *padding_block = &rotated_buffer[num_blocks * big_lwe_size];
Torus *padding_block = &rotated_buffer[(num_blocks + 1) * big_lwe_size];
Torus *last_block_copy = &padding_block[big_lwe_size];

auto lut_univariate_shift_last_block =
mem->lut_buffers_univariate[shift_within_block - 1];
auto lut_univariate_padding_block =
mem->lut_buffers_univariate[num_bits_in_block - 1];
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

if (mem->shift_type == RIGHT_SHIFT) {
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
rotated_buffer, lwe_array, rotations,
Expand Down Expand Up @@ -197,59 +189,70 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
return;
}

// In the arithmetic shift case we have to pad with the value of the sign
// bit. This creates the need for a different shifting lut than in the
// logical shift case. We also need another PBS to create the padding block.
Torus *last_block = lwe_array + (num_blocks - rotations - 1) * big_lwe_size;
cuda_memcpy_async_gpu_to_gpu(
last_block_copy,
rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size,
big_lwe_size_bytes, streams[0], gpu_indexes[0]);
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
size_t partial_block_count = num_blocks - rotations;
if (shift_within_block != 0 && rotations != num_blocks) {
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, partial_current_blocks,
partial_current_blocks, partial_next_blocks, bsks, ksks,
partial_block_count, lut_bivariate,
lut_bivariate->params.message_modulus);
}
// Since our CPU threads will be working on different streams we shall
// assert the work in the main stream is completed
for (uint j = 0; j < gpu_count; j++) {
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
}
if (num_blocks != rotations) {
// In the arithmetic shift case we have to pad with the value of the sign
// bit. This creates the need for a different shifting lut than in the
// logical shift case. We also need another PBS to create the padding
// block.
Torus *last_block =
lwe_array + (num_blocks - rotations - 1) * big_lwe_size;
cuda_memcpy_async_gpu_to_gpu(
last_block_copy,
rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size,
big_lwe_size_bytes, streams[0], gpu_indexes[0]);
if (shift_within_block != 0) {
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
size_t partial_block_count = num_blocks - rotations;
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, partial_current_blocks,
partial_current_blocks, partial_next_blocks, bsks, ksks,
partial_block_count, lut_bivariate,
lut_bivariate->params.message_modulus);
}
// Since our CPU threads will be working on different streams we shall
// assert the work in the main stream is completed
for (uint j = 0; j < gpu_count; j++) {
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
}
#pragma omp parallel sections
{
// All sections may be executed in parallel
#pragma omp section
{
integer_radix_apply_univariate_lookup_table_kb(
mem->local_streams_1, gpu_indexes, gpu_count, padding_block,
last_block_copy, bsks, ksks, 1, lut_univariate_padding_block);
// Replace blocks 'pulled' from the left with the correct padding block
for (uint i = 0; i < rotations; i++) {
cuda_memcpy_async_gpu_to_gpu(
lwe_array + (num_blocks - rotations + i) * big_lwe_size,
padding_block, big_lwe_size_bytes, mem->local_streams_1[0],
gpu_indexes[0]);
}
}
// All sections may be executed in parallel
#pragma omp section
{
if (shift_within_block != 0 && rotations != num_blocks) {
{
auto lut_univariate_padding_block =
mem->lut_buffers_univariate[num_bits_in_block - 1];
integer_radix_apply_univariate_lookup_table_kb(
mem->local_streams_2, gpu_indexes, gpu_count, last_block,
last_block_copy, bsks, ksks, 1, lut_univariate_shift_last_block);
mem->local_streams_1, gpu_indexes, gpu_count, padding_block,
last_block_copy, bsks, ksks, 1, lut_univariate_padding_block);
// Replace blocks 'pulled' from the left with the correct padding
// block
for (uint i = 0; i < rotations; i++) {
cuda_memcpy_async_gpu_to_gpu(
lwe_array + (num_blocks - rotations + i) * big_lwe_size,
padding_block, big_lwe_size_bytes, mem->local_streams_1[0],
gpu_indexes[0]);
}
}
#pragma omp section
{
if (shift_within_block != 0) {
auto lut_univariate_shift_last_block =
mem->lut_buffers_univariate[shift_within_block - 1];
integer_radix_apply_univariate_lookup_table_kb(
mem->local_streams_2, gpu_indexes, gpu_count, last_block,
last_block_copy, bsks, ksks, 1,
lut_univariate_shift_last_block);
}
}
}
for (uint j = 0; j < mem->active_gpu_count; j++) {
cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]);
}
}
for (uint j = 0; j < mem->active_gpu_count; j++) {
cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]);
}

} else {
PANIC("Cuda error (scalar shift): left scalar shift is never of the "
"arithmetic type")
Expand Down

0 comments on commit 87da4aa

Please sign in to comment.