diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 71516125f3..ba7b942820 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -5,6 +5,12 @@ use crate::shortint::ciphertext::Degree; use crate::shortint::Ciphertext; use rayon::prelude::*; +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub(crate) enum CarryPropagationAlgorithm { + Sequential, + Parallel, + Automatic, +} /// Possible output flag that the advanced_add_assign_with_carry family of /// functions can compute. #[derive(Copy, Clone, PartialEq, Eq, Debug)] @@ -448,6 +454,7 @@ impl ServerKey { rhs.blocks(), input_carry, OutputFlag::None, + CarryPropagationAlgorithm::Automatic, ); } @@ -468,6 +475,7 @@ impl ServerKey { rhs.blocks(), input_carry, OutputFlag::from_signedness(T::IS_SIGNED), + CarryPropagationAlgorithm::Automatic, ) .expect("internal error, overflow computation was not returned as was requested") } @@ -493,21 +501,43 @@ impl ServerKey { rhs: &[Ciphertext], input_carry: Option<&BooleanBlock>, requested_flag: OutputFlag, + mut algorithm: CarryPropagationAlgorithm, ) -> Option { - if self.is_eligible_for_parallel_single_carry_propagation(lhs.len()) { - self.advanced_add_assign_with_carry_at_least_4_bits( - lhs, - rhs, - input_carry, - requested_flag, - ) - } else { - self.advanced_add_assign_with_carry_sequential_parallelized( - lhs, - rhs, - input_carry, - requested_flag, - ) + // having 4-bits is a hard requirement + // So to protect against bad carry prop choice we do this check + let total_modulus = self.key.message_modulus.0 * self.key.carry_modulus.0; + let has_enough_bits_per_block = total_modulus >= (1 << 4); + if !has_enough_bits_per_block { + algorithm = CarryPropagationAlgorithm::Sequential; + } + + if algorithm == CarryPropagationAlgorithm::Automatic { + if should_parallel_propagation_be_faster( + self.message_modulus().0 * self.carry_modulus().0, + lhs.len(), + rayon::current_num_threads(), + ) { + algorithm = CarryPropagationAlgorithm::Parallel; + } else { + algorithm = CarryPropagationAlgorithm::Sequential + } + } + match algorithm { + CarryPropagationAlgorithm::Parallel => self + .advanced_add_assign_with_carry_at_least_4_bits( + lhs, + rhs, + input_carry, + requested_flag, + ), + CarryPropagationAlgorithm::Sequential => self + .advanced_add_assign_with_carry_sequential_parallelized( + lhs, + rhs, + input_carry, + requested_flag, + ), + CarryPropagationAlgorithm::Automatic => unreachable!(), } } @@ -969,10 +999,16 @@ impl ServerKey { rayon::join( || { let block = output_flag.as_mut().unwrap(); - self.key.unchecked_add_assign( - block, - &resolved_carries[resolved_carries.len() - 1], - ); + // When num block is 1, we have to use the input carry + // given by the caller + let carry_into_last_block = input_carry + .as_ref() + .filter(|_| num_blocks == 1) + .map_or_else( + || &resolved_carries[resolved_carries.len() - 1], + |input_carry| &input_carry.0, + ); + self.key.unchecked_add_assign(block, carry_into_last_block); self.key .apply_lookup_table_assign(block, &overflow_flag_lut); }, @@ -1013,7 +1049,7 @@ impl ServerKey { ) -> (Vec, Vec) { if block_states.is_empty() { return ( - vec![self.key.create_trivial(0)], + vec![self.key.create_trivial(1)], vec![self.key.create_trivial(0)], ); } @@ -1175,7 +1211,7 @@ impl ServerKey { let mut propagation_simulators = Vec::with_capacity(num_blocks); // First block does not get a carry from - propagation_simulators.push(self.key.create_trivial(0)); + propagation_simulators.push(self.key.create_trivial(1)); for block in propagation_cum_sums.drain(..) { if propagation_simulators.len() % grouping_size == 0 { groupings_pgns.push(block); diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index 2474a5f32e..1ee94f88ca 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -1,6 +1,9 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; +use crate::integer::server_key::radix_parallel::add::CarryPropagationAlgorithm; use crate::integer::server_key::radix_parallel::OutputFlag; -use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; +use crate::integer::{ + BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey, SignedRadixCiphertext, +}; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -218,6 +221,7 @@ impl ServerKey { neg.blocks(), None, OutputFlag::None, + CarryPropagationAlgorithm::Automatic, ); } @@ -654,6 +658,19 @@ impl ServerKey { &self, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext, + ) -> (SignedRadixCiphertext, BooleanBlock) { + self.unchecked_signed_overflowing_sub_parallelized_with_choice( + lhs, + rhs, + CarryPropagationAlgorithm::Automatic, + ) + } + + pub(crate) fn unchecked_signed_overflowing_sub_parallelized_with_choice( + &self, + lhs: &SignedRadixCiphertext, + rhs: &SignedRadixCiphertext, + algorithm: CarryPropagationAlgorithm, ) -> (SignedRadixCiphertext, BooleanBlock) { assert_eq!( lhs.blocks.len(), @@ -675,8 +692,15 @@ impl ServerKey { let flipped_rhs = self.bitnot(rhs); let input_carry = self.create_trivial_boolean_block(true); let mut result = lhs.clone(); - let overflowed = - self.overflowing_add_assign_with_carry(&mut result, &flipped_rhs, Some(&input_carry)); + let overflowed = self + .advanced_add_assign_with_carry_parallelized( + result.blocks_mut(), + flipped_rhs.blocks(), + Some(&input_carry), + OutputFlag::Overflow, + algorithm, + ) + .expect("internal error, overflow computation was not returned as was requested"); (result, overflowed) } @@ -714,6 +738,19 @@ impl ServerKey { &self, ctxt_left: &SignedRadixCiphertext, ctxt_right: &SignedRadixCiphertext, + ) -> (SignedRadixCiphertext, BooleanBlock) { + self.signed_overflowing_sub_parallelized_with_choice( + ctxt_left, + ctxt_right, + CarryPropagationAlgorithm::Automatic, + ) + } + + pub(crate) fn signed_overflowing_sub_parallelized_with_choice( + &self, + ctxt_left: &SignedRadixCiphertext, + ctxt_right: &SignedRadixCiphertext, + algorithm: CarryPropagationAlgorithm, ) -> (SignedRadixCiphertext, BooleanBlock) { let mut tmp_lhs; let mut tmp_rhs; @@ -744,6 +781,6 @@ impl ServerKey { } }; - self.unchecked_signed_overflowing_sub_parallelized(lhs, rhs) + self.unchecked_signed_overflowing_sub_parallelized_with_choice(lhs, rhs, algorithm) } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs index 608b674505..26657e81e5 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs @@ -1,4 +1,5 @@ use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::add::CarryPropagationAlgorithm; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_signed::{ create_iterator_of_signed_random_pairs, random_non_zero_value, signed_add_under_modulus, @@ -23,6 +24,22 @@ create_parametrized_test!(integer_signed_unchecked_overflowing_sub); create_parametrized_test!(integer_signed_default_sub); create_parametrized_test!(integer_extensive_trivial_signed_default_sub); create_parametrized_test!(integer_signed_default_overflowing_sub); +create_parametrized_test!(integer_signed_default_overflowing_sub_sequential); +create_parametrized_test!(integer_signed_default_overflowing_sub_parallel { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64 + } +}); create_parametrized_test!(integer_extensive_trivial_signed_default_overflowing_sub); fn integer_signed_unchecked_sub

(param: P) @@ -73,6 +90,36 @@ where signed_default_overflowing_sub_test(param, executor); } +fn integer_signed_default_overflowing_sub_sequential

(param: P) +where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| { + sks.signed_overflowing_sub_parallelized_with_choice( + lhs, + rhs, + CarryPropagationAlgorithm::Sequential, + ) + }; + let executor = CpuFunctionExecutor::new(&func); + signed_default_overflowing_sub_test(param, executor); +} + +fn integer_signed_default_overflowing_sub_parallel

(param: P) +where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| { + sks.signed_overflowing_sub_parallelized_with_choice( + lhs, + rhs, + CarryPropagationAlgorithm::Parallel, + ) + }; + let executor = CpuFunctionExecutor::new(&func); + signed_default_overflowing_sub_test(param, executor); +} + pub(crate) fn signed_default_overflowing_sub_test(param: P, mut executor: T) where P: Into,