From cb1110fc795ac5c558af5c595c769ebc0cb161bf Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Wed, 20 Mar 2024 14:56:06 +0100 Subject: [PATCH] feat(gpu): signed and unsigned scalar mul + remove small scalar mul + move around signed tests_cases --- .../tfhe-cuda-backend/cuda/include/integer.h | 177 +- .../cuda/src/integer/multiplication.cu | 142 +- .../cuda/src/integer/multiplication.cuh | 45 - .../cuda/src/integer/scalar_mul.cu | 89 + .../cuda/src/integer/scalar_mul.cuh | 136 ++ .../cuda/src/integer/scalar_shifts.cu | 2 + .../cuda/src/integer/shift_and_rotate.cuh | 4 +- .../cuda/src/pbs/programmable_bootstrap.cuh | 4 + backends/tfhe-cuda-backend/src/cuda_bind.rs | 33 +- tfhe/benches/integer/bench.rs | 14 + tfhe/benches/integer/signed_bench.rs | 26 + .../integers/unsigned/scalar_ops.rs | 16 +- tfhe/src/integer/gpu/ciphertext/info.rs | 14 +- tfhe/src/integer/gpu/mod.rs | 111 +- .../gpu/server_key/radix/scalar_add.rs | 32 +- .../gpu/server_key/radix/scalar_mul.rs | 179 +- .../gpu/server_key/radix/scalar_sub.rs | 13 +- .../gpu/server_key/radix/tests_signed/mod.rs | 1 + .../server_key/radix/tests_signed/test_add.rs | 2 +- .../radix/tests_signed/test_bitwise_op.rs | 2 +- .../server_key/radix/tests_signed/test_mul.rs | 2 +- .../server_key/radix/tests_signed/test_neg.rs | 2 +- .../radix/tests_signed/test_scalar_add.rs | 2 +- .../tests_signed/test_scalar_bitwise_op.rs | 2 +- .../radix/tests_signed/test_scalar_mul.rs | 16 + .../radix/tests_signed/test_scalar_shift.rs | 2 +- .../radix/tests_signed/test_scalar_sub.rs | 2 +- .../server_key/radix/tests_signed/test_sub.rs | 2 +- .../server_key/radix/tests_unsigned/mod.rs | 19 +- .../radix/tests_unsigned/test_scalar_mul.rs | 27 + .../server_key/radix_parallel/ilog2.rs | 5 +- .../integer/server_key/radix_parallel/mod.rs | 2 - .../server_key/radix_parallel/scalar_mul.rs | 314 +-- .../radix_parallel/tests_cases_signed.rs | 1933 ----------------- .../radix_parallel/tests_cases_unsigned.rs | 116 - .../radix_parallel/tests_signed/mod.rs | 344 ++- .../radix_parallel/tests_signed/test_add.rs | 152 +- .../tests_signed/test_bitwise_op.rs | 301 ++- .../radix_parallel/tests_signed/test_mul.rs | 92 +- .../radix_parallel/tests_signed/test_neg.rs | 149 +- .../tests_signed/test_rotate.rs | 4 +- .../tests_signed/test_scalar_add.rs | 280 ++- .../tests_signed/test_scalar_bitwise_op.rs | 141 +- .../tests_signed/test_scalar_mul.rs | 51 + .../tests_signed/test_scalar_shift.rs | 229 +- .../tests_signed/test_scalar_sub.rs | 232 +- .../radix_parallel/tests_signed/test_shift.rs | 7 +- .../radix_parallel/tests_signed/test_sub.rs | 106 +- .../radix_parallel/tests_unsigned/mod.rs | 98 +- .../tests_unsigned/test_scalar_mul.rs | 75 + 50 files changed, 2966 insertions(+), 2783 deletions(-) create mode 100644 backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu create mode 100644 backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_mul.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_mul.rs delete mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_mul.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_mul.rs diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 8e252e6344..a2669dd28f 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -85,14 +85,6 @@ void cuda_scalar_addition_integer_radix_ciphertext_64_inplace( uint32_t lwe_dimension, uint32_t lwe_ciphertext_count, uint32_t message_modulus, uint32_t carry_modulus); -void cuda_small_scalar_multiplication_integer_radix_ciphertext_64( - cuda_stream_t *stream, void *output_lwe_array, void *input_lwe_array, - uint64_t scalar, uint32_t lwe_dimension, uint32_t lwe_ciphertext_count); - -void cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( - cuda_stream_t *stream, void *lwe_array, uint64_t scalar, - uint32_t lwe_dimension, uint32_t lwe_ciphertext_count); - void scratch_cuda_integer_radix_logical_scalar_shift_kb_64( cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t big_lwe_dimension, @@ -269,8 +261,36 @@ void cuda_integer_radix_overflowing_sub_kb_64( void cleanup_cuda_integer_radix_overflowing_sub(cuda_stream_t *stream, int8_t **mem_ptr_void); -} // extern "C" +void scratch_cuda_integer_scalar_mul_kb_64( + cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory); + +void cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace( + cuda_stream_t *stream, void *lwe_array, uint64_t *decomposed_scalar, + uint64_t *has_at_least_one_set, int8_t *mem_ptr, void *bsk, void *ksk, + uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t message_modulus, + uint32_t num_blocks, uint32_t num_scalars); + +void cleanup_cuda_integer_radix_scalar_mul(cuda_stream_t *stream, + int8_t **mem_ptr_void); +} + +template +__global__ void radix_blocks_rotate_right(Torus *dst, Torus *src, + uint32_t value, uint32_t blocks_count, + uint32_t lwe_size); +void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in, + size_t *h_lwe_idx_out, + int32_t *h_smart_copy_in, + int32_t *h_smart_copy_out, size_t ch_amount, + uint32_t num_radix, uint32_t num_blocks, + size_t chunk_size, size_t message_max, + size_t &total_count, size_t &message_count, + size_t &carry_count, size_t &sm_copy_count); /* * generate bivariate accumulator (lut) for device pointer * v_stream - cuda stream @@ -1225,6 +1245,8 @@ template struct int_logical_scalar_shift_buffer { Torus *tmp_rotated; + bool reuse_memory = false; + int_logical_scalar_shift_buffer(cuda_stream_t *stream, SHIFT_OR_ROTATE_TYPE shift_type, int_radix_params params, @@ -1310,6 +1332,89 @@ template struct int_logical_scalar_shift_buffer { } } + int_logical_scalar_shift_buffer(cuda_stream_t *stream, + SHIFT_OR_ROTATE_TYPE shift_type, + int_radix_params params, + uint32_t num_radix_blocks, + bool allocate_gpu_memory, + Torus *pre_allocated_buffer) { + this->shift_type = shift_type; + this->params = params; + tmp_rotated = pre_allocated_buffer; + reuse_memory = true; + + uint32_t max_amount_of_pbs = num_radix_blocks; + uint32_t big_lwe_size = params.big_lwe_dimension + 1; + uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus); + cuda_memset_async(tmp_rotated, 0, + (max_amount_of_pbs + 2) * big_lwe_size_bytes, stream); + if (allocate_gpu_memory) { + + uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus); + + // LUT + // pregenerate lut vector and indexes + // lut for left shift + // here we generate 'num_bits_in_block' times lut + // one for each 'shift_within_block' = 'shift' % 'num_bits_in_block' + // even though lut_left contains 'num_bits_in_block' lut + // lut_indexes will have indexes for single lut only and those indexes + // will be 0 it means for pbs corresponding lut should be selected and + // pass along lut_indexes filled with zeros + + // calculate bivariate lut for each 'shift_within_block' + // so that in case an application calls scratches only once for a whole + // circuit it can reuse memory for different shift values + for (int s_w_b = 1; s_w_b < num_bits_in_block; s_w_b++) { + auto cur_lut_bivariate = new int_radix_lut( + stream, params, 1, num_radix_blocks, allocate_gpu_memory); + + uint32_t shift_within_block = s_w_b; + + std::function shift_lut_f; + + if (shift_type == LEFT_SHIFT) { + shift_lut_f = [shift_within_block, + params](Torus current_block, + Torus previous_block) -> Torus { + current_block = current_block << shift_within_block; + previous_block = previous_block << shift_within_block; + + Torus message_of_current_block = + current_block % params.message_modulus; + Torus carry_of_previous_block = + previous_block / params.message_modulus; + return message_of_current_block + carry_of_previous_block; + }; + } else { + shift_lut_f = [num_bits_in_block, shift_within_block, params]( + Torus current_block, Torus next_block) -> Torus { + // left shift so as not to lose + // bits when shifting right afterwards + next_block <<= num_bits_in_block; + next_block >>= shift_within_block; + + // The way of getting carry / message is reversed compared + // to the usual way but its normal: + // The message is in the upper bits, the carry in lower bits + Torus message_of_current_block = + current_block >> shift_within_block; + Torus carry_of_previous_block = next_block % params.message_modulus; + + return message_of_current_block + carry_of_previous_block; + }; + } + + // right shift + generate_device_accumulator_bivariate( + stream, cur_lut_bivariate->lut, params.glwe_dimension, + params.polynomial_size, params.message_modulus, + params.carry_modulus, shift_lut_f); + + lut_buffers_bivariate.push_back(cur_lut_bivariate); + } + } + } void release(cuda_stream_t *stream) { for (auto &buffer : lut_buffers_bivariate) { buffer->release(stream); @@ -1317,7 +1422,8 @@ template struct int_logical_scalar_shift_buffer { } lut_buffers_bivariate.clear(); - cuda_drop_async(tmp_rotated, stream); + if (!reuse_memory) + cuda_drop_async(tmp_rotated, stream); } }; @@ -2048,4 +2154,55 @@ template struct int_bitop_buffer { } }; +template struct int_scalar_mul_buffer { + int_radix_params params; + int_logical_scalar_shift_buffer *logical_scalar_shift_buffer; + int_sum_ciphertexts_vec_memory *sum_ciphertexts_vec_mem; + Torus *preshifted_buffer; + Torus *all_shifted_buffer; + + int_scalar_mul_buffer(cuda_stream_t *stream, int_radix_params params, + uint32_t num_radix_blocks, bool allocate_gpu_memory) { + this->params = params; + + if (allocate_gpu_memory) { + uint32_t msg_bits = (uint32_t)std::log2(params.message_modulus); + uint32_t lwe_size = params.big_lwe_dimension + 1; + uint32_t lwe_size_bytes = lwe_size * sizeof(Torus); + size_t num_ciphertext_bits = msg_bits * num_radix_blocks; + + //// Contains all shifted values of lhs for shift in range (0..msg_bits) + //// The idea is that with these we can create all other shift that are in + //// range (0..total_bits) for free (block rotation) + preshifted_buffer = (Torus *)cuda_malloc_async( + num_ciphertext_bits * lwe_size_bytes, stream); + + all_shifted_buffer = (Torus *)cuda_malloc_async( + num_ciphertext_bits * num_radix_blocks * lwe_size_bytes, stream); + + cuda_memset_async(preshifted_buffer, 0, + num_ciphertext_bits * lwe_size_bytes, stream); + + cuda_memset_async(all_shifted_buffer, 0, + num_ciphertext_bits * num_radix_blocks * lwe_size_bytes, + stream); + + logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer( + stream, LEFT_SHIFT, params, num_radix_blocks, allocate_gpu_memory, + all_shifted_buffer); + + sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory( + stream, params, num_radix_blocks, num_ciphertext_bits, + allocate_gpu_memory); + } + } + + void release(cuda_stream_t *stream) { + logical_scalar_shift_buffer->release(stream); + sum_ciphertexts_vec_mem->release(stream); + cuda_drop_async(preshifted_buffer, stream); + cuda_drop_async(all_shifted_buffer, stream); + } +}; + #endif // CUDA_INTEGER_H diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu index a95143d4da..29ac9c622f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu @@ -1,5 +1,66 @@ #include "integer/multiplication.cuh" +/* + * when adding chunk_size times terms together, there might be some blocks + * where addition have not happened or degree is zero, in that case we don't + * need to apply lookup table, so we find the indexes of the blocks where + * addition happened and store them inside h_lwe_idx_in, from same block + * might be extracted message and carry(if it is not the last block), so + * one block id might have two output id and we store them in h_lwe_idx_out + * blocks that do not require applying lookup table might be copied on both + * message and carry side or be replaced with zero ciphertexts, indexes of such + * blocks are stored inside h_smart_copy_in as input ids and h_smart_copy_out + * as output ids, -1 value as an input id means that zero ciphertext will be + * copied on output index. + */ +void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in, + size_t *h_lwe_idx_out, + int32_t *h_smart_copy_in, + int32_t *h_smart_copy_out, size_t ch_amount, + uint32_t num_radix, uint32_t num_blocks, + size_t chunk_size, size_t message_max, + size_t &total_count, size_t &message_count, + size_t &carry_count, size_t &sm_copy_count) { + for (size_t c_id = 0; c_id < ch_amount; c_id++) { + auto cur_chunk = &terms_degree[c_id * chunk_size * num_blocks]; + for (size_t r_id = 0; r_id < num_blocks; r_id++) { + size_t new_degree = 0; + for (size_t chunk_id = 0; chunk_id < chunk_size; chunk_id++) { + new_degree += cur_chunk[chunk_id * num_blocks + r_id]; + } + + if (new_degree > message_max) { + h_lwe_idx_in[message_count] = c_id * num_blocks + r_id; + h_lwe_idx_out[message_count] = c_id * num_blocks + r_id; + message_count++; + } else { + h_smart_copy_in[sm_copy_count] = c_id * num_blocks + r_id; + h_smart_copy_out[sm_copy_count] = c_id * num_blocks + r_id; + sm_copy_count++; + } + } + } + for (size_t i = 0; i < sm_copy_count; i++) { + h_smart_copy_in[i] = -1; + h_smart_copy_out[i] = h_smart_copy_out[i] + ch_amount * num_blocks + 1; + } + + for (size_t i = 0; i < message_count; i++) { + if (h_lwe_idx_in[i] % num_blocks != num_blocks - 1) { + h_lwe_idx_in[message_count + carry_count] = h_lwe_idx_in[i]; + h_lwe_idx_out[message_count + carry_count] = + ch_amount * num_blocks + h_lwe_idx_in[i] + 1; + carry_count++; + } else { + h_smart_copy_in[sm_copy_count] = -1; + h_smart_copy_out[sm_copy_count] = + h_lwe_idx_in[i] - (num_blocks - 1) + ch_amount * num_blocks; + sm_copy_count++; + } + } + + total_count = message_count + carry_count; +} /* * This scratch function allocates the necessary amount of data on the GPU for * the integer radix multiplication in keyswitch->bootstrap order. @@ -89,25 +150,6 @@ void cleanup_cuda_integer_mult(cuda_stream_t *stream, int8_t **mem_ptr_void) { mem_ptr->release(stream); } -void cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( - cuda_stream_t *stream, void *lwe_array, uint64_t scalar, - uint32_t lwe_dimension, uint32_t lwe_ciphertext_count) { - - cuda_small_scalar_multiplication_integer_radix_ciphertext_64( - stream, lwe_array, lwe_array, scalar, lwe_dimension, - lwe_ciphertext_count); -} - -void cuda_small_scalar_multiplication_integer_radix_ciphertext_64( - cuda_stream_t *stream, void *output_lwe_array, void *input_lwe_array, - uint64_t scalar, uint32_t lwe_dimension, uint32_t lwe_ciphertext_count) { - - host_integer_small_scalar_mult_radix( - stream, static_cast(output_lwe_array), - static_cast(input_lwe_array), scalar, lwe_dimension, - lwe_ciphertext_count); -} - void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, @@ -197,65 +239,3 @@ void cleanup_cuda_integer_radix_sum_ciphertexts_vec(cuda_stream_t *stream, mem_ptr->release(stream); } - -/* - * when adding chunk_size times terms together, there might be some blocks - * where addition have not happened or degree is zero, in that case we don't - * need to apply lookup table, so we find the indexes of the blocks where - * addition happened and store them inside h_lwe_idx_in, from same block - * might be extracted message and carry(if it is not the last block), so - * one block id might have two output id and we store them in h_lwe_idx_out - * blocks that do not require applying lookup table might be copied on both - * message and carry side or be replaced with zero ciphertexts, indexes of such - * blocks are stored inside h_smart_copy_in as input ids and h_smart_copy_out - * as output ids, -1 value as an input id means that zero ciphertext will be - * copied on output index. - */ -void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in, - size_t *h_lwe_idx_out, - int32_t *h_smart_copy_in, - int32_t *h_smart_copy_out, size_t ch_amount, - uint32_t num_radix, uint32_t num_blocks, - size_t chunk_size, size_t message_max, - size_t &total_count, size_t &message_count, - size_t &carry_count, size_t &sm_copy_count) { - for (size_t c_id = 0; c_id < ch_amount; c_id++) { - auto cur_chunk = &terms_degree[c_id * chunk_size * num_blocks]; - for (size_t r_id = 0; r_id < num_blocks; r_id++) { - size_t new_degree = 0; - for (size_t chunk_id = 0; chunk_id < chunk_size; chunk_id++) { - new_degree += cur_chunk[chunk_id * num_blocks + r_id]; - } - - if (new_degree > message_max) { - h_lwe_idx_in[message_count] = c_id * num_blocks + r_id; - h_lwe_idx_out[message_count] = c_id * num_blocks + r_id; - message_count++; - } else { - h_smart_copy_in[sm_copy_count] = c_id * num_blocks + r_id; - h_smart_copy_out[sm_copy_count] = c_id * num_blocks + r_id; - sm_copy_count++; - } - } - } - for (size_t i = 0; i < sm_copy_count; i++) { - h_smart_copy_in[i] = -1; - h_smart_copy_out[i] = h_smart_copy_out[i] + ch_amount * num_blocks + 1; - } - - for (size_t i = 0; i < message_count; i++) { - if (h_lwe_idx_in[i] % num_blocks != num_blocks - 1) { - h_lwe_idx_in[message_count + carry_count] = h_lwe_idx_in[i]; - h_lwe_idx_out[message_count + carry_count] = - ch_amount * num_blocks + h_lwe_idx_in[i] + 1; - carry_count++; - } else { - h_smart_copy_in[sm_copy_count] = -1; - h_smart_copy_out[sm_copy_count] = - h_lwe_idx_in[i] - (num_blocks - 1) + ch_amount * num_blocks; - sm_copy_count++; - } - } - - total_count = message_count + carry_count; -} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index bebef38d63..ef73e372b1 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -39,15 +39,6 @@ __global__ void smart_copy(Torus *dst, Torus *src, int32_t *id_out, } } -void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in, - size_t *h_lwe_idx_out, - int32_t *h_smart_copy_in, - int32_t *h_smart_copy_out, size_t ch_amount, - uint32_t num_radix, uint32_t num_blocks, - size_t chunk_size, size_t message_max, - size_t &total_count, size_t &message_count, - size_t &carry_count, size_t &sm_copy_count); - template __global__ void all_shifted_lhs_rhs(Torus *radix_lwe_left, Torus *lsb_ciphertext, @@ -446,40 +437,4 @@ __host__ void scratch_cuda_integer_mult_radix_ciphertext_kb( allocate_gpu_memory); } -template -__global__ void device_small_scalar_radix_multiplication(T *output_lwe_array, - T *input_lwe_array, - T scalar, - uint32_t lwe_dimension, - uint32_t num_blocks) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int lwe_size = lwe_dimension + 1; - if (index < num_blocks * lwe_size) { - // Here we take advantage of the wrapping behaviour of uint - output_lwe_array[index] = input_lwe_array[index] * scalar; - } -} - -template -__host__ void host_integer_small_scalar_mult_radix( - cuda_stream_t *stream, T *output_lwe_array, T *input_lwe_array, T scalar, - uint32_t input_lwe_dimension, uint32_t input_lwe_ciphertext_count) { - - cudaSetDevice(stream->gpu_index); - // lwe_size includes the presence of the body - // whereas lwe_dimension is the number of elements in the mask - int lwe_size = input_lwe_dimension + 1; - // Create a 1-dimensional grid of threads - int num_blocks = 0, num_threads = 0; - int num_entries = input_lwe_ciphertext_count * lwe_size; - getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads); - dim3 grid(num_blocks, 1, 1); - dim3 thds(num_threads, 1, 1); - - device_small_scalar_radix_multiplication<<stream>>>( - output_lwe_array, input_lwe_array, scalar, input_lwe_dimension, - input_lwe_ciphertext_count); - check_cuda_error(cudaGetLastError()); -} - #endif diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu new file mode 100644 index 0000000000..644ebe4bbc --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu @@ -0,0 +1,89 @@ +#include "integer/scalar_mul.cuh" + +void scratch_cuda_integer_scalar_mul_kb_64( + cuda_stream_t *stream, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus); + + scratch_cuda_integer_radix_scalar_mul_kb( + stream, (int_scalar_mul_buffer **)mem_ptr, num_blocks, params, + allocate_gpu_memory); +} + +void cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace( + cuda_stream_t *stream, void *lwe_array, uint64_t *decomposed_scalar, + uint64_t *has_at_least_one_set, int8_t *mem, void *bsk, void *ksk, + uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t message_modulus, + uint32_t num_blocks, uint32_t num_scalars) { + + switch (polynomial_size) { + case 512: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + case 1024: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + case 2048: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + case 4096: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + case 8192: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + case 16384: + host_integer_scalar_mul_radix>( + stream, static_cast(lwe_array), decomposed_scalar, + has_at_least_one_set, + reinterpret_cast *>(mem), bsk, + static_cast(ksk), lwe_dimension, message_modulus, + num_blocks, num_scalars); + break; + default: + PANIC("Cuda error (scalar multiplication): unsupported polynomial size. " + "Only N = 512, 1024, 2048, 4096, 8192, 16384 are supported.") + } +} + +void cleanup_cuda_integer_radix_scalar_mul(cuda_stream_t *stream, + int8_t **mem_ptr_void) { + + cudaSetDevice(stream->gpu_index); + int_scalar_mul_buffer *mem_ptr = + (int_scalar_mul_buffer *)(*mem_ptr_void); + + mem_ptr->release(stream); +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh new file mode 100644 index 0000000000..a52de8604e --- /dev/null +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -0,0 +1,136 @@ +#ifndef CUDA_INTEGER_SCALAR_MUL_CUH +#define CUDA_INTEGER_SCALAR_MUL_CUH + +#ifdef __CDT_PARSER__ +#undef __CUDA_RUNTIME_H__ +#include +#endif + +#include "device.h" +#include "integer.h" +#include "multiplication.cuh" +#include "scalar_shifts.cuh" +#include "utils/kernel_dimensions.cuh" +#include + +template +__global__ void device_small_scalar_radix_multiplication(T *output_lwe_array, + T *input_lwe_array, + T scalar, + uint32_t lwe_dimension, + uint32_t num_blocks) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int lwe_size = lwe_dimension + 1; + if (index < num_blocks * lwe_size) { + // Here we take advantage of the wrapping behaviour of uint + output_lwe_array[index] = input_lwe_array[index] * scalar; + } +} + +template +__host__ void scratch_cuda_integer_radix_scalar_mul_kb( + cuda_stream_t *stream, int_scalar_mul_buffer **mem_ptr, + uint32_t num_radix_blocks, int_radix_params params, + bool allocate_gpu_memory) { + + cudaSetDevice(stream->gpu_index); + size_t sm_size = (params.big_lwe_dimension + 1) * sizeof(T); + check_cuda_error(cudaFuncSetAttribute( + tree_add_chunks, cudaFuncAttributeMaxDynamicSharedMemorySize, + sm_size)); + cudaFuncSetCacheConfig(tree_add_chunks, cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + + *mem_ptr = new int_scalar_mul_buffer(stream, params, num_radix_blocks, + allocate_gpu_memory); +} + +template +__host__ void host_integer_scalar_mul_radix( + cuda_stream_t *stream, T *lwe_array, T *decomposed_scalar, + T *has_at_least_one_set, int_scalar_mul_buffer *mem, void *bsk, T *ksk, + uint32_t input_lwe_dimension, uint32_t message_modulus, + uint32_t num_radix_blocks, uint32_t num_scalars) { + + if (num_radix_blocks == 0 | num_scalars == 0) + return; + + cudaSetDevice(stream->gpu_index); + // lwe_size includes the presence of the body + // whereas lwe_dimension is the number of elements in the mask + uint32_t lwe_size = input_lwe_dimension + 1; + uint32_t lwe_size_bytes = lwe_size * sizeof(T); + uint32_t msg_bits = (uint32_t)std::log2(message_modulus); + uint32_t num_ciphertext_bits = msg_bits * num_radix_blocks; + + T *preshifted_buffer = mem->preshifted_buffer; + T *all_shifted_buffer = mem->all_shifted_buffer; + + for (size_t shift_amount = 0; shift_amount < msg_bits; shift_amount++) { + T *ptr = preshifted_buffer + shift_amount * lwe_size * num_radix_blocks; + if (has_at_least_one_set[shift_amount] == 1) { + cuda_memcpy_async_gpu_to_gpu(ptr, lwe_array, + lwe_size_bytes * num_radix_blocks, stream); + host_integer_radix_logical_scalar_shift_kb_inplace( + stream, ptr, shift_amount, mem->logical_scalar_shift_buffer, bsk, ksk, + num_radix_blocks); + } else { + // create trivial assign for value = 0 + cuda_memset_async(ptr, 0, num_radix_blocks * lwe_size_bytes, stream); + } + } + size_t j = 0; + for (size_t i = 0; i < min(num_scalars, num_ciphertext_bits); i++) { + if (decomposed_scalar[i] == 1) { + // Perform a block shift + T *preshifted_radix_ct = + preshifted_buffer + (i % msg_bits) * num_radix_blocks * lwe_size; + T *block_shift_buffer = + all_shifted_buffer + j * num_radix_blocks * lwe_size; + radix_blocks_rotate_right<<stream>>>( + block_shift_buffer, preshifted_radix_ct, i / msg_bits, + num_radix_blocks, lwe_size); + // create trivial assign for value = 0 + cuda_memset_async(block_shift_buffer, 0, (i / msg_bits) * lwe_size_bytes, + stream); + j++; + } + } + + if (j == 0) { + // lwe array = 0 + cuda_memset_async(lwe_array, 0, num_radix_blocks * lwe_size_bytes, stream); + } else { + int terms_degree[j * num_radix_blocks]; + for (int i = 0; i < j * num_radix_blocks; i++) { + terms_degree[i] = message_modulus - 1; + } + host_integer_sum_ciphertexts_vec_kb( + stream, lwe_array, all_shifted_buffer, terms_degree, bsk, ksk, + mem->sum_ciphertexts_vec_mem, num_radix_blocks, j); + } +} + +// Small scalar_mul is used in shift/rotate +template +__host__ void host_integer_small_scalar_mul_radix( + cuda_stream_t *stream, T *output_lwe_array, T *input_lwe_array, T scalar, + uint32_t input_lwe_dimension, uint32_t input_lwe_ciphertext_count) { + + cudaSetDevice(stream->gpu_index); + // lwe_size includes the presence of the body + // whereas lwe_dimension is the number of elements in the mask + int lwe_size = input_lwe_dimension + 1; + // Create a 1-dimensional grid of threads + int num_blocks = 0, num_threads = 0; + int num_entries = input_lwe_ciphertext_count * lwe_size; + getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads); + dim3 grid(num_blocks, 1, 1); + dim3 thds(num_threads, 1, 1); + + device_small_scalar_radix_multiplication<<stream>>>( + output_lwe_array, input_lwe_array, scalar, input_lwe_dimension, + input_lwe_ciphertext_count); + check_cuda_error(cudaGetLastError()); +} +#endif diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu index b3189b6102..0c6bb9c8f4 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu @@ -72,6 +72,7 @@ void cuda_integer_radix_arithmetic_scalar_shift_kb_64_inplace( void cleanup_cuda_integer_radix_logical_scalar_shift(cuda_stream_t *stream, int8_t **mem_ptr_void) { + cudaSetDevice(stream->gpu_index); int_logical_scalar_shift_buffer *mem_ptr = (int_logical_scalar_shift_buffer *)(*mem_ptr_void); @@ -81,6 +82,7 @@ void cleanup_cuda_integer_radix_logical_scalar_shift(cuda_stream_t *stream, void cleanup_cuda_integer_radix_arithmetic_scalar_shift(cuda_stream_t *stream, int8_t **mem_ptr_void) { + cudaSetDevice(stream->gpu_index); int_arithmetic_scalar_shift_buffer *mem_ptr = (int_arithmetic_scalar_shift_buffer *)(*mem_ptr_void); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh index b40e598eed..ba4cc811b5 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh @@ -5,9 +5,9 @@ #include "device.h" #include "integer.cuh" #include "integer.h" -#include "multiplication.cuh" #include "pbs/programmable_bootstrap_classic.cuh" #include "pbs/programmable_bootstrap_multibit.cuh" +#include "scalar_mul.cuh" #include "types/complex/operations.cuh" #include "utils/helper.cuh" #include "utils/kernel_dimensions.cuh" @@ -157,7 +157,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace( lwe_last_out = lwe_array; for (int i = bits_per_block - 2; i >= 0; i--) { - host_integer_small_scalar_mult_radix( + host_integer_small_scalar_mul_radix( stream, lwe_last_out, lwe_last_out, 2, big_lwe_dimension, num_radix_blocks); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index ba8b88d1b4..d40a929abb 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -37,6 +37,8 @@ void execute_pbs(cuda_stream_t *stream, Torus *lwe_array_out, // 64 bits switch (pbs_type) { case MULTI_BIT: + if (grouping_factor == 0) + PANIC("Multi-bit PBS error: grouping factor should be > 0.") cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( stream, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, @@ -89,6 +91,8 @@ void execute_scratch_pbs(cuda_stream_t *stream, int8_t **pbs_buffer, // 64 bits switch (pbs_type) { case MULTI_BIT: + if (grouping_factor == 0) + PANIC("Multi-bit PBS error: grouping factor should be > 0.") scratch_cuda_multi_bit_programmable_bootstrap_64( stream, pbs_buffer, lwe_dimension, glwe_dimension, polynomial_size, level_count, grouping_factor, input_lwe_ciphertext_count, diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 864625cd61..6d7438ddd1 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -516,14 +516,41 @@ extern "C" { carry_modulus: u32, ); - pub fn cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( + pub fn scratch_cuda_integer_scalar_mul_kb_64( + v_stream: *const c_void, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: u32, + allocate_gpu_memory: bool, + ); + + pub fn cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace( v_stream: *const c_void, lwe_array: *mut c_void, - scalar_input: u64, + decomposed_scalar: *const u64, + has_at_least_one_set: *const u64, + mem: *mut i8, + bsk: *const c_void, + ksk: *const c_void, lwe_dimension: u32, - lwe_ciphertext_count: u32, + polynomial_size: u32, + message_modulus: u32, + num_blocks: u32, + num_scalars: u32, ); + pub fn cleanup_cuda_integer_radix_scalar_mul(v_stream: *const c_void, mem_ptr: *mut *mut i8); + pub fn scratch_cuda_integer_radix_bitop_kb_64( v_stream: *const c_void, mem_ptr: *mut *mut i8, diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs index 441b559a86..f6ae351523 100644 --- a/tfhe/benches/integer/bench.rs +++ b/tfhe/benches/integer/bench.rs @@ -1608,6 +1608,12 @@ mod cuda { rng_func: default_scalar ); + define_cuda_server_key_bench_clean_input_scalar_fn!( + method_name: unchecked_scalar_mul, + display_name: mul, + rng_func: mul_scalar + ); + define_cuda_server_key_bench_clean_input_scalar_fn!( method_name: unchecked_scalar_sub, display_name: sub, @@ -1802,6 +1808,12 @@ mod cuda { rng_func: default_scalar ); + define_cuda_server_key_bench_clean_input_scalar_fn!( + method_name: scalar_mul, + display_name: mul, + rng_func: mul_scalar + ); + define_cuda_server_key_bench_clean_input_scalar_fn!( method_name: scalar_left_shift, display_name: left_shift, @@ -1904,6 +1916,7 @@ mod cuda { cuda_unchecked_scalar_bitor, cuda_unchecked_scalar_bitxor, cuda_unchecked_scalar_add, + cuda_unchecked_scalar_mul, cuda_unchecked_scalar_sub, cuda_unchecked_scalar_left_shift, cuda_unchecked_scalar_right_shift, @@ -1946,6 +1959,7 @@ mod cuda { default_scalar_cuda_ops, cuda_scalar_sub, cuda_scalar_add, + cuda_scalar_mul, cuda_scalar_left_shift, cuda_scalar_right_shift, cuda_scalar_bitand, diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index af834a93e1..f09cd87a18 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -1562,6 +1562,18 @@ mod cuda { tfhe::integer::I256::from((clearlow, clearhigh)) } + fn mul_signed_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType { + loop { + let clearlow = rng.gen::(); + let clearhigh = rng.gen::(); + let scalar = tfhe::integer::I256::from((clearlow, clearhigh)); + // If scalar is power of two, it is just a shit, which is a happy path. + if !scalar.is_power_of_two() { + return scalar; + } + } + } + define_cuda_server_key_bench_clean_input_signed_fn!( method_name: unchecked_add, display_name: add @@ -1628,6 +1640,12 @@ mod cuda { rng_func: default_signed_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_mul, + display_name: mul, + rng_func: mul_signed_scalar + ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( method_name: unchecked_scalar_sub, display_name: sub, @@ -1734,6 +1752,12 @@ mod cuda { rng_func: default_signed_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_mul, + display_name: mul, + rng_func: mul_signed_scalar + ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( method_name: scalar_sub, display_name: sub, @@ -1789,6 +1813,7 @@ mod cuda { criterion_group!( unchecked_scalar_cuda_ops, cuda_unchecked_scalar_add, + cuda_unchecked_scalar_mul, cuda_unchecked_scalar_sub, cuda_unchecked_scalar_bitand, cuda_unchecked_scalar_bitor, @@ -1816,6 +1841,7 @@ mod cuda { criterion_group!( default_scalar_cuda_ops, cuda_scalar_add, + cuda_scalar_mul, cuda_scalar_sub, cuda_scalar_bitand, cuda_scalar_bitor, diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index cede302bfe..075a63116a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -538,8 +538,13 @@ generic_integer_impl_scalar_operation!( RadixCiphertext::Cpu(inner_result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Mul '*' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + let inner_result = with_thread_local_cuda_stream(|stream| { + cuda_key.key.scalar_mul( + &*lhs.ciphertext.on_gpu(), rhs, stream + ) + }); + RadixCiphertext::Cuda(inner_result) } }) } @@ -1371,8 +1376,11 @@ generic_integer_impl_scalar_operation_assign!( .scalar_mul_assign_parallelized(lhs.ciphertext.as_cpu_mut(), rhs); }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("MulAssign '*=' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + with_thread_local_cuda_stream(|stream| { + cuda_key.key + .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(), rhs, stream); + }) } }) } diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index 63b1b88700..22a430a8ba 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -187,17 +187,17 @@ impl CudaRadixCiphertextInfo { } } - pub(crate) fn after_small_scalar_mul(&self, scalar: u8) -> Self { + pub(crate) fn after_scalar_mul(&self) -> Self { Self { blocks: self .blocks .iter() - .map(|left| CudaBlockInfo { - degree: Degree::new(left.degree.get() * scalar as usize), - message_modulus: left.message_modulus, - carry_modulus: left.carry_modulus, - pbs_order: left.pbs_order, - noise_level: left.noise_level, + .map(|info| CudaBlockInfo { + degree: Degree::new(info.message_modulus.0 - 1), + message_modulus: info.message_modulus, + carry_modulus: info.carry_modulus, + pbs_order: info.pbs_order, + noise_level: info.noise_level + NoiseLevel::NOMINAL, }) .collect(), } diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 332d95e7b5..ce2d784364 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -170,24 +170,127 @@ impl CudaStream { ); } + #[allow(clippy::too_many_arguments)] /// # Safety /// /// - [CudaStream::synchronize] __must__ be called after this function /// as soon as synchronization is required - pub unsafe fn small_scalar_mult_integer_radix_assign_async( + pub unsafe fn unchecked_scalar_mul_integer_radix_classic_kb_async( &self, lwe_array: &mut CudaVec, - scalar: u64, + decomposed_scalar: &[T], + has_at_least_one_set: &[T], + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, lwe_dimension: LweDimension, + pbs_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + ks_level: DecompositionLevelCount, num_blocks: u32, + num_scalars: u32, ) { - cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_scalar_mul_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::Classical as u32, + true, + ); + + cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace( self.as_c_ptr(), lwe_array.as_mut_c_ptr(), - scalar, + decomposed_scalar.as_ptr().cast::(), + has_at_least_one_set.as_ptr().cast::(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + (glwe_dimension.0 * polynomial_size.0) as u32, + polynomial_size.0 as u32, + message_modulus.0 as u32, + num_blocks, + num_scalars, + ); + + cleanup_cuda_integer_radix_scalar_mul(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); + } + + #[allow(clippy::too_many_arguments)] + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_mul_integer_radix_multibit_kb_async( + &self, + lwe_array: &mut CudaVec, + decomposed_scalar: &[T], + has_at_least_one_set: &[T], + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension: LweDimension, + pbs_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + ks_level: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + num_blocks: u32, + num_scalars: u32, + ) { + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_scalar_mul_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + true, ); + + cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace( + self.as_c_ptr(), + lwe_array.as_mut_c_ptr(), + decomposed_scalar.as_ptr().cast::(), + has_at_least_one_set.as_ptr().cast::(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + (glwe_dimension.0 * polynomial_size.0) as u32, + polynomial_size.0 as u32, + message_modulus.0 as u32, + num_blocks, + num_scalars, + ); + + cleanup_cuda_integer_radix_scalar_mul(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs index 16b7383919..857d8f8dbf 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs @@ -3,7 +3,7 @@ use crate::core_crypto::gpu::CudaStream; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::gpu::server_key::CudaServerKey; -use itertools::Itertools; +use crate::prelude::CastInto; impl CudaServerKey { /// Computes homomorphically an addition between a scalar and a ciphertext. @@ -45,7 +45,7 @@ impl CudaServerKey { /// ``` pub fn unchecked_scalar_add(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; @@ -63,23 +63,19 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { if scalar != Scalar::ZERO { let bits_in_message = self.message_modulus.0.ilog2(); - let decomposer = - BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message).iter_as::(); - let mut d_decomposed_scalar = CudaVec::::new_async(ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream); - let scalar64 = decomposer - .collect_vec() - .iter() - .map(|&x| x as u64) - .take(d_decomposed_scalar.len()) - .collect_vec(); - d_decomposed_scalar.copy_from_cpu_async(scalar64.as_slice(), stream); + let decomposed_scalar = + BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message) + .iter_as::() + .take(d_decomposed_scalar.len()) + .collect::>(); + d_decomposed_scalar.copy_from_cpu_async(decomposed_scalar.as_slice(), stream); let lwe_dimension = ct.as_ref().d_blocks.lwe_dimension(); // If the scalar is decomposed using less than the number of blocks our ciphertext @@ -88,7 +84,7 @@ impl CudaServerKey { &mut ct.as_mut().d_blocks.0.d_vec, &d_decomposed_scalar, lwe_dimension, - scalar64.len() as u32, + decomposed_scalar.len() as u32, self.message_modulus.0 as u32, self.carry_modulus.0 as u32, ); @@ -103,7 +99,7 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { unsafe { @@ -151,7 +147,7 @@ impl CudaServerKey { /// ``` pub fn scalar_add(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; @@ -169,7 +165,7 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { if !ct.block_carries_are_empty() { @@ -182,7 +178,7 @@ impl CudaServerKey { pub fn scalar_add_assign(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream) where - Scalar: DecomposableInto, + Scalar: DecomposableInto + CastInto, T: CudaIntegerRadixCiphertext, { unsafe { diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs index afbcc5b69d..862802dace 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs @@ -1,6 +1,10 @@ use crate::core_crypto::gpu::CudaStream; -use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; -use crate::integer::gpu::server_key::CudaServerKey; +use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; +use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; +use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; +use crate::integer::server_key::ScalarMultiplier; +use crate::prelude::CastInto; +use itertools::Itertools; impl CudaServerKey { /// Computes homomorphically a multiplication between a scalar and a ciphertext. @@ -34,20 +38,19 @@ impl CudaServerKey { /// let mut d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &mut stream); /// /// // Compute homomorphically a scalar multiplication: - /// let d_ct_res = sks.unchecked_small_scalar_mul(&d_ct, scalar, &mut stream); + /// let d_ct_res = sks.unchecked_scalar_mul(&d_ct, scalar, &mut stream); /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); /// /// let clear: u64 = cks.decrypt(&ct_res); /// assert_eq!(scalar * msg, clear); /// ``` - pub fn unchecked_small_scalar_mul( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: u64, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext { + pub fn unchecked_scalar_mul(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T + where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { let mut result = unsafe { ct.duplicate_async(stream) }; - self.unchecked_small_scalar_mul_assign(&mut result, scalar, stream); + self.unchecked_scalar_mul_assign(&mut result, scalar, stream); result } @@ -55,43 +58,110 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_small_scalar_mul_assign_async( + pub unsafe fn unchecked_scalar_mul_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: u64, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, - ) { - match scalar { - 0 => { - ct.as_mut().d_blocks.0.d_vec.memset_async(0, stream); - } - 1 => { - // Multiplication by one is the identity + ) where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { + if scalar == Scalar::ZERO { + ct.as_mut().d_blocks.0.d_vec.memset_async(0, stream); + return; + } + + if scalar == Scalar::ONE { + return; + } + + if scalar.is_power_of_two() { + // Shifting cost one bivariate PBS so its always faster + // than multiplying + self.unchecked_scalar_left_shift_assign_async(ct, scalar.ilog2() as u64, stream); + return; + } + let ciphertext = ct.as_mut(); + let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0; + let msg_bits = self.message_modulus.0.ilog2() as usize; + let decomposer = BlockDecomposer::with_early_stop_at_zero(scalar, 1).iter_as::(); + + // We don't want to compute shifts if we are not going to use the + // resulting value + let mut has_at_least_one_set = vec![0u64; msg_bits]; + for (i, bit) in decomposer.collect_vec().iter().copied().enumerate() { + if bit == 1 { + has_at_least_one_set[i % msg_bits] = 1; } - _ => { - let lwe_dimension = ct.as_ref().d_blocks.lwe_dimension(); - let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); + } - stream.small_scalar_mult_integer_radix_assign_async( + let decomposed_scalar = BlockDecomposer::with_early_stop_at_zero(scalar, 1) + .iter_as::() + .collect::>(); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + stream.unchecked_scalar_mul_integer_radix_classic_kb_async( &mut ct.as_mut().d_blocks.0.d_vec, - scalar, - lwe_dimension, - lwe_ciphertext_count.0 as u32, + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.decomp_base_log, + d_bsk.decomp_level_count, + self.key_switching_key.decomposition_base_log(), + self.key_switching_key.decomposition_level_count(), + num_blocks as u32, + decomposed_scalar.len() as u32, ); } - } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + stream.unchecked_scalar_mul_integer_radix_multibit_kb_async( + &mut ct.as_mut().d_blocks.0.d_vec, + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.decomp_level_count, + self.key_switching_key.decomposition_base_log(), + self.key_switching_key.decomposition_level_count(), + d_multibit_bsk.grouping_factor, + num_blocks as u32, + decomposed_scalar.len() as u32, + ); + } + }; - ct.as_mut().info = ct.as_ref().info.after_small_scalar_mul(scalar as u8); + ct.as_mut().info = ct.as_ref().info.after_scalar_mul(); } - pub fn unchecked_small_scalar_mul_assign( + pub fn unchecked_scalar_mul_assign( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: u64, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, - ) { + ) where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { unsafe { - self.unchecked_small_scalar_mul_assign_async(ct, scalar, stream); + self.unchecked_scalar_mul_assign_async(ct, scalar, stream); } stream.synchronize(); } @@ -127,20 +197,19 @@ impl CudaServerKey { /// let mut d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &mut stream); /// /// // Compute homomorphically a scalar multiplication: - /// let d_ct_res = sks.small_scalar_mul(&d_ct, scalar, &mut stream); + /// let d_ct_res = sks.scalar_mul(&d_ct, scalar, &mut stream); /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); /// /// let clear: u64 = cks.decrypt(&ct_res); /// assert_eq!(scalar * msg, clear); /// ``` - pub fn small_scalar_mul( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: u64, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext { + pub fn scalar_mul(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T + where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { let mut result = unsafe { ct.duplicate_async(stream) }; - self.small_scalar_mul_assign(&mut result, scalar, stream); + self.scalar_mul_assign(&mut result, scalar, stream); result } @@ -148,28 +217,30 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn small_scalar_mul_assign_async( + pub unsafe fn scalar_mul_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: u64, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, - ) { + ) where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { if !ct.block_carries_are_empty() { self.full_propagate_assign_async(ct, stream); }; - self.unchecked_small_scalar_mul_assign_async(ct, scalar, stream); + self.unchecked_scalar_mul_assign_async(ct, scalar, stream); self.full_propagate_assign_async(ct, stream); } - pub fn small_scalar_mul_assign( - &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: u64, - stream: &CudaStream, - ) { + pub fn scalar_mul_assign(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream) + where + Scalar: ScalarMultiplier + DecomposableInto + CastInto, + T: CudaIntegerRadixCiphertext, + { unsafe { - self.small_scalar_mul_assign_async(ct, scalar, stream); + self.scalar_mul_assign_async(ct, scalar, stream); } stream.synchronize(); } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs index 33332d8468..573ed0d219 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs @@ -4,6 +4,7 @@ use crate::integer::block_decomposition::DecomposableInto; use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::gpu::server_key::CudaServerKey; use crate::integer::server_key::TwosComplementNegation; +use crate::prelude::CastInto; impl CudaServerKey { /// Computes homomorphically a subtraction between a ciphertext and a scalar. @@ -45,7 +46,7 @@ impl CudaServerKey { /// ``` pub fn unchecked_scalar_sub(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; @@ -63,7 +64,7 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { let negated_scalar = scalar.twos_complement_negation(); @@ -77,7 +78,7 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { unsafe { @@ -125,7 +126,7 @@ impl CudaServerKey { /// ``` pub fn scalar_sub(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; @@ -143,7 +144,7 @@ impl CudaServerKey { scalar: Scalar, stream: &CudaStream, ) where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { if !ct.block_carries_are_empty() { @@ -156,7 +157,7 @@ impl CudaServerKey { pub fn scalar_sub_assign(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream) where - Scalar: DecomposableInto + Numeric + TwosComplementNegation, + Scalar: DecomposableInto + Numeric + TwosComplementNegation + CastInto, T: CudaIntegerRadixCiphertext, { unsafe { diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index 52a288607c..0e4a21e949 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_add.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_add.rs index 562a5ea5e3..185a1ae45d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_add.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_add::{ signed_default_add_test, signed_unchecked_add_test, }; use crate::shortint::parameters::*; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_bitwise_op.rs index bd77845f9f..0d7db13cae 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_bitwise_op.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_bitwise_op::{ signed_default_bitand_test, signed_default_bitnot_test, signed_default_bitor_test, signed_default_bitxor_test, signed_unchecked_bitand_test, signed_unchecked_bitor_test, signed_unchecked_bitxor_test, diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_mul.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_mul.rs index ed009dc2df..d6a0627de3 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_mul.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_mul.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_mul::{ signed_default_mul_test, signed_unchecked_mul_test, }; use crate::shortint::parameters::*; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_neg.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_neg.rs index 7b8fb2198c..b05ad75d9d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_neg.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_neg.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_neg::{ signed_default_neg_test, signed_unchecked_neg_test, }; use crate::shortint::parameters::*; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs index 9248e64e40..7303db2c5e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_add::{ signed_default_scalar_add_test, signed_unchecked_scalar_add_test, }; use crate::shortint::parameters::*; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_bitwise_op.rs index d9fed7d861..5adbd2c67d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_bitwise_op.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_bitwise_op::{ signed_default_scalar_bitand_test, signed_default_scalar_bitor_test, signed_default_scalar_bitxor_test, }; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_mul.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_mul.rs new file mode 100644 index 0000000000..bcf7572d4a --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_mul.rs @@ -0,0 +1,16 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_mul::signed_unchecked_scalar_mul_test; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_mul); + +fn integer_signed_unchecked_scalar_mul

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_mul); + signed_unchecked_scalar_mul_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs index d6eda8ff89..526c999018 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_shift::{ signed_default_scalar_left_shift_test, signed_default_scalar_right_shift_test, signed_unchecked_scalar_left_shift_test, signed_unchecked_scalar_right_shift_test, }; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs index 585cba9854..b6655ca1dd 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::signed_unchecked_scalar_sub_test; +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::signed_unchecked_scalar_sub_test; use crate::shortint::parameters::*; create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_sub.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_sub.rs index 7b3e7cc6b2..74c7d7f62d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_sub.rs @@ -2,7 +2,7 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ +use crate::integer::server_key::radix_parallel::tests_signed::test_sub::{ signed_default_sub_test, signed_unchecked_sub_test, }; use crate::shortint::parameters::*; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index d619833a40..6f11988704 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; @@ -82,7 +83,6 @@ impl GpuFunctionExecutor { } // Unchecked operations -create_gpu_parametrized_test!(integer_unchecked_small_scalar_mul); create_gpu_parametrized_test!(integer_unchecked_eq); create_gpu_parametrized_test!(integer_unchecked_ne); create_gpu_parametrized_test!(integer_unchecked_gt); @@ -104,7 +104,6 @@ create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left); create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right); // Default operations -create_gpu_parametrized_test!(integer_small_scalar_mul); create_gpu_parametrized_test!(integer_eq); create_gpu_parametrized_test!(integer_ne); create_gpu_parametrized_test!(integer_gt); @@ -361,14 +360,6 @@ where } } -fn integer_unchecked_small_scalar_mul

