From 72f1aee69e83faf5c20f77e18ed6b512d492f39f Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 8 Nov 2023 19:21:56 +0100 Subject: [PATCH] feat(integer): add unsigned_oveflowing_add --- .../integer/server_key/radix_parallel/add.rs | 101 ++++++++++++- .../integer/server_key/radix_parallel/mod.rs | 12 +- .../integer/server_key/radix_parallel/neg.rs | 2 +- .../server_key/radix_parallel/scalar_add.rs | 2 +- .../server_key/radix_parallel/scalar_sub.rs | 2 +- .../integer/server_key/radix_parallel/sub.rs | 2 +- .../radix_parallel/tests_cases_unsigned.rs | 136 +++++++++++++++++- .../radix_parallel/tests_unsigned.rs | 9 ++ 8 files changed, 250 insertions(+), 16 deletions(-) diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 719c2b64d8..519aaccdbe 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -232,12 +232,94 @@ impl ServerKey { }; if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - self.unchecked_add_assign_parallelized_low_latency(lhs, rhs); + let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, rhs); } else { self.unchecked_add_assign(lhs, rhs); self.full_propagate_parallelized(lhs); } } + /// Computes the addition of two unsigned ciphertexts and returns the overflow flag + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg1 = u8::MAX; + /// let msg2 = 1; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let (ct_res, overflowed) = sks.unsigned_overflowing_add_parallelized(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec_result: u8 = cks.decrypt(&ct_res); + /// let dec_overflowed = cks.decrypt_one_block(&overflowed); + /// let (expected_result, expected_overflow) = msg1.overflowing_add(msg2); + /// assert_eq!(dec_result, expected_result); + /// assert_eq!(dec_overflowed, u64::from(expected_overflow)); + /// ``` + pub fn unsigned_overflowing_add_parallelized( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> (RadixCiphertext, Ciphertext) { + let mut ct_res = ct_left.clone(); + let overflowed = self.unsigned_overflowing_add_assign_parallelized(&mut ct_res, ct_right); + (ct_res, overflowed) + } + + pub fn unsigned_overflowing_add_assign_parallelized( + &self, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> Ciphertext { + let mut tmp_rhs: RadixCiphertext; + if ct_left.blocks.is_empty() || ct_right.blocks.is_empty() { + return self.key.create_trivial(0); + } + + let (lhs, rhs) = match ( + ct_left.block_carries_are_empty(), + ct_right.block_carries_are_empty(), + ) { + (true, true) => (ct_left, ct_right), + (true, false) => { + tmp_rhs = ct_right.clone(); + self.full_propagate_parallelized(&mut tmp_rhs); + (ct_left, &tmp_rhs) + } + (false, true) => { + self.full_propagate_parallelized(ct_left); + (ct_left, ct_right) + } + (false, false) => { + tmp_rhs = ct_right.clone(); + rayon::join( + || self.full_propagate_parallelized(ct_left), + || self.full_propagate_parallelized(&mut tmp_rhs), + ); + (ct_left, &tmp_rhs) + } + }; + + if self.is_eligible_for_parallel_single_carry_propagation(lhs) { + self.unchecked_add_assign_parallelized_low_latency(lhs, rhs) + } else { + self.unchecked_add_assign(lhs, rhs); + let len = lhs.blocks.len(); + for i in 0..len - 1 { + let _ = self.propagate_parallelized(lhs, i); + } + self.propagate_parallelized(lhs, len - 1) + } + } pub fn add_parallelized_work_efficient(&self, ct_left: &T, ct_right: &T) -> T where @@ -309,6 +391,9 @@ impl ServerKey { /// /// At most num_block - 1 threads are used /// + /// Returns the output carry that can be used to check for unsigned addition + /// overflow. + /// /// # Requirements /// /// - The parameters have 4 bits in total @@ -317,7 +402,11 @@ impl ServerKey { /// # Output /// /// - lhs will have its carries empty - pub(crate) fn unchecked_add_assign_parallelized_low_latency(&self, lhs: &mut T, rhs: &T) + pub(crate) fn unchecked_add_assign_parallelized_low_latency( + &self, + lhs: &mut T, + rhs: &T, + ) -> Ciphertext where T: IntegerRadixCiphertext, { @@ -342,12 +431,15 @@ impl ServerKey { /// - first unchecked_add /// - at this point at most on bit of carry is taken /// - use this function to propagate them in parallel - pub(crate) fn propagate_single_carry_parallelized_low_latency(&self, ct: &mut T) + pub(crate) fn propagate_single_carry_parallelized_low_latency( + &self, + ct: &mut T, + ) -> Ciphertext where T: IntegerRadixCiphertext, { let generates_or_propagates = self.generate_init_carry_array(ct); - let (input_carries, _) = + let (input_carries, output_carry) = self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates); ct.blocks_mut() @@ -357,6 +449,7 @@ impl ServerKey { self.key.unchecked_add_assign(block, input_carry); self.key.message_extract_assign(block); }); + output_carry } /// Backbone algorithm of parallel carry (only one bit) propagation diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 248c7e02d7..17aac98be7 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -62,7 +62,11 @@ impl ServerKey { /// let res: u64 = cks.decrypt_one_block(&ct_res.blocks()[1]); /// assert_eq!(3, res); /// ``` - pub fn propagate_parallelized(&self, ctxt: &mut T, index: usize) + pub fn propagate_parallelized( + &self, + ctxt: &mut T, + index: usize, + ) -> crate::shortint::Ciphertext where T: IntegerRadixCiphertext, { @@ -77,6 +81,8 @@ impl ServerKey { self.key .unchecked_add_assign(&mut ctxt.blocks_mut()[index + 1], &carry); } + + carry } pub fn partial_propagate_parallelized(&self, ctxt: &mut T, start_index: usize) @@ -107,11 +113,11 @@ impl ServerKey { ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks); let carries = T::from_blocks(carry_blocks); - self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries); + let _ = self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries); } else { let len = ctxt.blocks().len(); for i in start_index..len { - self.propagate_parallelized(ctxt, i); + let _ = self.propagate_parallelized(ctxt, i); } } } diff --git a/tfhe/src/integer/server_key/radix_parallel/neg.rs b/tfhe/src/integer/server_key/radix_parallel/neg.rs index 1e282d011b..ddfe96f4fc 100644 --- a/tfhe/src/integer/server_key/radix_parallel/neg.rs +++ b/tfhe/src/integer/server_key/radix_parallel/neg.rs @@ -90,7 +90,7 @@ impl ServerKey { if self.is_eligible_for_parallel_single_carry_propagation(ct) { let mut ct = self.unchecked_neg(ct); - self.propagate_single_carry_parallelized_low_latency(&mut ct); + let _carry = self.propagate_single_carry_parallelized_low_latency(&mut ct); ct } else { let mut ct = self.unchecked_neg(ct); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs index c7ba6392ae..89b233918e 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -170,7 +170,7 @@ impl ServerKey { if self.is_eligible_for_parallel_single_carry_propagation(ct) { self.unchecked_scalar_add_assign(ct, scalar); - self.propagate_single_carry_parallelized_low_latency(ct); + let _carry = self.propagate_single_carry_parallelized_low_latency(ct); } else { self.unchecked_scalar_add_assign(ct, scalar); self.full_propagate_parallelized(ct); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index 1c62c5846d..1b92f56cfc 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -107,7 +107,7 @@ impl ServerKey { self.unchecked_scalar_sub_assign(ct, scalar); if self.is_eligible_for_parallel_single_carry_propagation(ct) { - self.propagate_single_carry_parallelized_low_latency(ct); + let _carry = self.propagate_single_carry_parallelized_low_latency(ct); } else { self.full_propagate_parallelized(ct); } diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index f63e56dc58..bc0fcd036e 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -227,7 +227,7 @@ impl ServerKey { if self.is_eligible_for_parallel_single_carry_propagation(lhs) { let neg = self.unchecked_neg(rhs); - self.unchecked_add_assign_parallelized_low_latency(lhs, &neg); + let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, &neg); } else { self.unchecked_sub_assign(lhs, rhs); self.full_propagate_parallelized(lhs); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index d578aa0c37..ac94c7be46 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -65,12 +65,17 @@ fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 { } fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) { - let result = lhs.wrapping_sub(rhs); - // Technically using a div is not the fastest way to check for overflow, - // but as we have to do the remainder regardless, that /% should be one instruction - let (q, r) = (result / modulus, result % modulus); + assert!( + !(modulus.is_power_of_two() && (modulus - 1).overflowing_mul(2).1), + "If modulus is not a power of two, then must not overflow u64" + ); + let (result, overflowed) = lhs.overflowing_sub(rhs); + (result % modulus, overflowed) +} - (r, q != 0) +fn overflowing_add_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) { + let (result, overflowed) = lhs.overflowing_add(rhs); + (result % modulus, overflowed || result >= modulus) } /// This trait is to be implemented by a struct that is capable @@ -1771,6 +1776,127 @@ where } } +pub(crate) fn default_overflowing_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, Ciphertext), + >, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..NB_TEST_SMALLER { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = cks.encrypt(clear_0); + let ctxt_1 = cks.encrypt(clear_1); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert!(result_overflowed.carry_is_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); + + let (expected_result, expected_overflowed) = + overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: u64 = cks.decrypt(&ct_res); + let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + + for _ in 0..NB_TEST_SMALLER { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); + + let (clear_lhs, _) = overflowing_add_under_modulus(clear_0, clear_2, modulus); + let (clear_rhs, _) = overflowing_add_under_modulus(clear_1, clear_3, modulus); + + let d0: u64 = cks.decrypt(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + let d1: u64 = cks.decrypt(&ctxt_1); + assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + assert!(result_overflowed.carry_is_empty()); + + let (expected_result, expected_overflowed) = + overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: u64 = cks.decrypt(&ct_res); + let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + } + } + + // Test with trivial inputs + for _ in 0..4 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT); + let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT); + + let (encrypted_result, encrypted_overflow) = + sks.unsigned_overflowing_add_parallelized(&a, &b); + + let (expected_result, expected_overflowed) = + overflowing_add_under_modulus(clear_0, clear_1, modulus); + + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let decrypted_overflowed = cks.decrypt_one_block(&encrypted_overflow) == 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + } +} + pub(crate) fn default_overflowing_sub_test(param: P, mut executor: T) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index fc92fa489b..44f9bc10ac 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -110,6 +110,7 @@ create_parametrized_test!(integer_smart_add); create_parametrized_test!(integer_smart_add_sequence_multi_thread); create_parametrized_test!(integer_smart_add_sequence_single_thread); create_parametrized_test!(integer_default_add); +create_parametrized_test!(integer_default_overflowing_add); create_parametrized_test!(integer_default_add_work_efficient { // This algorithm requires 3 bits PARAM_MESSAGE_2_CARRY_2_KS_PBS, @@ -717,6 +718,14 @@ where default_add_test(param, executor); } +fn integer_default_overflowing_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized); + default_overflowing_add_test(param, executor); +} + fn integer_default_sub

(param: P) where P: Into,