Skip to content

Commit

Permalink
chore(gpu): run pbs in parallel in difference_check
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Dec 16, 2024
1 parent b1ce34f commit 86f0704
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3279,8 +3279,7 @@ template <typename Torus> struct int_comparison_diff_buffer {
int_radix_params params;
COMPARISON_TYPE op;

Torus *tmp_packed_left;
Torus *tmp_packed_right;
Torus *tmp_packed;

std::function<Torus(Torus)> operator_f;

Expand Down Expand Up @@ -3317,11 +3316,8 @@ template <typename Torus> struct int_comparison_diff_buffer {

Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);

tmp_packed_left = (Torus *)cuda_malloc_async(
big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]);

tmp_packed_right = (Torus *)cuda_malloc_async(
big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]);
tmp_packed = (Torus *)cuda_malloc_async(big_size * num_radix_blocks,
streams[0], gpu_indexes[0]);

tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
streams, gpu_indexes, gpu_count, operator_f, params, num_radix_blocks,
Expand All @@ -3344,8 +3340,7 @@ template <typename Torus> struct int_comparison_diff_buffer {
reduce_signs_lut->release(streams, gpu_indexes, gpu_count);
delete reduce_signs_lut;

cuda_drop_async(tmp_packed_left, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_packed_right, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_packed, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_signs_a, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_signs_b, streams[0], gpu_indexes[0]);
}
Expand Down
5 changes: 5 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
case GE:
case LT:
case LE:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (comparisons): the number of radix blocks has to be "
"even.")
host_integer_radix_difference_check_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand All @@ -68,6 +71,8 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
break;
case MAX:
case MIN:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (max/min): the number of radix blocks has to be even.")
host_integer_radix_maxmin_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand Down
18 changes: 9 additions & 9 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ __host__ void host_integer_radix_difference_check_kb(
if (carry_modulus >= message_modulus) {
// Packing is possible
// Pack inputs
Torus *packed_left = diff_buffer->tmp_packed_left;
Torus *packed_right = diff_buffer->tmp_packed_right;
Torus *packed_left = diff_buffer->tmp_packed;
Torus *packed_right =
diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size;
// In case the ciphertext is signed, the sign block and the one before it
// are handled separately
if (mem_ptr->is_signed) {
Expand All @@ -510,10 +511,7 @@ __host__ void host_integer_radix_difference_check_kb(
auto identity_lut = mem_ptr->identity_lut;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, packed_left, packed_left, bsks, ksks,
packed_num_radix_blocks, identity_lut);
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, packed_right, packed_right, bsks, ksks,
packed_num_radix_blocks, identity_lut);
2 * packed_num_radix_blocks, identity_lut);

lhs = packed_left;
rhs = packed_right;
Expand Down Expand Up @@ -542,11 +540,13 @@ __host__ void host_integer_radix_difference_check_kb(

// Compare the last block before the sign block separately
auto identity_lut = mem_ptr->identity_lut;
Torus *packed_left = diff_buffer->tmp_packed;
Torus *packed_right =
diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size;
Torus *last_left_block_before_sign_block =
diff_buffer->tmp_packed_left + packed_num_radix_blocks * big_lwe_size;
packed_left + packed_num_radix_blocks * big_lwe_size;
Torus *last_right_block_before_sign_block =
diff_buffer->tmp_packed_right +
packed_num_radix_blocks * big_lwe_size;
packed_right + packed_num_radix_blocks * big_lwe_size;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, last_left_block_before_sign_block,
lwe_array_left + (num_radix_blocks - 2) * big_lwe_size, bsks, ksks, 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
case GE:
case LT:
case LE:
if (lwe_ciphertext_count % 2 != 0)
PANIC("Cuda error (scalar comparisons): the number of radix blocks has "
"to be even.")
host_integer_radix_scalar_difference_check_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand All @@ -32,6 +35,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
break;
case MAX:
case MIN:
if (lwe_ciphertext_count % 2 != 0)
PANIC("Cuda error (scalar max/min): the number of radix blocks has to be "
"even.")
host_integer_radix_scalar_maxmin_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand Down
20 changes: 12 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(

//////////////
// lsb
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -210,8 +211,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
uint32_t num_scalar_blocks = total_num_scalar_blocks;

Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -358,8 +360,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(

//////////////
// lsb
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -458,8 +461,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_sign_out =
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks - 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,8 @@ define_gpu_scalar_comparison_test_functions!(le, U256);
define_gpu_scalar_comparison_test_functions!(gt, U256);
define_gpu_scalar_comparison_test_functions!(ge, U256);

create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge);
create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
});

0 comments on commit 86f0704

Please sign in to comment.