Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Al/rework difference check #1871

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3279,8 +3279,7 @@ template <typename Torus> struct int_comparison_diff_buffer {
int_radix_params params;
COMPARISON_TYPE op;

Torus *tmp_packed_left;
Torus *tmp_packed_right;
Torus *tmp_packed;

std::function<Torus(Torus)> operator_f;

Expand Down Expand Up @@ -3317,11 +3316,8 @@ template <typename Torus> struct int_comparison_diff_buffer {

Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus);

tmp_packed_left = (Torus *)cuda_malloc_async(
big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]);

tmp_packed_right = (Torus *)cuda_malloc_async(
big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]);
tmp_packed = (Torus *)cuda_malloc_async(big_size * num_radix_blocks,
streams[0], gpu_indexes[0]);

tree_buffer = new int_tree_sign_reduction_buffer<Torus>(
streams, gpu_indexes, gpu_count, operator_f, params, num_radix_blocks,
Expand All @@ -3344,8 +3340,7 @@ template <typename Torus> struct int_comparison_diff_buffer {
reduce_signs_lut->release(streams, gpu_indexes, gpu_count);
delete reduce_signs_lut;

cuda_drop_async(tmp_packed_left, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_packed_right, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_packed, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_signs_a, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_signs_b, streams[0], gpu_indexes[0]);
}
Expand Down
5 changes: 5 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
case GE:
case LT:
case LE:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (comparisons): the number of radix blocks has to be "
"even.")
host_integer_radix_difference_check_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand All @@ -68,6 +71,8 @@ void cuda_comparison_integer_radix_ciphertext_kb_64(
break;
case MAX:
case MIN:
if (num_radix_blocks % 2 != 0)
PANIC("Cuda error (max/min): the number of radix blocks has to be even.")
host_integer_radix_maxmin_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand Down
18 changes: 9 additions & 9 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ __host__ void host_integer_radix_difference_check_kb(
if (carry_modulus >= message_modulus) {
// Packing is possible
// Pack inputs
Torus *packed_left = diff_buffer->tmp_packed_left;
Torus *packed_right = diff_buffer->tmp_packed_right;
Torus *packed_left = diff_buffer->tmp_packed;
Torus *packed_right =
diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size;
// In case the ciphertext is signed, the sign block and the one before it
// are handled separately
if (mem_ptr->is_signed) {
Expand All @@ -510,10 +511,7 @@ __host__ void host_integer_radix_difference_check_kb(
auto identity_lut = mem_ptr->identity_lut;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, packed_left, packed_left, bsks, ksks,
packed_num_radix_blocks, identity_lut);
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, packed_right, packed_right, bsks, ksks,
packed_num_radix_blocks, identity_lut);
2 * packed_num_radix_blocks, identity_lut);

lhs = packed_left;
rhs = packed_right;
Expand Down Expand Up @@ -542,11 +540,13 @@ __host__ void host_integer_radix_difference_check_kb(

// Compare the last block before the sign block separately
auto identity_lut = mem_ptr->identity_lut;
Torus *packed_left = diff_buffer->tmp_packed;
Torus *packed_right =
diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size;
Torus *last_left_block_before_sign_block =
diff_buffer->tmp_packed_left + packed_num_radix_blocks * big_lwe_size;
packed_left + packed_num_radix_blocks * big_lwe_size;
Torus *last_right_block_before_sign_block =
diff_buffer->tmp_packed_right +
packed_num_radix_blocks * big_lwe_size;
packed_right + packed_num_radix_blocks * big_lwe_size;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, last_left_block_before_sign_block,
lwe_array_left + (num_radix_blocks - 2) * big_lwe_size, bsks, ksks, 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
case GE:
case LT:
case LE:
if (lwe_ciphertext_count % 2 != 0)
PANIC("Cuda error (scalar comparisons): the number of radix blocks has "
"to be even.")
host_integer_radix_scalar_difference_check_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand All @@ -32,6 +35,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
break;
case MAX:
case MIN:
if (lwe_ciphertext_count % 2 != 0)
PANIC("Cuda error (scalar max/min): the number of radix blocks has to be "
"even.")
host_integer_radix_scalar_maxmin_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand Down
20 changes: 12 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(

//////////////
// lsb
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -210,8 +211,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
uint32_t num_scalar_blocks = total_num_scalar_blocks;

Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;
agnesLeroy marked this conversation as resolved.
Show resolved Hide resolved

pack_blocks<Torus>(streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -358,8 +360,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(

//////////////
// lsb
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
Expand Down Expand Up @@ -458,8 +461,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_sign_out =
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
Torus *lhs = diff_buffer->tmp_packed_left;
Torus *rhs = diff_buffer->tmp_packed_right;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks - 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,8 @@ define_gpu_scalar_comparison_test_functions!(le, U256);
define_gpu_scalar_comparison_test_functions!(gt, U256);
define_gpu_scalar_comparison_test_functions!(ge, U256);

create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge);
create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
});
Loading