Skip to content

Commit

Permalink
chore(gpu): speed up signed comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Apr 3, 2024
1 parent 1fc3297 commit 1ccc72c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 34 deletions.
14 changes: 7 additions & 7 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1974,9 +1974,9 @@ template <typename Torus> struct int_comparison_buffer {
int_radix_lut<Torus> *signed_lut;
bool is_signed;

// Used for scalar comparisons
cuda_stream_t *lsb_stream;
cuda_stream_t *msb_stream;
// Used for scalar comparisons & signed comparisons
cuda_stream_t *local_stream_1;
cuda_stream_t *local_stream_2;

int_comparison_buffer(cuda_stream_t *stream, COMPARISON_TYPE op,
int_radix_params params, uint32_t num_radix_blocks,
Expand All @@ -1988,8 +1988,8 @@ template <typename Torus> struct int_comparison_buffer {
identity_lut_f = [](Torus x) -> Torus { return x; };

if (allocate_gpu_memory) {
lsb_stream = cuda_create_stream(stream->gpu_index);
msb_stream = cuda_create_stream(stream->gpu_index);
local_stream_1 = cuda_create_stream(stream->gpu_index);
local_stream_2 = cuda_create_stream(stream->gpu_index);

tmp_lwe_array_out = (Torus *)cuda_malloc_async(
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus),
Expand Down Expand Up @@ -2125,12 +2125,12 @@ template <typename Torus> struct int_comparison_buffer {
cuda_drop_async(tmp_block_comparisons, stream);
cuda_drop_async(tmp_packed_input, stream);

cuda_destroy_stream(local_stream_1);
cuda_destroy_stream(local_stream_2);
if (is_signed) {
signed_lut->release(stream);
delete (signed_lut);
}
cuda_destroy_stream(lsb_stream);
cuda_destroy_stream(msb_stream);
}
};

Expand Down
62 changes: 39 additions & 23 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -498,35 +498,51 @@ __host__ void host_integer_radix_difference_check_kb(
} else {
// Packing is possible
if (carry_modulus >= message_modulus) {
// Compare (num_radix_blocks - 2) / 2 packed blocks
compare_radix_blocks_kb(stream, comparisons, lhs, rhs, mem_ptr, bsk, ksk,
packed_num_radix_blocks);

// Compare the last block before the sign block separately
auto identity_lut = mem_ptr->identity_lut;
Torus *last_left_block_before_sign_block =
diff_buffer->tmp_packed_left + packed_num_radix_blocks * big_lwe_size;
Torus *last_right_block_before_sign_block =
diff_buffer->tmp_packed_right +
packed_num_radix_blocks * big_lwe_size;
integer_radix_apply_univariate_lookup_table_kb(
stream, last_left_block_before_sign_block,
lwe_array_left + (num_radix_blocks - 2) * big_lwe_size, bsk, ksk, 1,
identity_lut);
integer_radix_apply_univariate_lookup_table_kb(
stream, last_right_block_before_sign_block,
lwe_array_right + (num_radix_blocks - 2) * big_lwe_size, bsk, ksk, 1,
identity_lut);
compare_radix_blocks_kb(
stream, comparisons + packed_num_radix_blocks * big_lwe_size,
last_left_block_before_sign_block, last_right_block_before_sign_block,
mem_ptr, bsk, ksk, 1);
// Compare the sign block separately
integer_radix_apply_bivariate_lookup_table_kb(
stream, comparisons + (packed_num_radix_blocks + 1) * big_lwe_size,
lwe_array_left + (num_radix_blocks - 1) * big_lwe_size,
lwe_array_right + (num_radix_blocks - 1) * big_lwe_size, bsk, ksk, 1,
mem_ptr->signed_lut);
// Compare (num_radix_blocks - 2) / 2 packed blocks
compare_radix_blocks_kb(mem_ptr->local_stream_1, comparisons, lhs, rhs,
mem_ptr, bsk, ksk, packed_num_radix_blocks);
// Since our CPU threads will be working on different streams we shall
// assert the work in the main stream is completed
stream->synchronize();
#pragma omp parallel sections
{
// All sections may be executed in parallel
#pragma omp section
{
// Compare the last block before the sign block separately
integer_radix_apply_univariate_lookup_table_kb(
mem_ptr->local_stream_1, last_left_block_before_sign_block,
lwe_array_left + (num_radix_blocks - 2) * big_lwe_size, bsk, ksk,
1, identity_lut);
integer_radix_apply_univariate_lookup_table_kb(
mem_ptr->local_stream_1, last_right_block_before_sign_block,
lwe_array_right + (num_radix_blocks - 2) * big_lwe_size, bsk, ksk,
1, identity_lut);
compare_radix_blocks_kb(
mem_ptr->local_stream_1,
comparisons + packed_num_radix_blocks * big_lwe_size,
last_left_block_before_sign_block,
last_right_block_before_sign_block, mem_ptr, bsk, ksk, 1);
}
#pragma omp section
{
// Compare the sign block separately
integer_radix_apply_bivariate_lookup_table_kb(
mem_ptr->local_stream_2,
comparisons + (packed_num_radix_blocks + 1) * big_lwe_size,
lwe_array_left + (num_radix_blocks - 1) * big_lwe_size,
lwe_array_right + (num_radix_blocks - 1) * big_lwe_size, bsk, ksk,
1, mem_ptr->signed_lut);
}
}
cuda_synchronize_stream(mem_ptr->local_stream_1);
cuda_synchronize_stream(mem_ptr->local_stream_2);
num_comparisons = packed_num_radix_blocks + 2;

} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ __host__ void host_integer_radix_scalar_difference_check_kb(
auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size;

cuda_synchronize_stream(stream);
auto lsb_stream = mem_ptr->lsb_stream;
auto msb_stream = mem_ptr->msb_stream;
auto lsb_stream = mem_ptr->local_stream_1;
auto msb_stream = mem_ptr->local_stream_2;

#pragma omp parallel sections
{
Expand Down Expand Up @@ -305,8 +305,8 @@ __host__ void host_integer_radix_scalar_equality_check_kb(

cuda_synchronize_stream(stream);

auto lsb_stream = mem_ptr->lsb_stream;
auto msb_stream = mem_ptr->msb_stream;
auto lsb_stream = mem_ptr->local_stream_1;
auto msb_stream = mem_ptr->local_stream_2;

#pragma omp parallel sections
{
Expand Down

0 comments on commit 1ccc72c

Please sign in to comment.