diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 0917588ee0..399567acee 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -3279,8 +3279,7 @@ template struct int_comparison_diff_buffer { int_radix_params params; COMPARISON_TYPE op; - Torus *tmp_packed_left; - Torus *tmp_packed_right; + Torus *tmp_packed; std::function operator_f; @@ -3317,11 +3316,8 @@ template struct int_comparison_diff_buffer { Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus); - tmp_packed_left = (Torus *)cuda_malloc_async( - big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]); - - tmp_packed_right = (Torus *)cuda_malloc_async( - big_size * (num_radix_blocks / 2), streams[0], gpu_indexes[0]); + tmp_packed = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, + streams[0], gpu_indexes[0]); tree_buffer = new int_tree_sign_reduction_buffer( streams, gpu_indexes, gpu_count, operator_f, params, num_radix_blocks, @@ -3344,8 +3340,7 @@ template struct int_comparison_diff_buffer { reduce_signs_lut->release(streams, gpu_indexes, gpu_count); delete reduce_signs_lut; - cuda_drop_async(tmp_packed_left, streams[0], gpu_indexes[0]); - cuda_drop_async(tmp_packed_right, streams[0], gpu_indexes[0]); + cuda_drop_async(tmp_packed, streams[0], gpu_indexes[0]); cuda_drop_async(tmp_signs_a, streams[0], gpu_indexes[0]); cuda_drop_async(tmp_signs_b, streams[0], gpu_indexes[0]); } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu index 528fd54bec..3e5c7fb683 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu @@ -58,6 +58,9 @@ void cuda_comparison_integer_radix_ciphertext_kb_64( case GE: case LT: case LE: + if (num_radix_blocks % 2 != 0) + PANIC("Cuda error (comparisons): the number of radix blocks has to be " + "even.") host_integer_radix_difference_check_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array_out), @@ -68,6 +71,8 @@ void cuda_comparison_integer_radix_ciphertext_kb_64( break; case MAX: case MIN: + if (num_radix_blocks % 2 != 0) + PANIC("Cuda error (max/min): the number of radix blocks has to be even.") host_integer_radix_maxmin_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array_out), diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index 80205eeff3..2535aa26d1 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -490,8 +490,9 @@ __host__ void host_integer_radix_difference_check_kb( if (carry_modulus >= message_modulus) { // Packing is possible // Pack inputs - Torus *packed_left = diff_buffer->tmp_packed_left; - Torus *packed_right = diff_buffer->tmp_packed_right; + Torus *packed_left = diff_buffer->tmp_packed; + Torus *packed_right = + diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size; // In case the ciphertext is signed, the sign block and the one before it // are handled separately if (mem_ptr->is_signed) { @@ -510,10 +511,7 @@ __host__ void host_integer_radix_difference_check_kb( auto identity_lut = mem_ptr->identity_lut; integer_radix_apply_univariate_lookup_table_kb( streams, gpu_indexes, gpu_count, packed_left, packed_left, bsks, ksks, - packed_num_radix_blocks, identity_lut); - integer_radix_apply_univariate_lookup_table_kb( - streams, gpu_indexes, gpu_count, packed_right, packed_right, bsks, ksks, - packed_num_radix_blocks, identity_lut); + 2 * packed_num_radix_blocks, identity_lut); lhs = packed_left; rhs = packed_right; @@ -542,11 +540,13 @@ __host__ void host_integer_radix_difference_check_kb( // Compare the last block before the sign block separately auto identity_lut = mem_ptr->identity_lut; + Torus *packed_left = diff_buffer->tmp_packed; + Torus *packed_right = + diff_buffer->tmp_packed + num_radix_blocks / 2 * big_lwe_size; Torus *last_left_block_before_sign_block = - diff_buffer->tmp_packed_left + packed_num_radix_blocks * big_lwe_size; + packed_left + packed_num_radix_blocks * big_lwe_size; Torus *last_right_block_before_sign_block = - diff_buffer->tmp_packed_right + - packed_num_radix_blocks * big_lwe_size; + packed_right + packed_num_radix_blocks * big_lwe_size; integer_radix_apply_univariate_lookup_table_kb( streams, gpu_indexes, gpu_count, last_left_block_before_sign_block, lwe_array_left + (num_radix_blocks - 2) * big_lwe_size, bsks, ksks, 1, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu index 8293ce9c1f..3417754dca 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu @@ -22,6 +22,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64( case GE: case LT: case LE: + if (lwe_ciphertext_count % 2 != 0) + PANIC("Cuda error (scalar comparisons): the number of radix blocks has " + "to be even.") host_integer_radix_scalar_difference_check_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array_out), @@ -32,6 +35,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64( break; case MAX: case MIN: + if (lwe_ciphertext_count % 2 != 0) + PANIC("Cuda error (scalar max/min): the number of radix blocks has to be " + "even.") host_integer_radix_scalar_maxmin_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array_out), diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh index 10301b5055..4b79a24cec 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh @@ -141,8 +141,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb( ////////////// // lsb - Torus *lhs = diff_buffer->tmp_packed_left; - Torus *rhs = diff_buffer->tmp_packed_right; + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; pack_blocks(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in, big_lwe_dimension, num_lsb_radix_blocks, @@ -210,8 +211,9 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb( uint32_t num_lsb_radix_blocks = total_num_radix_blocks; uint32_t num_scalar_blocks = total_num_scalar_blocks; - Torus *lhs = diff_buffer->tmp_packed_left; - Torus *rhs = diff_buffer->tmp_packed_right; + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; pack_blocks(streams[0], gpu_indexes[0], lhs, lwe_array_in, big_lwe_dimension, num_lsb_radix_blocks, @@ -358,8 +360,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb( ////////////// // lsb - Torus *lhs = diff_buffer->tmp_packed_left; - Torus *rhs = diff_buffer->tmp_packed_right; + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; pack_blocks(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in, big_lwe_dimension, num_lsb_radix_blocks, @@ -458,8 +461,9 @@ __host__ void integer_radix_signed_scalar_difference_check_kb( auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out; auto lwe_array_sign_out = lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size; - Torus *lhs = diff_buffer->tmp_packed_left; - Torus *rhs = diff_buffer->tmp_packed_right; + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; pack_blocks(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in, big_lwe_dimension, num_lsb_radix_blocks - 1, diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs index 2da563f28e..f16d5a2b39 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs @@ -251,4 +251,8 @@ define_gpu_scalar_comparison_test_functions!(le, U256); define_gpu_scalar_comparison_test_functions!(gt, U256); define_gpu_scalar_comparison_test_functions!(ge, U256); -create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge); +create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge { + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, +});