diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 523f1ae266..d636e059aa 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -36,7 +36,7 @@ enum COMPARISON_TYPE { MAX = 6, MIN = 7, }; -enum IS_RELATIONSHIP { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 }; +enum CMP_ORDERING { IS_INFERIOR = 0, IS_EQUAL = 1, IS_SUPERIOR = 2 }; extern "C" { void scratch_cuda_full_propagation_64( @@ -1846,6 +1846,8 @@ template struct int_tree_sign_reduction_buffer { bool allocate_gpu_memory) { this->params = params; + Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus); + block_selector_f = [](Torus msb, Torus lsb) -> Torus { if (msb == IS_EQUAL) // EQUAL return lsb; @@ -1854,13 +1856,8 @@ template struct int_tree_sign_reduction_buffer { }; if (allocate_gpu_memory) { - tmp_x = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) * - num_radix_blocks * sizeof(Torus), - stream); - tmp_y = (Torus *)cuda_malloc_async((params.big_lwe_dimension + 1) * - num_radix_blocks * sizeof(Torus), - stream); - + tmp_x = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + tmp_y = (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); // LUTs tree_inner_leaf_lut = new int_radix_lut( stream, params, 1, num_radix_blocks, allocate_gpu_memory); @@ -1901,6 +1898,10 @@ template struct int_comparison_diff_buffer { int_tree_sign_reduction_buffer *tree_buffer; + Torus *tmp_signs_a; + Torus *tmp_signs_b; + int_radix_lut *reduce_signs_lut; + int_comparison_diff_buffer(cuda_stream_t *stream, COMPARISON_TYPE op, int_radix_params params, uint32_t num_radix_blocks, bool allocate_gpu_memory) { @@ -1922,7 +1923,6 @@ template struct int_comparison_diff_buffer { return 42; } }; - if (allocate_gpu_memory) { Torus big_size = (params.big_lwe_dimension + 1) * sizeof(Torus); @@ -1935,15 +1935,26 @@ template struct int_comparison_diff_buffer { tree_buffer = new int_tree_sign_reduction_buffer( stream, operator_f, params, num_radix_blocks, allocate_gpu_memory); + tmp_signs_a = + (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + tmp_signs_b = + (Torus *)cuda_malloc_async(big_size * num_radix_blocks, stream); + // LUTs + reduce_signs_lut = new int_radix_lut( + stream, params, 1, num_radix_blocks, allocate_gpu_memory); } } void release(cuda_stream_t *stream) { tree_buffer->release(stream); delete tree_buffer; + reduce_signs_lut->release(stream); + delete reduce_signs_lut; cuda_drop_async(tmp_packed_left, stream); cuda_drop_async(tmp_packed_right, stream); + cuda_drop_async(tmp_signs_a, stream); + cuda_drop_async(tmp_signs_b, stream); } }; @@ -1963,6 +1974,7 @@ template struct int_comparison_buffer { Torus *tmp_block_comparisons; Torus *tmp_lwe_array_out; + Torus *tmp_trivial_sign_block; // Scalar EQ / NE Torus *tmp_packed_input; @@ -1975,6 +1987,7 @@ template struct int_comparison_buffer { bool is_signed; // Used for scalar comparisons + int_radix_lut *signed_msb_lut; cuda_stream_t *lsb_stream; cuda_stream_t *msb_stream; @@ -1987,22 +2000,22 @@ template struct int_comparison_buffer { identity_lut_f = [](Torus x) -> Torus { return x; }; + auto big_lwe_size = params.big_lwe_dimension + 1; + if (allocate_gpu_memory) { lsb_stream = cuda_create_stream(stream->gpu_index); msb_stream = cuda_create_stream(stream->gpu_index); + // +1 to have space for signed comparison tmp_lwe_array_out = (Torus *)cuda_malloc_async( - (params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus), - stream); + big_lwe_size * (num_radix_blocks + 1) * sizeof(Torus), stream); tmp_packed_input = (Torus *)cuda_malloc_async( - (params.big_lwe_dimension + 1) * 2 * num_radix_blocks * sizeof(Torus), - stream); + big_lwe_size * 2 * num_radix_blocks * sizeof(Torus), stream); // Block comparisons tmp_block_comparisons = (Torus *)cuda_malloc_async( - (params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus), - stream); + big_lwe_size * num_radix_blocks * sizeof(Torus), stream); // Cleaning LUT identity_lut = new int_radix_lut( @@ -2054,13 +2067,19 @@ template struct int_comparison_buffer { } if (is_signed) { + + tmp_trivial_sign_block = + (Torus *)cuda_malloc_async(big_lwe_size * sizeof(Torus), stream); + signed_lut = new int_radix_lut(stream, params, 1, 1, allocate_gpu_memory); + signed_msb_lut = + new int_radix_lut(stream, params, 1, 1, allocate_gpu_memory); auto message_modulus = (int)params.message_modulus; uint32_t sign_bit_pos = log2(message_modulus) - 1; - std::function signed_lut_f; - signed_lut_f = [sign_bit_pos](Torus x, Torus y) -> Torus { + std::function signed_lut_f = + [sign_bit_pos](Torus x, Torus y) -> Torus { auto x_sign_bit = x >> sign_bit_pos; auto y_sign_bit = y >> sign_bit_pos; @@ -2076,14 +2095,14 @@ template struct int_comparison_buffer { return (Torus)(IS_INFERIOR); else if (x == y) return (Torus)(IS_EQUAL); - else if (x > y) + else return (Torus)(IS_SUPERIOR); } else { if (x < y) return (Torus)(IS_SUPERIOR); else if (x == y) return (Torus)(IS_EQUAL); - else if (x > y) + else return (Torus)(IS_INFERIOR); } PANIC("Cuda error: sign_lut creation failed due to wrong function.") @@ -2126,8 +2145,11 @@ template struct int_comparison_buffer { cuda_drop_async(tmp_packed_input, stream); if (is_signed) { + cuda_drop_async(tmp_trivial_sign_block, stream); signed_lut->release(stream); delete (signed_lut); + signed_msb_lut->release(stream); + delete (signed_msb_lut); } cuda_destroy_stream(lsb_stream); cuda_destroy_stream(msb_stream); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index 5d4a294519..4f89eb8a60 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -273,7 +273,7 @@ __host__ void host_compare_with_zero_equality( remainder_blocks -= (chunk_size - 1); // Update operands - chunk += chunk_size * big_lwe_size; + chunk += (chunk_size - 1) * big_lwe_size; sum_i += big_lwe_size; } } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 4fd9923be3..b96e8f6e61 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -587,7 +587,7 @@ __global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in, packed_block[tid] = lsb_block[tid] + factor * msb_block[tid]; } - if (num_radix_blocks % 2 != 0) { + if (num_radix_blocks % 2 == 1) { // We couldn't pack the last block, so we just copy it Torus *lsb_block = lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1); @@ -599,6 +599,36 @@ __global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in, } } +// template +//__global__ void device_pack_blocks(Torus *lwe_array_out, Torus *lwe_array_in, +// uint32_t lwe_dimension, +// uint32_t num_radix_blocks, uint32_t +// factor) { +// int tid = threadIdx.x + blockIdx.x * blockDim.x; +// int bid = tid / (lwe_dimension + 1); +// int total_blocks = (num_radix_blocks / 2) + (num_radix_blocks % 2); +// +// if (tid < total_blocks * (lwe_dimension + 1)) { +// +// if (bid < num_radix_blocks / 2) { +// Torus *lsb_block = lwe_array_in + (2 * bid) * (lwe_dimension + 1); +// Torus *msb_block = lsb_block + (lwe_dimension + 1); +// +// Torus *packed_block = lwe_array_out + bid * (lwe_dimension + 1); +// +// packed_block[tid] = lsb_block[tid] + factor * msb_block[tid]; +// }else if (bid == num_radix_blocks / 2) { +// // We can't pack the last block, so we just copy it +// Torus *lsb_block = +// lwe_array_in + (num_radix_blocks - 1) * (lwe_dimension + 1); +// Torus *last_block = +// lwe_array_out + (num_radix_blocks / 2) * (lwe_dimension + 1); +// +// last_block[tid] = lsb_block[tid]; +// } +// } +// } + // Packs the low ciphertext in the message parts of the high ciphertext // and moves the high ciphertext into the carry part. // @@ -684,4 +714,91 @@ __host__ void extract_n_bits(cuda_stream_t *stream, Torus *lwe_array_out, num_radix_blocks * bits_per_block, bit_extract->lut); } +template +__host__ void reduce_signs(cuda_stream_t *stream, Torus *signs_array_out, + Torus *signs_array_in, + int_comparison_buffer *mem_ptr, + std::function sign_handler_f, + void *bsk, Torus *ksk, uint32_t num_sign_blocks) { + + auto diff_buffer = mem_ptr->diff_buffer; + + auto params = mem_ptr->params; + auto big_lwe_dimension = params.big_lwe_dimension; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto message_modulus = params.message_modulus; + auto carry_modulus = params.carry_modulus; + + std::function reduce_two_orderings_function = + [diff_buffer, sign_handler_f](Torus x) -> Torus { + int msb = (x >> 2) & 3; + int lsb = x & 3; + + return diff_buffer->tree_buffer->block_selector_f(msb, lsb); + }; + + auto signs_a = diff_buffer->tmp_signs_a; + auto signs_b = diff_buffer->tmp_signs_b; + + cuda_memcpy_async_gpu_to_gpu( + signs_a, signs_array_in, + (big_lwe_dimension + 1) * num_sign_blocks * sizeof(Torus), stream); + if (num_sign_blocks > 2) { + auto lut = diff_buffer->reduce_signs_lut; + generate_device_accumulator( + stream, lut->lut, glwe_dimension, polynomial_size, message_modulus, + carry_modulus, reduce_two_orderings_function); + + while (num_sign_blocks > 2) { + pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, num_sign_blocks, + 4); + integer_radix_apply_univariate_lookup_table_kb( + stream, signs_a, signs_b, bsk, ksk, num_sign_blocks / 2, lut); + + auto last_block_signs_b = + signs_b + (num_sign_blocks / 2) * (big_lwe_dimension + 1); + auto last_block_signs_a = + signs_a + (num_sign_blocks / 2) * (big_lwe_dimension + 1); + if (num_sign_blocks % 2 == 1) + cuda_memcpy_async_gpu_to_gpu(last_block_signs_a, last_block_signs_b, + (big_lwe_dimension + 1) * sizeof(Torus), + stream); + + num_sign_blocks = (num_sign_blocks / 2) + (num_sign_blocks % 2); + } + } + + if (num_sign_blocks == 2) { + std::function final_lut_f = + [reduce_two_orderings_function, sign_handler_f](Torus x) -> Torus { + Torus final_sign = reduce_two_orderings_function(x); + return sign_handler_f(final_sign); + }; + + auto lut = diff_buffer->reduce_signs_lut; + generate_device_accumulator(stream, lut->lut, glwe_dimension, + polynomial_size, message_modulus, + carry_modulus, final_lut_f); + + pack_blocks(stream, signs_b, signs_a, big_lwe_dimension, 2, 4); + integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out, + signs_b, bsk, ksk, 1, lut); + + } else { + + std::function final_lut_f = + [mem_ptr, sign_handler_f](Torus x) -> Torus { + return sign_handler_f(x & 3); + }; + + auto lut = mem_ptr->diff_buffer->reduce_signs_lut; + generate_device_accumulator(stream, lut->lut, glwe_dimension, + polynomial_size, message_modulus, + carry_modulus, final_lut_f); + + integer_radix_apply_univariate_lookup_table_kb(stream, signs_array_out, + signs_a, bsk, ksk, 1, lut); + } +} #endif // TFHE_RS_INTERNAL_INTEGER_CUH diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh index 8a4b0e1616..409493a626 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh @@ -5,7 +5,7 @@ #include template -__host__ void host_integer_radix_scalar_difference_check_kb( +__host__ void integer_radix_unsigned_scalar_difference_check_kb( cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *scalar_blocks, int_comparison_buffer *mem_ptr, std::function sign_handler_f, void *bsk, Torus *ksk, @@ -184,6 +184,344 @@ __host__ void host_integer_radix_scalar_difference_check_kb( } } +template +__host__ void integer_radix_signed_scalar_difference_check_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, + std::function sign_handler_f, void *bsk, Torus *ksk, + uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) { + + cudaSetDevice(stream->gpu_index); + auto params = mem_ptr->params; + auto big_lwe_dimension = params.big_lwe_dimension; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto message_modulus = params.message_modulus; + auto carry_modulus = params.carry_modulus; + + auto diff_buffer = mem_ptr->diff_buffer; + + size_t big_lwe_size = big_lwe_dimension + 1; + + // Reducing the signs is the bottleneck of the comparison algorithms, + // however if the scalar case there is an improvement: + // + // The idea is to reduce the number of signs block we have to + // reduce. We can do that by splitting the comparison problem in two parts. + // + // - One part where we compute the signs block between the scalar with just + // enough blocks + // from the ciphertext that can represent the scalar value + // + // - The other part is to compare the ciphertext blocks not considered for the + // sign + // computation with zero, and create a single sign block from that. + // + // The smaller the scalar value is compared to the ciphertext num bits + // encrypted, the more the comparisons with zeros we have to do, and the less + // signs block we will have to reduce. + // + // This will create a speedup as comparing a bunch of blocks with 0 + // is faster + if (total_num_scalar_blocks == 0) { + // We only have to compare blocks with zero + // means scalar is zero + Torus *are_all_msb_zeros = mem_ptr->tmp_lwe_array_out; + host_compare_with_zero_equality(stream, are_all_msb_zeros, lwe_array_in, + mem_ptr, bsk, ksk, total_num_radix_blocks, + mem_ptr->is_zero_lut); + Torus *sign_block = + lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; + + auto sign_bit_pos = (int)std::log2(message_modulus) - 1; + + auto scalar_last_leaf_with_respect_to_zero_lut_f = + [sign_handler_f, sign_bit_pos, + message_modulus](Torus sign_block) -> Torus { + sign_block %= message_modulus; + int sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; + CMP_ORDERING sign_block_ordering; + if (sign_bit_is_set) { + sign_block_ordering = CMP_ORDERING::IS_INFERIOR; + } else if (sign_block != 0) { + sign_block_ordering = CMP_ORDERING::IS_SUPERIOR; + } else { + sign_block_ordering = CMP_ORDERING::IS_EQUAL; + } + + return sign_block_ordering; + }; + + auto block_selector_f = mem_ptr->diff_buffer->tree_buffer->block_selector_f; + auto scalar_bivariate_last_leaf_lut_f = + [scalar_last_leaf_with_respect_to_zero_lut_f, sign_handler_f, + block_selector_f](Torus are_all_zeros, Torus sign_block) -> Torus { + // "re-code" are_all_zeros as an ordering value + if (are_all_zeros == 1) { + are_all_zeros = CMP_ORDERING::IS_EQUAL; + } else { + are_all_zeros = CMP_ORDERING::IS_SUPERIOR; + }; + + return sign_handler_f(block_selector_f( + scalar_last_leaf_with_respect_to_zero_lut_f(sign_block), + are_all_zeros)); + }; + + auto lut = mem_ptr->diff_buffer->tree_buffer->tree_last_leaf_scalar_lut; + generate_device_accumulator_bivariate( + stream, lut->lut, glwe_dimension, polynomial_size, message_modulus, + carry_modulus, scalar_bivariate_last_leaf_lut_f); + + integer_radix_apply_bivariate_lookup_table_kb( + stream, lwe_array_out, are_all_msb_zeros, sign_block, bsk, ksk, 1, lut); + + } else if (total_num_scalar_blocks < total_num_radix_blocks) { + // We have to handle both part of the work described above + // And the sign bit is located in the most_significant_blocks + + uint32_t num_lsb_radix_blocks = total_num_scalar_blocks; + uint32_t num_msb_radix_blocks = + total_num_radix_blocks - num_lsb_radix_blocks; + auto msb = lwe_array_in + num_lsb_radix_blocks * big_lwe_size; + + auto lwe_array_lsb_out = mem_ptr->tmp_lwe_array_out; + auto lwe_array_msb_out = lwe_array_lsb_out + big_lwe_size; + + cuda_synchronize_stream(stream); + auto lsb_stream = mem_ptr->lsb_stream; + auto msb_stream = mem_ptr->msb_stream; + +#pragma omp parallel sections + { + // Both sections may be executed in parallel +#pragma omp section + { + ////////////// + // lsb + Torus *lhs = diff_buffer->tmp_packed_left; + Torus *rhs = diff_buffer->tmp_packed_right; + + pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension, + num_lsb_radix_blocks, message_modulus); + pack_blocks(lsb_stream, rhs, scalar_blocks, 0, total_num_scalar_blocks, + message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + num_lsb_radix_blocks += (total_num_scalar_blocks % 2); + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + + auto comparisons = mem_ptr->tmp_block_comparisons; + scalar_compare_radix_blocks_kb(lsb_stream, comparisons, lhs, rhs, + mem_ptr, bsk, ksk, num_lsb_radix_blocks); + + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + tree_sign_reduction(lsb_stream, lwe_array_lsb_out, comparisons, + mem_ptr->diff_buffer->tree_buffer, + mem_ptr->identity_lut_f, bsk, ksk, + num_lsb_radix_blocks); + } +#pragma omp section + { + ////////////// + // msb + // We remove the last block (which is the sign) + Torus *are_all_msb_zeros = lwe_array_msb_out; + host_compare_with_zero_equality(msb_stream, are_all_msb_zeros, msb, + mem_ptr, bsk, ksk, num_msb_radix_blocks, + mem_ptr->is_zero_lut); + + auto sign_bit_pos = (int)log2(message_modulus) - 1; + + auto lut_f = [mem_ptr, sign_bit_pos](Torus sign_block, + Torus msb_are_zeros) { + bool sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; + CMP_ORDERING sign_block_ordering; + if (sign_bit_is_set) { + sign_block_ordering = CMP_ORDERING::IS_INFERIOR; + } else if (sign_block != 0) { + sign_block_ordering = CMP_ORDERING::IS_SUPERIOR; + } else { + sign_block_ordering = CMP_ORDERING::IS_EQUAL; + } + + CMP_ORDERING msb_ordering; + if (msb_are_zeros == 1) + msb_ordering = CMP_ORDERING::IS_EQUAL; + else + msb_ordering = CMP_ORDERING::IS_SUPERIOR; + + return mem_ptr->diff_buffer->tree_buffer->block_selector_f( + sign_block_ordering, msb_ordering); + }; + + auto signed_msb_lut = mem_ptr->signed_msb_lut; + generate_device_accumulator_bivariate( + msb_stream, signed_msb_lut->lut, params.glwe_dimension, + params.polynomial_size, params.message_modulus, + params.carry_modulus, lut_f); + + Torus *sign_block = msb + (num_msb_radix_blocks - 1) * big_lwe_size; + integer_radix_apply_bivariate_lookup_table_kb( + msb_stream, lwe_array_msb_out, sign_block, are_all_msb_zeros, bsk, + ksk, 1, signed_msb_lut); + } + } + cuda_synchronize_stream(lsb_stream); + cuda_synchronize_stream(msb_stream); + + ////////////// + // Reduce the two blocks into one final + reduce_signs(stream, lwe_array_out, lwe_array_lsb_out, mem_ptr, + sign_handler_f, bsk, ksk, 2); + + } else { + // We only have to do the regular comparison + // And not the part where we compare most significant blocks with zeros + // total_num_radix_blocks == total_num_scalar_blocks + uint32_t num_lsb_radix_blocks = total_num_radix_blocks; + + cuda_synchronize_stream(stream); + auto lsb_stream = mem_ptr->lsb_stream; + auto msb_stream = mem_ptr->msb_stream; + + auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out; + auto lwe_array_sign_out = + lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size; +#pragma omp parallel sections + { + // Both sections may be executed in parallel +#pragma omp section + { + Torus *lhs = diff_buffer->tmp_packed_left; + Torus *rhs = diff_buffer->tmp_packed_right; + + pack_blocks(lsb_stream, lhs, lwe_array_in, big_lwe_dimension, + num_lsb_radix_blocks - 1, message_modulus); + pack_blocks(lsb_stream, rhs, scalar_blocks, 0, num_lsb_radix_blocks - 1, + message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + scalar_compare_radix_blocks_kb(lsb_stream, lwe_array_ct_out, lhs, rhs, + mem_ptr, bsk, ksk, num_lsb_radix_blocks); + } +#pragma omp section + { + Torus *encrypted_sign_block = + lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; + Torus *scalar_sign_block = + scalar_blocks + (total_num_scalar_blocks - 1); + + auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block; + create_trivial_radix(msb_stream, trivial_sign_block, scalar_sign_block, + big_lwe_dimension, 1, 1, message_modulus, + carry_modulus); + + integer_radix_apply_bivariate_lookup_table_kb( + msb_stream, lwe_array_sign_out, encrypted_sign_block, + trivial_sign_block, bsk, ksk, 1, mem_ptr->signed_lut); + } + } + cuda_synchronize_stream(lsb_stream); + cuda_synchronize_stream(msb_stream); + + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + reduce_signs(stream, lwe_array_out, lwe_array_ct_out, mem_ptr, + sign_handler_f, bsk, ksk, num_lsb_radix_blocks + 1); + } +} + +template +__host__ void integer_radix_signed_scalar_maxmin_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, void *bsk, + Torus *ksk, uint32_t total_num_radix_blocks, + uint32_t total_num_scalar_blocks) { + + cudaSetDevice(stream->gpu_index); + auto params = mem_ptr->params; + // Calculates the difference sign between the ciphertext and the scalar + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + auto sign = mem_ptr->tmp_lwe_array_out; + integer_radix_signed_scalar_difference_check_kb( + stream, sign, lwe_array_in, scalar_blocks, mem_ptr, + mem_ptr->identity_lut_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + + // There is no optimized CMUX for scalars, so we convert to a trivial + // ciphertext + auto lwe_array_left = lwe_array_in; + auto lwe_array_right = mem_ptr->tmp_block_comparisons; + + create_trivial_radix(stream, lwe_array_right, scalar_blocks, + params.big_lwe_dimension, total_num_radix_blocks, + total_num_scalar_blocks, params.message_modulus, + params.carry_modulus); + + // Selector + // CMUX for Max or Min + host_integer_radix_cmux_kb(stream, lwe_array_out, sign, lwe_array_left, + lwe_array_right, mem_ptr->cmux_buffer, bsk, ksk, + total_num_radix_blocks); +} + +template +__host__ void host_integer_radix_scalar_difference_check_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, + std::function sign_handler_f, void *bsk, Torus *ksk, + uint32_t total_num_radix_blocks, uint32_t total_num_scalar_blocks) { + + if (mem_ptr->is_signed) { + // is signed and scalar is positive + integer_radix_signed_scalar_difference_check_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, + sign_handler_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + } else { + integer_radix_unsigned_scalar_difference_check_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, + sign_handler_f, bsk, ksk, total_num_radix_blocks, + total_num_scalar_blocks); + } +} + +template +__host__ void host_integer_radix_signed_scalar_maxmin_kb( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_array_in, + Torus *scalar_blocks, int_comparison_buffer *mem_ptr, void *bsk, + Torus *ksk, uint32_t total_num_radix_blocks, + uint32_t total_num_scalar_blocks) { + + if (mem_ptr->is_signed) { + // is signed and scalar is positive + integer_radix_signed_scalar_maxmin_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk, + total_num_radix_blocks, total_num_scalar_blocks); + } else { + integer_radix_unsigned_scalar_maxmin_kb( + stream, lwe_array_out, lwe_array_in, scalar_blocks, mem_ptr, bsk, ksk, + total_num_radix_blocks, total_num_scalar_blocks); + } +} + template __host__ void scalar_compare_radix_blocks_kb(cuda_stream_t *stream, Torus *lwe_array_out, diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index 114212c650..64e8ecdd0a 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -1795,6 +1795,54 @@ mod cuda { rng_func: shift_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_eq, + display_name: eq, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_ne, + display_name: ne, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_gt, + display_name: gt, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_ge, + display_name: ge, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_lt, + display_name: lt, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_le, + display_name: le, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_min, + display_name: min, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_max, + display_name: max, + rng_func: default_signed_scalar + ); + //=========================================== // Default //=========================================== @@ -1959,6 +2007,54 @@ mod cuda { rng_func: shift_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_eq, + display_name: eq, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_ne, + display_name: ne, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_gt, + display_name: gt, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_ge, + display_name: ge, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_lt, + display_name: lt, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_le, + display_name: le, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_min, + display_name: min, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_max, + display_name: max, + rng_func: default_signed_scalar + ); + criterion_group!( unchecked_cuda_ops, cuda_unchecked_add, @@ -1995,6 +2091,14 @@ mod cuda { cuda_unchecked_scalar_right_shift, cuda_unchecked_scalar_rotate_left, cuda_unchecked_scalar_rotate_right, + cuda_unchecked_scalar_eq, + cuda_unchecked_scalar_ne, + cuda_unchecked_scalar_gt, + cuda_unchecked_scalar_ge, + cuda_unchecked_scalar_lt, + cuda_unchecked_scalar_le, + cuda_unchecked_scalar_min, + cuda_unchecked_scalar_max, ); criterion_group!( @@ -2034,6 +2138,14 @@ mod cuda { cuda_scalar_right_shift, cuda_scalar_rotate_left, cuda_scalar_rotate_right, + cuda_scalar_eq, + cuda_scalar_ne, + cuda_scalar_gt, + cuda_scalar_ge, + cuda_scalar_lt, + cuda_scalar_le, + cuda_scalar_min, + cuda_scalar_max, ); } diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index e12416b855..6fb205e9ac 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -337,7 +337,7 @@ impl FheEq for FheBool { let inner = cuda_key .key - .scalar_eq(&self.ciphertext.on_gpu(), u8::from(other), stream); + .scalar_eq(&*self.ciphertext.on_gpu(), u8::from(other), stream); InnerBoolean::Cuda(inner) }), }); @@ -376,7 +376,7 @@ impl FheEq for FheBool { let inner = cuda_key .key - .scalar_ne(&self.ciphertext.on_gpu(), u8::from(other), stream); + .scalar_ne(&*self.ciphertext.on_gpu(), u8::from(other), stream); InnerBoolean::Cuda(inner) }), }); diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index ac348bf012..d0d109c975 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -107,9 +107,7 @@ impl FheTryTrivialEncrypt for FheBool { let inner = cuda_key .key .create_trivial_radix(u64::from(value), 1, stream); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner.ciphertext, - )) + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext(inner)) }), }); Ok(Self::new(ciphertext)) diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index 83b69f7347..d18f559238 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -5,6 +5,8 @@ use crate::high_level_api::global_state::with_thread_local_cuda_stream; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; +#[cfg(feature = "gpu")] +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt}; use crate::{ClientKey, CompactPublicKey, CompressedPublicKey, FheUint, PublicKey}; @@ -133,11 +135,14 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { - let inner = cuda_key.key.create_trivial_radix( - value, - Id::num_blocks(cuda_key.key.message_modulus), - stream, - ); + let inner: CudaUnsignedRadixCiphertext = + ::from( + cuda_key.key.create_trivial_radix( + value, + Id::num_blocks(cuda_key.key.message_modulus), + stream, + ), + ); Ok(Self::new(inner)) }), }) diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 7265436fd2..b6e2197060 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -15,6 +15,8 @@ use crate::high_level_api::traits::{ }; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; +#[cfg(feature = "gpu")] +use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use crate::{FheBool, FheUint}; use std::borrow::Borrow; use std::ops::{ @@ -78,12 +80,12 @@ where let inner = cuda_key .key .sum_ciphertexts(cts, stream) - .unwrap_or_else(|| { - cuda_key.key.create_trivial_radix( + .unwrap_or_else(|| CudaUnsignedRadixCiphertext { + ciphertext: cuda_key.key.create_trivial_radix( 0, Id::num_blocks(cuda_key.message_modulus()), stream, - ) + ), }); Self::new(inner) }), @@ -166,12 +168,12 @@ where let inner = cuda_key .key .sum_ciphertexts(cts, stream) - .unwrap_or_else(|| { - cuda_key.key.create_trivial_radix( + .unwrap_or_else(|| CudaUnsignedRadixCiphertext { + ciphertext: cuda_key.key.create_trivial_radix( 0, Id::num_blocks(cuda_key.message_modulus()), stream, - ) + ), }); Self::new(inner) }) diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index c0bab8a68b..6c511afb81 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -17,6 +17,8 @@ use crate::high_level_api::traits::{ }; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::IntegerCiphertext; +#[cfg(feature = "gpu")] +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::integer::U256; use crate::FheBool; use std::ops::{ @@ -57,9 +59,12 @@ where FheBool::new(inner_result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support equality with clear"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { + let inner_result = cuda_key + .key + .scalar_eq(&*self.ciphertext.on_gpu(), rhs, stream); + FheBool::new(inner_result) + }), }) } @@ -91,9 +96,12 @@ where FheBool::new(inner_result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - todo!("cuda devices do not support difference with clear") - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { + let inner_result = cuda_key + .key + .scalar_ne(&*self.ciphertext.on_gpu(), rhs, stream); + FheBool::new(inner_result) + }), }) } } @@ -134,7 +142,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_lt(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_lt(&*self.ciphertext.on_gpu(), rhs, stream); FheBool::new(inner_result) }), }) @@ -171,7 +179,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_le(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_le(&*self.ciphertext.on_gpu(), rhs, stream); FheBool::new(inner_result) }), }) @@ -208,7 +216,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_gt(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_gt(&*self.ciphertext.on_gpu(), rhs, stream); FheBool::new(inner_result) }), }) @@ -245,7 +253,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_ge(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_ge(&*self.ciphertext.on_gpu(), rhs, stream); FheBool::new(inner_result) }), }) @@ -290,7 +298,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_max(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_max(&*self.ciphertext.on_gpu(), rhs, stream); Self::new(inner_result) }), }) @@ -335,7 +343,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| { let inner_result = cuda_key .key - .scalar_min(&self.ciphertext.on_gpu(), rhs, stream); + .scalar_min(&*self.ciphertext.on_gpu(), rhs, stream); Self::new(inner_result) }), }) @@ -1026,7 +1034,9 @@ generic_integer_impl_scalar_left_operation!( #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_stream(|stream| { - let mut result = cuda_key.key.create_trivial_radix(lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), stream); + let mut result = ::from(cuda_key.key.create_trivial_radix(lhs, rhs + .ciphertext.on_gpu().ciphertext.info.blocks.len(), stream)); cuda_key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), stream); RadixCiphertext::Cuda(result) }) diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 20726c06df..758c34a976 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -1136,7 +1136,7 @@ impl CudaStream { num_blocks: u32, num_scalar_blocks: u32, op: ComparisonType, - is_signed: bool, + signed_with_positive_scalar: bool, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); scratch_cuda_integer_radix_comparison_kb_64( @@ -1156,7 +1156,7 @@ impl CudaStream { carry_modulus.0 as u32, PBSType::Classical as u32, op as u32, - is_signed, + signed_with_positive_scalar, true, ); @@ -1203,7 +1203,7 @@ impl CudaStream { num_blocks: u32, num_scalar_blocks: u32, op: ComparisonType, - is_signed: bool, + signed_with_positive_scalar: bool, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); scratch_cuda_integer_radix_comparison_kb_64( @@ -1223,7 +1223,7 @@ impl CudaStream { carry_modulus.0 as u32, PBSType::MultiBit as u32, op as u32, - is_signed, + signed_with_positive_scalar, true, ); cuda_scalar_comparison_integer_radix_ciphertext_kb_64( diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 385aa029d1..17d4dddc8e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -71,7 +71,7 @@ impl CudaServerKey { num_blocks: usize, stream: &CudaStream, ) -> T { - T::from(self.create_trivial_radix(0, num_blocks, stream).ciphertext) + T::from(self.create_trivial_radix(0, num_blocks, stream)) } /// Create a trivial ciphertext on the GPU @@ -105,7 +105,7 @@ impl CudaServerKey { scalar: Scalar, num_blocks: usize, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> CudaRadixCiphertext where Scalar: DecomposableInto, { @@ -140,11 +140,9 @@ impl CudaServerKey { let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream); - CudaUnsignedRadixCiphertext { - ciphertext: CudaRadixCiphertext { - d_blocks, - info: CudaRadixCiphertextInfo { blocks: info }, - }, + CudaRadixCiphertext { + d_blocks, + info: CudaRadixCiphertextInfo { blocks: info }, } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index 76166b0026..7c090b852e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -5,36 +5,121 @@ use crate::core_crypto::prelude::{CiphertextModulus, LweCiphertextCount}; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo; -use crate::integer::gpu::ciphertext::{ - CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaUnsignedRadixCiphertext, -}; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext}; use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::ComparisonType; use crate::integer::server_key::comparator::Comparator; use crate::shortint::ciphertext::Degree; impl CudaServerKey { + /// Returns whether the clear scalar is outside of the + /// value range the ciphertext can hold. + /// + /// - Returns None if the scalar is in the range of values that the ciphertext can represent + /// + /// - Returns Some(ordering) when the scalar is out of representable range of the ciphertext. + /// - Equal will never be returned + /// - Less means the scalar is less than the min value representable by the ciphertext + /// - Greater means the scalar is greater that the max value representable by the ciphertext + pub(crate) fn is_scalar_out_of_bounds( + &self, + ct: &T, + scalar: Scalar, + ) -> Option + where + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, + { + let scalar_blocks = + BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2()) + .iter_as::() + .collect::>(); + + let ct_len = ct.as_ref().d_blocks.lwe_ciphertext_count(); + + if T::IS_SIGNED { + let sign_bit_pos = self.message_modulus.0.ilog2() - 1; + let sign_bit_is_set = scalar_blocks + .get(ct_len.0 - 1) + .map_or(false, |block| (block >> sign_bit_pos) == 1); + + if scalar > Scalar::ZERO + && (scalar_blocks.len() > ct_len.0 + || (scalar_blocks.len() == ct_len.0 && sign_bit_is_set)) + { + // If scalar is positive and that any bits above the ct's n-1 bits is set + // it means scalar is bigger. + // + // This is checked in two step + // - If there a more scalar blocks than ct blocks then ct is trivially bigger + // - If there are the same number of blocks but the "sign bit" / msb of st scalar is + // set then, the scalar is trivially bigger + return Some(std::cmp::Ordering::Greater); + } else if scalar < Scalar::ZERO { + // If scalar is negative, and that any bits above the ct's n-1 bits is not set + // it means scalar is smaller. + + if ct_len.0 > scalar_blocks.len() { + // Ciphertext has more blocks, the scalar may be in range + return None; + } + + // (returns false for empty iter) + let at_least_one_block_is_not_full_of_1s = scalar_blocks[ct_len.0..] + .iter() + .any(|&scalar_block| scalar_block != (self.message_modulus.0 as u64 - 1)); + + let sign_bit_pos = self.message_modulus.0.ilog2() - 1; + let sign_bit_is_unset = scalar_blocks + .get(ct_len.0 - 1) + .map_or(false, |block| (block >> sign_bit_pos) == 0); + + if at_least_one_block_is_not_full_of_1s || sign_bit_is_unset { + // Scalar is smaller than lowest value of T + return Some(std::cmp::Ordering::Less); + } + } + } else { + // T is unsigned + if scalar < Scalar::ZERO { + // ct represent an unsigned (always >= 0) + return Some(std::cmp::Ordering::Less); + } else if scalar > Scalar::ZERO { + // scalar is obviously bigger if it has non-zero + // blocks after lhs's last block + let is_scalar_obviously_bigger = + scalar_blocks.get(ct_len.0..).is_some_and(|sub_slice| { + sub_slice.iter().any(|&scalar_block| scalar_block != 0) + }); + if is_scalar_obviously_bigger { + return Some(std::cmp::Ordering::Greater); + } + } + } + + None + } + /// # Safety /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_comparison_async( + pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, op: ComparisonType, + signed_with_positive_scalar: bool, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { - if scalar < T::ZERO { + if scalar < Scalar::ZERO { // ct represents an unsigned (always >= 0) let ct_res = self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream); - return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new( - ct_res.ciphertext.d_blocks, - ct_res.ciphertext.info, - )); + return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res); } let message_modulus = self.message_modulus.0; @@ -51,10 +136,7 @@ impl CudaServerKey { .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); if is_scalar_obviously_bigger { let ct_res = self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream); - return CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new( - ct_res.ciphertext.d_blocks, - ct_res.ciphertext.info, - )); + return CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res); } // If we are still here, that means scalar_blocks above @@ -105,7 +187,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + signed_with_positive_scalar, ); } CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { @@ -133,7 +215,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + signed_with_positive_scalar, ); } } @@ -145,49 +227,98 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_minmax_async( + pub unsafe fn unchecked_scalar_comparison_async( &self, - ct: &CudaUnsignedRadixCiphertext, + ct: &T, scalar: Scalar, op: ComparisonType, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> CudaBooleanBlock where Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { - if scalar < Scalar::ZERO { - // ct represents an unsigned (always >= 0) - return self.create_trivial_radix(Comparator::IS_SUPERIOR, 1, stream); - } + let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + + if T::IS_SIGNED { + match self.is_scalar_out_of_bounds(ct, scalar) { + Some(std::cmp::Ordering::Greater) => { + // Scalar is greater than the bounds, so ciphertext is smaller + let result = match op { + ComparisonType::LT | ComparisonType::LE => { + self.create_trivial_radix(1, num_blocks, stream) + } + _ => self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + stream, + ), + }; + return CudaBooleanBlock::from_cuda_radix_ciphertext(result); + } + Some(std::cmp::Ordering::Less) => { + // Scalar is smaller than the bounds, so ciphertext is bigger + let result = match op { + ComparisonType::GT | ComparisonType::GE => { + self.create_trivial_radix(1, num_blocks, stream) + } + _ => self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + stream, + ), + }; + return CudaBooleanBlock::from_cuda_radix_ciphertext(result); + } + Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"), + None => { + // scalar is in range, fallthrough + } + } + if scalar >= Scalar::ZERO { + self.unchecked_signed_and_unsigned_scalar_comparison_async( + ct, scalar, op, true, stream, + ) + } else { + let scalar_as_trivial = + T::from(self.create_trivial_radix(scalar, num_blocks, stream)); + self.unchecked_comparison_async(ct, &scalar_as_trivial, op, stream) + } + } else { + // Unsigned + self.unchecked_signed_and_unsigned_scalar_comparison_async( + ct, scalar, op, false, stream, + ) + } + } + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn unchecked_scalar_minmax_async( + &self, + ct: &T, + scalar: Scalar, + op: ComparisonType, + stream: &CudaStream, + ) -> T + where + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, + { let message_modulus = self.message_modulus.0; - let mut scalar_blocks = + let scalar_blocks = BlockDecomposer::with_early_stop_at_zero(scalar, message_modulus.ilog2()) .iter_as::() .collect::>(); - // scalar is obviously bigger if it has non-zero - // blocks after lhs's last block - let is_scalar_obviously_bigger = scalar_blocks - .get(ct.as_ref().d_blocks.lwe_ciphertext_count().0..) - .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); - if is_scalar_obviously_bigger { - return self.create_trivial_radix(Comparator::IS_INFERIOR, 1, stream); - } - - // If we are still here, that means scalar_blocks above - // num_blocks are 0s, we can remove them - // as we will handle them separately. - scalar_blocks.truncate(ct.as_ref().d_blocks.lwe_ciphertext_count().0); - let d_scalar_blocks: CudaVec = CudaVec::from_cpu_async(&scalar_blocks, stream); let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); - let mut result = CudaUnsignedRadixCiphertext { - ciphertext: ct.as_ref().duplicate_async(stream), - }; + let mut result = ct.duplicate_async(stream); match &self.bootstrapping_key { CudaBootstrappingKey::Classic(d_bsk) => { @@ -214,7 +345,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + T::IS_SIGNED, ); } CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { @@ -242,7 +373,7 @@ impl CudaServerKey { lwe_ciphertext_count.0 as u32, scalar_blocks.len() as u32, op, - false, + T::IS_SIGNED, ); } } @@ -254,26 +385,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_eq_async( + pub unsafe fn unchecked_scalar_eq_async( &self, - ct: &CudaUnsignedRadixCiphertext, + ct: &T, scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where + T: CudaIntegerRadixCiphertext, Scalar: DecomposableInto, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, stream) } - pub fn unchecked_scalar_eq( + pub fn unchecked_scalar_eq( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, stream) }; stream.synchronize(); @@ -284,14 +417,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_eq_async( + pub unsafe fn scalar_eq_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -346,14 +480,15 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 == msg2); /// ``` - pub fn scalar_eq( + pub fn scalar_eq( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { let result = unsafe { self.scalar_eq_async(ct, scalar, stream) }; stream.synchronize(); @@ -364,14 +499,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_ne_async( + pub unsafe fn scalar_ne_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -426,14 +562,15 @@ impl CudaServerKey { /// let dec_result = cks.decrypt_bool(&ct_res); /// assert_eq!(dec_result, msg1 != msg2); /// ``` - pub fn scalar_ne( + pub fn scalar_ne( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_ne_async(ct, scalar, stream) }; stream.synchronize(); @@ -444,26 +581,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_ne_async( + pub unsafe fn unchecked_scalar_ne_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, stream) } - pub fn unchecked_scalar_ne( + pub fn unchecked_scalar_ne( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + T: CudaIntegerRadixCiphertext, + Scalar: DecomposableInto, { let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, stream) }; stream.synchronize(); @@ -474,26 +613,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_gt_async( + pub unsafe fn unchecked_scalar_gt_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GT, stream) } - pub fn unchecked_scalar_gt( + pub fn unchecked_scalar_gt( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_gt_async(ct, scalar, stream) }; stream.synchronize(); @@ -504,26 +645,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_ge_async( + pub unsafe fn unchecked_scalar_ge_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GE, stream) } - pub fn unchecked_scalar_ge( + pub fn unchecked_scalar_ge( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_ge_async(ct, scalar, stream) }; stream.synchronize(); @@ -534,26 +677,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_lt_async( + pub unsafe fn unchecked_scalar_lt_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LT, stream) } - pub fn unchecked_scalar_lt( + pub fn unchecked_scalar_lt( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_lt_async(ct, scalar, stream) }; stream.synchronize(); @@ -564,26 +709,28 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_le_async( + pub unsafe fn unchecked_scalar_le_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LE, stream) } - pub fn unchecked_scalar_le( + pub fn unchecked_scalar_le( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_le_async(ct, scalar, stream) }; stream.synchronize(); @@ -593,14 +740,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_gt_async( + pub unsafe fn scalar_gt_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -614,14 +762,15 @@ impl CudaServerKey { self.unchecked_scalar_gt_async(lhs, scalar, stream) } - pub fn scalar_gt( + pub fn scalar_gt( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_gt_async(ct, scalar, stream) }; stream.synchronize(); @@ -632,14 +781,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_ge_async( + pub unsafe fn scalar_ge_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -653,14 +803,15 @@ impl CudaServerKey { self.unchecked_scalar_ge_async(lhs, scalar, stream) } - pub fn scalar_ge( + pub fn scalar_ge( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_ge_async(ct, scalar, stream) }; stream.synchronize(); @@ -671,14 +822,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_lt_async( + pub unsafe fn scalar_lt_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -692,14 +844,15 @@ impl CudaServerKey { self.unchecked_scalar_lt_async(lhs, scalar, stream) } - pub fn scalar_lt( + pub fn scalar_lt( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_lt_async(ct, scalar, stream) }; stream.synchronize(); @@ -709,14 +862,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_le_async( + pub unsafe fn scalar_le_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -730,14 +884,15 @@ impl CudaServerKey { self.unchecked_scalar_le_async(lhs, scalar, stream) } - pub fn scalar_le( + pub fn scalar_le( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, ) -> CudaBooleanBlock where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_le_async(ct, scalar, stream) }; stream.synchronize(); @@ -748,26 +903,23 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_max_async( + pub unsafe fn unchecked_scalar_max_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MAX, stream) } - pub fn unchecked_scalar_max( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn unchecked_scalar_max(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_max_async(ct, scalar, stream) }; stream.synchronize(); @@ -778,26 +930,23 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_min_async( + pub unsafe fn unchecked_scalar_min_async( &self, - ct: &CudaUnsignedRadixCiphertext, + ct: &T, scalar: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MIN, stream) } - pub fn unchecked_scalar_min( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: Scalar, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn unchecked_scalar_min(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.unchecked_scalar_min_async(ct, scalar, stream) }; stream.synchronize(); @@ -808,14 +957,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_max_async( + pub unsafe fn scalar_max_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -829,14 +979,10 @@ impl CudaServerKey { self.unchecked_scalar_max_async(lhs, scalar, stream) } - pub fn scalar_max( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn scalar_max(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_max_async(ct, scalar, stream) }; stream.synchronize(); @@ -847,14 +993,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_min_async( + pub unsafe fn scalar_min_async( &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, + ct: &T, + scalar: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut tmp_lhs; let lhs = if ct.block_carries_are_empty() { @@ -868,14 +1015,10 @@ impl CudaServerKey { self.unchecked_scalar_min_async(lhs, scalar, stream) } - pub fn scalar_min( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn scalar_min(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let result = unsafe { self.scalar_min_async(ct, scalar, stream) }; stream.synchronize(); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index d8f3f86409..4c486681b7 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_comparison; pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs new file mode 100644 index 0000000000..f78e883a20 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_comparison.rs @@ -0,0 +1,92 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_comparison::{ + test_signed_default_scalar_function, test_signed_default_scalar_minmax, + test_signed_unchecked_scalar_function, test_signed_unchecked_scalar_minmax, +}; +use crate::shortint::parameters::*; + +/// This macro generates the tests for a given comparison fn +/// +/// All our comparison function have 2 variants: +/// - unchecked_$comparison_name +/// - $comparison_name +/// +/// So, for example, for the `gt` comparison fn, this macro will generate the tests for +/// the 2 variants described above +macro_rules! define_gpu_signed_scalar_comparison_test_functions { + ($comparison_name:ident, $clear_type:ty) => { + ::paste::paste!{ + fn []

(param: P) where P: Into { + let num_tests = 1; + let executor = GpuFunctionExecutor::new(&CudaServerKey::[]); + test_signed_unchecked_scalar_function( + param, + num_tests, + executor, + |lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)), + ) + } + + fn []

(param: P) where P: Into { + let num_tests = 10; + let executor = GpuFunctionExecutor::new(&CudaServerKey::[]); + test_signed_default_scalar_function( + param, + num_tests, + executor, + |lhs, rhs| $clear_type::from(<$clear_type>::$comparison_name(&lhs, &rhs)), + ) + } + + create_gpu_parametrized_test!([]); + create_gpu_parametrized_test!([]); + } + }; +} + +fn integer_signed_unchecked_scalar_min_i128

