Skip to content

Commit

Permalink
chore(ci): use function executor for abs and signed div tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Nov 7, 2024
1 parent f8bde7f commit 5c189d6
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 305 deletions.
308 changes: 3 additions & 305 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod modulus_switch_compression;
pub(crate) mod test_abs;
pub(crate) mod test_add;
pub(crate) mod test_bitwise_op;
pub(crate) mod test_cmux;
pub(crate) mod test_comparison;
mod test_count_zeros_ones;
pub(crate) mod test_div_rem;
pub(crate) mod test_ilog2;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
Expand All @@ -24,7 +26,7 @@ use crate::core_crypto::prelude::SignedInteger;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::{
nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor, MAX_NB_CTXT, NB_CTXT,
nb_tests_for_params, CpuFunctionExecutor, MAX_NB_CTXT, NB_CTXT,
};
use crate::integer::tests::create_parametrized_test;
use crate::integer::{
Expand Down Expand Up @@ -187,306 +189,6 @@ fn integer_signed_encrypt_decrypt(param: impl Into<PBSParameters>) {
}
}

//================================================================================
// Unchecked Tests
//================================================================================

create_parametrized_test!(
integer_signed_unchecked_div_rem {
coverage => {
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
},
no_coverage => {
// Does not support 1_1
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
}
}
);
create_parametrized_test!(
integer_signed_unchecked_div_rem_floor {
coverage => {
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
},
no_coverage => {
// Does not support 1_1
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
}
}
);
create_parametrized_test!(integer_signed_unchecked_absolute_value);

fn integer_signed_unchecked_absolute_value(param: impl Into<PBSParameters>) {
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();

let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

// For signed integers, the range of value is [-modulus..modulus[
// e.g.: for i8, the range is [-128..128[ <=> [-128..127]
// which means -modulus cannot be represented.
//
// In Rust, .abs() / .wrapping_abs() returns MIN (-modulus)
// https://doc.rust-lang.org/std/primitive.i8.html#method.wrapping_abs
//
// Here we test we have same behaviour
//
// (Conveniently, when using Two's complement, casting the result of abs to
// an unsigned to will give correct value for -modulus
// e.g.:(-128i8).wrapping_abs() as u8 == 128
{
let clear_0 = -modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ct_res = sks.unchecked_abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
assert_eq!(dec_res, -modulus);
}

for _ in 0..nb_tests {
let clear_0 = rng.gen::<i64>() % modulus;

let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);

let ct_res = sks.unchecked_abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
let clear_res = absolute_value_under_modulus(clear_0, modulus);
assert_eq!(clear_res, dec_res);
}
}

fn integer_signed_unchecked_div_rem(param: impl Into<PBSParameters>) {
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();

let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

// Test case of division by 0
// This is mainly to show we know the behaviour of division by 0
// using the current algorithm
for clear_0 in [0i64, rng.gen::<i64>() % modulus] {
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ctxt_1 = cks.encrypt_signed_radix(0, NB_CTXT);

let (q_res, r_res) = sks.unchecked_div_rem_parallelized(&ctxt_0, &ctxt_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);

assert_eq!(r, clear_0);
assert_eq!(q, if clear_0 >= 0 { -1 } else { 1 });
}

// Div is the slowest operation
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = loop {
let value = rng.gen::<i64>() % modulus;
if value != 0 {
break value;
}
};

let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT);

let (q_res, r_res) = sks.unchecked_div_rem_parallelized(&ctxt_0, &ctxt_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
let expected_q = signed_div_under_modulus(clear_0, clear_1, modulus);
assert_eq!(
q, expected_q,
"Invalid division result, for {clear_0} / {clear_1} \
expected quotient: {expected_q} got: {q}"
);
let expected_r = signed_rem_under_modulus(clear_0, clear_1, modulus);
assert_eq!(
r, expected_r,
"Invalid remainder result, for {clear_0} % {clear_1} \
expected quotient: {expected_r} got: {r}"
);
}
}

fn integer_signed_unchecked_div_rem_floor(param: impl Into<PBSParameters>) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();

let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

