Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integer): add unsigned_oveflowing_add #674

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,94 @@ impl ServerKey {
};

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
self.unchecked_add_assign_parallelized_low_latency(lhs, rhs);
let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, rhs);
} else {
self.unchecked_add_assign(lhs, rhs);
self.full_propagate_parallelized(lhs);
}
}
/// Computes the addition of two unsigned ciphertexts and returns the overflow flag
///
/// # Example
///
/// ```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
///
/// let msg1 = u8::MAX;
/// let msg2 = 1;
///
/// let ct1 = cks.encrypt(msg1);
/// let ct2 = cks.encrypt(msg2);
///
/// let (ct_res, overflowed) = sks.unsigned_overflowing_add_parallelized(&ct1, &ct2);
///
/// // Decrypt:
/// let dec_result: u8 = cks.decrypt(&ct_res);
/// let dec_overflowed = cks.decrypt_one_block(&overflowed);
/// let (expected_result, expected_overflow) = msg1.overflowing_add(msg2);
/// assert_eq!(dec_result, expected_result);
/// assert_eq!(dec_overflowed, u64::from(expected_overflow));
/// ```
pub fn unsigned_overflowing_add_parallelized(
&self,
ct_left: &RadixCiphertext,
ct_right: &RadixCiphertext,
) -> (RadixCiphertext, Ciphertext) {
IceTDrinker marked this conversation as resolved.
Show resolved Hide resolved
let mut ct_res = ct_left.clone();
let overflowed = self.unsigned_overflowing_add_assign_parallelized(&mut ct_res, ct_right);
(ct_res, overflowed)
}

pub fn unsigned_overflowing_add_assign_parallelized(
&self,
ct_left: &mut RadixCiphertext,
ct_right: &RadixCiphertext,
) -> Ciphertext {
let mut tmp_rhs: RadixCiphertext;
if ct_left.blocks.is_empty() || ct_right.blocks.is_empty() {
return self.key.create_trivial(0);
}

let (lhs, rhs) = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.clone();
self.full_propagate_parallelized(&mut tmp_rhs);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_parallelized(ct_left);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.clone();
rayon::join(
|| self.full_propagate_parallelized(ct_left),
|| self.full_propagate_parallelized(&mut tmp_rhs),
);
(ct_left, &tmp_rhs)
}
};

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
self.unchecked_add_assign_parallelized_low_latency(lhs, rhs)
} else {
self.unchecked_add_assign(lhs, rhs);
let len = lhs.blocks.len();
for i in 0..len - 1 {
let _ = self.propagate_parallelized(lhs, i);
}
self.propagate_parallelized(lhs, len - 1)
}
}

