Skip to content

Commit

Permalink
refactor(integer): sum by columns in overflowing_sum_parallelized
Browse files Browse the repository at this point in the history
At some point, the sum was refactored to by written reasing
in columns not rows, and it simplified the code and helped gain
some performance.

The overflowing version was not reworked, until this commit
  • Loading branch information
tmontaigu committed Oct 16, 2024
1 parent 4cd8a9c commit 96571ba
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl ServerKey {
self.create_trivial_radix(-i64::from(max_possible_bit_count), num_signed_blocks),
);
let result = self
.unchecked_partial_sum_ciphertexts_vec_parallelized(things_to_sum)
.unchecked_partial_sum_ciphertexts_vec_parallelized(things_to_sum, None)
.expect("internal error, empty ciphertext count");
let (message_blocks, carry_blocks) = rayon::join(
|| {
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/ilog2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl ServerKey {
);

let result = self
.unchecked_partial_sum_ciphertexts_vec_parallelized(cts)
.unchecked_partial_sum_ciphertexts_vec_parallelized(cts, None)
.expect("internal error, empty ciphertext count");

// This is the part where we extract message and carry blocks
Expand Down
177 changes: 18 additions & 159 deletions tfhe/src/integer/server_key/radix_parallel/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::{BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey};
use crate::shortint::Ciphertext;
use rayon::prelude::*;
use std::ops::RangeInclusive;

impl ServerKey {
/// Computes the sum of the ciphertexts in parallel.
///
/// output_carries: if not None, carries generated by last blocks will
/// be stored in it.
///
/// Returns a result that has non propagated carries
pub(crate) fn unchecked_partial_sum_ciphertexts_vec_parallelized<T>(
&self,
terms: Vec<T>,
mut output_carries: Option<&mut Vec<Ciphertext>>,
) -> Option<T>
where
T: IntegerRadixCiphertext,
Expand Down Expand Up @@ -89,7 +92,7 @@ impl ServerKey {
self.key.unchecked_add_assign(&mut result, c);
}

if column_index < num_columns - 1 {
if (column_index < num_columns - 1) || output_carries.is_some() {
rayon::join(
|| self.key.message_extract(&result),
|| Some(self.key.carry_extract(&result)),
Expand All @@ -114,8 +117,12 @@ impl ServerKey {
for (msg, maybe_carry) in column_output.drain(..) {
columns[i].push(msg);

if let (Some(carry), true) = (maybe_carry, (i + 1) < columns.len()) {
columns[i + 1].push(carry);
if let Some(carry) = maybe_carry {
if (i + 1) < columns.len() {
columns[i + 1].push(carry);
} else if let Some(ref mut out) = output_carries {
out.push(carry);
}
}
}
}
Expand Down Expand Up @@ -152,7 +159,8 @@ impl ServerKey {
where
T: IntegerRadixCiphertext,
{
let mut result = self.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts)?;
let mut result =
self.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts, None)?;

self.full_propagate_parallelized(&mut result);
assert!(result.block_carries_are_empty());
Expand Down Expand Up @@ -215,51 +223,6 @@ impl ServerKey {
self.unchecked_sum_ciphertexts_parallelized(ciphertexts.as_ref())
}

/// This sums all ciphertext contained in the chunk into the first element of the chunk
/// i.e: [A, B, C] -> [A + B + C, B, C]
/// and returns the inclusive range indicating the range of blocks which where addition were
/// made that is, if the ciphertexts contains trailing (end or start) trivial zeros some
/// addition will be skipped (as adding a bunch of zeros is not useful)
fn unchecked_sum_ciphertext_chunk<T>(&self, chunk: &mut [T]) -> RangeInclusive<usize>
where
T: IntegerRadixCiphertext,
{
assert_ne!(chunk.len(), 0);
if chunk.len() <= 1 {
return 0..=0;
}
let num_blocks = chunk[0].blocks().len();
let (s, rest) = chunk.split_first_mut().unwrap();
let mut first_block_where_addition_happened = num_blocks - 1;
let mut last_block_where_addition_happened = 0;
for a in rest.iter() {
let first_block_to_add = a
.blocks()
.iter()
.position(|block| block.degree.get() != 0)
.unwrap_or(num_blocks);
first_block_where_addition_happened =
first_block_where_addition_happened.min(first_block_to_add);
let last_block_to_add = a
.blocks()
.iter()
.rev()
.position(|block| block.degree.get() != 0)
.map_or(num_blocks - 1, |pos| num_blocks - pos - 1);
last_block_where_addition_happened =
last_block_where_addition_happened.max(last_block_to_add);
for (ct_left_i, ct_right_i) in &mut s.blocks_mut()
[first_block_to_add..last_block_to_add + 1]
.iter_mut()
.zip(a.blocks()[first_block_to_add..last_block_to_add + 1].iter())
{
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
}
}

first_block_where_addition_happened..=last_block_where_addition_happened
}

/// - Expects all ciphertexts to have empty carries
/// - Expects all ciphertexts to have the same size
pub fn unchecked_unsigned_overflowing_sum_ciphertexts_vec_parallelized(
Expand Down Expand Up @@ -291,125 +254,21 @@ impl ServerKey {
);
}

assert!(
ciphertexts
.iter()
.all(IntegerRadixCiphertext::block_carries_are_empty),
"All ciphertexts must have empty carries"
);

let num_blocks = ciphertexts[0].blocks.len();
assert!(
ciphertexts[1..]
.iter()
.all(|ct| ct.blocks.len() == num_blocks),
"Not all ciphertexts have the same number of blocks"
);
assert!(
ciphertexts
.iter()
.all(RadixCiphertext::block_carries_are_empty),
"All ciphertexts must have empty carries"
);

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let message_max = message_modulus - 1;

let num_elements_to_fill_carry = (total_modulus - 1) / message_max;

let mut tmp_out = Vec::new();

let mut carries = Vec::<Ciphertext>::new();
while ciphertexts.len() > num_elements_to_fill_carry {
let mut chunks_iter = ciphertexts.par_chunks_exact_mut(num_elements_to_fill_carry);
let remainder_len = chunks_iter.remainder().len();

chunks_iter
.map(|chunk| {
let addition_range = self.unchecked_sum_ciphertext_chunk(chunk);
let s = &mut chunk[0];

let mut carry_ct = s.clone();
rayon::join(
|| {
s.blocks[addition_range.clone()]
.par_iter_mut()
.for_each(|block| {
self.key.message_extract_assign(block);
});
},
|| {
// Contrary to non overflowing version we always extract all carries
// as we need to track overflows
carry_ct.blocks[addition_range.clone()]
.par_iter_mut()
.for_each(|block| {
self.key.carry_extract_assign(block);
});
// Blocks for which we do not extract carries, means carry value is 0
for block in &mut carry_ct.blocks[..*addition_range.start()] {
self.key.create_trivial_assign(block, 0);
}
for block in &mut carry_ct.blocks[*addition_range.end() + 1..] {
self.key.create_trivial_assign(block, 0);
}
},
);

let out_carry = if *addition_range.end() == num_blocks - 1 {
let carry = carry_ct.blocks[num_blocks - 1].clone();
self.key
.create_trivial_assign(carry_ct.blocks.last_mut().unwrap(), 0);
carry
} else {
self.key.create_trivial(0)
};
carry_ct.blocks.rotate_right(1);

(s.clone(), carry_ct, out_carry)
})
.collect_into_vec(&mut tmp_out);

// tmp_out elements are tuple of 3 elements (message, carry, last_block_carry)
let num_ct_created = tmp_out.len() * 2;
// Ciphertexts not treated in this iteration are at the end of ciphertexts vec.
// the rotation will make them 'wrap around' and be placed at range index
// (num_ct_created..remainder_len + num_ct_created)
// We will then fill the indices in range (0..num_ct_created)
ciphertexts.rotate_right(remainder_len + num_ct_created);

// Drain elements out of tmp_out to replace them
// at the beginning of the ciphertexts left to add
for (i, (m, c, out_carry)) in tmp_out.drain(..).enumerate() {
ciphertexts[i * 2] = m;
ciphertexts[(i * 2) + 1] = c;
carries.push(out_carry);
}
ciphertexts.truncate(num_ct_created + remainder_len);
}

// Now we will add the last chunk of terms
// just as was done above, however we do it
// we want to use an addition that leaves
// the resulting ciphertext with empty carries
let (result, rest) = ciphertexts.split_first_mut().unwrap();
for term in rest.iter() {
self.unchecked_add_assign(result, term);
}
let mut carries = Vec::with_capacity(15);
let un_propagated_result = self
.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts, Some(&mut carries))?;

let (message_blocks, carry_blocks) = rayon::join(
|| {
result
un_propagated_result
.blocks
.par_iter()
.map(|block| self.key.message_extract(block))
.collect::<Vec<_>>()
},
|| {
let mut carry_blocks = Vec::with_capacity(num_blocks);
result
un_propagated_result
.blocks
.par_iter()
.map(|block| self.key.carry_extract(block))
Expand Down

0 comments on commit 96571ba

Please sign in to comment.