diff --git a/tfhe/src/integer/server_key/radix_parallel/count_zeros_ones.rs b/tfhe/src/integer/server_key/radix_parallel/count_zeros_ones.rs index 7479c27e88..abd3396c21 100644 --- a/tfhe/src/integer/server_key/radix_parallel/count_zeros_ones.rs +++ b/tfhe/src/integer/server_key/radix_parallel/count_zeros_ones.rs @@ -236,7 +236,7 @@ impl ServerKey { self.create_trivial_radix(-i64::from(max_possible_bit_count), num_signed_blocks), ); let result = self - .unchecked_partial_sum_ciphertexts_vec_parallelized(things_to_sum) + .unchecked_partial_sum_ciphertexts_vec_parallelized(things_to_sum, None) .expect("internal error, empty ciphertext count"); let (message_blocks, carry_blocks) = rayon::join( || { diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs index 5f851c5495..cc46d77d08 100644 --- a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -286,7 +286,7 @@ impl ServerKey { ); let result = self - .unchecked_partial_sum_ciphertexts_vec_parallelized(cts) + .unchecked_partial_sum_ciphertexts_vec_parallelized(cts, None) .expect("internal error, empty ciphertext count"); // This is the part where we extract message and carry blocks diff --git a/tfhe/src/integer/server_key/radix_parallel/sum.rs b/tfhe/src/integer/server_key/radix_parallel/sum.rs index 0df4870a66..de430be1c3 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sum.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sum.rs @@ -2,15 +2,18 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::{BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey}; use crate::shortint::Ciphertext; use rayon::prelude::*; -use std::ops::RangeInclusive; impl ServerKey { /// Computes the sum of the ciphertexts in parallel. /// + /// output_carries: if not None, carries generated by last blocks will + /// be stored in it. + /// /// Returns a result that has non propagated carries pub(crate) fn unchecked_partial_sum_ciphertexts_vec_parallelized( &self, terms: Vec, + mut output_carries: Option<&mut Vec>, ) -> Option where T: IntegerRadixCiphertext, @@ -89,7 +92,7 @@ impl ServerKey { self.key.unchecked_add_assign(&mut result, c); } - if column_index < num_columns - 1 { + if (column_index < num_columns - 1) || output_carries.is_some() { rayon::join( || self.key.message_extract(&result), || Some(self.key.carry_extract(&result)), @@ -114,8 +117,12 @@ impl ServerKey { for (msg, maybe_carry) in column_output.drain(..) { columns[i].push(msg); - if let (Some(carry), true) = (maybe_carry, (i + 1) < columns.len()) { - columns[i + 1].push(carry); + if let Some(carry) = maybe_carry { + if (i + 1) < columns.len() { + columns[i + 1].push(carry); + } else if let Some(ref mut out) = output_carries { + out.push(carry); + } } } } @@ -152,7 +159,8 @@ impl ServerKey { where T: IntegerRadixCiphertext, { - let mut result = self.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts)?; + let mut result = + self.unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts, None)?; self.full_propagate_parallelized(&mut result); assert!(result.block_carries_are_empty()); @@ -215,51 +223,6 @@ impl ServerKey { self.unchecked_sum_ciphertexts_parallelized(ciphertexts.as_ref()) } - /// This sums all ciphertext contained in the chunk into the first element of the chunk - /// i.e: [A, B, C] -> [A + B + C, B, C] - /// and returns the inclusive range indicating the range of blocks which where addition were - /// made that is, if the ciphertexts contains trailing (end or start) trivial zeros some - /// addition will be skipped (as adding a bunch of zeros is not useful) - fn unchecked_sum_ciphertext_chunk(&self, chunk: &mut [T]) -> RangeInclusive - where - T: IntegerRadixCiphertext, - { - assert_ne!(chunk.len(), 0); - if chunk.len() <= 1 { - return 0..=0; - } - let num_blocks = chunk[0].blocks().len(); - let (s, rest) = chunk.split_first_mut().unwrap(); - let mut first_block_where_addition_happened = num_blocks - 1; - let mut last_block_where_addition_happened = 0; - for a in rest.iter() { - let first_block_to_add = a - .blocks() - .iter() - .position(|block| block.degree.get() != 0) - .unwrap_or(num_blocks); - first_block_where_addition_happened = - first_block_where_addition_happened.min(first_block_to_add); - let last_block_to_add = a - .blocks() - .iter() - .rev() - .position(|block| block.degree.get() != 0) - .map_or(num_blocks - 1, |pos| num_blocks - pos - 1); - last_block_where_addition_happened = - last_block_where_addition_happened.max(last_block_to_add); - for (ct_left_i, ct_right_i) in &mut s.blocks_mut() - [first_block_to_add..last_block_to_add + 1] - .iter_mut() - .zip(a.blocks()[first_block_to_add..last_block_to_add + 1].iter()) - { - self.key.unchecked_add_assign(ct_left_i, ct_right_i); - } - } - - first_block_where_addition_happened..=last_block_where_addition_happened - } - /// - Expects all ciphertexts to have empty carries /// - Expects all ciphertexts to have the same size pub fn unchecked_unsigned_overflowing_sum_ciphertexts_vec_parallelized( @@ -291,117 +254,13 @@ impl ServerKey { ); } - assert!( - ciphertexts - .iter() - .all(IntegerRadixCiphertext::block_carries_are_empty), - "All ciphertexts must have empty carries" - ); - - let num_blocks = ciphertexts[0].blocks.len(); - assert!( - ciphertexts[1..] - .iter() - .all(|ct| ct.blocks.len() == num_blocks), - "Not all ciphertexts have the same number of blocks" - ); - assert!( - ciphertexts - .iter() - .all(RadixCiphertext::block_carries_are_empty), - "All ciphertexts must have empty carries" - ); - - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let message_max = message_modulus - 1; - - let num_elements_to_fill_carry = (total_modulus - 1) / message_max; - - let mut tmp_out = Vec::new(); - - let mut carries = Vec::::new(); - while ciphertexts.len() > num_elements_to_fill_carry { - let mut chunks_iter = ciphertexts.par_chunks_exact_mut(num_elements_to_fill_carry); - let remainder_len = chunks_iter.remainder().len(); - - chunks_iter - .map(|chunk| { - let addition_range = self.unchecked_sum_ciphertext_chunk(chunk); - let s = &mut chunk[0]; - - let mut carry_ct = s.clone(); - rayon::join( - || { - s.blocks[addition_range.clone()] - .par_iter_mut() - .for_each(|block| { - self.key.message_extract_assign(block); - }); - }, - || { - // Contrary to non overflowing version we always extract all carries - // as we need to track overflows - carry_ct.blocks[addition_range.clone()] - .par_iter_mut() - .for_each(|block| { - self.key.carry_extract_assign(block); - }); - // Blocks for which we do not extract carries, means carry value is 0 - for block in &mut carry_ct.blocks[..*addition_range.start()] { - self.key.create_trivial_assign(block, 0); - } - for block in &mut carry_ct.blocks[*addition_range.end() + 1..] { - self.key.create_trivial_assign(block, 0); - } - }, - ); - - let out_carry = if *addition_range.end() == num_blocks - 1 { - let carry = carry_ct.blocks[num_blocks - 1].clone(); - self.key - .create_trivial_assign(carry_ct.blocks.last_mut().unwrap(), 0); - carry - } else { - self.key.create_trivial(0) - }; - carry_ct.blocks.rotate_right(1); - - (s.clone(), carry_ct, out_carry) - }) - .collect_into_vec(&mut tmp_out); - - // tmp_out elements are tuple of 3 elements (message, carry, last_block_carry) - let num_ct_created = tmp_out.len() * 2; - // Ciphertexts not treated in this iteration are at the end of ciphertexts vec. - // the rotation will make them 'wrap around' and be placed at range index - // (num_ct_created..remainder_len + num_ct_created) - // We will then fill the indices in range (0..num_ct_created) - ciphertexts.rotate_right(remainder_len + num_ct_created); - - // Drain elements out of tmp_out to replace them - // at the beginning of the ciphertexts left to add - for (i, (m, c, out_carry)) in tmp_out.drain(..).enumerate() { - ciphertexts[i * 2] = m; - ciphertexts[(i * 2) + 1] = c; - carries.push(out_carry); - } - ciphertexts.truncate(num_ct_created + remainder_len); - } - - // Now we will add the last chunk of terms - // just as was done above, however we do it - // we want to use an addition that leaves - // the resulting ciphertext with empty carries - let (result, rest) = ciphertexts.split_first_mut().unwrap(); - for term in rest.iter() { - self.unchecked_add_assign(result, term); - } + let mut carries = Vec::with_capacity(15); + let un_propagated_result = self + .unchecked_partial_sum_ciphertexts_vec_parallelized(ciphertexts, Some(&mut carries))?; let (message_blocks, carry_blocks) = rayon::join( || { - result + un_propagated_result .blocks .par_iter() .map(|block| self.key.message_extract(block)) @@ -409,7 +268,7 @@ impl ServerKey { }, || { let mut carry_blocks = Vec::with_capacity(num_blocks); - result + un_propagated_result .blocks .par_iter() .map(|block| self.key.carry_extract(block))