pub fn add_parallelized_work_efficient<T>(&self, ct_left: &T, ct_right: &T) -> T
where
Expand Down Expand Up @@ -309,6 +391,9 @@ impl ServerKey {
///
/// At most num_block - 1 threads are used
///
/// Returns the output carry that can be used to check for unsigned addition
/// overflow.
///
/// # Requirements
///
/// - The parameters have 4 bits in total
Expand All @@ -317,7 +402,11 @@ impl ServerKey {
/// # Output
///
/// - lhs will have its carries empty
pub(crate) fn unchecked_add_assign_parallelized_low_latency<T>(&self, lhs: &mut T, rhs: &T)
pub(crate) fn unchecked_add_assign_parallelized_low_latency<T>(
&self,
lhs: &mut T,
rhs: &T,
) -> Ciphertext
where
T: IntegerRadixCiphertext,
{
Expand All @@ -342,12 +431,15 @@ impl ServerKey {
/// - first unchecked_add
/// - at this point at most on bit of carry is taken
/// - use this function to propagate them in parallel
pub(crate) fn propagate_single_carry_parallelized_low_latency<T>(&self, ct: &mut T)
pub(crate) fn propagate_single_carry_parallelized_low_latency<T>(
&self,
ct: &mut T,
) -> Ciphertext
where
T: IntegerRadixCiphertext,
{
let generates_or_propagates = self.generate_init_carry_array(ct);
let (input_carries, _) =
let (input_carries, output_carry) =
self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates);

ct.blocks_mut()
Expand All @@ -357,6 +449,7 @@ impl ServerKey {
self.key.unchecked_add_assign(block, input_carry);
self.key.message_extract_assign(block);
});
output_carry
}

/// Backbone algorithm of parallel carry (only one bit) propagation
Expand Down
12 changes: 9 additions & 3 deletions tfhe/src/integer/server_key/radix_parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ impl ServerKey {
/// let res: u64 = cks.decrypt_one_block(&ct_res.blocks()[1]);
/// assert_eq!(3, res);
/// ```
pub fn propagate_parallelized<T>(&self, ctxt: &mut T, index: usize)
pub fn propagate_parallelized<T>(
&self,
ctxt: &mut T,
index: usize,
) -> crate::shortint::Ciphertext
where
T: IntegerRadixCiphertext,
{
Expand All @@ -77,6 +81,8 @@ impl ServerKey {
self.key
.unchecked_add_assign(&mut ctxt.blocks_mut()[index + 1], &carry);
}

carry
}

pub fn partial_propagate_parallelized<T>(&self, ctxt: &mut T, start_index: usize)
Expand Down Expand Up @@ -107,11 +113,11 @@ impl ServerKey {

ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks);
let carries = T::from_blocks(carry_blocks);
self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries);
let _ = self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries);
} else {
let len = ctxt.blocks().len();
for i in start_index..len {
self.propagate_parallelized(ctxt, i);
let _ = self.propagate_parallelized(ctxt, i);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/neg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
let mut ct = self.unchecked_neg(ct);
self.propagate_single_carry_parallelized_low_latency(&mut ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(&mut ct);
ct
} else {
let mut ct = self.unchecked_neg(ct);
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
self.unchecked_scalar_add_assign(ct, scalar);
self.propagate_single_carry_parallelized_low_latency(ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(ct);
} else {
self.unchecked_scalar_add_assign(ct, scalar);
self.full_propagate_parallelized(ct);
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl ServerKey {
self.unchecked_scalar_sub_assign(ct, scalar);

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
self.propagate_single_carry_parallelized_low_latency(ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(ct);
} else {
self.full_propagate_parallelized(ct);
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
let neg = self.unchecked_neg(rhs);
self.unchecked_add_assign_parallelized_low_latency(lhs, &neg);
let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, &neg);
} else {
self.unchecked_sub_assign(lhs, rhs);
self.full_propagate_parallelized(lhs);
Expand Down
136 changes: 131 additions & 5 deletions tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,17 @@ fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 {
}

fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) {
let result = lhs.wrapping_sub(rhs);
// Technically using a div is not the fastest way to check for overflow,
// but as we have to do the remainder regardless, that /% should be one instruction
let (q, r) = (result / modulus, result % modulus);
assert!(
!(modulus.is_power_of_two() && (modulus - 1).overflowing_mul(2).1),
"If modulus is not a power of two, then must not overflow u64"
);
let (result, overflowed) = lhs.overflowing_sub(rhs);
(result % modulus, overflowed)
}

(r, q != 0)
fn overflowing_add_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) {
let (result, overflowed) = lhs.overflowing_add(rhs);
(result % modulus, overflowed || result >= modulus)
}

/// This trait is to be implemented by a struct that is capable
Expand Down Expand Up @@ -1771,6 +1776,127 @@ where
}
}

pub(crate) fn default_overflowing_add_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, Ciphertext),
>,
{
let (cks, mut sks) = KEY_CACHE.get_from_params(param);
let cks = RadixClientKey::from((cks, NB_CTXT));

sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;

executor.setup(&cks, sks.clone());

for _ in 0..NB_TEST_SMALLER {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;

let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert!(result_overflowed.carry_is_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);

for _ in 0..NB_TEST_SMALLER {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let (clear_lhs, _) = overflowing_add_under_modulus(clear_0, clear_2, modulus);
let (clear_rhs, _) = overflowing_add_under_modulus(clear_1, clear_3, modulus);

let d0: u64 = cks.decrypt(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: u64 = cks.decrypt(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert!(result_overflowed.carry_is_empty());

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus);

let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}

// Test with trivial inputs
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;

let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);

let (encrypted_result, encrypted_overflow) =
sks.unsigned_overflowing_add_parallelized(&a, &b);

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let decrypted_overflowed = cks.decrypt_one_block(&encrypted_overflow) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}

pub(crate) fn default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
Expand Down
9 changes: 9 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ create_parametrized_test!(integer_smart_add);
create_parametrized_test!(integer_smart_add_sequence_multi_thread);
create_parametrized_test!(integer_smart_add_sequence_single_thread);
create_parametrized_test!(integer_default_add);
create_parametrized_test!(integer_default_overflowing_add);
create_parametrized_test!(integer_default_add_work_efficient {
// This algorithm requires 3 bits
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Expand Down Expand Up @@ -717,6 +718,14 @@ where
default_add_test(param, executor);
}

fn integer_default_overflowing_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
default_overflowing_add_test(param, executor);
}

fn integer_default_sub<P>(param: P)
where
P: Into<PBSParameters>,
Expand Down
Loading