Skip to content

Commit

Permalink
fix: parallel overflow flag on single block
Browse files Browse the repository at this point in the history
Fix a bug in the overflow flag computation in the
parallel algorithm when input only had one block.
It was due to the first block not having a proper propagation
simulator

Add to the tests explicitly the parallel and sequential
versions of the signed_overflowing_sub to be sure they are
both tested regardless of cpu thread count and block count
  • Loading branch information
tmontaigu committed Nov 22, 2024
1 parent c3def17 commit 68016da
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 24 deletions.
76 changes: 56 additions & 20 deletions tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ use crate::shortint::ciphertext::Degree;
use crate::shortint::Ciphertext;
use rayon::prelude::*;

#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub(crate) enum CarryPropagationAlgorithm {
Sequential,
Parallel,
Automatic,
}
/// Possible output flag that the advanced_add_assign_with_carry family of
/// functions can compute.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -448,6 +454,7 @@ impl ServerKey {
rhs.blocks(),
input_carry,
OutputFlag::None,
CarryPropagationAlgorithm::Automatic,
);
}

Expand All @@ -468,6 +475,7 @@ impl ServerKey {
rhs.blocks(),
input_carry,
OutputFlag::from_signedness(T::IS_SIGNED),
CarryPropagationAlgorithm::Automatic,
)
.expect("internal error, overflow computation was not returned as was requested")
}
Expand All @@ -493,21 +501,43 @@ impl ServerKey {
rhs: &[Ciphertext],
input_carry: Option<&BooleanBlock>,
requested_flag: OutputFlag,
mut algorithm: CarryPropagationAlgorithm,
) -> Option<BooleanBlock> {
if self.is_eligible_for_parallel_single_carry_propagation(lhs.len()) {
self.advanced_add_assign_with_carry_at_least_4_bits(
lhs,
rhs,
input_carry,
requested_flag,
)
} else {
self.advanced_add_assign_with_carry_sequential_parallelized(
lhs,
rhs,
input_carry,
requested_flag,
)
// having 4-bits is a hard requirement
// So to protect against bad carry prop choice we do this check
let total_modulus = self.key.message_modulus.0 * self.key.carry_modulus.0;
let has_enough_bits_per_block = total_modulus >= (1 << 4);
if !has_enough_bits_per_block {
algorithm = CarryPropagationAlgorithm::Sequential;
}

if algorithm == CarryPropagationAlgorithm::Automatic {
if should_parallel_propagation_be_faster(
self.message_modulus().0 * self.carry_modulus().0,
lhs.len(),
rayon::current_num_threads(),
) {
algorithm = CarryPropagationAlgorithm::Parallel;
} else {
algorithm = CarryPropagationAlgorithm::Sequential
}
}
match algorithm {
CarryPropagationAlgorithm::Parallel => self
.advanced_add_assign_with_carry_at_least_4_bits(
lhs,
rhs,
input_carry,
requested_flag,
),
CarryPropagationAlgorithm::Sequential => self
.advanced_add_assign_with_carry_sequential_parallelized(
lhs,
rhs,
input_carry,
requested_flag,
),
CarryPropagationAlgorithm::Automatic => unreachable!(),
}
}

Expand Down Expand Up @@ -969,10 +999,16 @@ impl ServerKey {
rayon::join(
|| {
let block = output_flag.as_mut().unwrap();
self.key.unchecked_add_assign(
block,
&resolved_carries[resolved_carries.len() - 1],
);
// When num block is 1, we have to use the input carry
// given by the caller
let carry_into_last_block = input_carry
.as_ref()
.filter(|_| num_blocks == 1)
.map_or_else(
|| &resolved_carries[resolved_carries.len() - 1],
|input_carry| &input_carry.0,
);
self.key.unchecked_add_assign(block, carry_into_last_block);
self.key
.apply_lookup_table_assign(block, &overflow_flag_lut);
},
Expand Down Expand Up @@ -1013,7 +1049,7 @@ impl ServerKey {
) -> (Vec<Ciphertext>, Vec<Ciphertext>) {
if block_states.is_empty() {
return (
vec![self.key.create_trivial(0)],
vec![self.key.create_trivial(1)],
vec![self.key.create_trivial(0)],
);
}
Expand Down Expand Up @@ -1175,7 +1211,7 @@ impl ServerKey {
let mut propagation_simulators = Vec::with_capacity(num_blocks);

// First block does not get a carry from
propagation_simulators.push(self.key.create_trivial(0));
propagation_simulators.push(self.key.create_trivial(1));
for block in propagation_cum_sums.drain(..) {
if propagation_simulators.len() % grouping_size == 0 {
groupings_pgns.push(block);
Expand Down
45 changes: 41 additions & 4 deletions tfhe/src/integer/server_key/radix_parallel/sub.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::radix_parallel::add::CarryPropagationAlgorithm;
use crate::integer::server_key::radix_parallel::OutputFlag;
use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext};
use crate::integer::{
BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey, SignedRadixCiphertext,
};
use crate::shortint::Ciphertext;
use rayon::prelude::*;

Expand Down Expand Up @@ -218,6 +221,7 @@ impl ServerKey {
neg.blocks(),
None,
OutputFlag::None,
CarryPropagationAlgorithm::Automatic,
);
}

Expand Down Expand Up @@ -654,6 +658,19 @@ impl ServerKey {
&self,
lhs: &SignedRadixCiphertext,
rhs: &SignedRadixCiphertext,
) -> (SignedRadixCiphertext, BooleanBlock) {
self.unchecked_signed_overflowing_sub_parallelized_with_choice(
lhs,
rhs,
CarryPropagationAlgorithm::Automatic,
)
}

pub(crate) fn unchecked_signed_overflowing_sub_parallelized_with_choice(
&self,
lhs: &SignedRadixCiphertext,
rhs: &SignedRadixCiphertext,
algorithm: CarryPropagationAlgorithm,
) -> (SignedRadixCiphertext, BooleanBlock) {
assert_eq!(
lhs.blocks.len(),
Expand All @@ -675,8 +692,15 @@ impl ServerKey {
let flipped_rhs = self.bitnot(rhs);
let input_carry = self.create_trivial_boolean_block(true);
let mut result = lhs.clone();
let overflowed =
self.overflowing_add_assign_with_carry(&mut result, &flipped_rhs, Some(&input_carry));
let overflowed = self
.advanced_add_assign_with_carry_parallelized(
result.blocks_mut(),
flipped_rhs.blocks(),
Some(&input_carry),
OutputFlag::Overflow,
algorithm,
)
.expect("internal error, overflow computation was not returned as was requested");
(result, overflowed)
}

Expand Down Expand Up @@ -714,6 +738,19 @@ impl ServerKey {
&self,
ctxt_left: &SignedRadixCiphertext,
ctxt_right: &SignedRadixCiphertext,
) -> (SignedRadixCiphertext, BooleanBlock) {
self.signed_overflowing_sub_parallelized_with_choice(
ctxt_left,
ctxt_right,
CarryPropagationAlgorithm::Automatic,
)
}

pub(crate) fn signed_overflowing_sub_parallelized_with_choice(
&self,
ctxt_left: &SignedRadixCiphertext,
ctxt_right: &SignedRadixCiphertext,
algorithm: CarryPropagationAlgorithm,
) -> (SignedRadixCiphertext, BooleanBlock) {
let mut tmp_lhs;
let mut tmp_rhs;
Expand Down Expand Up @@ -744,6 +781,6 @@ impl ServerKey {
}
};

self.unchecked_signed_overflowing_sub_parallelized(lhs, rhs)
self.unchecked_signed_overflowing_sub_parallelized_with_choice(lhs, rhs, algorithm)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::add::CarryPropagationAlgorithm;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_signed::{
create_iterator_of_signed_random_pairs, random_non_zero_value, signed_add_under_modulus,
Expand All @@ -23,6 +24,22 @@ create_parametrized_test!(integer_signed_unchecked_overflowing_sub);
create_parametrized_test!(integer_signed_default_sub);
create_parametrized_test!(integer_extensive_trivial_signed_default_sub);
create_parametrized_test!(integer_signed_default_overflowing_sub);
create_parametrized_test!(integer_signed_default_overflowing_sub_sequential);
create_parametrized_test!(integer_signed_default_overflowing_sub_parallel {
coverage => {
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
},
no_coverage => {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64
}
});
create_parametrized_test!(integer_extensive_trivial_signed_default_overflowing_sub);

fn integer_signed_unchecked_sub<P>(param: P)
Expand Down Expand Up @@ -73,6 +90,36 @@ where
signed_default_overflowing_sub_test(param, executor);
}

fn integer_signed_default_overflowing_sub_sequential<P>(param: P)
where
P: Into<PBSParameters>,
{
let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| {
sks.signed_overflowing_sub_parallelized_with_choice(
lhs,
rhs,
CarryPropagationAlgorithm::Sequential,
)
};
let executor = CpuFunctionExecutor::new(&func);
signed_default_overflowing_sub_test(param, executor);
}

fn integer_signed_default_overflowing_sub_parallel<P>(param: P)
where
P: Into<PBSParameters>,
{
let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| {
sks.signed_overflowing_sub_parallelized_with_choice(
lhs,
rhs,
CarryPropagationAlgorithm::Parallel,
)
};
let executor = CpuFunctionExecutor::new(&func);
signed_default_overflowing_sub_test(param, executor);
}

pub(crate) fn signed_default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
Expand Down

0 comments on commit 68016da

Please sign in to comment.