From 5c189d6bf3330370c54aba68a39129d2a01b96ea Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 7 Nov 2024 14:05:59 +0100 Subject: [PATCH] chore(ci): use function executor for abs and signed div tests --- .../radix_parallel/tests_signed/mod.rs | 308 +----------------- .../radix_parallel/tests_signed/test_abs.rs | 180 ++++++++++ .../tests_signed/test_div_rem.rs | 231 +++++++++++++ 3 files changed, 414 insertions(+), 305 deletions(-) create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_abs.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_div_rem.rs diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 70138f86e6..28f8add13e 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -1,9 +1,11 @@ mod modulus_switch_compression; +pub(crate) mod test_abs; pub(crate) mod test_add; pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; mod test_count_zeros_ones; +pub(crate) mod test_div_rem; pub(crate) mod test_ilog2; pub(crate) mod test_mul; pub(crate) mod test_neg; @@ -24,7 +26,7 @@ use crate::core_crypto::prelude::SignedInteger; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::{ - nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor, MAX_NB_CTXT, NB_CTXT, + nb_tests_for_params, CpuFunctionExecutor, MAX_NB_CTXT, NB_CTXT, }; use crate::integer::tests::create_parametrized_test; use crate::integer::{ @@ -187,306 +189,6 @@ fn integer_signed_encrypt_decrypt(param: impl Into) { } } -//================================================================================ -// Unchecked Tests -//================================================================================ - -create_parametrized_test!( - integer_signed_unchecked_div_rem { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // Does not support 1_1 - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - } - } -); -create_parametrized_test!( - integer_signed_unchecked_div_rem_floor { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // Does not support 1_1 - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - } - } -); -create_parametrized_test!(integer_signed_unchecked_absolute_value); - -fn integer_signed_unchecked_absolute_value(param: impl Into) { - let param = param.into(); - let nb_tests = nb_tests_for_params(param); - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - // For signed integers, the range of value is [-modulus..modulus[ - // e.g.: for i8, the range is [-128..128[ <=> [-128..127] - // which means -modulus cannot be represented. - // - // In Rust, .abs() / .wrapping_abs() returns MIN (-modulus) - // https://doc.rust-lang.org/std/primitive.i8.html#method.wrapping_abs - // - // Here we test we have same behaviour - // - // (Conveniently, when using Two's complement, casting the result of abs to - // an unsigned to will give correct value for -modulus - // e.g.:(-128i8).wrapping_abs() as u8 == 128 - { - let clear_0 = -modulus; - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ct_res = sks.unchecked_abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - assert_eq!(dec_res, -modulus); - } - - for _ in 0..nb_tests { - let clear_0 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - - let ct_res = sks.unchecked_abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = absolute_value_under_modulus(clear_0, modulus); - assert_eq!(clear_res, dec_res); - } -} - -fn integer_signed_unchecked_div_rem(param: impl Into) { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - // Test case of division by 0 - // This is mainly to show we know the behaviour of division by 0 - // using the current algorithm - for clear_0 in [0i64, rng.gen::() % modulus] { - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ctxt_1 = cks.encrypt_signed_radix(0, NB_CTXT); - - let (q_res, r_res) = sks.unchecked_div_rem_parallelized(&ctxt_0, &ctxt_1); - let q: i64 = cks.decrypt_signed_radix(&q_res); - let r: i64 = cks.decrypt_signed_radix(&r_res); - - assert_eq!(r, clear_0); - assert_eq!(q, if clear_0 >= 0 { -1 } else { 1 }); - } - - // Div is the slowest operation - for _ in 0..nb_tests_smaller { - let clear_0 = rng.gen::() % modulus; - let clear_1 = loop { - let value = rng.gen::() % modulus; - if value != 0 { - break value; - } - }; - - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT); - - let (q_res, r_res) = sks.unchecked_div_rem_parallelized(&ctxt_0, &ctxt_1); - let q: i64 = cks.decrypt_signed_radix(&q_res); - let r: i64 = cks.decrypt_signed_radix(&r_res); - let expected_q = signed_div_under_modulus(clear_0, clear_1, modulus); - assert_eq!( - q, expected_q, - "Invalid division result, for {clear_0} / {clear_1} \ - expected quotient: {expected_q} got: {q}" - ); - let expected_r = signed_rem_under_modulus(clear_0, clear_1, modulus); - assert_eq!( - r, expected_r, - "Invalid remainder result, for {clear_0} % {clear_1} \ - expected quotient: {expected_r} got: {r}" - ); - } -} - -fn integer_signed_unchecked_div_rem_floor(param: impl Into) { - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - if modulus > 8 { - // Some hard coded test for flooring div - // For example, truncating_div(-7, 3) would give q = -2 and r = -1 - // truncating div is the default in rust (and many other languages) - // Python does use a flooring div, so you can try these values in you local - // interpreter. - let values = [ - (-8, 3, -3, 1), - (8, -3, -3, -1), - (7, 3, 2, 1), - (-7, 3, -3, 2), - (7, -3, -3, -2), - (-7, -3, 2, -1), - ]; - for (clear_0, clear_1, expected_q, expected_r) in values { - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT); - - let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1); - let q: i64 = cks.decrypt_signed_radix(&q_res); - let r: i64 = cks.decrypt_signed_radix(&r_res); - - // Uses the hardcoded values to also test our clear function - let (q2, r2) = signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus); - - assert_eq!(q2, expected_q); - assert_eq!(r2, expected_r); - assert_eq!(q, expected_q); - assert_eq!(r, expected_r); - } - } - - // A test where the division is whole, aka remainder is zero - { - let ctxt_0 = cks.encrypt_signed_radix(4, NB_CTXT); - let ctxt_1 = cks.encrypt_signed_radix(-2, NB_CTXT); - - let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1); - let q: i64 = cks.decrypt_signed_radix(&q_res); - let r: i64 = cks.decrypt_signed_radix(&r_res); - - // Uses the hardcoded values to also test our clear function - let (q2, r2) = signed_div_rem_floor_under_modulus(4, -2, modulus); - - assert_eq!(q2, -2); - assert_eq!(r2, 0); - assert_eq!(q, -2); - assert_eq!(r, 0); - } - - // Div is the slowest operation - for _ in 0..5 { - let clear_0 = rng.gen::() % modulus; - let clear_1 = loop { - let value = rng.gen::() % modulus; - if value != 0 { - break value; - } - }; - - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT); - - let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1); - let q: i64 = cks.decrypt_signed_radix(&q_res); - let r: i64 = cks.decrypt_signed_radix(&r_res); - let (expected_q, expected_r) = - signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus); - - println!("{clear_0} / {clear_1} -> ({q}, {r})"); - assert_eq!(q, expected_q); - assert_eq!(r, expected_r); - } -} - -//================================================================================ -// Smart Tests -//================================================================================ - -create_parametrized_test!(integer_signed_smart_absolute_value); - -fn integer_signed_smart_absolute_value(param: impl Into) { - let param = param.into(); - let nb_tests = nb_tests_for_params(param); - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - { - let clear_0 = -modulus; - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ct_res = sks.abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - assert_eq!(dec_res, -modulus); - } - - for _ in 0..nb_tests { - let mut clear_0 = rng.gen::() % modulus; - let clear_to_add = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add); - clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus); - - let ct_res = sks.abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = absolute_value_under_modulus(clear_0, modulus); - assert_eq!(clear_res, dec_res); - } -} - -//================================================================================ -// Default Tests -//================================================================================ - -create_parametrized_test!(integer_signed_default_absolute_value); - -fn integer_signed_default_absolute_value(param: impl Into) { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - { - let clear_0 = -modulus; - let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - let ct_res = sks.abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - assert_eq!(dec_res, -modulus); - } - - for _ in 0..nb_tests_smaller { - let mut clear_0 = rng.gen::() % modulus; - let clear_to_add = rng.gen::() % modulus; - - let mut ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT); - sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add); - clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus); - - let ct_res = sks.abs_parallelized(&ctxt_0); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let clear_res = absolute_value_under_modulus(clear_0, modulus); - assert_eq!(clear_res, dec_res); - - let ct_res2 = sks.abs_parallelized(&ctxt_0); - assert_eq!(ct_res2, ct_res); - } -} - //================================================================================ // Unchecked Scalar Tests //================================================================================ @@ -595,10 +297,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into } } -//================================================================================ -// Smart Scalar Tests -//================================================================================ - //================================================================================ // Default Scalar Tests //================================================================================ diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_abs.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_abs.rs new file mode 100644 index 0000000000..7f04d1e10d --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_abs.rs @@ -0,0 +1,180 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + absolute_value_under_modulus, signed_add_under_modulus, NB_CTXT, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + nb_tests_for_params, CpuFunctionExecutor, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_signed_default_absolute_value); +create_parametrized_test!(integer_signed_unchecked_absolute_value); +create_parametrized_test!(integer_signed_smart_absolute_value); + +fn integer_signed_default_absolute_value

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::abs_parallelized); + signed_default_absolute_value_test(param, executor); +} + +fn integer_signed_unchecked_absolute_value

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_abs_parallelized); + signed_unchecked_absolute_value_test(param, executor); +} + +fn integer_signed_smart_absolute_value

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_abs_parallelized); + signed_smart_absolute_value_test(param, executor); +} + +pub(crate) fn signed_unchecked_absolute_value_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + executor.setup(&cks, sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + // For signed integers, the range of value is [-modulus..modulus[ + // e.g.: for i8, the range is [-128..128[ <=> [-128..127] + // which means -modulus cannot be represented. + // + // In Rust, .abs() / .wrapping_abs() returns MIN (-modulus) + // https://doc.rust-lang.org/std/primitive.i8.html#method.wrapping_abs + // + // Here we test we have same behaviour + // + // (Conveniently, when using Two's complement, casting the result of abs to + // an unsigned to will give correct value for -modulus + // e.g.:(-128i8).wrapping_abs() as u8 == 128 + { + let clear_0 = -modulus; + let ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute(&ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(dec_res, -modulus); + } + + for _ in 0..nb_tests { + let clear_0 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt_signed(clear_0); + + let ct_res = executor.execute(&ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = absolute_value_under_modulus(clear_0, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_smart_absolute_value_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a mut SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + executor.setup(&cks, sks.clone()); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + { + let clear_0 = -modulus; + let mut ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute(&mut ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(dec_res, -modulus); + } + + for _ in 0..nb_tests { + let mut clear_0 = rng.gen::() % modulus; + let clear_to_add = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add); + clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus); + + let ct_res = executor.execute(&mut ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = absolute_value_under_modulus(clear_0, modulus); + assert_eq!(clear_res, dec_res); + } +} + +pub(crate) fn signed_default_absolute_value_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + executor.setup(&cks, sks.clone()); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + { + let clear_0 = -modulus; + let ctxt_0 = cks.encrypt_signed(clear_0); + let ct_res = executor.execute(&ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + assert_eq!(dec_res, -modulus); + } + + for _ in 0..nb_tests { + let mut clear_0 = rng.gen::() % modulus; + let clear_to_add = rng.gen::() % modulus; + + let mut ctxt_0 = cks.encrypt_signed(clear_0); + sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add); + clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus); + + let ct_res = executor.execute(&ctxt_0); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = absolute_value_under_modulus(clear_0, modulus); + assert_eq!(clear_res, dec_res); + + let ct_res2 = executor.execute(&ctxt_0); + assert_eq!(ct_res2, ct_res); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_div_rem.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_div_rem.rs new file mode 100644 index 0000000000..125313dd2d --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_div_rem.rs @@ -0,0 +1,231 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + signed_div_rem_floor_under_modulus, signed_div_under_modulus, signed_rem_under_modulus, NB_CTXT, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + nb_tests_smaller_for_params, CpuFunctionExecutor, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!( + integer_signed_unchecked_div_rem { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + }, + no_coverage => { + // Does not support 1_1 + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +create_parametrized_test!( + integer_signed_unchecked_div_rem_floor { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + }, + no_coverage => { + // Does not support 1_1 + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +fn integer_signed_unchecked_div_rem

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_div_rem_parallelized); + signed_unchecked_div_rem_test(param, executor); +} + +fn integer_signed_unchecked_div_rem_floor

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_div_rem_floor_parallelized); + signed_unchecked_div_rem_floor_test(param, executor); +} + +pub(crate) fn signed_unchecked_div_rem_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + (SignedRadixCiphertext, SignedRadixCiphertext), + >, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + executor.setup(&cks, sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + // Test case of division by 0 + // This is mainly to show we know the behaviour of division by 0 + // using the current algorithm + for clear_0 in [0i64, rng.gen::() % modulus] { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(0); + + let (q_res, r_res) = executor.execute((&ctxt_0, &ctxt_1)); + let q: i64 = cks.decrypt_signed(&q_res); + let r: i64 = cks.decrypt_signed(&r_res); + + assert_eq!(r, clear_0); + assert_eq!(q, if clear_0 >= 0 { -1 } else { 1 }); + } + + // Div is the slowest operation + for _ in 0..nb_tests_smaller { + let clear_0 = rng.gen::() % modulus; + let clear_1 = loop { + let value = rng.gen::() % modulus; + if value != 0 { + break value; + } + }; + + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let (q_res, r_res) = executor.execute((&ctxt_0, &ctxt_1)); + let q: i64 = cks.decrypt_signed(&q_res); + let r: i64 = cks.decrypt_signed(&r_res); + let expected_q = signed_div_under_modulus(clear_0, clear_1, modulus); + assert_eq!( + q, expected_q, + "Invalid division result, for {clear_0} / {clear_1} \ + expected quotient: {expected_q} got: {q}" + ); + let expected_r = signed_rem_under_modulus(clear_0, clear_1, modulus); + assert_eq!( + r, expected_r, + "Invalid remainder result, for {clear_0} % {clear_1} \ + expected quotient: {expected_r} got: {r}" + ); + } +} + +pub(crate) fn signed_unchecked_div_rem_floor_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + (SignedRadixCiphertext, SignedRadixCiphertext), + >, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + executor.setup(&cks, sks); + + let mut rng = rand::thread_rng(); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + if modulus > 8 { + // Some hard coded test for flooring div + // For example, truncating_div(-7, 3) would give q = -2 and r = -1 + // truncating div is the default in rust (and many other languages) + // Python does use a flooring div, so you can try these values in you local + // interpreter. + let values = [ + (-8, 3, -3, 1), + (8, -3, -3, -1), + (7, 3, 2, 1), + (-7, 3, -3, 2), + (7, -3, -3, -2), + (-7, -3, 2, -1), + ]; + for (clear_0, clear_1, expected_q, expected_r) in values { + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let (q_res, r_res) = executor.execute((&ctxt_0, &ctxt_1)); + let q: i64 = cks.decrypt_signed(&q_res); + let r: i64 = cks.decrypt_signed(&r_res); + + // Uses the hardcoded values to also test our clear function + let (q2, r2) = signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus); + + assert_eq!(q2, expected_q); + assert_eq!(r2, expected_r); + assert_eq!(q, expected_q); + assert_eq!(r, expected_r); + } + } + + // A test where the division is whole, aka remainder is zero + { + let ctxt_0 = cks.encrypt_signed(4); + let ctxt_1 = cks.encrypt_signed(-2); + + let (q_res, r_res) = executor.execute((&ctxt_0, &ctxt_1)); + let q: i64 = cks.decrypt_signed(&q_res); + let r: i64 = cks.decrypt_signed(&r_res); + + // Uses the hardcoded values to also test our clear function + let (q2, r2) = signed_div_rem_floor_under_modulus(4, -2, modulus); + + assert_eq!(q2, -2); + assert_eq!(r2, 0); + assert_eq!(q, -2); + assert_eq!(r, 0); + } + + // Div is the slowest operation + for _ in 0..5 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = loop { + let value = rng.gen::() % modulus; + if value != 0 { + break value; + } + }; + + let ctxt_0 = cks.encrypt_signed(clear_0); + let ctxt_1 = cks.encrypt_signed(clear_1); + + let (q_res, r_res) = executor.execute((&ctxt_0, &ctxt_1)); + let q: i64 = cks.decrypt_signed(&q_res); + let r: i64 = cks.decrypt_signed(&r_res); + let (expected_q, expected_r) = + signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus); + + println!("{clear_0} / {clear_1} -> ({q}, {r})"); + assert_eq!(q, expected_q); + assert_eq!(r, expected_r); + } +}