(param: P) -where - P: Into, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_small_scalar_mul); - unchecked_small_scalar_mul_test(param, executor); -} - fn integer_unchecked_eq

(param: P) where P: Into + Copy, @@ -1226,14 +1217,6 @@ where } } -fn integer_small_scalar_mul

(param: P) -where - P: Into, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::small_scalar_mul); - default_small_scalar_mul_test(param, executor); -} - fn integer_eq

(param: P) where P: Into + Copy, diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_mul.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_mul.rs new file mode 100644 index 0000000000..7c2fa70e86 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_mul.rs @@ -0,0 +1,27 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + default_scalar_mul_test, unchecked_scalar_mul_corner_cases_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_unchecked_scalar_mul); +create_gpu_parametrized_test!(integer_scalar_mul); + +fn integer_unchecked_scalar_mul

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_mul); + unchecked_scalar_mul_corner_cases_test(param, executor); +} + +fn integer_scalar_mul

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_mul); + default_scalar_mul_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs index 005e76cee0..98bba139b3 100644 --- a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -1116,9 +1116,10 @@ pub(crate) mod tests_unsigned { pub(crate) mod tests_signed { use super::*; use crate::integer::keycache::KEY_CACHE; - use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - random_non_zero_value, signed_add_under_modulus, NB_CTXT, NB_TESTS_SMALLER, + use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, }; + use crate::integer::server_key::radix_parallel::tests_unsigned::{NB_CTXT, NB_TESTS_SMALLER}; use crate::integer::{IntegerKeyKind, RadixClientKey}; use crate::shortint::PBSParameters; use rand::Rng; diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 88088a471b..d7702fc17a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -24,8 +24,6 @@ mod ilog2; #[cfg(test)] pub(crate) mod tests_cases_comparisons; #[cfg(test)] -pub(crate) mod tests_cases_signed; -#[cfg(test)] pub(crate) mod tests_cases_unsigned; #[cfg(test)] pub(crate) mod tests_signed; diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs index 08e6c675e9..2f48db4d69 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_mul.rs @@ -1,322 +1,10 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; -use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext}; +use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::radix::scalar_mul::ScalarMultiplier; -use crate::integer::server_key::CheckError; use crate::integer::ServerKey; use rayon::prelude::*; impl ServerKey { - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// This function computes the operation without checking if it exceeds the capacity of the - /// ciphertext. - /// - /// The result is returned as a new ciphertext. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 30; - /// let scalar = 3; - /// - /// let ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// let ct_res = sks.unchecked_small_scalar_mul_parallelized(&ct, scalar); - /// - /// let clear: u64 = cks.decrypt(&ct_res); - /// assert_eq!(scalar * msg, clear); - /// ``` - pub fn unchecked_small_scalar_mul_parallelized( - &self, - ctxt: &RadixCiphertext, - scalar: u64, - ) -> RadixCiphertext { - let mut ct_result = ctxt.clone(); - self.unchecked_small_scalar_mul_assign_parallelized(&mut ct_result, scalar); - ct_result - } - - pub fn unchecked_small_scalar_mul_assign_parallelized( - &self, - ctxt: &mut RadixCiphertext, - scalar: u64, - ) { - ctxt.blocks.par_iter_mut().for_each(|ct_i| { - self.key.unchecked_scalar_mul_assign(ct_i, scalar as u8); - }); - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// If the operation can be performed, the result is returned in a new ciphertext. - /// Otherwise a [CheckError] is returned. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 33; - /// let scalar = 3; - /// - /// let ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// let ct_res = sks.checked_small_scalar_mul_parallelized(&ct, scalar); - /// - /// match ct_res { - /// Err(x) => panic!("{:?}", x), - /// Ok(y) => { - /// let clear: u64 = cks.decrypt(&y); - /// assert_eq!(msg * scalar, clear); - /// } - /// } - /// ``` - pub fn checked_small_scalar_mul_parallelized( - &self, - ct: &RadixCiphertext, - scalar: u64, - ) -> Result { - // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext - self.is_small_scalar_mul_possible(ct, scalar)?; - Ok(self.unchecked_small_scalar_mul_parallelized(ct, scalar)) - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// If the operation can be performed, the result is assigned to the ciphertext given - /// as parameter. - /// Otherwise a [CheckError] is returned. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 33; - /// let scalar = 3; - /// - /// let mut ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// sks.checked_small_scalar_mul_assign_parallelized(&mut ct, scalar); - /// - /// let clear_res: u64 = cks.decrypt(&ct); - /// assert_eq!(clear_res, msg * scalar); - /// ``` - pub fn checked_small_scalar_mul_assign_parallelized( - &self, - ct: &mut RadixCiphertext, - scalar: u64, - ) -> Result<(), CheckError> { - // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext - self.is_small_scalar_mul_possible(ct, scalar)?; - self.unchecked_small_scalar_mul_assign_parallelized(ct, scalar); - Ok(()) - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// `small` means the scalar value shall fit in a __shortint block__. - /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2_KS_PBS, - /// the scalar should fit in 2 bits. - /// - /// The result is returned as a new ciphertext. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let modulus = 1 << 8; - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 13; - /// let scalar = 3; - /// - /// let mut ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// let ct_res = sks.smart_small_scalar_mul_parallelized(&mut ct, scalar); - /// - /// // Decrypt: - /// let clear: u64 = cks.decrypt(&ct_res); - /// assert_eq!(msg * scalar % modulus, clear); - /// ``` - pub fn smart_small_scalar_mul_parallelized( - &self, - ctxt: &mut RadixCiphertext, - scalar: u64, - ) -> RadixCiphertext { - if self.is_small_scalar_mul_possible(ctxt, scalar).is_err() { - self.full_propagate_parallelized(ctxt); - } - self.is_small_scalar_mul_possible(ctxt, scalar).unwrap(); - self.unchecked_small_scalar_mul_parallelized(ctxt, scalar) - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// `small` means the scalar shall value fit in a __shortint block__. - /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2_KS_PBS, - /// the scalar should fit in 2 bits. - /// - /// The result is assigned to the input ciphertext - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let modulus = 1 << 8; - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 9; - /// let scalar = 3; - /// - /// let mut ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// sks.smart_small_scalar_mul_assign_parallelized(&mut ct, scalar); - /// - /// // Decrypt: - /// let clear: u64 = cks.decrypt(&ct); - /// assert_eq!(msg * scalar % modulus, clear); - /// ``` - pub fn smart_small_scalar_mul_assign_parallelized( - &self, - ctxt: &mut RadixCiphertext, - scalar: u64, - ) { - if self.is_small_scalar_mul_possible(ctxt, scalar).is_err() { - self.full_propagate_parallelized(ctxt); - } - self.is_small_scalar_mul_possible(ctxt, scalar).unwrap(); - self.unchecked_small_scalar_mul_assign_parallelized(ctxt, scalar); - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// `small` means the scalar value shall fit in a __shortint block__. - /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2_KS_PBS, - /// the scalar should fit in 2 bits. - /// - /// The result is returned as a new ciphertext. - /// - /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will - /// check that the input ciphertexts block carries are empty and clears them if it's not the - /// case and the operation requires it. It outputs a ciphertext whose block carries are always - /// empty. - /// - /// This means that when using only "default" operations, a given operation (like add for - /// example) has always the same performance characteristics from one call to another and - /// guarantees correctness by pre-emptively clearing carries of output ciphertexts. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let modulus = 1 << 8; - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 13; - /// let scalar = 3; - /// - /// let mut ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// let ct_res = sks.small_scalar_mul_parallelized(&mut ct, scalar); - /// - /// // Decrypt: - /// let clear: u64 = cks.decrypt(&ct_res); - /// assert_eq!(msg * scalar % modulus, clear); - /// ``` - pub fn small_scalar_mul_parallelized( - &self, - ctxt: &RadixCiphertext, - scalar: u64, - ) -> RadixCiphertext { - let mut ct_res = ctxt.clone(); - self.small_scalar_mul_assign_parallelized(&mut ct_res, scalar); - ct_res - } - - /// Computes homomorphically a multiplication between a scalar and a ciphertext. - /// - /// `small` means the scalar shall value fit in a __shortint block__. - /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2_KS_PBS, - /// the scalar should fit in 2 bits. - /// - /// The result is assigned to the input ciphertext - /// - /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will - /// check that the input ciphertexts block carries are empty and clears them if it's not the - /// case and the operation requires it. It outputs a ciphertext whose block carries are always - /// empty. - /// - /// This means that when using only "default" operations, a given operation (like add for - /// example) has always the same performance characteristics from one call to another and - /// guarantees correctness by pre-emptively clearing carries of output ciphertexts. - /// - /// # Example - /// - /// ```rust - /// use tfhe::integer::gen_keys_radix; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// // We have 4 * 2 = 8 bits of message - /// let modulus = 1 << 8; - /// let size = 4; - /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, size); - /// - /// let msg = 9; - /// let scalar = 3; - /// - /// let mut ct = cks.encrypt(msg); - /// - /// // Compute homomorphically a scalar multiplication: - /// sks.small_scalar_mul_assign_parallelized(&mut ct, scalar); - /// - /// // Decrypt: - /// let clear: u64 = cks.decrypt(&ct); - /// assert_eq!(msg * scalar % modulus, clear); - /// ``` - pub fn small_scalar_mul_assign_parallelized(&self, ctxt: &mut RadixCiphertext, scalar: u64) { - if !ctxt.block_carries_are_empty() { - self.full_propagate_parallelized(ctxt); - } - self.unchecked_small_scalar_mul_assign_parallelized(ctxt, scalar); - self.full_propagate_parallelized(ctxt); - } - pub fn unchecked_scalar_mul_parallelized(&self, ct: &T, scalar: Scalar) -> T where T: IntegerRadixCiphertext, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs deleted file mode 100644 index 837a129bd9..0000000000 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs +++ /dev/null @@ -1,1933 +0,0 @@ -use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; -use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, SignedRadixCiphertext}; -use crate::shortint::ciphertext::NoiseLevel; -use crate::shortint::PBSParameters; -use itertools::izip; -use rand::prelude::ThreadRng; -use rand::Rng; -use std::sync::Arc; - -/// Number of loop iteration within randomized tests -#[cfg(not(tarpaulin))] -pub(crate) const NB_TESTS: usize = 30; -/// Smaller number of loop iteration within randomized test, -/// meant for test where the function tested is more expensive -#[cfg(not(tarpaulin))] -pub(crate) const NB_TESTS_SMALLER: usize = 10; -#[cfg(not(tarpaulin))] -pub(crate) const NB_TESTS_UNCHECKED: usize = NB_TESTS; - -// Use lower numbers for coverage to ensure fast tests to counter balance slowdown due to code -// instrumentation -#[cfg(tarpaulin)] -pub(crate) const NB_TESTS: usize = 1; -#[cfg(tarpaulin)] -pub(crate) const NB_TESTS_SMALLER: usize = 1; -/// Unchecked test cases needs a minimum number of tests of 4 in order to provide guarantees. -#[cfg(tarpaulin)] -pub(crate) const NB_TESTS_UNCHECKED: usize = 4; - -#[cfg(not(tarpaulin))] -pub(crate) const NB_CTXT: usize = 4; -#[cfg(tarpaulin)] -pub(crate) const NB_CTXT: usize = 2; - -//================================================================================ -// Helper functions -//================================================================================ - -pub(crate) fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { - signed_overflowing_add_under_modulus(lhs, rhs, modulus).0 -} - -// Adds two signed number modulo the given modulus -// -// This is to 'simulate' i8, i16, ixy using i64 integers -// -// lhs and rhs must be in [-modulus..modulus[ -pub(crate) fn signed_overflowing_add_under_modulus( - lhs: i64, - rhs: i64, - modulus: i64, -) -> (i64, bool) { - assert!(modulus > 0); - assert!((-modulus..modulus).contains(&lhs)); - - // The code below requires rhs and lhs to be in range -modulus..modulus - // in scalar tests, rhs may exceed modulus - // so we truncate it (is the fhe ops does) - let (mut res, mut overflowed) = if (-modulus..modulus).contains(&rhs) { - (lhs + rhs, false) - } else { - // 2*modulus to get all the bits - (lhs + (rhs % (2 * modulus)), true) - }; - - if res < -modulus { - // rem_euclid(modulus) would also work - res = modulus + (res - -modulus); - overflowed = true; - } else if res > modulus - 1 { - res = -modulus + (res - modulus); - overflowed = true; - } - (res, overflowed) -} - -pub(crate) fn signed_neg_under_modulus(lhs: i64, modulus: i64) -> i64 { - assert!(modulus > 0); - let mut res = -lhs; - if res < -modulus { - // rem_euclid(modulus) would also work - res = modulus + (res - -modulus); - } else if res > modulus - 1 { - res = -modulus + (res - modulus); - } - res -} - -// Subs two signed number modulo the given modulus -// -// This is to 'simulate' i8, i16, ixy using i64 integers -// -// lhs and rhs must be in [-modulus..modulus[ -pub(crate) fn signed_sub_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { - signed_overflowing_sub_under_modulus(lhs, rhs, modulus).0 -} - -pub(crate) fn signed_overflowing_sub_under_modulus( - lhs: i64, - rhs: i64, - modulus: i64, -) -> (i64, bool) { - // Technically we should be able to call overflowing_add_under_modulus(lhs, -rhs, ...) - // but due to -rhs being a 'special case' when rhs == -modulus, we have to - // so the impl here - assert!(modulus > 0); - assert!((-modulus..modulus).contains(&lhs)); - - // The code below requires rhs and lhs to be in range -modulus..modulus - // in scalar tests, rhs may exceed modulus - // so we truncate it (is the fhe ops does) - let (mut res, mut overflowed) = if (-modulus..modulus).contains(&rhs) { - (lhs - rhs, false) - } else { - // 2*modulus to get all the bits - (lhs - (rhs % (2 * modulus)), true) - }; - - if res < -modulus { - // rem_euclid(modulus) would also work - res = modulus + (res - -modulus); - overflowed = true; - } else if res > modulus - 1 { - res = -modulus + (res - modulus); - overflowed = true; - } - (res, overflowed) -} - -pub(crate) fn signed_mul_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { - assert!(modulus > 0); - overflowing_mul_under_modulus(lhs, rhs, modulus).0 -} - -pub(crate) fn overflowing_mul_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> (i64, bool) { - let (mut res, mut overflowed) = lhs.overflowing_mul(rhs); - overflowed |= res < -modulus || res >= modulus; - res %= modulus * 2; - if res < -modulus { - // rem_euclid(modulus) would also work - res = modulus + (res - -modulus); - } else if res > modulus - 1 { - res = -modulus + (res - modulus); - } - - (res, overflowed) -} - -pub(crate) fn absolute_value_under_modulus(lhs: i64, modulus: i64) -> i64 { - if lhs < 0 { - signed_neg_under_modulus(lhs, modulus) - } else { - lhs - } -} - -pub(crate) fn signed_left_shift_under_modulus(lhs: i64, rhs: u32, modulus: i64) -> i64 { - signed_mul_under_modulus(lhs, 1 << rhs, modulus) -} - -pub(crate) fn signed_right_shift_under_modulus(lhs: i64, rhs: u32, _modulus: i64) -> i64 { - lhs >> rhs -} - -pub(crate) fn signed_div_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { - // in signed integers, -modulus can be represented, but +modulus cannot - // thus, when dividing: -128 / -1 = 128 the results overflows to -128 - assert!(modulus > 0); - let mut res = lhs / rhs; - if res < -modulus { - // rem_euclid(modulus) would also work - res = modulus + (res - -modulus); - } else if res > modulus - 1 { - res = -modulus + (res - modulus); - } - res -} - -pub(crate) fn signed_rem_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { - assert!(modulus > 0); - let q = signed_div_under_modulus(lhs, rhs, modulus); - let q_times_rhs = signed_mul_under_modulus(q, rhs, modulus); - signed_sub_under_modulus(lhs, q_times_rhs, modulus) -} - -pub(crate) fn signed_div_rem_floor_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> (i64, i64) { - let mut q = signed_div_under_modulus(lhs, rhs, modulus); - let mut r = signed_rem_under_modulus(lhs, rhs, modulus); - - if (r != 0) && ((r < 0) != (rhs < 0)) { - q = signed_sub_under_modulus(q, 1, modulus); - r = signed_add_under_modulus(r, rhs, modulus); - } - - (q, r) -} - -/// helper function to do a rotate left when the type used to store -/// the value is bigger than the actual intended bit size -pub(crate) fn rotate_left_helper(value: i64, n: u32, actual_bit_size: u32) -> i64 { - // We start with: - // [0000000000000|xxxx] - // 64 b 0 - // - // rotated will be - // [0000000000xx|xx00] - // 64 b 0 - let n = n % actual_bit_size; - let mask = 1i64.wrapping_shl(actual_bit_size) - 1; - let shifted_mask = mask.wrapping_shl(n) & !mask; - - // Value maybe be negative and so, have its msb - // set to one, so use mask to only keep the part that interest - // us - let rotated = (value & mask).rotate_left(n); - - let tmp = (rotated & mask) | ((rotated & shifted_mask) >> actual_bit_size); - // If the sign bit after rotation is one, - // then all bits above it needs to be one - let new_sign_bit = (tmp >> (actual_bit_size - 1)) & 1; - let mut pad = -new_sign_bit; - pad <<= actual_bit_size; // only bits above actual_bit_size should be set - - pad | tmp -} - -/// helper function to do a rotate right when the type used to store -/// the value is bigger than the actual intended bit size -pub(crate) fn rotate_right_helper(value: i64, n: u32, actual_bit_size: u32) -> i64 { - // We start with: - // [yyyyyyyyyyyy|xxxx] - // 64 b 0 - // where xs are bits that we are interested in - // and ys are either 0 or 1 depending on if value is positive - // - // mask: [yyyyyyyyyyyy|mmmm] - // shifted_ mask: [mmyyyyyyyyyy|0000] - // - // rotated will be - // [xxyyyyyyyyyy|00xx] - // 64 b 0 - // - // To get the 'cycled' bits where they should be, - // we get them using a mask then shift - let n = n % actual_bit_size; - let mask = 1i64.wrapping_shl(actual_bit_size) - 1; - // shifted mask only needs the bits that cycled - let shifted_mask = mask.rotate_right(n) & !mask; - - // Value maybe be negative and so, have its msb - // set to one, so use mask to only keep the part that interest - // us - let rotated = (value & mask).rotate_right(n); - - let tmp = (rotated & mask) | ((rotated & shifted_mask) >> (u64::BITS - actual_bit_size)); - // If the sign bit after rotation is one, - // then all bits above it needs to be one - let new_sign_bit = (tmp >> (actual_bit_size - 1)) & 1; - let mut pad = -new_sign_bit; - pad <<= actual_bit_size; // only bits above actual_bit_size should be set - - pad | tmp -} - -/// Returns an array filled with random values such that: -/// - the first half contains values in [0..modulus[ -/// - the second half contains values in [-modulus..0] -pub(crate) fn random_signed_value_under_modulus( - rng: &mut rand::prelude::ThreadRng, - modulus: i64, -) -> [i64; N] { - assert!(modulus > 0); - - let mut values = [0i64; N]; - - for value in &mut values[..N / 2] { - *value = rng.gen_range(0..modulus); - } - - for value in &mut values[N / 2..] { - *value = rng.gen_range(-modulus..=0); - } - - values -} - -/// Returns an array filled with random values such that: -/// - the first half contains values in ]0..modulus[ -/// - the second half contains values in [-modulus..0[ -pub(crate) fn random_non_zero_signed_value_under_modulus( - rng: &mut rand::prelude::ThreadRng, - modulus: i64, -) -> [i64; N] { - assert!(modulus > 0); - - let mut values = [0i64; N]; - - for value in &mut values[..N / 2] { - *value = rng.gen_range(1..modulus); - } - - for value in &mut values[N / 2..] { - *value = rng.gen_range(-modulus..0); - } - - values -} - -/// Returns an iterator that yields pairs of i64 values in range `-modulus..modulus` -/// such that there is at least one pair of (P, P), (P, N), (N, N) (N, P) -/// where P means value >=0 and N means <= 0 -pub(crate) fn create_iterator_of_signed_random_pairs( - rng: &mut rand::prelude::ThreadRng, - modulus: i64, -) -> impl Iterator { - assert!(N >= 4, "N must be at least 4 to uphold the guarantee"); - let mut lhs_values = [0i64; N]; - let mut rhs_values = [0i64; N]; - - lhs_values[0] = rng.gen_range(0..modulus); - rhs_values[0] = rng.gen_range(0..modulus); - - lhs_values[1] = rng.gen_range(0..modulus); - rhs_values[1] = rng.gen_range(-modulus..=0); - - lhs_values[2] = rng.gen_range(-modulus..=0); - rhs_values[2] = rng.gen_range(-modulus..=0); - - lhs_values[3] = rng.gen_range(-modulus..=0); - rhs_values[3] = rng.gen_range(0..modulus); - - for i in 4..N { - lhs_values[i] = rng.gen_range(-modulus..modulus); - rhs_values[i] = rng.gen_range(-modulus..modulus); - } - - izip!(lhs_values, rhs_values) -} - -pub(crate) fn random_non_zero_value(rng: &mut ThreadRng, modulus: i64) -> i64 { - loop { - let value = rng.gen::() % modulus; - if value != 0 { - break value; - } - } -} - -// Signed tests - -pub(crate) fn signed_unchecked_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // check some overflow behaviour - let overflowing_values = [ - (-modulus, -1, modulus - 1), - (modulus - 1, 1, -modulus), - (-modulus, -2, modulus - 2), - (modulus - 2, 2, -modulus), - ]; - for (clear_0, clear_1, expected_clear) in overflowing_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - assert_eq!(clear_res, expected_clear); - } - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - let mut clear; - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct); - - clear = signed_add_under_modulus(clear_0, clear_1, modulus); - - // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); - // add multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - ct_res = executor.execute((&ct_res, &ctxt_0)); - assert!(ct_res.block_carries_are_empty()); - clear = signed_add_under_modulus(clear, clear_0, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - // println!("clear = {}, dec_res = {}", clear, dec_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn signed_smart_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a mut SignedRadixCiphertext, &'a mut SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - let mut clear; - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen_range(-modulus..modulus); - let clear_1 = rng.gen_range(-modulus..modulus); - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - let mut ctxt_1 = cks.encrypt_signed(clear_1); - - let mut ct_res = executor.execute((&mut ctxt_0, &mut ctxt_1)); - clear = signed_add_under_modulus(clear_0, clear_1, modulus); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - assert_eq!(clear, dec_res); - - // add multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - ct_res = executor.execute((&mut ct_res, &mut ctxt_0)); - clear = signed_add_under_modulus(clear, clear_0, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - // println!("clear = {}, dec_res = {}", clear, dec_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn signed_unchecked_sub_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // check some overflow behaviour - let overflowing_values = [ - (-modulus, 1, modulus - 1), - (modulus - 1, -1, -modulus), - (-modulus, 2, modulus - 2), - (modulus - 2, -2, -modulus), - ]; - for (clear_0, clear_1, expected_clear) in overflowing_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - assert_eq!(clear_res, expected_clear); - } - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_sub_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - let mut clear; - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct); - - clear = signed_sub_under_modulus(clear_0, clear_1, modulus); - - // sub multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - ct_res = executor.execute((&ct_res, &ctxt_0)); - assert!(ct_res.block_carries_are_empty()); - clear = signed_sub_under_modulus(clear, clear_0, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - // println!("clear = {}, dec_res = {}", clear, dec_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn signed_unchecked_neg_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let ctxt_zero = sks.create_trivial_radix(0i64, NB_CTXT); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // -modulus is a special case, its negation cannot be - // represented. rust by default returns -modulus - // (which is what two complement result in) - { - let clear = -modulus; - let ctxt = cks.encrypt_signed(clear); - - let ct_res = executor.execute(&ctxt); - - let dec: i64 = cks.decrypt_signed(&ct_res); - let clear_result = signed_neg_under_modulus(clear, modulus); - - assert_eq!(clear_result, dec); - assert_eq!(clear_result, -modulus); - } - - for (clear_0, _) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_neg_under_modulus(clear_0, modulus); - assert_eq!(clear_res, dec_res); - } - - // negation of trivial 0 - { - let ct_res = executor.execute(&ctxt_zero); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - assert_eq!(0, dec_res); - } -} - -pub(crate) fn signed_smart_neg_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<&'a mut SignedRadixCiphertext, SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for _ in 0..NB_TESTS_SMALLER { - let clear = rng.gen::() % modulus; - - let mut ctxt = cks.encrypt_signed(clear); - - let mut ct_res = executor.execute(&mut ctxt); - let mut clear_res = signed_neg_under_modulus(clear, modulus); - let dec: i64 = cks.decrypt_signed(&ct_res); - assert_eq!(clear_res, dec); - - for _ in 0..NB_TESTS_SMALLER { - ct_res = executor.execute(&mut ct_res); - clear_res = signed_neg_under_modulus(clear_res, modulus); - - let dec: i64 = cks.decrypt_signed(&ct_res); - println!("clear_res: {clear_res}, dec : {dec}"); - assert_eq!(clear_res, dec); - } - } -} - -pub(crate) fn signed_default_neg_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // -modulus is a special case, its negation cannot be - // represented. rust by default returns -modulus - // (which is what two complement result in) - { - let clear = -modulus; - let ctxt = cks.encrypt_signed(clear); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let dec: i64 = cks.decrypt_signed(&ct_res); - let clear_result = signed_neg_under_modulus(clear, modulus); - - assert_eq!(clear_result, dec); - } - - for _ in 0..NB_TESTS_SMALLER { - let clear = rng.gen::() % modulus; - - let ctxt = cks.encrypt_signed(clear); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let dec: i64 = cks.decrypt_signed(&ct_res); - let clear_result = signed_neg_under_modulus(clear, modulus); - - assert_eq!(clear_result, dec); - } -} - -pub(crate) fn signed_unchecked_mul_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_mul_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_mul_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - let mut clear; - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct); - - clear = signed_mul_under_modulus(clear_0, clear_1, modulus); - - // mul multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - ct_res = executor.execute((&ct_res, &ctxt_0)); - assert!(ct_res.block_carries_are_empty()); - clear = signed_mul_under_modulus(clear, clear_0, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - // println!("clear = {}, dec_res = {}", clear, dec_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn signed_unchecked_scalar_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from(( - cks, - crate::integer::server_key::radix_parallel::tests_cases_unsigned::NB_CTXT, - )); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // check some overflow behaviour - let overflowing_values = [ - (-modulus, -1, modulus - 1), - (modulus - 1, 1, -modulus), - (-modulus, -2, modulus - 2), - (modulus - 2, 2, -modulus), - ]; - for (clear_0, clear_1, expected_clear) in overflowing_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ct_res = executor.execute((&ctxt_0, clear_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - assert_eq!(clear_res, expected_clear); - } - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_scalar_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - let mut clear; - - let mut rng = rand::thread_rng(); - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let mut ct_res = executor.execute((&ctxt_0, clear_1)); - assert!(ct_res.block_carries_are_empty()); - - clear = signed_add_under_modulus(clear_0, clear_1, modulus); - - // add multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - let tmp = executor.execute((&ct_res, clear_1)); - ct_res = executor.execute((&ct_res, clear_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - clear = signed_add_under_modulus(clear, clear_1, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn signed_default_overflowing_scalar_add_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, i64), - (SignedRadixCiphertext, BooleanBlock), - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - let hardcoded_values = [ - (-modulus, -1), - (modulus - 1, 1), - (-1, -modulus), - (1, modulus - 1), - ]; - for (clear_0, clear_1) in hardcoded_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); - let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct, "Failed determinism check"); - assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - - for _ in 0..NB_TESTS_SMALLER { - // Add non zero scalar to have non clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_rhs = random_non_zero_value(&mut rng, modulus); - - let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); - let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); - let d0: i64 = cks.decrypt_signed(&ctxt_0); - assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs)); - assert!(ct_res.block_carries_are_empty()); - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - } - - // Test with trivial inputs - for _ in 0..4 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); - - let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } - - // Test with scalar that is bigger than ciphertext modulus - for _ in 0..2 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen_range(modulus..=i64::MAX); - - let a = cks.encrypt_signed(clear_0); - - let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert!(decrypted_overflowed); // Actually we know its an overflow case - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } -} - -pub(crate) fn signed_unchecked_scalar_sub_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - // check some overflow behaviour - let overflowing_values = [ - (-modulus, 1, modulus - 1), - (modulus - 1, -1, -modulus), - (-modulus, 2, modulus - 2), - (modulus - 2, -2, -modulus), - ]; - for (clear_0, clear_1, expected_clear) in overflowing_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ct_res = executor.execute((&ctxt_0, clear_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - assert_eq!(clear_res, expected_clear); - } - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_overflowing_scalar_sub_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, i64), - (SignedRadixCiphertext, BooleanBlock), - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - let hardcoded_values = [ - (-modulus, 1), - (modulus - 1, -1), - (1, -modulus), - (-1, modulus - 1), - ]; - for (clear_0, clear_1) in hardcoded_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); - let (expected_result, expected_overflowed) = - signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); - let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct, "Failed determinism check"); - assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); - - let (expected_result, expected_overflowed) = - signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - - for _ in 0..NB_TESTS_SMALLER { - // Add non zero scalar to have non clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_rhs = random_non_zero_value(&mut rng, modulus); - - let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); - let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); - let d0: i64 = cks.decrypt_signed(&ctxt_0); - assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs)); - assert!(ct_res.block_carries_are_empty()); - let (expected_result, expected_overflowed) = - signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - } - - // Test with trivial inputs - for _ in 0..4 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); - - let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); - - let (expected_result, expected_overflowed) = - signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } - - // Test with scalar that is bigger than ciphertext modulus - for _ in 0..2 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen_range(modulus..=i64::MAX); - - let a = cks.encrypt_signed(clear_0); - - let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); - - let (expected_result, expected_overflowed) = - signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert!(decrypted_overflowed); // Actually we know its an overflow case - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } -} - -pub(crate) fn signed_unchecked_bitand_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 & clear_1; - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_unchecked_bitor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 | clear_1; - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_unchecked_bitxor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for (clear_0, clear_1) in - create_iterator_of_signed_random_pairs::(&mut rng, modulus) - { - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 ^ clear_1; - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_bitnot_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks); - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute(&ctxt_0); - let ct_res2 = executor.execute(&ctxt_0); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = !clear_0; - assert_eq!(clear_res, dec_res); - } -} - -pub(crate) fn signed_default_bitand_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let mut clear_0 = rng.gen::() % modulus; - let mut clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - let mut ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 & clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_3 = random_non_zero_value(&mut rng, modulus); - - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); - - assert!(!ctxt_0.block_carries_are_empty()); - assert!(!ctxt_1.block_carries_are_empty()); - - clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); - clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = clear_0 & clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_default_bitor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let mut clear_0 = rng.gen::() % modulus; - let mut clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - let mut ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 | clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_3 = random_non_zero_value(&mut rng, modulus); - - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); - - assert!(!ctxt_0.block_carries_are_empty()); - assert!(!ctxt_1.block_carries_are_empty()); - - clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); - clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = clear_0 | clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_default_bitxor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor< - (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), - SignedRadixCiphertext, - >, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let mut clear_0 = rng.gen::() % modulus; - let mut clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - let mut ctxt_1 = cks.encrypt_signed(clear_1); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 ^ clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_3 = random_non_zero_value(&mut rng, modulus); - - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); - - assert!(!ctxt_0.block_carries_are_empty()); - assert!(!ctxt_1.block_carries_are_empty()); - - clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); - clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); - - let ct_res = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = clear_0 ^ clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_default_scalar_bitand_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - let ct_res2 = executor.execute((&ctxt_0, clear_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 & clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - assert!(!ctxt_0.block_carries_are_empty()); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) & clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_default_scalar_bitor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - let ct_res2 = executor.execute((&ctxt_0, clear_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 | clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - assert!(!ctxt_0.block_carries_are_empty()); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) | clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_default_scalar_bitxor_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - executor.setup(&cks, sks.clone()); - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed(clear_0); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - let ct_res2 = executor.execute((&ctxt_0, clear_1)); - assert_eq!(ct_res, ct_res2); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = clear_0 ^ clear_1; - assert_eq!(clear_res, dec_res); - - let clear_2 = random_non_zero_value(&mut rng, modulus); - - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); - assert!(!ctxt_0.block_carries_are_empty()); - - let ct_res = executor.execute((&ctxt_0, clear_1)); - assert!(ct_res.block_carries_are_empty()); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - - let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) ^ clear_1; - assert_eq!(dec_res, expected_result); - } -} - -pub(crate) fn signed_unchecked_scalar_left_shift_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - executor.setup(&cks, sks); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS { - let clear = rng.gen::() % modulus; - let clear_shift = rng.gen::(); - - let ct = cks.encrypt_signed(clear); - - // case when 0 <= shift < nb_bits - { - let clear_shift = clear_shift % nb_bits; - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let expected = signed_left_shift_under_modulus(clear, clear_shift, modulus); - assert_eq!(expected, dec_res); - } - - // case when shift >= nb_bits - { - let clear_shift = clear_shift.saturating_add(nb_bits); - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let expected = signed_left_shift_under_modulus(clear, clear_shift % nb_bits, modulus); - assert_eq!(expected, dec_res); - } - } -} - -pub(crate) fn signed_unchecked_scalar_right_shift_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - executor.setup(&cks, sks); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS { - let clear = rng.gen::() % modulus; - let clear_shift = rng.gen::(); - - let ct = cks.encrypt_signed(clear); - - // case when 0 <= shift < nb_bits - { - let clear_shift = clear_shift % nb_bits; - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let expected = signed_right_shift_under_modulus(clear, clear_shift, modulus); - assert_eq!(expected, dec_res); - } - - // case when shift >= nb_bits - { - let clear_shift = clear_shift.saturating_add(nb_bits); - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let expected = signed_right_shift_under_modulus(clear, clear_shift % nb_bits, modulus); - assert_eq!(expected, dec_res); - } - } -} - -pub(crate) fn signed_default_scalar_left_shift_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - executor.setup(&cks, sks.clone()); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS_SMALLER { - let mut clear = rng.gen::() % modulus; - - let offset = random_non_zero_value(&mut rng, modulus); - - let mut ct = cks.encrypt_signed(clear); - sks.unchecked_scalar_add_assign(&mut ct, offset); - clear = signed_add_under_modulus(clear, offset, modulus); - - // case when 0 <= shift < nb_bits - { - let clear_shift = rng.gen::() % nb_bits; - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_left_shift_under_modulus(clear, clear_shift, modulus); - assert_eq!( - clear_res, dec_res, - "Invalid left shift result, for '{clear} << {clear_shift}', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = executor.execute((&ct, clear_shift as i64)); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - - // case when shift >= nb_bits - { - let clear_shift = rng.gen_range(nb_bits..=u32::MAX); - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - // We mimic wrapping_shl manually as we use a bigger type - // than the nb_bits we actually simulate in this test - let clear_res = signed_left_shift_under_modulus(clear, clear_shift % nb_bits, modulus); - assert_eq!( - clear_res, dec_res, - "Invalid left shift result, for '{clear} << {clear_shift}', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = executor.execute((&ct, clear_shift as i64)); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - } -} - -pub(crate) fn signed_default_scalar_right_shift_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - executor.setup(&cks, sks.clone()); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS_SMALLER { - let mut clear = rng.gen::() % modulus; - - let offset = random_non_zero_value(&mut rng, modulus); - - let mut ct = cks.encrypt_signed(clear); - sks.unchecked_scalar_add_assign(&mut ct, offset); - clear = signed_add_under_modulus(clear, offset, modulus); - - // case when 0 <= shift < nb_bits - { - let clear_shift = rng.gen::() % nb_bits; - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = signed_right_shift_under_modulus(clear, clear_shift, modulus); - assert_eq!( - clear_res, dec_res, - "Invalid right shift result, for '{clear} >> {clear_shift}', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = executor.execute((&ct, clear_shift as i64)); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - - // case when shift >= nb_bits - { - let clear_shift = rng.gen_range(nb_bits..=u32::MAX); - let ct_res = executor.execute((&ct, clear_shift as i64)); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - // We mimic wrapping_shl manually as we use a bigger type - // than the nb_bits we actually simulate in this test - let clear_res = signed_right_shift_under_modulus(clear, clear_shift % nb_bits, modulus); - assert_eq!( - clear_res, dec_res, - "Invalid right shift result, for '{clear} >> {clear_shift}', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = executor.execute((&ct, clear_shift as i64)); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - } -} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index e0fe84d0fa..feb86341af 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -687,38 +687,6 @@ where ); } } -pub(crate) fn unchecked_small_scalar_mul_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let scalar_modulus = cks.parameters().message_modulus().0 as u64; - - executor.setup(&cks, sks); - - for _ in 0..NB_TESTS { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::() % scalar_modulus; - - let ct = cks.encrypt(clear); - - let encrypted_result = executor.execute((&ct, scalar)); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - - let expected_result = clear.wrapping_mul(scalar) % modulus; - - assert_eq!(decrypted_result, expected_result); - } -} - pub(crate) fn unchecked_scalar_mul_corner_cases_test(param: P, mut executor: T) where P: Into, @@ -1611,45 +1579,6 @@ where assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); } -pub(crate) fn smart_small_scalar_mul_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - - let scalar_modulus = cks.parameters().message_modulus().0 as u64; - - executor.setup(&cks, sks); - - let mut clear_res; - for _ in 0..NB_TESTS_SMALLER { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::() % scalar_modulus; - - let mut ct = cks.encrypt(clear); - - let mut ct_res = executor.execute((&mut ct, scalar)); - - clear_res = clear * scalar; - for _ in 0..NB_TESTS_SMALLER { - // scalar multiplication - ct_res = executor.execute((&mut ct_res, scalar)); - clear_res *= scalar; - } - - let dec_res: u64 = cks.decrypt(&ct_res); - assert_eq!(clear_res % modulus, dec_res); - } -} - //============================================================================= // Default Tests //============================================================================= @@ -2934,51 +2863,6 @@ where } } -pub(crate) fn default_small_scalar_mul_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - - let scalar_modulus = cks.parameters().message_modulus().0 as u64; - - executor.setup(&cks, sks); - - let mut clear_res; - for _ in 0..NB_TESTS_SMALLER { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::() % scalar_modulus; - - let ct = cks.encrypt(clear); - - let mut ct_res = executor.execute((&ct, scalar)); - assert!(ct_res.block_carries_are_empty()); - - clear_res = clear * scalar; - for _ in 0..NB_TESTS_SMALLER { - // scalar multiplication - let tmp = executor.execute((&ct_res, scalar)); - ct_res = executor.execute((&ct_res, scalar)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(tmp, ct_res); - clear_res = clear_res.wrapping_mul(scalar); - } - - let dec_res: u64 = cks.decrypt(&ct_res); - assert_eq!(clear_res % modulus, dec_res); - } -} - pub(crate) fn default_scalar_mul_test(param: P, mut executor: T) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 3fd638e191..f6eff6231f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -6,15 +6,17 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; pub(crate) mod test_sub; use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_signed::*; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; -use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + CpuFunctionExecutor, NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, NB_TESTS_UNCHECKED, +}; use crate::integer::tests::create_parametrized_test; use crate::integer::{ BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext, @@ -22,7 +24,8 @@ use crate::integer::{ #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; -use itertools::iproduct; +use itertools::{iproduct, izip}; +use rand::prelude::ThreadRng; use rand::Rng; use std::sync::Arc; @@ -507,32 +510,11 @@ where // Unchecked Scalar Tests //================================================================================ -create_parametrized_test!(integer_signed_unchecked_scalar_mul); create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left); create_parametrized_test!(integer_signed_unchecked_scalar_rotate_right); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor); -fn integer_signed_unchecked_scalar_mul(param: impl Into) { - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - - let ct_res = sks.unchecked_scalar_mul_parallelized(&ctxt_0, clear_1); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = signed_mul_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - fn integer_signed_unchecked_scalar_rotate_left(param: impl Into) { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -1042,3 +1024,317 @@ where } } } + +//================================================================================ +// Helper functions +//================================================================================ + +pub(crate) fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { + signed_overflowing_add_under_modulus(lhs, rhs, modulus).0 +} + +// Adds two signed number modulo the given modulus +// +// This is to 'simulate' i8, i16, ixy using i64 integers +// +// lhs and rhs must be in [-modulus..modulus[ +pub(crate) fn signed_overflowing_add_under_modulus( + lhs: i64, + rhs: i64, + modulus: i64, +) -> (i64, bool) { + assert!(modulus > 0); + assert!((-modulus..modulus).contains(&lhs)); + + // The code below requires rhs and lhs to be in range -modulus..modulus + // in scalar tests, rhs may exceed modulus + // so we truncate it (is the fhe ops does) + let (mut res, mut overflowed) = if (-modulus..modulus).contains(&rhs) { + (lhs + rhs, false) + } else { + // 2*modulus to get all the bits + (lhs + (rhs % (2 * modulus)), true) + }; + + if res < -modulus { + // rem_euclid(modulus) would also work + res = modulus + (res - -modulus); + overflowed = true; + } else if res > modulus - 1 { + res = -modulus + (res - modulus); + overflowed = true; + } + (res, overflowed) +} + +pub(crate) fn signed_neg_under_modulus(lhs: i64, modulus: i64) -> i64 { + assert!(modulus > 0); + let mut res = -lhs; + if res < -modulus { + // rem_euclid(modulus) would also work + res = modulus + (res - -modulus); + } else if res > modulus - 1 { + res = -modulus + (res - modulus); + } + res +} + +// Subs two signed number modulo the given modulus +// +// This is to 'simulate' i8, i16, ixy using i64 integers +// +// lhs and rhs must be in [-modulus..modulus[ +pub(crate) fn signed_sub_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { + signed_overflowing_sub_under_modulus(lhs, rhs, modulus).0 +} + +pub(crate) fn signed_overflowing_sub_under_modulus( + lhs: i64, + rhs: i64, + modulus: i64, +) -> (i64, bool) { + // Technically we should be able to call overflowing_add_under_modulus(lhs, -rhs, ...) + // but due to -rhs being a 'special case' when rhs == -modulus, we have to + // so the impl here + assert!(modulus > 0); + assert!((-modulus..modulus).contains(&lhs)); + + // The code below requires rhs and lhs to be in range -modulus..modulus + // in scalar tests, rhs may exceed modulus + // so we truncate it (is the fhe ops does) + let (mut res, mut overflowed) = if (-modulus..modulus).contains(&rhs) { + (lhs - rhs, false) + } else { + // 2*modulus to get all the bits + (lhs - (rhs % (2 * modulus)), true) + }; + + if res < -modulus { + // rem_euclid(modulus) would also work + res = modulus + (res - -modulus); + overflowed = true; + } else if res > modulus - 1 { + res = -modulus + (res - modulus); + overflowed = true; + } + (res, overflowed) +} + +pub(crate) fn signed_mul_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { + assert!(modulus > 0); + overflowing_mul_under_modulus(lhs, rhs, modulus).0 +} + +pub(crate) fn overflowing_mul_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> (i64, bool) { + let (mut res, mut overflowed) = lhs.overflowing_mul(rhs); + overflowed |= res < -modulus || res >= modulus; + res %= modulus * 2; + if res < -modulus { + // rem_euclid(modulus) would also work + res = modulus + (res - -modulus); + } else if res > modulus - 1 { + res = -modulus + (res - modulus); + } + + (res, overflowed) +} + +pub(crate) fn absolute_value_under_modulus(lhs: i64, modulus: i64) -> i64 { + if lhs < 0 { + signed_neg_under_modulus(lhs, modulus) + } else { + lhs + } +} + +pub(crate) fn signed_left_shift_under_modulus(lhs: i64, rhs: u32, modulus: i64) -> i64 { + signed_mul_under_modulus(lhs, 1 << rhs, modulus) +} + +pub(crate) fn signed_right_shift_under_modulus(lhs: i64, rhs: u32, _modulus: i64) -> i64 { + lhs >> rhs +} + +pub(crate) fn signed_div_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { + // in signed integers, -modulus can be represented, but +modulus cannot + // thus, when dividing: -128 / -1 = 128 the results overflows to -128 + assert!(modulus > 0); + let mut res = lhs / rhs; + if res < -modulus { + // rem_euclid(modulus) would also work + res = modulus + (res - -modulus); + } else if res > modulus - 1 { + res = -modulus + (res - modulus); + } + res +} + +pub(crate) fn signed_rem_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { + assert!(modulus > 0); + let q = signed_div_under_modulus(lhs, rhs, modulus); + let q_times_rhs = signed_mul_under_modulus(q, rhs, modulus); + signed_sub_under_modulus(lhs, q_times_rhs, modulus) +} + +pub(crate) fn signed_div_rem_floor_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> (i64, i64) { + let mut q = signed_div_under_modulus(lhs, rhs, modulus); + let mut r = signed_rem_under_modulus(lhs, rhs, modulus); + + if (r != 0) && ((r < 0) != (rhs < 0)) { + q = signed_sub_under_modulus(q, 1, modulus); + r = signed_add_under_modulus(r, rhs, modulus); + } + + (q, r) +} + +/// helper function to do a rotate left when the type used to store +/// the value is bigger than the actual intended bit size +pub(crate) fn rotate_left_helper(value: i64, n: u32, actual_bit_size: u32) -> i64 { + // We start with: + // [0000000000000|xxxx] + // 64 b 0 + // + // rotated will be + // [0000000000xx|xx00] + // 64 b 0 + let n = n % actual_bit_size; + let mask = 1i64.wrapping_shl(actual_bit_size) - 1; + let shifted_mask = mask.wrapping_shl(n) & !mask; + + // Value maybe be negative and so, have its msb + // set to one, so use mask to only keep the part that interest + // us + let rotated = (value & mask).rotate_left(n); + + let tmp = (rotated & mask) | ((rotated & shifted_mask) >> actual_bit_size); + // If the sign bit after rotation is one, + // then all bits above it needs to be one + let new_sign_bit = (tmp >> (actual_bit_size - 1)) & 1; + let mut pad = -new_sign_bit; + pad <<= actual_bit_size; // only bits above actual_bit_size should be set + + pad | tmp +} + +/// helper function to do a rotate right when the type used to store +/// the value is bigger than the actual intended bit size +pub(crate) fn rotate_right_helper(value: i64, n: u32, actual_bit_size: u32) -> i64 { + // We start with: + // [yyyyyyyyyyyy|xxxx] + // 64 b 0 + // where xs are bits that we are interested in + // and ys are either 0 or 1 depending on if value is positive + // + // mask: [yyyyyyyyyyyy|mmmm] + // shifted_ mask: [mmyyyyyyyyyy|0000] + // + // rotated will be + // [xxyyyyyyyyyy|00xx] + // 64 b 0 + // + // To get the 'cycled' bits where they should be, + // we get them using a mask then shift + let n = n % actual_bit_size; + let mask = 1i64.wrapping_shl(actual_bit_size) - 1; + // shifted mask only needs the bits that cycled + let shifted_mask = mask.rotate_right(n) & !mask; + + // Value maybe be negative and so, have its msb + // set to one, so use mask to only keep the part that interest + // us + let rotated = (value & mask).rotate_right(n); + + let tmp = (rotated & mask) | ((rotated & shifted_mask) >> (u64::BITS - actual_bit_size)); + // If the sign bit after rotation is one, + // then all bits above it needs to be one + let new_sign_bit = (tmp >> (actual_bit_size - 1)) & 1; + let mut pad = -new_sign_bit; + pad <<= actual_bit_size; // only bits above actual_bit_size should be set + + pad | tmp +} + +/// Returns an array filled with random values such that: +/// - the first half contains values in [0..modulus[ +/// - the second half contains values in [-modulus..0] +pub(crate) fn random_signed_value_under_modulus( + rng: &mut rand::prelude::ThreadRng, + modulus: i64, +) -> [i64; N] { + assert!(modulus > 0); + + let mut values = [0i64; N]; + + for value in &mut values[..N / 2] { + *value = rng.gen_range(0..modulus); + } + + for value in &mut values[N / 2..] { + *value = rng.gen_range(-modulus..=0); + } + + values +} + +/// Returns an array filled with random values such that: +/// - the first half contains values in ]0..modulus[ +/// - the second half contains values in [-modulus..0[ +pub(crate) fn random_non_zero_signed_value_under_modulus( + rng: &mut rand::prelude::ThreadRng, + modulus: i64, +) -> [i64; N] { + assert!(modulus > 0); + + let mut values = [0i64; N]; + + for value in &mut values[..N / 2] { + *value = rng.gen_range(1..modulus); + } + + for value in &mut values[N / 2..] { + *value = rng.gen_range(-modulus..0); + } + + values +} + +/// Returns an iterator that yields pairs of i64 values in range `-modulus..modulus` +/// such that there is at least one pair of (P, P), (P, N), (N, N) (N, P) +/// where P means value >=0 and N means <= 0 +pub(crate) fn create_iterator_of_signed_random_pairs( + rng: &mut rand::prelude::ThreadRng, + modulus: i64, +) -> impl Iterator { + assert!(N >= 4, "N must be at least 4 to uphold the guarantee"); + let mut lhs_values = [0i64; N]; + let mut rhs_values = [0i64; N]; + + lhs_values[0] = rng.gen_range(0..modulus); + rhs_values[0] = rng.gen_range(0..modulus); + + lhs_values[1] = rng.gen_range(0..modulus); + rhs_values[1] = rng.gen_range(-modulus..=0); + + lhs_values[2] = rng.gen_range(-modulus..=0); + rhs_values[2] = rng.gen_range(-modulus..=0); + + lhs_values[3] = rng.gen_range(-modulus..=0); + rhs_values[3] = rng.gen_range(0..modulus); + + for i in 4..N { + lhs_values[i] = rng.gen_range(-modulus..modulus); + rhs_values[i] = rng.gen_range(-modulus..modulus); + } + + izip!(lhs_values, rhs_values) +} + +pub(crate) fn random_non_zero_value(rng: &mut ThreadRng, modulus: i64) -> i64 { + loop { + let value = rng.gen::() % modulus; + if value != 0 { + break value; + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs index 0221645e2e..f667c11430 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs @@ -1,11 +1,9 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::sub::SignedOperation; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - random_non_zero_value, signed_add_under_modulus, signed_default_add_test, - signed_overflowing_add_under_modulus, signed_smart_add_test, signed_unchecked_add_test, -}; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_signed::{ - NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, + create_iterator_of_signed_random_pairs, random_non_zero_value, signed_add_under_modulus, + signed_overflowing_add_under_modulus, NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, NB_TESTS_UNCHECKED, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; @@ -17,6 +15,7 @@ use crate::shortint::ciphertext::NoiseLevel; use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_add); create_parametrized_test!(integer_signed_unchecked_overflowing_add); @@ -332,3 +331,146 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::smart_add_parallelized); signed_smart_add_test(param, executor); } + +pub(crate) fn signed_unchecked_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // check some overflow behaviour + let overflowing_values = [ + (-modulus, -1, modulus - 1), + (modulus - 1, 1, -modulus), + (-modulus, -2, modulus - 2), + (modulus - 2, 2, -modulus), + ]; + for (clear_0, clear_1, expected_clear) in overflowing_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + assert_eq!(clear_res, expected_clear); + } + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct); + + clear = signed_add_under_modulus(clear_0, clear_1, modulus); + + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + // add multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + ct_res = executor.execute((&ct_res, &ctxt_0)); + assert!(ct_res.block_carries_are_empty()); + clear = signed_add_under_modulus(clear, clear_0, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + assert_eq!(clear, dec_res); + } + } +} + +pub(crate) fn signed_smart_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut SignedRadixCiphertext, &'a mut SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen_range(-modulus..modulus); + let clear_1 = rng.gen_range(-modulus..modulus); + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + let mut ctxt_1 = cks.encrypt_signed(clear_1); + + let mut ct_res = executor.execute((&mut ctxt_0, &mut ctxt_1)); + clear = signed_add_under_modulus(clear_0, clear_1, modulus); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(clear, dec_res); + + // add multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + ct_res = executor.execute((&mut ct_res, &mut ctxt_0)); + clear = signed_add_under_modulus(clear, clear_0, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + assert_eq!(clear, dec_res); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_bitwise_op.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_bitwise_op.rs index 2761deb731..96f8043328 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_bitwise_op.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_bitwise_op.rs @@ -1,14 +1,17 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_bitand_test, signed_default_bitnot_test, signed_default_bitor_test, - signed_default_bitxor_test, signed_unchecked_bitand_test, signed_unchecked_bitor_test, - signed_unchecked_bitxor_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + create_iterator_of_signed_random_pairs, random_non_zero_value, signed_add_under_modulus, + NB_CTXT, NB_TESTS, NB_TESTS_UNCHECKED, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_bitand); create_parametrized_test!(integer_signed_unchecked_bitor); @@ -73,3 +76,291 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized); signed_default_bitxor_test(param, executor); } +pub(crate) fn signed_unchecked_bitand_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 & clear_1; + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_unchecked_bitor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 | clear_1; + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_unchecked_bitxor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 ^ clear_1; + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_bitnot_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute(&ctxt_0); + let ct_res2 = executor.execute(&ctxt_0); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = !clear_0; + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_bitand_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + let mut ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 & clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); + + assert!(!ctxt_0.block_carries_are_empty()); + assert!(!ctxt_1.block_carries_are_empty()); + + clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); + clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = clear_0 & clear_1; + assert_eq!(dec_res, expected_result); + } +} + +pub(crate) fn signed_default_bitor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + let mut ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 | clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); + + assert!(!ctxt_0.block_carries_are_empty()); + assert!(!ctxt_1.block_carries_are_empty()); + + clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); + clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = clear_0 | clear_1; + assert_eq!(dec_res, expected_result); + } +} + +pub(crate) fn signed_default_bitxor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + let mut ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let ct_res2 = executor.execute((&ctxt_0, &ctxt_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 ^ clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + sks.unchecked_scalar_add_assign(&mut ctxt_1, clear_3); + + assert!(!ctxt_0.block_carries_are_empty()); + assert!(!ctxt_1.block_carries_are_empty()); + + clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus); + clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = clear_0 ^ clear_1; + assert_eq!(dec_res, expected_result); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_mul.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_mul.rs index f71d7fd8b5..99c0aa8f26 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_mul.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_mul.rs @@ -1,9 +1,10 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - overflowing_mul_under_modulus, random_non_zero_value, signed_add_under_modulus, - signed_default_mul_test, signed_unchecked_mul_test, +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + create_iterator_of_signed_random_pairs, overflowing_mul_under_modulus, random_non_zero_value, + signed_add_under_modulus, signed_mul_under_modulus, NB_CTXT, NB_TESTS_SMALLER, + NB_TESTS_UNCHECKED, }; -use crate::integer::server_key::radix_parallel::tests_signed::{NB_CTXT, NB_TESTS_SMALLER}; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; @@ -12,6 +13,7 @@ use crate::shortint::ciphertext::NoiseLevel; use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_mul); create_parametrized_test!(integer_signed_default_mul); @@ -179,3 +181,85 @@ fn integer_signed_default_overflowing_mul(param: impl Into) { assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); } } + +pub(crate) fn signed_unchecked_mul_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_mul_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_mul_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct); + + clear = signed_mul_under_modulus(clear_0, clear_1, modulus); + + // mul multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + ct_res = executor.execute((&ct_res, &ctxt_0)); + assert!(ct_res.block_carries_are_empty()); + clear = signed_mul_under_modulus(clear, clear_0, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + assert_eq!(clear, dec_res); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_neg.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_neg.rs index c3608ad4f1..bc2192dbad 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_neg.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_neg.rs @@ -1,12 +1,17 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_neg_test, signed_smart_neg_test, signed_unchecked_neg_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + create_iterator_of_signed_random_pairs, signed_neg_under_modulus, NB_CTXT, NB_TESTS_SMALLER, + NB_TESTS_UNCHECKED, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_neg); create_parametrized_test!(integer_signed_smart_neg); @@ -35,3 +40,141 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized); signed_default_neg_test(param, executor); } + +pub(crate) fn signed_unchecked_neg_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let ctxt_zero = sks.create_trivial_radix(0i64, NB_CTXT); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // -modulus is a special case, its negation cannot be + // represented. rust by default returns -modulus + // (which is what two complement result in) + { + let clear = -modulus; + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + + let dec: i64 = cks.decrypt_signed(&ct_res); + let clear_result = signed_neg_under_modulus(clear, modulus); + + assert_eq!(clear_result, dec); + assert_eq!(clear_result, -modulus); + } + + for (clear_0, _) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute(&ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_neg_under_modulus(clear_0, modulus); + assert_eq!(clear_res, dec_res); + } + + // negation of trivial 0 + { + let ct_res = executor.execute(&ctxt_zero); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(0, dec_res); + } +} + +pub(crate) fn signed_smart_neg_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a mut SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for _ in 0..NB_TESTS_SMALLER { + let clear = rng.gen::() % modulus; + + let mut ctxt = cks.encrypt_signed(clear); + + let mut ct_res = executor.execute(&mut ctxt); + let mut clear_res = signed_neg_under_modulus(clear, modulus); + let dec: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(clear_res, dec); + + for _ in 0..NB_TESTS_SMALLER { + ct_res = executor.execute(&mut ct_res); + clear_res = signed_neg_under_modulus(clear_res, modulus); + + let dec: i64 = cks.decrypt_signed(&ct_res); + println!("clear_res: {clear_res}, dec : {dec}"); + assert_eq!(clear_res, dec); + } + } +} + +pub(crate) fn signed_default_neg_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // -modulus is a special case, its negation cannot be + // represented. rust by default returns -modulus + // (which is what two complement result in) + { + let clear = -modulus; + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let dec: i64 = cks.decrypt_signed(&ct_res); + let clear_result = signed_neg_under_modulus(clear, modulus); + + assert_eq!(clear_result, dec); + } + + for _ in 0..NB_TESTS_SMALLER { + let clear = rng.gen::() % modulus; + + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let dec: i64 = cks.decrypt_signed(&ct_res); + let clear_result = signed_neg_under_modulus(clear, modulus); + + assert_eq!(clear_result, dec); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_rotate.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_rotate.rs index 323df00d4f..1cdd1ac660 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_rotate.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_rotate.rs @@ -1,9 +1,7 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - random_non_zero_value, rotate_left_helper, rotate_right_helper, signed_add_under_modulus, -}; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, rotate_left_helper, rotate_right_helper, signed_add_under_modulus, NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs index 44095dbcea..d5e1ab215c 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs @@ -1,13 +1,20 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_overflowing_scalar_add_test, signed_default_scalar_add_test, - signed_unchecked_scalar_add_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, signed_overflowing_add_under_modulus, NB_CTXT, + NB_TESTS, NB_TESTS_SMALLER, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{ + BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext, +}; +use crate::shortint::ciphertext::NoiseLevel; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_scalar_add); create_parametrized_test!(integer_signed_default_scalar_add); @@ -36,3 +43,268 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized); signed_default_overflowing_scalar_add_test(param, executor); } +pub(crate) fn signed_unchecked_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from(( + cks, + crate::integer::server_key::radix_parallel::tests_unsigned::NB_CTXT, + )); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // check some overflow behaviour + let overflowing_values = [ + (-modulus, -1, modulus - 1), + (modulus - 1, 1, -modulus), + (-modulus, -2, modulus - 2), + (modulus - 2, 2, -modulus), + ]; + for (clear_0, clear_1, expected_clear) in overflowing_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + assert_eq!(clear_res, expected_clear); + } + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let mut ct_res = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + + clear = signed_add_under_modulus(clear_0, clear_1, modulus); + + // add multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + let tmp = executor.execute((&ct_res, clear_1)); + ct_res = executor.execute((&ct_res, clear_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + clear = signed_add_under_modulus(clear, clear_1, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(clear, dec_res); + } + } +} + +pub(crate) fn signed_default_overflowing_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, i64), + (SignedRadixCiphertext, BooleanBlock), + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + let hardcoded_values = [ + (-modulus, -1), + (modulus - 1, 1), + (-1, -modulus), + (1, modulus - 1), + ]; + for (clear_0, clear_1) in hardcoded_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + + for _ in 0..NB_TESTS_SMALLER { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_rhs = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); + let d0: i64 = cks.decrypt_signed(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs)); + assert!(ct_res.block_carries_are_empty()); + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + } + + // Test with trivial inputs + for _ in 0..4 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } + + // Test with scalar that is bigger than ciphertext modulus + for _ in 0..2 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen_range(modulus..=i64::MAX); + + let a = cks.encrypt_signed(clear_0); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert!(decrypted_overflowed); // Actually we know its an overflow case + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_bitwise_op.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_bitwise_op.rs index 25d178966c..95e8a421a8 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_bitwise_op.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_bitwise_op.rs @@ -1,13 +1,16 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_scalar_bitand_test, signed_default_scalar_bitor_test, - signed_default_scalar_bitxor_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, NB_CTXT, NB_TESTS, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_default_scalar_bitand); create_parametrized_test!(integer_signed_default_scalar_bitor); @@ -36,3 +39,133 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized); signed_default_scalar_bitxor_test(param, executor); } +pub(crate) fn signed_default_scalar_bitand_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let ct_res2 = executor.execute((&ctxt_0, clear_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 & clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + assert!(!ctxt_0.block_carries_are_empty()); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) & clear_1; + assert_eq!(dec_res, expected_result); + } +} + +pub(crate) fn signed_default_scalar_bitor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let ct_res2 = executor.execute((&ctxt_0, clear_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 | clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + assert!(!ctxt_0.block_carries_are_empty()); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) | clear_1; + assert_eq!(dec_res, expected_result); + } +} + +pub(crate) fn signed_default_scalar_bitxor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let ct_res2 = executor.execute((&ctxt_0, clear_1)); + assert_eq!(ct_res, ct_res2); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = clear_0 ^ clear_1; + assert_eq!(clear_res, dec_res); + + let clear_2 = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_2); + assert!(!ctxt_0.block_carries_are_empty()); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + let expected_result = signed_add_under_modulus(clear_0, clear_2, modulus) ^ clear_1; + assert_eq!(dec_res, expected_result); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_mul.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_mul.rs new file mode 100644 index 0000000000..7de2bf36f1 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_mul.rs @@ -0,0 +1,51 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + signed_mul_under_modulus, NB_CTXT, NB_TESTS, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_signed_unchecked_scalar_mul); + +fn integer_signed_unchecked_scalar_mul

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_mul_parallelized); + signed_unchecked_scalar_mul_test(param, executor); +} + +pub(crate) fn signed_unchecked_scalar_mul_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_mul_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_shift.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_shift.rs index 8f4f01589e..b2ab24c040 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_shift.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_shift.rs @@ -1,13 +1,17 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_scalar_left_shift_test, signed_default_scalar_right_shift_test, - signed_unchecked_scalar_left_shift_test, signed_unchecked_scalar_right_shift_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, signed_left_shift_under_modulus, + signed_right_shift_under_modulus, NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_scalar_left_shift); create_parametrized_test!(integer_signed_default_scalar_left_shift); @@ -45,3 +49,220 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized); signed_default_scalar_right_shift_test(param, executor); } +pub(crate) fn signed_unchecked_scalar_left_shift_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS { + let clear = rng.gen::() % modulus; + let clear_shift = rng.gen::(); + + let ct = cks.encrypt_signed(clear); + + // case when 0 <= shift < nb_bits + { + let clear_shift = clear_shift % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = signed_left_shift_under_modulus(clear, clear_shift, modulus); + assert_eq!(expected, dec_res); + } + + // case when shift >= nb_bits + { + let clear_shift = clear_shift.saturating_add(nb_bits); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = signed_left_shift_under_modulus(clear, clear_shift % nb_bits, modulus); + assert_eq!(expected, dec_res); + } + } +} + +pub(crate) fn signed_unchecked_scalar_right_shift_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS { + let clear = rng.gen::() % modulus; + let clear_shift = rng.gen::(); + + let ct = cks.encrypt_signed(clear); + + // case when 0 <= shift < nb_bits + { + let clear_shift = clear_shift % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = signed_right_shift_under_modulus(clear, clear_shift, modulus); + assert_eq!(expected, dec_res); + } + + // case when shift >= nb_bits + { + let clear_shift = clear_shift.saturating_add(nb_bits); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = signed_right_shift_under_modulus(clear, clear_shift % nb_bits, modulus); + assert_eq!(expected, dec_res); + } + } +} + +pub(crate) fn signed_default_scalar_left_shift_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS_SMALLER { + let mut clear = rng.gen::() % modulus; + + let offset = random_non_zero_value(&mut rng, modulus); + + let mut ct = cks.encrypt_signed(clear); + sks.unchecked_scalar_add_assign(&mut ct, offset); + clear = signed_add_under_modulus(clear, offset, modulus); + + // case when 0 <= shift < nb_bits + { + let clear_shift = rng.gen::() % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_left_shift_under_modulus(clear, clear_shift, modulus); + assert_eq!( + clear_res, dec_res, + "Invalid left shift result, for '{clear} << {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + + // case when shift >= nb_bits + { + let clear_shift = rng.gen_range(nb_bits..=u32::MAX); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + // We mimic wrapping_shl manually as we use a bigger type + // than the nb_bits we actually simulate in this test + let clear_res = signed_left_shift_under_modulus(clear, clear_shift % nb_bits, modulus); + assert_eq!( + clear_res, dec_res, + "Invalid left shift result, for '{clear} << {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +} + +pub(crate) fn signed_default_scalar_right_shift_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS_SMALLER { + let mut clear = rng.gen::() % modulus; + + let offset = random_non_zero_value(&mut rng, modulus); + + let mut ct = cks.encrypt_signed(clear); + sks.unchecked_scalar_add_assign(&mut ct, offset); + clear = signed_add_under_modulus(clear, offset, modulus); + + // case when 0 <= shift < nb_bits + { + let clear_shift = rng.gen::() % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_right_shift_under_modulus(clear, clear_shift, modulus); + assert_eq!( + clear_res, dec_res, + "Invalid right shift result, for '{clear} >> {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + + // case when shift >= nb_bits + { + let clear_shift = rng.gen_range(nb_bits..=u32::MAX); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + // We mimic wrapping_shl manually as we use a bigger type + // than the nb_bits we actually simulate in this test + let clear_res = signed_right_shift_under_modulus(clear, clear_shift % nb_bits, modulus); + assert_eq!( + clear_res, dec_res, + "Invalid right shift result, for '{clear} >> {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs index 649f3b3de6..8a98dd1f5f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs @@ -1,12 +1,21 @@ -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test, +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_overflowing_add_under_modulus, + signed_overflowing_sub_under_modulus, signed_sub_under_modulus, NB_CTXT, NB_TESTS, + NB_TESTS_SMALLER, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; -use crate::integer::ServerKey; +use crate::integer::{ + BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext, +}; +use crate::shortint::ciphertext::NoiseLevel; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_scalar_sub); create_parametrized_test!(integer_signed_default_overflowing_scalar_sub); @@ -26,3 +35,220 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized); signed_default_overflowing_scalar_sub_test(param, executor); } +pub(crate) fn signed_unchecked_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // check some overflow behaviour + let overflowing_values = [ + (-modulus, 1, modulus - 1), + (modulus - 1, -1, -modulus), + (-modulus, 2, modulus - 2), + (modulus - 2, -2, -modulus), + ]; + for (clear_0, clear_1, expected_clear) in overflowing_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + assert_eq!(clear_res, expected_clear); + } + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_overflowing_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, i64), + (SignedRadixCiphertext, BooleanBlock), + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + let hardcoded_values = [ + (-modulus, 1), + (modulus - 1, -1), + (1, -modulus), + (-1, modulus - 1), + ]; + for (clear_0, clear_1) in hardcoded_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (expected_result, expected_overflowed) = + signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + + let (expected_result, expected_overflowed) = + signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + + for _ in 0..NB_TESTS_SMALLER { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_rhs = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); + let d0: i64 = cks.decrypt_signed(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs)); + assert!(ct_res.block_carries_are_empty()); + let (expected_result, expected_overflowed) = + signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + } + + // Test with trivial inputs + for _ in 0..4 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } + + // Test with scalar that is bigger than ciphertext modulus + for _ in 0..2 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen_range(modulus..=i64::MAX); + + let a = cks.encrypt_signed(clear_0); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert!(decrypted_overflowed); // Actually we know its an overflow case + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_shift.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_shift.rs index fcfd052d3f..853ba297b2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_shift.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_shift.rs @@ -1,11 +1,8 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_signed::{ - random_non_zero_value, signed_add_under_modulus, signed_left_shift_under_modulus, - signed_right_shift_under_modulus, -}; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_signed::{ - NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, + random_non_zero_value, signed_add_under_modulus, signed_left_shift_under_modulus, + signed_right_shift_under_modulus, NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, }; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs index 5ab530bb7c..1f422051ba 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs @@ -1,5 +1,10 @@ use crate::integer::keycache::KEY_CACHE; -use crate::integer::server_key::radix_parallel::tests_cases_signed::*; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + create_iterator_of_signed_random_pairs, random_non_zero_value, signed_add_under_modulus, + signed_overflowing_sub_under_modulus, signed_sub_under_modulus, NB_CTXT, NB_TESTS, + NB_TESTS_SMALLER, NB_TESTS_UNCHECKED, +}; use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parametrized_test; use crate::integer::{ @@ -10,6 +15,7 @@ use crate::shortint::ciphertext::NoiseLevel; use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; use rand::Rng; +use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_sub); create_parametrized_test!(integer_signed_unchecked_overflowing_sub); @@ -304,3 +310,101 @@ where assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); } } + +pub(crate) fn signed_unchecked_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // check some overflow behaviour + let overflowing_values = [ + (-modulus, 1, modulus - 1), + (modulus - 1, -1, -modulus), + (-modulus, 2, modulus - 2), + (modulus - 2, -2, -modulus), + ]; + for (clear_0, clear_1, expected_clear) in overflowing_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + assert_eq!(clear_res, expected_clear); + } + + for (clear_0, clear_1) in + create_iterator_of_signed_random_pairs::<{ NB_TESTS_UNCHECKED }>(&mut rng, modulus) + { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct); + + clear = signed_sub_under_modulus(clear_0, clear_1, modulus); + + // sub multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + ct_res = executor.execute((&ct_res, &ctxt_0)); + assert!(ct_res.block_carries_are_empty()); + clear = signed_sub_under_modulus(clear, clear_0, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + assert_eq!(clear, dec_res); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 0a3ed20172..d7b2611db0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -6,6 +6,7 @@ pub(crate) mod test_neg; pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; +pub(crate) mod test_scalar_mul; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; @@ -42,6 +43,12 @@ pub(crate) const NB_CTXT: usize = 4; #[cfg(tarpaulin)] pub(crate) const NB_CTXT: usize = 2; +#[cfg(not(tarpaulin))] +pub(crate) const NB_TESTS_UNCHECKED: usize = NB_TESTS; +/// Unchecked test cases needs a minimum number of tests of 4 in order to provide guarantees. +#[cfg(tarpaulin)] +pub(crate) const NB_TESTS_UNCHECKED: usize = 4; + pub(crate) fn random_non_zero_value(rng: &mut ThreadRng, modulus: u64) -> u64 { rng.gen_range(1..modulus) } @@ -414,33 +421,6 @@ create_parametrized_test!( ); create_parametrized_test!(integer_smart_sum_ciphertexts_slice); create_parametrized_test!(integer_default_unsigned_overflowing_sum_ciphertexts_vec); -create_parametrized_test!(integer_unchecked_small_scalar_mul); -create_parametrized_test!(integer_smart_small_scalar_mul); -create_parametrized_test!(integer_default_small_scalar_mul); -create_parametrized_test!( - integer_smart_scalar_mul_u128_fix_non_reg_test { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - }, - no_coverage => { - PARAM_MESSAGE_1_CARRY_1_KS_PBS, - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - } - } -); -create_parametrized_test!(integer_unchecked_scalar_mul_corner_cases); -create_parametrized_test!( - integer_default_scalar_mul_u128_fix_non_reg_test { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - }, - no_coverage => { - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - } - } -); -create_parametrized_test!(integer_smart_scalar_mul); -create_parametrized_test!(integer_default_scalar_mul); // left/right shifts create_parametrized_test!( integer_unchecked_left_shift { @@ -729,22 +709,6 @@ where // Unchecked Scalar Tests //============================================================================= -fn integer_unchecked_small_scalar_mul

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_small_scalar_mul_parallelized); - unchecked_small_scalar_mul_test(param, executor); -} - -fn integer_unchecked_scalar_mul_corner_cases

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); - unchecked_scalar_mul_corner_cases_test(param, executor); -} - fn integer_unchecked_scalar_rotate_right