if modulus > 8 {
// Some hard coded test for flooring div
// For example, truncating_div(-7, 3) would give q = -2 and r = -1
// truncating div is the default in rust (and many other languages)
// Python does use a flooring div, so you can try these values in you local
// interpreter.
let values = [
(-8, 3, -3, 1),
(8, -3, -3, -1),
(7, 3, 2, 1),
(-7, 3, -3, 2),
(7, -3, -3, -2),
(-7, -3, 2, -1),
];
for (clear_0, clear_1, expected_q, expected_r) in values {
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT);

let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);

// Uses the hardcoded values to also test our clear function
let (q2, r2) = signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus);

assert_eq!(q2, expected_q);
assert_eq!(r2, expected_r);
assert_eq!(q, expected_q);
assert_eq!(r, expected_r);
}
}

// A test where the division is whole, aka remainder is zero
{
let ctxt_0 = cks.encrypt_signed_radix(4, NB_CTXT);
let ctxt_1 = cks.encrypt_signed_radix(-2, NB_CTXT);

let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);

// Uses the hardcoded values to also test our clear function
let (q2, r2) = signed_div_rem_floor_under_modulus(4, -2, modulus);

assert_eq!(q2, -2);
assert_eq!(r2, 0);
assert_eq!(q, -2);
assert_eq!(r, 0);
}

// Div is the slowest operation
for _ in 0..5 {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = loop {
let value = rng.gen::<i64>() % modulus;
if value != 0 {
break value;
}
};

let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ctxt_1 = cks.encrypt_signed_radix(clear_1, NB_CTXT);

let (q_res, r_res) = sks.unchecked_div_rem_floor_parallelized(&ctxt_0, &ctxt_1);
let q: i64 = cks.decrypt_signed_radix(&q_res);
let r: i64 = cks.decrypt_signed_radix(&r_res);
let (expected_q, expected_r) =
signed_div_rem_floor_under_modulus(clear_0, clear_1, modulus);

println!("{clear_0} / {clear_1} -> ({q}, {r})");
assert_eq!(q, expected_q);
assert_eq!(r, expected_r);
}
}

//================================================================================
// Smart Tests
//================================================================================

create_parametrized_test!(integer_signed_smart_absolute_value);

fn integer_signed_smart_absolute_value(param: impl Into<PBSParameters>) {
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

let mut rng = rand::thread_rng();

let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

{
let clear_0 = -modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ct_res = sks.abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
assert_eq!(dec_res, -modulus);
}

for _ in 0..nb_tests {
let mut clear_0 = rng.gen::<i64>() % modulus;
let clear_to_add = rng.gen::<i64>() % modulus;

let mut ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add);
clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus);

let ct_res = sks.abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
let clear_res = absolute_value_under_modulus(clear_0, modulus);
assert_eq!(clear_res, dec_res);
}
}

//================================================================================
// Default Tests
//================================================================================

create_parametrized_test!(integer_signed_default_absolute_value);

fn integer_signed_default_absolute_value(param: impl Into<PBSParameters>) {
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
sks.set_deterministic_pbs_execution(true);

let mut rng = rand::thread_rng();

let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

{
let clear_0 = -modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ct_res = sks.abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
assert_eq!(dec_res, -modulus);
}

for _ in 0..nb_tests_smaller {
let mut clear_0 = rng.gen::<i64>() % modulus;
let clear_to_add = rng.gen::<i64>() % modulus;

let mut ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
sks.unchecked_scalar_add_assign(&mut ctxt_0, clear_to_add);
clear_0 = signed_add_under_modulus(clear_0, clear_to_add, modulus);

let ct_res = sks.abs_parallelized(&ctxt_0);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
let clear_res = absolute_value_under_modulus(clear_0, modulus);
assert_eq!(clear_res, dec_res);

let ct_res2 = sks.abs_parallelized(&ctxt_0);
assert_eq!(ct_res2, ct_res);
}
}

//================================================================================
// Unchecked Scalar Tests
//================================================================================
Expand Down Expand Up @@ -595,10 +297,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into<PBSParameters>
}
}

//================================================================================
// Smart Scalar Tests
//================================================================================

//================================================================================
// Default Scalar Tests
//================================================================================
Expand Down
Loading

0 comments on commit 5c189d6

Please sign in to comment.