From 6158d8cdc75b2ebb3e808a3adc5b0f83b784fa65 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Wed, 3 Apr 2024 11:15:05 -0300 Subject: [PATCH] feat(gpu): implement signed scalar ge, gt, le, lt, max, and min --- .../tfhe-cuda-backend/cuda/include/integer.h | 56 ++- .../cuda/src/integer/integer.cuh | 119 +++++- .../cuda/src/integer/scalar_comparison.cuh | 338 +++++++++++++++++- tfhe/src/high_level_api/booleans/encrypt.rs | 5 +- .../integers/unsigned/encrypt.rs | 14 +- .../integers/unsigned/scalar_ops.rs | 9 +- tfhe/src/integer/gpu/mod.rs | 8 +- tfhe/src/integer/gpu/server_key/radix/cmux.rs | 7 +- tfhe/src/integer/gpu/server_key/radix/mod.rs | 16 +- .../gpu/server_key/radix/scalar_comparison.rs | 177 ++++++++- tfhe/src/integer/gpu/server_key/radix/sub.rs | 4 +- .../gpu/server_key/radix/tests_signed/mod.rs | 1 + .../tests_signed/test_scalar_comparison.rs | 2 +- 13 files changed, 694 insertions(+), 62 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 523f1ae266..1bc285c7df 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -36,7 +36,7 @@ enum COMPARISON_TYPE { MAX = 6, MIN = 7, }; -enum IS_RELATIONSHIP { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 }; +enum CMP_ORDERING { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 }; extern "C" { void scratch_cuda_full_propagation_64( @@ -1846,6 +1846,8 @@ template struct int_tree_sign_reduction_buffer { bool allocate_gpu_memory) { this->params = params; + Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus); + block_selector_f = [](Torus msb, Torus lsb) -> Torus { if (msb == IS_EQUAL) // EQUAL return lsb; @@ -1854,13 +1856,8 @@ template struct int_tree_sign_reduction_buffer { }; if (allocate_gpu_memory) { - tmp_x = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) * - num_radix_blocks * sizeof(Torus), - stream); - tmp_y = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) * - num_radix_blocks * sizeof(Torus), - stream); - + tmp_x = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + tmp_y = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); // LUTs tree_inner_leaf_lut = new int_radix_lut( stream, params, 1, num_radix_blocks, allocate_gpu_memory); @@ -1901,6 +1898,10 @@ template struct int_comparison_diff_buffer { int_tree_sign_reduction_buffer *tree_buffer; + Torus *tmp_signs_a; + Torus *tmp_signs_b; + int_radix_lut *reduce_signs_lut; + int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op, int_radix_params params, uint32_t num_radix_blocks, bool allocate_gpu_memory) { @@ -1922,7 +1923,6 @@ template struct int_comparison_diff_buffer { return 42; } }; - if (allocate_gpu_memory) { Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus); @@ -1935,15 +1935,26 @@ template struct int_comparison_diff_buffer { tree_buffer = new int_tree_sign_reduction_buffer( stream, operator_f, params, num_radix_blocks, allocate_gpu_memory); + tmp_signs_a = + (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + tmp_signs_b = + (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + // LUTs + reduce_signs_lut = new int_radix_lut( + stream, params, 1, num_radix_blocks, allocate_gpu_memory); } } void release(cuda_stream_t *stream) { tree_buffer->release(stream); delete tree_buffer; + reduce_signs_lut->release(stream); + delete reduce_signs_lut; cuda_drop_async(tmp_packed_left, stream); cuda_drop_async(tmp_packed_right, stream); + cuda_drop_async(tmp_signs_a, stream); + cuda_drop_async(tmp_signs_b, stream); } }; @@ -1963,6 +1974,7 @@ template struct int_comparison_buffer { Torus *tmp_block_comparisons; Torus *tmp_lwe_array_out; + Torus *tmp_trivial_sign_block; // Scalar EQ / NE Torus *tmp_packed_input; @@ -1975,6 +1987,7 @@ template struct int_comparison_buffer { bool is_signed; // Used for scalar comparisons + int_radix_lut *signed_msb_lut; cuda_stream_t *lsb_stream; cuda_stream_t *msb_stream; @@ -1991,9 +2004,11 @@ template struct int_comparison_buffer { lsb_stream = cuda_create_stream(stream->gpu_index); msb_stream = cuda_create_stream(stream->gpu_index); - tmp_lwe_array_out = (Torus *)cuda_malloc_async( - (params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus), - stream); + // +1 to have space for signed comparison + tmp_lwe_array_out = + (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) * + (num_radix_blocks + 1) * sizeof(Torus), + stream); tmp_packed_input = (Torus *)cuda_malloc_async( (params.big_lwe_dimension + 1) * 2 * num_radix_blocks * sizeof(Torus), @@ -2054,13 +2069,19 @@ template struct int_comparison_buffer { } if (is_signed) { + + tmp_trivial_sign_block = (Torus *)cuda_malloc_async( + (params.big_lwe_dimension + 1) * sizeof(Torus), stream); + signed_lut = new int_radix_lut(stream, params, 1, 1, allocate_gpu_memory); + signed_msb_lut = + new int_radix_lut(stream, params, 1, 1, allocate_gpu_memory); auto message_modulus = (int)params.message_modulus; uint32_t sign_bit_pos = log2(message_modulus) - 1; - std::function signed_lut_f; - signed_lut_f = [sign_bit_pos](Torus x, Torus y) -> Torus { + std::function signed_lut_f = + [sign_bit_pos](Torus x, Torus y) -> Torus { auto x_sign_bit = x >> sign_bit_pos; auto y_sign_bit = y >> sign_bit_pos; @@ -2076,14 +2097,14 @@ template struct int_comparison_buffer { return (Torus)(IS_INFERIOR); else if (x == y) return (Torus)(IS_EQUAL); - else if (x > y) + else return (Torus)(IS_SUPERIOR); } else { if (x < y) return (Torus)(IS_SUPERIOR); else if (x == y) return (Torus)(IS_EQUAL); - else if (x > y) + else return (Torus)(IS_INFERIOR); } PANIC("Cuda error: sign_lut creation failed due to wrong function.") @@ -2126,8 +2147,11 @@ template struct int_comparison_buffer { cuda_drop_async(tmp_packed_input, stream); if (is_signed) { + cuda_drop_async(tmp_trivial_sign_block, stream); signed_lut->release(stream); delete (signed_lut); + signed_msb_lut->release(stream); + delete (signed_msb_lut); } cuda_destroy_stream(lsb_stream); cuda_destroy_stream(msb_stream); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 4fd9923be3..cd5bd8d88b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -587,18 +587,48 @@ __global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in, packed_block[tid] = lsb_block[tid] + factor * msb_block[tid]; } - if (num_radix_blocks % 2 != 0) { + if (num_radix_blocks % 2 == 1) { // We couldn't pack the last block, so we just copy it Torus *lsb_block = lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1); Torus *last_block = - lwe_array_out + (num_radix_blocks / 2) * (lwe_dimension + 1); + lwe_array_out + (num_radix_blocks / 2 ) * (lwe_dimension + 1); last_block[tid] = lsb_block[tid]; } } } + +//template +//__global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in, +// uint32_t lwe_dimension, +// uint32_t num_radix_blocks, uint32_t factor) { +// int tid = threadIdx.x + blockIdx.x * blockDim.x; +// int bid = tid / (lwe_dimension + 1); +//int total_blocks = (num_radix_blocks / 2) + (num_radix_blocks % 2); +// +// if (tid < total_blocks * (lwe_dimension + 1)) { +// +// if (bid < num_radix_blocks / 2) { +// Torus *lsb_block = lwe_array_in + (2 * bid) * (lwe_dimension + 1); +// Torus *msb_block = lsb_block + (lwe_dimension + 1); +// +// Torus *packed_block = lwe_array_out + bid * (lwe_dimension + 1); +// +// packed_block[tid] = lsb_block[tid] + factor * msb_block[tid]; +// }else if (bid == num_radix_blocks / 2) { +// // We can't pack the last block, so we just copy it +// Torus *lsb_block = +// lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1); +// Torus *last_block = +// lwe_array_out + (num_radix_blocks / 2) * (lwe_dimension + 1); +// +// last_block[tid] = lsb_block[tid]; +// } +// } +//} + // Packs the low ciphertext in the message parts of the high ciphertext // and moves the high ciphertext into the carry part. // @@ -684,4 +714,89 @@ __host__ void extract_n_bits(cuda_stream_t *stream, Torus *lwe_array_out, num_radix_blocks * bits_per_block, bit_extract->lut); } +template +__host__ void reduce_signs(cuda_stream_t *stream, Torus *signs_array_out, + Torus *signs_array_in, + int_comparison_buffer *mem_ptr, + std::function sign_handler_f, + void *bsk, Torus *ksk, uint32_t num_sign_blocks) { + + auto diff_buffer = mem_ptr->diff_buffer; + + auto params = mem_ptr->params; + auto big_lwe_dimension = params.big_lwe_dimension; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto message_modulus = params.message_modulus; + auto carry_modulus = params.carry_modulus; + + std::function reduce_two_orderings_function = + [diff_buffer, sign_handler_f](Torus x) -> Torus { + int msb = (x >> 2) & 3; + int lsb = x & 3; + + return diff_buffer->tree_buffer->block_selector_f(msb, lsb); + }; + + auto signs_a = diff_buffer->tmp_signs_a; + auto signs_b = diff_buffer->tmp_signs_b; + + cuda_memcpy_async_gpu_to_gpu( + signs_a, signs_array_in, + (big_lwe_dimension + 1) * num_sign_blocks * sizeof(Torus), stream); + if (num_sign_blocks > 2) { + auto lut = diff_buffer->reduce_signs_lut; + generate_device_accumulator( + stream, lut->lut, glwe_dimension, polynomial_size, message_modulus, + carry_modulus, reduce_two_orderings_function); + + while (num_sign_blocks > 2) { + pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, num_sign_blocks, + 4); + integer_radix_apply_univariate_lookup_table_kb( + stream, signs_a, signs_b, bsk, ksk, num_sign_blocks / 2, lut); + + auto last_block_signs_b = signs_b + (num_sign_blocks / 2) * (big_lwe_dimension+1); + auto last_block_signs_a = signs_a + (num_sign_blocks / 2) * (big_lwe_dimension+1); + if(num_sign_blocks % 2 == 1) + cuda_memcpy_async_gpu_to_gpu(last_block_signs_a, + last_block_signs_b, + (big_lwe_dimension+1) * sizeof(Torus), stream); + + num_sign_blocks = (num_sign_blocks / 2) + (num_sign_blocks % 2); + } + } + + if (num_sign_blocks == 2) { + std::function final_lut_f = + [reduce_two_orderings_function, sign_handler_f](Torus x) -> Torus { + Torus final_sign = reduce_two_orderings_function(x); + return sign_handler_f(final_sign); + }; + + auto lut = diff_buffer->reduce_signs_lut; + generate_device_accumulator(stream, lut->lut, glwe_dimension, + polynomial_size, message_modulus, + carry_modulus, final_lut_f); + + pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, 2, 4); + integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out, + signs_b, bsk, ksk, 1, lut); + + } else { + + std::function final_lut_f = + [mem_ptr, sign_handler_f](Torus x) -> Torus { + return sign_handler_f(x & 3); + }; + + auto lut = mem_ptr->diff_buffer->reduce_signs_lut; + generate_device_accumulator(stream, lut->lut, glwe_dimension, + polynomial_size, message_modulus, + carry_modulus, final_lut_f); + + integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out, + signs_a, bsk, ksk, 1, lut); + } +} #endif // TFHE_RS_INTERNAL_INTEGER_CUH diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh index 8a4b0e1616..72afb0ff2b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh @@ -5,7 +5,7 @@ #include template -__host__ void host_integer_radix_scalar_difference_check_kb( +__host__ void integer_radix_unsigned_scalar_difference_check_kb( cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *scalar_blocks, int_comparison_buffer *mem_ptr, std::function sign_handler_f, void *bsk, Torus *ksk, @@ -184,6 +184,342 @@ __host__ void host_integer_radix_scalar_difference_check_kb( } } +template +__host__ void integer_radix_signed_scalar_difference_check_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, + std::function sign_handler_f, void *bsk, Torus *ksk, + uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) { + + cudaSetDevice(stream->gpu_index); + auto params = mem_ptr->params; + auto big_lwe_dimension = params.big_lwe_dimension; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto message_modulus = params.message_modulus; + auto carry_modulus = params.carry_modulus; + + auto diff_buffer = mem_ptr->diff_buffer; + + size_t big_lwe_size = big_lwe_dimension + 1; + + // Reducing the signs is the bottleneck of the comparison algorithms, + // however if the scalar case there is an improvement: + // + // The idea is to reduce the number of signs block we have to + // reduce. We can do that by splitting the comparison problem in two parts. + // + // - One part where we compute the signs block between the scalar with just + // enough blocks + // from the ciphertext that can represent the scalar value + // + // - The other part is to compare the ciphertext blocks not considered for the + // sign + // computation with zero, and create a single sign block from that. + // + // The smaller the scalar value is compared to the ciphertext num bits + // encrypted, the more the comparisons with zeros we have to do, and the less + // signs block we will have to reduce. + // + // This will create a speedup as comparing a bunch of blocks with 0 + // is faster + if (total_num_scalar_blocks == 0) { + // We only have to compare blocks with zero + // means scalar is zero + Torus *are_all_msb_zeros = mem_ptr->tmp_lwe_array_out; + host_compare_with_zero_equality(stream, are_all_msb_zeros, lwe_array_in, + mem_ptr, bsk, ksk, total_num_radix_blocks, + mem_ptr->is_zero_lut); + Torus *sign_block = + lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; + + auto sign_bit_pos = (int)std::log2(message_modulus) - 1; + + auto scalar_last_leaf_with_respect_to_zero_lut_f = + [sign_handler_f, sign_bit_pos, + message_modulus](Torus sign_block) -> Torus { + int sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; + CMP_ORDERING sign_block_ordering; + if (sign_bit_is_set) { + sign_block_ordering = CMP_ORDERING::IS_INFERIOR; + } else if (sign_block != 0) { + sign_block_ordering = CMP_ORDERING::IS_SUPERIOR; + } else { + sign_block_ordering = CMP_ORDERING::IS_EQUAL; + } + + return sign_block_ordering; + }; + + auto scalar_bivariate_last_leaf_lut_f = + [scalar_last_leaf_with_respect_to_zero_lut_f, sign_handler_f, + mem_ptr](Torus are_all_zeros, Torus sign_block) -> Torus { + // "re-code" are_all_zeros as an ordering value + if (are_all_zeros == 1) { + are_all_zeros = CMP_ORDERING::IS_EQUAL; + } else { + are_all_zeros = CMP_ORDERING::IS_SUPERIOR; + }; + + return mem_ptr->diff_buffer->tree_buffer->block_selector_f( + scalar_last_leaf_with_respect_to_zero_lut_f(sign_block), + are_all_zeros); + }; + + auto lut = mem_ptr->diff_buffer->tree_buffer->tree_last_leaf_scalar_lut; + generate_device_accumulator_bivariate( + stream, lut->lut, glwe_dimension, polynomial_size, message_modulus, + carry_modulus, scalar_bivariate_last_leaf_lut_f); + + integer_radix_apply_bivariate_lookup_table_kb( + stream, lwe_array_out, are_all_msb_zeros, sign_block, bsk, ksk, 1, lut); + + } else if (total_num_scalar_blocks < total_num_radix_blocks) { + // We have to handle both part of the work described above + // And the sign bit is located in the most_significant_blocks + + uint32_t num_lsb_radix_blocks = total_num_scalar_blocks; + uint32_t num_msb_radix_blocks = + total_num_radix_blocks - num_lsb_radix_blocks; + auto msb = lwe_array_in + num_lsb_radix_blocks * big_lwe_size; + + auto lwe_array_lsb_out = mem_ptr->tmp_lwe_array_out; + auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size; + + cuda_synchronize_stream(stream); + auto lsb_stream = mem_ptr->lsb_stream; + auto msb_stream = mem_ptr->msb_stream; + +#pragma omp parallel sections + { + // Both sections may be executed in parallel +#pragma omp section + { + ////////////// + // lsb + Torus *lhs = diff_buffer->tmp_packed_left; + Torus *rhs = diff_buffer->tmp_packed_right; + + pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension, + num_lsb_radix_blocks, message_modulus); + pack_blocks(lsb_stream, rhs, scalar_blocks, 0, total_num_scalar_blocks, + message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + num_lsb_radix_blocks += (total_num_scalar_blocks % 2); + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + + auto comparisons = mem_ptr->tmp_block_comparisons; + scalar_compare_radix_blocks_kb(lsb_stream, comparisons, lhs, rhs, + mem_ptr, bsk, ksk, num_lsb_radix_blocks); + + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + tree_sign_reduction(lsb_stream, lwe_array_lsb_out, comparisons, + mem_ptr->diff_buffer->tree_buffer, + mem_ptr->identity_lut_f, bsk, ksk, + num_lsb_radix_blocks); + } +#pragma omp section + { + ////////////// + // msb + // We remove the last block (which is the sign) + Torus *are_all_msb_zeros = lwe_array_msb_out; + host_compare_with_zero_equality(msb_stream, are_all_msb_zeros, msb, + mem_ptr, bsk, ksk, num_msb_radix_blocks, + mem_ptr->is_zero_lut); + + auto sign_bit_pos = (int)log2(message_modulus) - 1; + + auto lut_f = [mem_ptr, sign_bit_pos](Torus sign_block, + Torus msb_are_zeros) { + bool sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; + CMP_ORDERING sign_block_ordering; + if (sign_bit_is_set) { + sign_block_ordering = CMP_ORDERING::IS_INFERIOR; + } else if (sign_block != 0) { + sign_block_ordering = CMP_ORDERING::IS_SUPERIOR; + } else { + sign_block_ordering = CMP_ORDERING::IS_EQUAL; + } + + CMP_ORDERING msb_ordering; + if (msb_are_zeros == 1) + msb_ordering = CMP_ORDERING::IS_EQUAL; + else + msb_ordering = CMP_ORDERING::IS_SUPERIOR; + + return mem_ptr->diff_buffer->tree_buffer->block_selector_f( + sign_block_ordering, msb_ordering); + }; + + auto signed_msb_lut = mem_ptr->signed_msb_lut; + generate_device_accumulator_bivariate( + msb_stream, signed_msb_lut->lut, params.glwe_dimension, + params.polynomial_size, params.message_modulus, + params.carry_modulus, lut_f); + + Torus *sign_block = msb + (num_msb_radix_blocks - 1) * big_lwe_size; + integer_radix_apply_bivariate_lookup_table_kb( + msb_stream, lwe_array_msb_out, sign_block, are_all_msb_zeros, bsk, + ksk, 1, signed_msb_lut); + } + } + cuda_synchronize_stream(lsb_stream); + cuda_synchronize_stream(msb_stream); + + ////////////// + // Reduce the two blocks into one final + reduce_signs(stream, lwe_array_out, lwe_array_lsb_out, mem_ptr, + sign_handler_f, bsk, ksk, 2); + + } else { + // We only have to do the regular comparison + // And not the part where we compare most significant blocks with zeros + // total_num_radix_blocks == total_num_scalar_blocks + uint32_t num_lsb_radix_blocks = total_num_radix_blocks; + + cuda_synchronize_stream(stream); + auto lsb_stream = mem_ptr->lsb_stream; + auto msb_stream = mem_ptr->msb_stream; + + auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out; + auto lwe_array_sign_out = + lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size; + #pragma omp parallel sections + { + // Both sections may be executed in parallel + #pragma omp section + { + Torus *lhs = diff_buffer->tmp_packed_left; + Torus *rhs = diff_buffer->tmp_packed_right; + + pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension, + num_lsb_radix_blocks - 1, message_modulus); + pack_blocks(lsb_stream, rhs, scalar_blocks, 0, num_lsb_radix_blocks - 1, + message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + scalar_compare_radix_blocks_kb(lsb_stream, lwe_array_ct_out, lhs, rhs, + mem_ptr, bsk, ksk, num_lsb_radix_blocks); + } + #pragma omp section + { + Torus *encrypted_sign_block = + lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; + Torus *scalar_sign_block = scalar_blocks + (total_num_scalar_blocks - 1); + + auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block; + create_trivial_radix(msb_stream, trivial_sign_block, scalar_sign_block, + big_lwe_dimension, 1, 1, message_modulus, + carry_modulus); + + integer_radix_apply_bivariate_lookup_table_kb( + msb_stream, lwe_array_sign_out, encrypted_sign_block, trivial_sign_block, + bsk, ksk, 1, mem_ptr->signed_lut); + + } + } + cuda_synchronize_stream(lsb_stream); + cuda_synchronize_stream(msb_stream); + + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + reduce_signs(stream, lwe_array_out, lwe_array_ct_out, mem_ptr, + sign_handler_f, bsk, ksk, num_lsb_radix_blocks + 1); + } +} + +template +__host__ void integer_radix_signed_scalar_maxmin_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, void *bsk, + Torus *ksk, uint32_t total_num_radix_blocks, + uint32_t total_num_scalar_blocks) { + + cudaSetDevice(stream->gpu_index); + auto params = mem_ptr->params; + // Calculates the difference sign between the ciphertext and the scalar + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + auto sign = mem_ptr->tmp_lwe_array_out; + integer_radix_signed_scalar_difference_check_kb( + stream, sign, lwe_array_in, scalar_blocks, mem_ptr, + mem_ptr->identity_lut_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + + // There is no optimized CMUX for scalars, so we convert to a trivial + // ciphertext + auto lwe_array_left = lwe_array_in; + auto lwe_array_right = mem_ptr->tmp_block_comparisons; + + create_trivial_radix(stream, lwe_array_right, scalar_blocks, + params.big_lwe_dimension, total_num_radix_blocks, + total_num_scalar_blocks, params.message_modulus, + params.carry_modulus); + + // Selector + // CMUX for Max or Min + host_integer_radix_cmux_kb(stream, lwe_array_out, sign, lwe_array_left, + lwe_array_right, mem_ptr->cmux_buffer, bsk, ksk, + total_num_radix_blocks); +} + +template +__host__ void host_integer_radix_scalar_difference_check_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, + std::function sign_handler_f, void *bsk, Torus *ksk, + uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) { + + if (mem_ptr->is_signed) { + // is signed and scalar is positive + integer_radix_signed_scalar_difference_check_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, + sign_handler_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + } else { + integer_radix_unsigned_scalar_difference_check_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, + sign_handler_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + } +} + +template +__host__ void host_integer_radix_signed_scalar_maxmin_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, void *bsk, + Torus *ksk, uint32_t total_num_radix_blocks, + uint32_t total_num_scalar_blocks) { + + if (mem_ptr->is_signed) { + // is signed and scalar is positive + integer_radix_signed_scalar_maxmin_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk, + total_num_radix_blocks, total_num_scalar_blocks); + } else { + integer_radix_unsigned_scalar_maxmin_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk, + total_num_radix_blocks, total_num_scalar_blocks); + } +} + template __host__ void scalar_compare_radix_blocks_kb(cuda_stream_t *stream, Torus *lwe_array_out, diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index 6781cee58b..3dc9c6e422 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -107,10 +107,7 @@ impl FheTryTrivialEncrypt for FheBool { let inner = cuda_key .key .create_trivial_radix(u64::from(value), 1, stream); - InnerBoolean::Cuda(CudaBooleanBlock::new( - inner.ciphertext.d_blocks, - inner.ciphertext.info, - )) + InnerBoolean::Cuda(CudaBooleanBlock::new(inner.d_blocks, inner.info)) }), }); Ok(Self::new(ciphertext)) diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index 83b69f7347..ccfcb35642 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -5,6 +5,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt}; use crate::{ClientKey, CompactPublicKey, CompressedPublicKey, FheUint, PublicKey}; @@ -133,11 +134,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { - let inner = cuda_key.key.create_trivial_radix( - value, - Id::num_blocks(cuda_key.key.message_modulus), - stream, - ); + let inner: CudaUnsignedRadixCiphertext = + ::from( + cuda_key.key.create_trivial_radix( + value, + Id::num_blocks(cuda_key.key.message_modulus), + stream, + ), + ); Ok(Self::new(inner)) }), }) diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 2f7b45a473..c6fc323c1a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -17,6 +17,7 @@ use crate::high_level_api::traits::{ }; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::IntegerCiphertext; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::integer::U256; use crate::FheBool; use std::ops::{ @@ -61,7 +62,7 @@ where let inner_result = cuda_key .key .scalar_eq(&*self.ciphertext.on_gpu(), rhs, stream); - FheBool::new(inner_result.to_cuda_unsigned_radix_ciphertext()) + FheBool::new(inner_result) }), }) } @@ -98,7 +99,7 @@ where let inner_result = cuda_key .key .scalar_ne(&*self.ciphertext.on_gpu(), rhs, stream); - FheBool::new(inner_result.to_cuda_unsigned_radix_ciphertext()) + FheBool::new(inner_result) }), }) } @@ -1032,7 +1033,9 @@ generic_integer_impl_scalar_left_operation!( #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_stream(|stream| { - let mut result = cuda_key.key.create_trivial_radix(lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream); + let mut result = ::from(cuda_key.key.create_trivial_radix(lhs, rhs + .ciphertext.on_gpu().ciphertext.info.blocks.len(), stream)); cuda_key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), stream); RadixCiphertext::Cuda(result) }) diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 20726c06df..758c34a976 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -1136,7 +1136,7 @@ impl CudaStream { num_blocks: u32, num_scalar_blocks: u32, op: ComparisonType, - is_signed: bool, + signed_with_positive_scalar: bool, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); scratch_cuda_integer_radix_comparison_kb_64( @@ -1156,7 +1156,7 @@ impl CudaStream { carry_modulus.0 as u32, PBSType::Classical as u32, op as u32, - is_signed, + signed_with_positive_scalar, true, ); @@ -1203,7 +1203,7 @@ impl CudaStream { num_blocks: u32, num_scalar_blocks: u32, op: ComparisonType, - is_signed: bool, + signed_with_positive_scalar: bool, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); scratch_cuda_integer_radix_comparison_kb_64( @@ -1223,7 +1223,7 @@ impl CudaStream { carry_modulus.0 as u32, PBSType::MultiBit as u32, op as u32, - is_signed, + signed_with_positive_scalar, true, ); cuda_scalar_comparison_integer_radix_ciphertext_kb_64( diff --git a/tfhe/src/integer/gpu/server_key/radix/cmux.rs b/tfhe/src/integer/gpu/server_key/radix/cmux.rs index de7e88d6f2..597e2ea375 100644 --- a/tfhe/src/integer/gpu/server_key/radix/cmux.rs +++ b/tfhe/src/integer/gpu/server_key/radix/cmux.rs @@ -17,8 +17,11 @@ impl CudaServerKey { stream: &CudaStream, ) -> T { let lwe_ciphertext_count = true_ct.as_ref().d_blocks.lwe_ciphertext_count(); - let mut result: T = self - .create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream); + let mut result = + T::from(self.create_trivial_zero_radix( + true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, + stream, + )); match &self.bootstrapping_key { CudaBootstrappingKey::Classic(d_bsk) => { diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 645a169f97..c20c72ea45 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -64,12 +64,12 @@ impl CudaServerKey { /// let dec: u64 = cks.decrypt(&ctxt); /// assert_eq!(0, dec); /// ``` - pub fn create_trivial_zero_radix( + pub fn create_trivial_zero_radix( &self, num_blocks: usize, stream: &CudaStream, - ) -> T { - T::from(self.create_trivial_radix(0, num_blocks, stream).ciphertext) + ) -> CudaRadixCiphertext { + self.create_trivial_radix(0, num_blocks, stream) } /// Create a trivial ciphertext on the GPU @@ -103,7 +103,7 @@ impl CudaServerKey { scalar: Scalar, num_blocks: usize, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> CudaRadixCiphertext where Scalar: DecomposableInto, { @@ -138,11 +138,9 @@ impl CudaServerKey { let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream); - CudaUnsignedRadixCiphertext { - ciphertext: CudaRadixCiphertext { - d_blocks, - info: CudaRadixCiphertextInfo { blocks: info }, - }, + CudaRadixCiphertext { + d_blocks, + info: CudaRadixCiphertextInfo { blocks: info }, } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index 079b2ee96e..890af0f918 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -12,15 +12,104 @@ use crate::integer::server_key::comparator::Comparator; use crate::shortint::ciphertext::Degree; impl CudaServerKey { + /// Returns whether the clear scalar is outside of the + /// value range the ciphertext can hold. + /// + /// - Returns None if the scalar is in the range of values that the ciphertext can represent + /// + /// - Returns Some(ordering) when the scalar is out of representable range of the ciphertext. + /// - Equal will never be returned + /// - Less means the scalar is less than the min value representable by the ciphertext + /// - Greater means the scalar is greater that the max value representable by the ciphertext + pub(crate) fn is_scalar_out_of_bounds( + &self, + ct: &T, + scalar: Scalar, + ) -> Option + where + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, + { + let scalar_blocks = + BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2()) + .iter_as::() + .collect::>(); + + let ct_len = ct.as_ref().d_blocks.lwe_ciphertext_count(); + + if T::IS_SIGNED { + let sign_bit_pos = self.message_modulus.0.ilog2() - 1; + let sign_bit_is_set = scalar_blocks + .get(ct_len.0 - 1) + .map_or(false, |block| (block >> sign_bit_pos) == 1); + + if scalar > Scalar::ZERO + && (scalar_blocks.len() > ct_len.0 + || (scalar_blocks.len() == ct_len.0 && sign_bit_is_set)) + { + // If scalar is positive and that any bits above the ct's n-1 bits is set + // it means scalar is bigger. + // + // This is checked in two step + // - If there a more scalar blocks than ct blocks then ct is trivially bigger + // - If there are the same number of blocks but the "sign bit" / msb of st scalar is + // set then, the scalar is trivially bigger + return Some(std::cmp::Ordering::Greater); + } else if scalar < Scalar::ZERO { + // If scalar is negative, and that any bits above the ct's n-1 bits is not set + // it means scalar is smaller. + + if ct_len.0 > scalar_blocks.len() { + // Ciphertext has more blocks, the scalar may be in range + return None; + } + + // (returns false for empty iter) + let at_least_one_block_is_not_full_of_1s = scalar_blocks[ct_len.0..] + .iter() + .any(|&scalar_block| scalar_block != (self.message_modulus.0 as u64 - 1)); + + let sign_bit_pos = self.message_modulus.0.ilog2() - 1; + let sign_bit_is_unset = scalar_blocks + .get(ct_len.0 - 1) + .map_or(false, |block| (block >> sign_bit_pos) == 0); + + if at_least_one_block_is_not_full_of_1s || sign_bit_is_unset { + // Scalar is smaller than lowest value of T + return Some(std::cmp::Ordering::Less); + } + } + } else { + // T is unsigned + if scalar < Scalar::ZERO { + // ct represent an unsigned (always >= 0) + return Some(std::cmp::Ordering::Less); + } else if scalar > Scalar::ZERO { + // scalar is obviously bigger if it has non-zero + // blocks after lhs's last block + let is_scalar_obviously_bigger = + scalar_blocks.get(ct_len.0..).is_some_and(|sub_slice| { + sub_slice.iter().any(|&scalar_block| scalar_block != 0) + }); + if is_scalar_obviously_bigger { + return Some(std::cmp::Ordering::Greater); + } + } + } + + None + } + /// # Safety /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_comparison_async( + pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async( &self, ct: &T, scalar: Scalar, op: ComparisonType, + signed_with_positive_scalar: bool, stream: &CudaStream, ) -> CudaBooleanBlock where @@ -30,7 +119,7 @@ impl CudaServerKey { if scalar < Scalar::ZERO { // ct represents an unsigned (always >= 0) let ct_res = self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream); - return CudaBooleanBlock::new(ct_res.ciphertext.d_blocks, ct_res.ciphertext.info); + return CudaBooleanBlock::new(ct_res.d_blocks, ct_res.info); } let message_modulus = self.message_modulus.0; @@ -47,7 +136,7 @@ impl CudaServerKey { .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); if is_scalar_obviously_bigger { let ct_res = self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream); - return CudaBooleanBlock::new(ct_res.ciphertext.d_blocks, ct_res.ciphertext.info); + return CudaBooleanBlock::new(ct_res.d_blocks, ct_res.info); } // If we are still here, that means scalar_blocks above @@ -97,7 +186,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + signed_with_positive_scalar, ); } CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { @@ -125,7 +214,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + signed_with_positive_scalar, ); } } @@ -133,6 +222,72 @@ impl CudaServerKey { result } + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn unchecked_scalar_comparison_async( + &self, + ct: &T, + scalar: Scalar, + op: ComparisonType, + stream: &CudaStream, + ) -> CudaBooleanBlock + where + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, + { + let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + + if T::IS_SIGNED { + match self.is_scalar_out_of_bounds(ct, scalar) { + Some(std::cmp::Ordering::Greater) => { + // Scalar is greater than the bounds, so ciphertext is smaller + let result = match op { + ComparisonType::LT | ComparisonType::LE => { + self.create_trivial_radix(1, num_blocks, stream) + } + _ => self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + stream, + ), + }; + return CudaBooleanBlock::new(result.d_blocks, result.info); + } + Some(std::cmp::Ordering::Less) => { + // Scalar is smaller than the bounds, so ciphertext is bigger + let result = match op { + ComparisonType::GT | ComparisonType::GE => { + self.create_trivial_radix(1, num_blocks, stream) + } + _ => self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + stream, + ), + }; + return CudaBooleanBlock::new(result.d_blocks, result.info); + } + Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"), + None => { + // scalar is in range, fallthrough + } + } + + if scalar >= Scalar::ZERO { + self.unchecked_signed_and_unsigned_scalar_comparison_async( + ct, scalar, op, true, stream, + ) + } else { + let scalar_as_trivial = + T::from(self.create_trivial_radix(scalar, num_blocks, stream)); + self.unchecked_comparison_async(ct, &scalar_as_trivial, op, stream) + } + } else { + self.unchecked_signed_and_unsigned_scalar_comparison_async(ct, scalar, op, true, stream) + } + } /// # Safety /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must @@ -148,12 +303,9 @@ impl CudaServerKey { T: CudaIntegerRadixCiphertext, Scalar: DecomposableInto, { - if scalar < Scalar::ZERO { + if !T::IS_SIGNED && scalar < Scalar::ZERO { // ct represents an unsigned (always >= 0) - return T::from( - self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream) - .ciphertext, - ); + return T::from(self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream)); } let message_modulus = self.message_modulus.0; @@ -169,10 +321,7 @@ impl CudaServerKey { .get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..) .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); if is_scalar_obviously_bigger { - return T::from( - self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream) - .ciphertext, - ); + return T::from(self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream)); } // If we are still here, that means scalar_blocks above diff --git a/tfhe/src/integer/gpu/server_key/radix/sub.rs b/tfhe/src/integer/gpu/server_key/radix/sub.rs index 9fb57bbc44..98cfd17c1d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/sub.rs @@ -350,7 +350,9 @@ impl CudaServerKey { stream: &CudaStream, ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) { let num_blocks = lhs.as_ref().d_blocks.lwe_ciphertext_count().0 as u32; - let mut tmp: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream); + let mut tmp = ::from( + self.create_trivial_zero_radix(1, stream), + ); if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO && rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO { diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index b162973cde..584706d5eb 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_comparison; pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs index 8c24e76f8c..f78e883a20 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs @@ -31,7 +31,7 @@ macro_rules! define_gpu_signed_scalar_comparison_test_functions { } fn []

(param: P) where P: Into { - let num_tests = 1; + let num_tests = 10; let executor = GpuFunctionExecutor::new(&CudaServerKey::[]); test_signed_default_scalar_function( param,