(params: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_min); + test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::min::); +} + +fn integer_signed_unchecked_scalar_max_i128

(params: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_max); + test_signed_unchecked_scalar_minmax(params, 2, executor, std::cmp::max::); +} + +fn integer_signed_scalar_min_i128

(params: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_min); + test_signed_default_scalar_minmax(params, 2, executor, std::cmp::min::); +} + +fn integer_signed_scalar_max_i128

(params: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_max); + test_signed_default_scalar_minmax(params, 2, executor, std::cmp::max::); +} + +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_max_i128); +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_min_i128); +create_gpu_parametrized_test!(integer_signed_scalar_max_i128); +create_gpu_parametrized_test!(integer_signed_scalar_min_i128); + +define_gpu_signed_scalar_comparison_test_functions!(eq, i128); +define_gpu_signed_scalar_comparison_test_functions!(ne, i128); +define_gpu_signed_scalar_comparison_test_functions!(lt, i128); +define_gpu_signed_scalar_comparison_test_functions!(le, i128); +define_gpu_signed_scalar_comparison_test_functions!(gt, i128); +define_gpu_signed_scalar_comparison_test_functions!(ge, i128); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs index fe130acf14..08136fc635 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs @@ -79,6 +79,137 @@ where test_default_scalar_minmax(params, 2, executor, std::cmp::max::); } +// The goal of this function is to ensure that scalar comparisons +// work when the scalar type used is either bigger or smaller (in bit size) +// compared to the ciphertext +//fn integer_unchecked_scalar_comparisons_edge(param: ClassicPBSParameters) { +// let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; +// +// let gpu_index = 0; +// let device = CudaDevice::new(gpu_index); +// let stream = CudaStream::new_unchecked(device); +// +// let (cks, sks) = gen_keys_gpu(param, &stream); +// +// let mut rng = rand::thread_rng(); +// +// for _ in 0..4 { +// let clear_a = rng.gen_range((u128::from(u64::MAX) + 1)..=u128::MAX); +// let smaller_clear = rng.gen::(); +// let bigger_clear = rng.gen::(); +// +// let a = cks.encrypt_radix(clear_a, num_block); +// // Copy to the GPU +// let d_a = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&a, &stream); +// +// // >= +// { +// let d_result = sks.unchecked_scalar_ge(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) >= U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_ge(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) >= bigger_clear); +// } +// +// // > +// { +// let d_result = sks.unchecked_scalar_gt(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) > U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_gt(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) > bigger_clear); +// } +// +// // <= +// { +// let d_result = sks.unchecked_scalar_le(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) <= U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_le(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) <= bigger_clear); +// } +// +// // < +// { +// let d_result = sks.unchecked_scalar_lt(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) < U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_lt(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) < bigger_clear); +// } +// +// // == +// { +// let d_result = sks.unchecked_scalar_eq(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) == U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_eq(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) == bigger_clear); +// } +// +// // != +// { +// let d_result = sks.unchecked_scalar_ne(&d_a, smaller_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) != U256::from(smaller_clear)); +// +// let d_result = sks.unchecked_scalar_ne(&d_a, bigger_clear, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) != bigger_clear); +// } +// +// // Here the goal is to test, the branching +// // made in the scalar sign function +// // +// // We are forcing one of the two branches to work on empty slices +// { +// let d_result = sks.unchecked_scalar_lt(&d_a, U256::ZERO, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) < U256::ZERO); +// +// let d_result = sks.unchecked_scalar_lt(&d_a, U256::MAX, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) < U256::MAX); +// +// // == (as it does not share same code) +// let d_result = sks.unchecked_scalar_eq(&d_a, U256::ZERO, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) == U256::ZERO); +// +// // != (as it does not share same code) +// let d_result = sks.unchecked_scalar_ne(&d_a, U256::MAX, &stream); +// let result = d_result.to_boolean_block(&stream); +// let decrypted = cks.decrypt_bool(&result); +// assert_eq!(decrypted, U256::from(clear_a) != U256::MAX); +// } +// } +//} + create_gpu_parametrized_test!(integer_unchecked_scalar_min_u256); create_gpu_parametrized_test!(integer_unchecked_scalar_max_u256); create_gpu_parametrized_test!(integer_scalar_min_u256); @@ -90,3 +221,8 @@ define_gpu_scalar_comparison_test_functions!(lt, U256); define_gpu_scalar_comparison_test_functions!(le, U256); define_gpu_scalar_comparison_test_functions!(gt, U256); define_gpu_scalar_comparison_test_functions!(ge, U256); + +//create_gpu_parametrized_test!(integer_unchecked_scalar_comparisons_edge { +// PARAM_MESSAGE_2_CARRY_2_KS_PBS, +//}); +//