From efeb35bd59af19badb63452ab7af642751e4fb51 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Wed, 6 Mar 2024 11:07:12 +0100 Subject: [PATCH] feat(gpu): signed scalar add --- tfhe/benches/integer/signed_bench.rs | 118 +++++++- .../integers/unsigned/scalar_ops.rs | 2 +- .../gpu/server_key/radix/scalar_add.rs | 66 ++--- .../gpu/server_key/radix/tests_signed/mod.rs | 37 ++- .../radix/tests_signed/test_scalar_add.rs | 27 ++ .../server_key/radix/tests_unsigned/mod.rs | 19 +- .../radix/tests_unsigned/test_scalar_add.rs | 27 ++ .../radix_parallel/tests_cases_signed.rs | 273 +++++++++++++++++- .../radix_parallel/tests_signed/mod.rs | 254 +--------------- .../tests_signed/test_scalar_add.rs | 37 +++ .../radix_parallel/tests_unsigned/mod.rs | 29 +- .../tests_unsigned/test_scalar_add.rs | 37 +++ 12 files changed, 582 insertions(+), 344 deletions(-) create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_add.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_add.rs diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index 04744a807b..7847a97300 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -1382,6 +1382,100 @@ mod cuda { } ); + fn bench_cuda_server_key_binary_scalar_signed_function_clean_inputs( + c: &mut Criterion, + bench_name: &str, + display_name: &str, + binary_op: F, + rng_func: G, + ) where + F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, ScalarType, &CudaStream), + G: Fn(&mut ThreadRng, usize) -> ScalarType, + { + let mut bench_group = c.benchmark_group(bench_name); + bench_group + .sample_size(15) + .measurement_time(std::time::Duration::from_secs(60)); + let mut rng = rand::thread_rng(); + + let gpu_index = 0; + let device = CudaDevice::new(gpu_index); + let stream = CudaStream::new_unchecked(device); + + for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() { + if bit_size > ScalarType::BITS as usize { + break; + } + let param_name = param.name(); + + let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size); + + let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits_scalar_{bit_size}"); + bench_group.bench_function(&bench_id, |b| { + let (cks, _cpu_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let gpu_sks = CudaServerKey::new(&cks, &stream); + + let encrypt_one_value = || { + let clearlow = rng.gen::(); + let clearhigh = rng.gen::(); + let clear_0 = tfhe::integer::I256::from((clearlow, clearhigh)); + let ct_0 = cks.encrypt_signed_radix(clear_0, num_block); + let d_ct_0 = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct_0, &stream); + + let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size; + + (d_ct_0, clear_1) + }; + + b.iter_batched( + encrypt_one_value, + |(mut ct_0, clear_1)| { + binary_op(&gpu_sks, &mut ct_0, clear_1, &stream); + }, + criterion::BatchSize::SmallInput, + ) + }); + + write_to_json::( + &bench_id, + param, + param.name(), + display_name, + &OperatorType::Atomic, + bit_size as u32, + vec![param.message_modulus().0.ilog2(); num_block], + ); + } + + bench_group.finish() + } + + macro_rules! define_cuda_server_key_bench_clean_input_scalar_signed_fn ( + (method_name: $server_key_method:ident, display_name:$name:ident, rng_func:$($rng_fn:tt)*) => { + ::paste::paste!{ + fn [](c: &mut Criterion) { + bench_cuda_server_key_binary_scalar_signed_function_clean_inputs( + c, + concat!("integer::cuda::signed::", stringify!($server_key_method)), + stringify!($name), + |server_key, lhs, rhs, stream| { + server_key.$server_key_method(lhs, rhs, stream); + }, + $($rng_fn)* + ) + } + } + } + ); + + // Functions used to apply different way of selecting a scalar based on the context. + fn default_signed_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType { + let clearlow = rng.gen::(); + let clearhigh = rng.gen::(); + tfhe::integer::I256::from((clearlow, clearhigh)) + } + define_cuda_server_key_bench_clean_input_signed_fn!( method_name: unchecked_add, display_name: add @@ -1402,6 +1496,12 @@ mod cuda { display_name: mul ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_add, + display_name: add, + rng_func: default_signed_scalar + ); + //=========================================== // Default //=========================================== @@ -1426,28 +1526,42 @@ mod cuda { display_name: mul ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_add, + display_name: add, + rng_func: default_signed_scalar + ); + criterion_group!( unchecked_cuda_ops, cuda_unchecked_add, cuda_unchecked_sub, cuda_unchecked_neg, - cuda_unchecked_mul + cuda_unchecked_mul, ); + criterion_group!(unchecked_scalar_cuda_ops, cuda_unchecked_scalar_add,); + criterion_group!(default_cuda_ops, cuda_add, cuda_sub, cuda_neg, cuda_mul); + + criterion_group!(default_scalar_cuda_ops, cuda_scalar_add); } #[cfg(feature = "gpu")] -use cuda::{default_cuda_ops, unchecked_cuda_ops}; +use cuda::{ + default_cuda_ops, default_scalar_cuda_ops, unchecked_cuda_ops, unchecked_scalar_cuda_ops, +}; #[cfg(feature = "gpu")] fn go_through_gpu_bench_groups(val: &str) { match val.to_lowercase().as_str() { "default" => { default_cuda_ops(); + default_scalar_cuda_ops(); } "unchecked" => { unchecked_cuda_ops(); + unchecked_scalar_cuda_ops(); } _ => panic!("unknown benchmark operations flavor"), }; diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 813b46ac03..d5aab6ce6d 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -465,7 +465,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_stream(|stream| { cuda_key.key.scalar_add( - &lhs.ciphertext.on_gpu(), rhs, stream + &*lhs.ciphertext.on_gpu(), rhs, stream ) }); RadixCiphertext::Cuda(inner_result) diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs index d4225e4223..16b7383919 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs @@ -1,7 +1,7 @@ use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::CudaStream; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; -use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::gpu::server_key::CudaServerKey; use itertools::Itertools; @@ -43,14 +43,10 @@ impl CudaServerKey { /// let dec: u64 = cks.decrypt(&ct_res); /// assert_eq!(msg + scalar, dec); /// ``` - pub fn unchecked_scalar_add( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn unchecked_scalar_add(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; self.unchecked_scalar_add_assign(&mut result, scalar, stream); @@ -61,15 +57,16 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_add_assign_async( + pub unsafe fn unchecked_scalar_add_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: T, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, ) where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { - if scalar > T::ZERO { + if scalar != Scalar::ZERO { let bits_in_message = self.message_modulus.0.ilog2(); let decomposer = BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message).iter_as::(); @@ -95,18 +92,19 @@ impl CudaServerKey { self.message_modulus.0 as u32, self.carry_modulus.0 as u32, ); - } - ct.as_mut().info = ct.as_ref().info.after_scalar_add(scalar); + ct.as_mut().info = ct.as_ref().info.after_scalar_add(scalar); + } } - pub fn unchecked_scalar_add_assign( + pub fn unchecked_scalar_add_assign( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: T, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, ) where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { unsafe { self.unchecked_scalar_add_assign_async(ct, scalar, stream); @@ -151,14 +149,10 @@ impl CudaServerKey { /// let dec: u64 = cks.decrypt(&ct_res); /// assert_eq!(msg + scalar, dec); /// ``` - pub fn scalar_add( - &self, - ct: &CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn scalar_add(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { let mut result = unsafe { ct.duplicate_async(stream) }; self.scalar_add_assign(&mut result, scalar, stream); @@ -169,13 +163,14 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn scalar_add_assign_async( + pub unsafe fn scalar_add_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: T, + ct: &mut T, + scalar: Scalar, stream: &CudaStream, ) where - T: DecomposableInto, + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { if !ct.block_carries_are_empty() { self.full_propagate_assign_async(ct, stream); @@ -185,13 +180,10 @@ impl CudaServerKey { self.full_propagate_assign_async(ct, stream); } - pub fn scalar_add_assign( - &self, - ct: &mut CudaUnsignedRadixCiphertext, - scalar: T, - stream: &CudaStream, - ) where - T: DecomposableInto, + pub fn scalar_add_assign(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream) + where + Scalar: DecomposableInto, + T: CudaIntegerRadixCiphertext, { unsafe { self.scalar_add_assign_async(ct, scalar, stream); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index 5460b9c1e1..b6027965aa 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod test_add; pub(crate) mod test_mul; pub(crate) mod test_neg; +pub(crate) mod test_scalar_add; pub(crate) mod test_sub; use crate::core_crypto::gpu::CudaStream; @@ -100,13 +101,13 @@ where } /// For unchecked/default binary functions with one scalar input -impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext> +impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext> for GpuFunctionExecutor where F: Fn( &CudaServerKey, &CudaSignedRadixCiphertext, - u64, + i64, &CudaStream, ) -> CudaSignedRadixCiphertext, { @@ -114,7 +115,7 @@ where self.setup_from_keys(cks, &sks); } - fn execute(&mut self, input: (&'a SignedRadixCiphertext, u64)) -> SignedRadixCiphertext { + fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> SignedRadixCiphertext { let context = self .context .as_ref() @@ -128,3 +129,33 @@ where gpu_result.to_signed_radix_ciphertext(&context.stream) } } + +/// For unchecked/default binary functions with one scalar input +impl FunctionExecutor<(SignedRadixCiphertext, i64), SignedRadixCiphertext> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaSignedRadixCiphertext, + i64, + &CudaStream, + ) -> CudaSignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: (SignedRadixCiphertext, i64)) -> SignedRadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1 = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&input.0, &context.stream); + + let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.stream); + + gpu_result.to_signed_radix_ciphertext(&context.stream) + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs new file mode 100644 index 0000000000..9248e64e40 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs @@ -0,0 +1,27 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_signed::{ + signed_default_scalar_add_test, signed_unchecked_scalar_add_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_add); +create_gpu_parametrized_test!(integer_signed_scalar_add); + +fn integer_signed_unchecked_scalar_add

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_add); + signed_unchecked_scalar_add_test(param, executor); +} + +fn integer_signed_scalar_add

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add); + signed_default_scalar_add_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index 22e48348ba..b0ab7a6ce8 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod test_add; pub(crate) mod test_mul; pub(crate) mod test_neg; +pub(crate) mod test_scalar_add; pub(crate) mod test_sub; use crate::core_crypto::gpu::{CudaDevice, CudaStream}; @@ -74,7 +75,6 @@ impl GpuFunctionExecutor { } // Unchecked operations -create_gpu_parametrized_test!(integer_unchecked_scalar_add); create_gpu_parametrized_test!(integer_unchecked_scalar_sub); create_gpu_parametrized_test!(integer_unchecked_small_scalar_mul); create_gpu_parametrized_test!(integer_unchecked_bitnot); @@ -107,7 +107,6 @@ create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left); create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right); // Default operations -create_gpu_parametrized_test!(integer_scalar_add); create_gpu_parametrized_test!(integer_scalar_sub); create_gpu_parametrized_test!(integer_small_scalar_mul); create_gpu_parametrized_test!(integer_scalar_right_shift); @@ -311,14 +310,6 @@ where } } -fn integer_unchecked_scalar_add