(param: P) where P: Into, @@ -842,30 +806,6 @@ where // Smart Scalar Tests //============================================================================= -fn integer_smart_small_scalar_mul

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::smart_small_scalar_mul_parallelized); - smart_small_scalar_mul_test(param, executor); -} - -fn integer_smart_scalar_mul

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_mul_parallelized); - smart_scalar_mul_test(param, executor); -} - -fn integer_smart_scalar_mul_u128_fix_non_reg_test

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_mul_parallelized); - smart_scalar_mul_u128_fix_non_reg_test(param, executor); -} - //============================================================================= // Default Tests //============================================================================= @@ -950,30 +890,6 @@ where default_checked_ilog2_test(param, executor); } -fn integer_default_small_scalar_mul

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::small_scalar_mul_parallelized); - default_small_scalar_mul_test(param, executor); -} - -fn integer_default_scalar_mul_u128_fix_non_reg_test

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); - default_scalar_mul_u128_fix_non_reg_test(param, executor); -} - -fn integer_default_scalar_mul

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); - default_scalar_mul_test(param, executor); -} - fn integer_default_scalar_rotate_right

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_mul.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_mul.rs new file mode 100644 index 0000000000..0257f409cc --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_mul.rs @@ -0,0 +1,75 @@ +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + default_scalar_mul_test, default_scalar_mul_u128_fix_non_reg_test, smart_scalar_mul_test, + smart_scalar_mul_u128_fix_non_reg_test, unchecked_scalar_mul_corner_cases_test, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::tests::create_parametrized_test; +use crate::integer::ServerKey; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; + +create_parametrized_test!( + integer_smart_scalar_mul_u128_fix_non_reg_test { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, + no_coverage => { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + } + } +); +create_parametrized_test!(integer_unchecked_scalar_mul_corner_cases); +create_parametrized_test!( + integer_default_scalar_mul_u128_fix_non_reg_test { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, + no_coverage => { + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + } + } +); +create_parametrized_test!(integer_smart_scalar_mul); +create_parametrized_test!(integer_default_scalar_mul); + +fn integer_unchecked_scalar_mul_corner_cases

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); + unchecked_scalar_mul_corner_cases_test(param, executor); +} + +fn integer_smart_scalar_mul

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_mul_parallelized); + smart_scalar_mul_test(param, executor); +} + +fn integer_smart_scalar_mul_u128_fix_non_reg_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_mul_parallelized); + smart_scalar_mul_u128_fix_non_reg_test(param, executor); +} + +fn integer_default_scalar_mul_u128_fix_non_reg_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); + default_scalar_mul_u128_fix_non_reg_test(param, executor); +} + +fn integer_default_scalar_mul

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); + default_scalar_mul_test(param, executor); +}