From 96d984264a6b72aabe3a073400aff99cacb635ca Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 4 Sep 2024 17:10:53 +0200 Subject: [PATCH] fix(integer): do sum by safe chunk sizes Parameters are made with with assumptions on the number of leveled add/sub/scalar_mul operations are made, so that the noise level before doing a PBS has a correct level and everything is safe, secure and correct. So the lib implementation has to uphold these assumptions in order to keep the error probability failure correct. In the comparisons, at some point we had a vector of ciphertexts with a degree == 1, so we greedily summed them (e.g with 2_2 params we summed them by chunks of 15), while it is correct with regards to the carry and message space it is however less correct with regards to the noise level. Noise wise, doing this huge sum is correct as long as the noise of each ciphertext is independent from the others in the same chunk. While it may generally be the case we are in, its not guaranteed, and since we do not track that information we have to take the safer approach of assuming the worst case: all noise are dependent. So to fix the issue we compute the correct size of sum chunk by also taking into account the max noise level. --- .../tfhe-cuda-backend/cuda/include/integer.h | 2 +- .../cuda/src/integer/comparison.cuh | 5 +-- tfhe/src/integer/server_key/mod.rs | 22 +++++++++++- .../integer/server_key/radix/comparison.rs | 21 +++++------ .../server_key/radix_parallel/comparison.rs | 34 +++--------------- .../radix_parallel/scalar_comparison.rs | 36 ++++++++----------- 6 files changed, 52 insertions(+), 68 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 99860fd1c7..8452199a47 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -2087,7 +2087,7 @@ template struct int_are_all_block_true_buffer { if (allocate_gpu_memory) { Torus total_modulus = params.message_modulus * params.carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (params.message_modulus - 1); int max_chunks = (num_radix_blocks + max_value - 1) / max_value; tmp_block_accumulated = (Torus *)cuda_malloc_async( diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index d56d6e3297..417023ba51 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -75,7 +75,8 @@ __host__ void are_all_comparisons_block_true( auto tmp_out = are_all_block_true_buffer->tmp_out; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); + cuda_memcpy_async_gpu_to_gpu(tmp_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * @@ -174,7 +175,7 @@ __host__ void is_at_least_one_comparisons_block_true( auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); cuda_memcpy_async_gpu_to_gpu(mem_ptr->tmp_lwe_array_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index bfe5d8fdd4..e189ab6446 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -9,7 +9,7 @@ pub(crate) mod radix; pub(crate) mod radix_parallel; use crate::integer::client_key::ClientKey; -use crate::shortint::ciphertext::MaxDegree; +use crate::shortint::ciphertext::{Degree, MaxDegree}; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -231,6 +231,26 @@ impl ServerKey { num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize) } + + /// Returns how many ciphertext can be summed at once + /// + /// The number of ciphertext that can be added together depends on the degree + /// (in order not to go beyond the carry space and keep results correct) but also + /// on the noise level (in order to have the correct error probability and so correctness and + /// security) + /// + /// - `degree` is expected degree of all elements to be summed + pub(crate) fn max_sum_size(&self, degree: Degree) -> usize { + let max_degree = + MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus()); + let max_sum_to_full_carry = max_degree.get() / degree.get(); + + // We want to compare with the true max noise level, that is the one + // were we don't need enough room to accept a carry + let true_max_noise_level = max_degree.get() / (self.message_modulus().0 - 1); + + max_sum_to_full_carry.min(true_max_noise_level) + } } impl AsRef for ServerKey { diff --git a/tfhe/src/integer/server_key/radix/comparison.rs b/tfhe/src/integer/server_key/radix/comparison.rs index bf570567d1..6b00c17826 100644 --- a/tfhe/src/integer/server_key/radix/comparison.rs +++ b/tfhe/src/integer/server_key/radix/comparison.rs @@ -2,6 +2,7 @@ use super::ServerKey; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::Comparator; +use crate::shortint::ciphertext::Degree; impl ServerKey { /// Compares for equality 2 ciphertexts @@ -53,29 +54,26 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from((x & max_value as u64) == max_value as u64)); + .generate_lookup_table(|x| u64::from((x & max_sum_size as u64) == max_sum_size as u64)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { let is_equal_to_num_blocks = self.key.generate_lookup_table(|x| { - u64::from((x & max_value as u64) == blocks.len() as u64) + u64::from((x & max_sum_size as u64) == blocks.len() as u64) }); self.key.apply_lookup_table(&sum, &is_equal_to_num_blocks) } @@ -112,15 +110,12 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { diff --git a/tfhe/src/integer/server_key/radix_parallel/comparison.rs b/tfhe/src/integer/server_key/radix_parallel/comparison.rs index 2f3e416ff1..2838f494d8 100644 --- a/tfhe/src/integer/server_key/radix_parallel/comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/comparison.rs @@ -11,7 +11,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -36,7 +36,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -50,34 +50,8 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - - let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); - let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); - - while block_comparisons.len() > 1 { - block_comparisons - .par_chunks(max_value) - .map(|blocks| { - let mut sum = blocks[0].clone(); - for other_block in &blocks[1..] { - self.key.unchecked_add_assign(&mut sum, other_block); - } - self.key.apply_lookup_table(&sum, &is_non_zero) - }) - .collect_into_vec(&mut block_comparisons_2); - std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); - } - - BooleanBlock::new_unchecked( - block_comparisons - .into_iter() - .next() - .unwrap_or_else(|| self.key.create_trivial(0)), - ) + let result = self.is_at_least_one_comparisons_block_true(block_comparisons); + BooleanBlock::new_unchecked(result) } pub fn unchecked_gt_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index 34cd2ff05a..4d712f53f2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -3,6 +3,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType}; +use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::LookupTableOwned; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -160,27 +161,23 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from(x == max_value as u64)); + .generate_lookup_table(|x| u64::from(x == max_sum_size as u64)); while block_comparisons.len() > 1 { // Since all blocks encrypt either 0 or 1, we can sum max_value of them // as in the worst case we will be adding `max_value` ones block_comparisons = block_comparisons - .par_chunks(max_value) + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { let is_equal_to_num_blocks = self @@ -213,25 +210,22 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - let is_not_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); + let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); + let max_sum_size = self.max_sum_size(Degree::new(1)); while block_comparisons.len() > 1 { - block_comparisons = block_comparisons - .par_chunks(max_value) + block_comparisons + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - self.key.apply_lookup_table(&sum, &is_not_zero) }) - .collect::>(); + .collect_into_vec(&mut block_comparisons_2); + std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); } block_comparisons @@ -423,10 +417,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2); @@ -516,10 +510,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);