(param: P) -where - P: Into + Copy, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_add); - unchecked_scalar_add_test(param, executor); -} - fn integer_unchecked_small_scalar_mul

(param: P) where P: Into, @@ -1497,14 +1488,6 @@ where } } -fn integer_scalar_add

(param: P) -where - P: Into + Copy, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add); - default_scalar_add_test(param, executor); -} - fn integer_small_scalar_mul

(param: P) where P: Into, diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_add.rs new file mode 100644 index 0000000000..869c006efb --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_add.rs @@ -0,0 +1,27 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + default_scalar_add_test, unchecked_scalar_add_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_unchecked_scalar_add); +create_gpu_parametrized_test!(integer_scalar_add); + +fn integer_unchecked_scalar_add

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_add); + unchecked_scalar_add_test(param, executor); +} + +fn integer_scalar_add

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add); + default_scalar_add_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs index a5c63a80d9..ab66b70e1e 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_signed.rs @@ -1,7 +1,10 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; -use crate::integer::server_key::radix_parallel::tests_signed::{NB_CTXT, NB_TESTS_SMALLER}; -use crate::integer::{IntegerKeyKind, RadixClientKey, SignedRadixCiphertext}; +use crate::integer::server_key::radix_parallel::tests_signed::{ + NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, +}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, SignedRadixCiphertext}; +use crate::shortint::ciphertext::NoiseLevel; use crate::shortint::PBSParameters; use itertools::izip; use rand::prelude::ThreadRng; @@ -788,3 +791,269 @@ where } } } + +pub(crate) fn signed_unchecked_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from(( + cks, + crate::integer::server_key::radix_parallel::tests_cases_unsigned::NB_CTXT, + )); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + // check some overflow behaviour + let overflowing_values = [ + (-modulus, -1, modulus - 1), + (modulus - 1, 1, -modulus), + (-modulus, -2, modulus - 2), + (modulus - 2, 2, -modulus), + ]; + for (clear_0, clear_1, expected_clear) in overflowing_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + assert_eq!(clear_res, expected_clear); + } + + for _ in 0..NB_TESTS { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute((&ctxt_0, clear_1)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + sks.set_deterministic_pbs_execution(true); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks); + + let mut clear; + + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let mut ct_res = executor.execute((&ctxt_0, clear_1)); + assert!(ct_res.block_carries_are_empty()); + + clear = signed_add_under_modulus(clear_0, clear_1, modulus); + + // add multiple times to raise the degree + for _ in 0..NB_TESTS_SMALLER { + let tmp = executor.execute((&ct_res, clear_1)); + ct_res = executor.execute((&ct_res, clear_1)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + clear = signed_add_under_modulus(clear, clear_1, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(clear, dec_res); + } + } +} + +pub(crate) fn signed_default_overflowing_scalar_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, i64), + (SignedRadixCiphertext, BooleanBlock), + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + executor.setup(&cks, sks.clone()); + + let hardcoded_values = [ + (-modulus, -1), + (modulus - 1, 1), + (-1, -modulus), + (1, modulus - 1), + ]; + for (clear_0, clear_1) in hardcoded_values { + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + + for _ in 0..NB_TESTS_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); + let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + + for _ in 0..NB_TESTS_SMALLER { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_rhs = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); + let d0: i64 = cks.decrypt_signed(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs)); + assert!(ct_res.block_carries_are_empty()); + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } + } + + // Test with trivial inputs + for _ in 0..4 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } + + // Test with scalar that is bigger than ciphertext modulus + for _ in 0..2 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen_range(modulus..=i64::MAX); + + let a = cks.encrypt_signed(clear_0); + + let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1)); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); + let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert!(decrypted_overflowed); // Actually we know its an overflow case + assert_eq!(encrypted_overflow.0.degree.get(), 1); + assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 0299e7c4b5..eb023b1535 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod test_add; pub(crate) mod test_mul; pub(crate) mod test_neg; +pub(crate) mod test_scalar_add; pub(crate) mod test_sub; use crate::integer::keycache::KEY_CACHE; @@ -1511,7 +1512,6 @@ where // Unchecked Scalar Tests //================================================================================ -create_parametrized_test!(integer_signed_unchecked_scalar_add); create_parametrized_test!(integer_signed_unchecked_scalar_sub); create_parametrized_test!(integer_signed_unchecked_scalar_mul); create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left); @@ -1524,42 +1524,6 @@ create_parametrized_test!(integer_signed_unchecked_scalar_bitxor); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor); -fn integer_signed_unchecked_scalar_add(param: impl Into) { - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - // check some overflow behaviour - let overflowing_values = [ - (-modulus, -1, modulus - 1), - (modulus - 1, 1, -modulus), - (-modulus, -2, modulus - 2), - (modulus - 2, 2, -modulus), - ]; - for (clear_0, clear_1, expected_clear) in overflowing_values { - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ct_res = sks.unchecked_scalar_add(&ctxt_0, clear_1); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - assert_eq!(clear_res, expected_clear); - } - - for _ in 0..NB_TESTS { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - - let ct_res = sks.unchecked_scalar_add(&ctxt_0, clear_1); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = signed_add_under_modulus(clear_0, clear_1, modulus); - assert_eq!(clear_res, dec_res); - } -} - fn integer_signed_unchecked_scalar_sub(param: impl Into) { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -2076,8 +2040,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into // Default Scalar Tests //================================================================================ -create_parametrized_test!(integer_signed_default_scalar_add); -create_parametrized_test!(integer_signed_default_overflowing_scalar_add); create_parametrized_test!(integer_signed_default_overflowing_scalar_sub); create_parametrized_test!(integer_signed_default_scalar_bitand); create_parametrized_test!(integer_signed_default_scalar_bitor); @@ -2088,220 +2050,6 @@ create_parametrized_test!(integer_signed_default_scalar_right_shift); create_parametrized_test!(integer_signed_default_scalar_rotate_right); create_parametrized_test!(integer_signed_default_scalar_rotate_left); -fn integer_signed_default_scalar_add

