From 87da4aafb01b4ce9121bf60a6b94d6dc903bb8e6 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 25 Jul 2024 15:08:36 +0200 Subject: [PATCH] fix(gpu): fix scalar shifts --- .../tfhe-cuda-backend/cuda/include/integer.h | 4 +- .../cuda/src/integer/scalar_shifts.cuh | 117 +++++++++--------- 2 files changed, 62 insertions(+), 59 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index df4eb39df0..33d862cd29 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -1755,12 +1755,12 @@ template struct int_arithmetic_scalar_shift_buffer { uint32_t big_lwe_size = params.big_lwe_dimension + 1; uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus); - tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 2) * + tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 3) * big_lwe_size_bytes, streams[0], gpu_indexes[0]); cuda_memset_async(tmp_rotated, 0, - (num_radix_blocks + 2) * big_lwe_size_bytes, streams[0], + (num_radix_blocks + 3) * big_lwe_size_bytes, streams[0], gpu_indexes[0]); uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh index e612c9ab2f..47eea7b426 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh @@ -52,8 +52,6 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( Torus *full_rotated_buffer = mem->tmp_rotated; Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size]; - auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; - // rotate right all the blocks in radix ciphertext // copy result in new buffer // 1024 threads are used in every block @@ -76,6 +74,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( return; } + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; auto partial_current_blocks = &lwe_array[rotations * big_lwe_size]; auto partial_previous_blocks = &full_rotated_buffer[rotations * big_lwe_size]; @@ -109,6 +108,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( auto partial_current_blocks = lwe_array; auto partial_next_blocks = &rotated_buffer[big_lwe_size]; + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; size_t partial_block_count = num_blocks - rotations; @@ -139,8 +139,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( int_arithmetic_scalar_shift_buffer *mem, void **bsks, Torus **ksks, uint32_t num_blocks) { - cudaSetDevice(gpu_indexes[0]); - auto params = mem->params; auto glwe_dimension = params.glwe_dimension; auto polynomial_size = params.polynomial_size; @@ -160,15 +158,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( size_t shift_within_block = shift % num_bits_in_block; Torus *rotated_buffer = mem->tmp_rotated; - Torus *padding_block = &rotated_buffer[num_blocks * big_lwe_size]; + Torus *padding_block = &rotated_buffer[(num_blocks + 1) * big_lwe_size]; Torus *last_block_copy = &padding_block[big_lwe_size]; - auto lut_univariate_shift_last_block = - mem->lut_buffers_univariate[shift_within_block - 1]; - auto lut_univariate_padding_block = - mem->lut_buffers_univariate[num_bits_in_block - 1]; - auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; - if (mem->shift_type == RIGHT_SHIFT) { host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, rotated_buffer, lwe_array, rotations, @@ -197,59 +189,70 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( return; } - // In the arithmetic shift case we have to pad with the value of the sign - // bit. This creates the need for a different shifting lut than in the - // logical shift case. We also need another PBS to create the padding block. - Torus *last_block = lwe_array + (num_blocks - rotations - 1) * big_lwe_size; - cuda_memcpy_async_gpu_to_gpu( - last_block_copy, - rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size, - big_lwe_size_bytes, streams[0], gpu_indexes[0]); - auto partial_current_blocks = lwe_array; - auto partial_next_blocks = &rotated_buffer[big_lwe_size]; - size_t partial_block_count = num_blocks - rotations; - if (shift_within_block != 0 && rotations != num_blocks) { - integer_radix_apply_bivariate_lookup_table_kb( - streams, gpu_indexes, gpu_count, partial_current_blocks, - partial_current_blocks, partial_next_blocks, bsks, ksks, - partial_block_count, lut_bivariate, - lut_bivariate->params.message_modulus); - } - // Since our CPU threads will be working on different streams we shall - // assert the work in the main stream is completed - for (uint j = 0; j < gpu_count; j++) { - cuda_synchronize_stream(streams[j], gpu_indexes[j]); - } + if (num_blocks != rotations) { + // In the arithmetic shift case we have to pad with the value of the sign + // bit. This creates the need for a different shifting lut than in the + // logical shift case. We also need another PBS to create the padding + // block. + Torus *last_block = + lwe_array + (num_blocks - rotations - 1) * big_lwe_size; + cuda_memcpy_async_gpu_to_gpu( + last_block_copy, + rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size, + big_lwe_size_bytes, streams[0], gpu_indexes[0]); + if (shift_within_block != 0) { + auto partial_current_blocks = lwe_array; + auto partial_next_blocks = &rotated_buffer[big_lwe_size]; + size_t partial_block_count = num_blocks - rotations; + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; + + integer_radix_apply_bivariate_lookup_table_kb( + streams, gpu_indexes, gpu_count, partial_current_blocks, + partial_current_blocks, partial_next_blocks, bsks, ksks, + partial_block_count, lut_bivariate, + lut_bivariate->params.message_modulus); + } + // Since our CPU threads will be working on different streams we shall + // assert the work in the main stream is completed + for (uint j = 0; j < gpu_count; j++) { + cuda_synchronize_stream(streams[j], gpu_indexes[j]); + } #pragma omp parallel sections - { - // All sections may be executed in parallel -#pragma omp section { - integer_radix_apply_univariate_lookup_table_kb( - mem->local_streams_1, gpu_indexes, gpu_count, padding_block, - last_block_copy, bsks, ksks, 1, lut_univariate_padding_block); - // Replace blocks 'pulled' from the left with the correct padding block - for (uint i = 0; i < rotations; i++) { - cuda_memcpy_async_gpu_to_gpu( - lwe_array + (num_blocks - rotations + i) * big_lwe_size, - padding_block, big_lwe_size_bytes, mem->local_streams_1[0], - gpu_indexes[0]); - } - } + // All sections may be executed in parallel #pragma omp section - { - if (shift_within_block != 0 && rotations != num_blocks) { + { + auto lut_univariate_padding_block = + mem->lut_buffers_univariate[num_bits_in_block - 1]; integer_radix_apply_univariate_lookup_table_kb( - mem->local_streams_2, gpu_indexes, gpu_count, last_block, - last_block_copy, bsks, ksks, 1, lut_univariate_shift_last_block); + mem->local_streams_1, gpu_indexes, gpu_count, padding_block, + last_block_copy, bsks, ksks, 1, lut_univariate_padding_block); + // Replace blocks 'pulled' from the left with the correct padding + // block + for (uint i = 0; i < rotations; i++) { + cuda_memcpy_async_gpu_to_gpu( + lwe_array + (num_blocks - rotations + i) * big_lwe_size, + padding_block, big_lwe_size_bytes, mem->local_streams_1[0], + gpu_indexes[0]); + } + } +#pragma omp section + { + if (shift_within_block != 0) { + auto lut_univariate_shift_last_block = + mem->lut_buffers_univariate[shift_within_block - 1]; + integer_radix_apply_univariate_lookup_table_kb( + mem->local_streams_2, gpu_indexes, gpu_count, last_block, + last_block_copy, bsks, ksks, 1, + lut_univariate_shift_last_block); + } } } + for (uint j = 0; j < mem->active_gpu_count; j++) { + cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]); + cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]); + } } - for (uint j = 0; j < mem->active_gpu_count; j++) { - cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]); - cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]); - } - } else { PANIC("Cuda error (scalar shift): left scalar shift is never of the " "arithmetic type")