(param: P) -where - P: Into, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - let mut clear; - - let mut rng = rand::thread_rng(); - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let mut ct_res = sks.scalar_add_parallelized(&ctxt_0, clear_1); - assert!(ct_res.block_carries_are_empty()); - - clear = signed_add_under_modulus(clear_0, clear_1, modulus); - - // add multiple times to raise the degree - for _ in 0..NB_TESTS_SMALLER { - let tmp = sks.scalar_add_parallelized(&ct_res, clear_1); - ct_res = sks.scalar_add_parallelized(&ct_res, clear_1); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - clear = signed_add_under_modulus(clear, clear_1, modulus); - - let dec_res: i64 = cks.decrypt_signed(&ct_res); - assert_eq!(clear, dec_res); - } - } -} - -pub(crate) fn integer_signed_default_overflowing_scalar_add

(param: P) -where - P: Into, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - let hardcoded_values = [ - (-modulus, -1), - (modulus - 1, 1), - (-1, -modulus), - (1, modulus - 1), - ]; - for (clear_0, clear_1) in hardcoded_values { - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = - sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - - for _ in 0..NB_TESTS_SMALLER { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - - let (ct_res, result_overflowed) = - sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); - let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct, "Failed determinism check"); - assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - - for _ in 0..NB_TESTS_SMALLER { - // Add non zero scalar to have non clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_rhs = random_non_zero_value(&mut rng, modulus); - - let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); - let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus); - let d0: i64 = cks.decrypt_signed(&ctxt_0); - assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); - - let (ct_res, result_overflowed) = - sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_rhs); - assert!(ct_res.block_carries_are_empty()); - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); - } - } - - // Test with trivial inputs - for _ in 0..4 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); - - let (encrypted_result, encrypted_overflow) = - sks.signed_overflowing_scalar_add_parallelized(&a, clear_1); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } - - // Test with scalar that is bigger than ciphertext modulus - for _ in 0..2 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen_range(modulus..=i64::MAX); - - let a = cks.encrypt_signed(clear_0); - - let (encrypted_result, encrypted_overflow) = - sks.signed_overflowing_scalar_add_parallelized(&a, clear_1); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result); - let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert!(decrypted_overflowed); // Actually we know its an overflow case - assert_eq!(encrypted_overflow.0.degree.get(), 1); - assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); - } -} - pub(crate) fn integer_signed_default_overflowing_scalar_sub

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs new file mode 100644 index 0000000000..08744ad14a --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs @@ -0,0 +1,37 @@ +use crate::integer::server_key::radix_parallel::tests_cases_signed::{ + signed_default_overflowing_scalar_add_test, signed_default_scalar_add_test, + signed_unchecked_scalar_add_test, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::ServerKey; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; + +create_parametrized_test!(integer_signed_unchecked_scalar_add); +create_parametrized_test!(integer_signed_default_scalar_add); +create_parametrized_test!(integer_signed_default_overflowing_scalar_add); + +fn integer_signed_unchecked_scalar_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_add); + signed_unchecked_scalar_add_test(param, executor); +} + +fn integer_signed_default_scalar_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized); + signed_default_scalar_add_test(param, executor); +} + +fn integer_signed_default_overflowing_scalar_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized); + signed_default_overflowing_scalar_add_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 0a84345061..e4ba883e9b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod test_add; pub(crate) mod test_mul; pub(crate) mod test_neg; +pub(crate) mod test_scalar_add; pub(crate) mod test_sub; use super::tests_cases_unsigned::*; @@ -599,9 +600,6 @@ create_parametrized_test!(integer_default_scalar_div_rem); create_parametrized_test!(integer_smart_scalar_sub); create_parametrized_test!(integer_default_scalar_sub); create_parametrized_test!(integer_default_overflowing_scalar_sub); -create_parametrized_test!(integer_smart_scalar_add); -create_parametrized_test!(integer_default_scalar_add); -create_parametrized_test!(integer_default_overflowing_scalar_add); create_parametrized_test!(integer_smart_if_then_else); create_parametrized_test!(integer_default_if_then_else); create_parametrized_test!(integer_trim_radix_msb_blocks_handles_dirty_inputs); @@ -984,14 +982,6 @@ where // Smart Scalar Tests //============================================================================= -fn integer_smart_scalar_add

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_add_parallelized); - smart_scalar_add_test(param, executor); -} - fn integer_smart_scalar_sub

(param: P) where P: Into, @@ -1144,23 +1134,6 @@ where // Default Scalar Tests //============================================================================= -fn integer_default_scalar_add

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized); - default_scalar_add_test(param, executor); -} - -fn integer_default_overflowing_scalar_add

(param: P) -where - P: Into, -{ - let executor = - CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_add_parallelized); - default_overflowing_scalar_add_test(param, executor); -} - fn integer_default_scalar_sub

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_add.rs new file mode 100644 index 0000000000..af3c5ac0c6 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_add.rs @@ -0,0 +1,37 @@ +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + default_overflowing_scalar_add_test, default_scalar_add_test, smart_scalar_add_test, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::ServerKey; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; + +create_parametrized_test!(integer_smart_scalar_add); +create_parametrized_test!(integer_default_scalar_add); +create_parametrized_test!(integer_default_overflowing_scalar_add); + +fn integer_smart_scalar_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_add_parallelized); + smart_scalar_add_test(param, executor); +} + +fn integer_default_scalar_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized); + default_scalar_add_test(param, executor); +} + +fn integer_default_overflowing_scalar_add

(param: P) +where + P: Into, +{ + let executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_add_parallelized); + default_overflowing_scalar_add_test(param, executor); +}