diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 3e35cc8e73..872ec7810b 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -2095,7 +2095,7 @@ template struct int_are_all_block_true_buffer { if (allocate_gpu_memory) { Torus total_modulus = params.message_modulus * params.carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (params.message_modulus - 1); int max_chunks = (num_radix_blocks + max_value - 1) / max_value; tmp_block_accumulated = (Torus *)cuda_malloc_async( diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index ed2a3bbef5..3b288f2283 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -74,7 +74,7 @@ __host__ void are_all_comparisons_block_true( auto tmp_out = are_all_block_true_buffer->tmp_out; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); cuda_memcpy_async_gpu_to_gpu(tmp_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * @@ -173,7 +173,7 @@ __host__ void is_at_least_one_comparisons_block_true( auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); cuda_memcpy_async_gpu_to_gpu(mem_ptr->tmp_lwe_array_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index cdc3fd4a36..dd382d9e33 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -9,7 +9,7 @@ pub(crate) mod radix; pub(crate) mod radix_parallel; use crate::integer::client_key::ClientKey; -use crate::shortint::ciphertext::MaxDegree; +use crate::shortint::ciphertext::{Degree, MaxDegree}; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -227,6 +227,22 @@ impl ServerKey { num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize) } + + /// Returns how many ciphertext can be summed at once + /// + /// The number of ciphertext that can be added together depends on the degree + /// (in order not to go beyond the carry space and keep results correct) but also + /// on the noise level (in order to have the correct error probability and so correctness and + /// security) + /// + /// - `degree` is expected degree of all elements to be summed + pub(crate) fn max_sum_size(&self, degree: Degree) -> usize { + let max_degree = + MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus()); + let max_sum_to_full_carry = max_degree.get() / degree.get(); + + max_sum_to_full_carry.min(self.key.max_noise_level.get()) + } } impl AsRef for ServerKey { diff --git a/tfhe/src/integer/server_key/radix/comparison.rs b/tfhe/src/integer/server_key/radix/comparison.rs index bf570567d1..a6e82b4e6c 100644 --- a/tfhe/src/integer/server_key/radix/comparison.rs +++ b/tfhe/src/integer/server_key/radix/comparison.rs @@ -2,6 +2,7 @@ use super::ServerKey; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::Comparator; +use crate::shortint::ciphertext::Degree; impl ServerKey { /// Compares for equality 2 ciphertexts @@ -53,30 +54,27 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from((x & max_value as u64) == max_value as u64)); + .generate_lookup_table(|x| u64::from(x == max_sum_size as u64)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { - let is_equal_to_num_blocks = self.key.generate_lookup_table(|x| { - u64::from((x & max_value as u64) == blocks.len() as u64) - }); + let is_equal_to_num_blocks = self + .key + .generate_lookup_table(|x| u64::from(x == blocks.len() as u64)); self.key.apply_lookup_table(&sum, &is_equal_to_num_blocks) } }) @@ -112,15 +110,12 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { diff --git a/tfhe/src/integer/server_key/radix_parallel/comparison.rs b/tfhe/src/integer/server_key/radix_parallel/comparison.rs index 2f45fcd592..b23ce60562 100644 --- a/tfhe/src/integer/server_key/radix_parallel/comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/comparison.rs @@ -51,7 +51,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -76,7 +76,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -90,34 +90,8 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - - let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); - let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); - - while block_comparisons.len() > 1 { - block_comparisons - .par_chunks(max_value) - .map(|blocks| { - let mut sum = blocks[0].clone(); - for other_block in &blocks[1..] { - self.key.unchecked_add_assign(&mut sum, other_block); - } - self.key.apply_lookup_table(&sum, &is_non_zero) - }) - .collect_into_vec(&mut block_comparisons_2); - std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); - } - - BooleanBlock::new_unchecked( - block_comparisons - .into_iter() - .next() - .unwrap_or_else(|| self.key.create_trivial(0)), - ) + let result = self.is_at_least_one_comparisons_block_true(block_comparisons); + BooleanBlock::new_unchecked(result) } /// This implements all comparisons (<, <=, >, >=) for both signed and unsigned diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index 34cd2ff05a..4d712f53f2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -3,6 +3,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType}; +use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::LookupTableOwned; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -160,27 +161,23 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from(x == max_value as u64)); + .generate_lookup_table(|x| u64::from(x == max_sum_size as u64)); while block_comparisons.len() > 1 { // Since all blocks encrypt either 0 or 1, we can sum max_value of them // as in the worst case we will be adding `max_value` ones block_comparisons = block_comparisons - .par_chunks(max_value) + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { let is_equal_to_num_blocks = self @@ -213,25 +210,22 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - let is_not_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); + let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); + let max_sum_size = self.max_sum_size(Degree::new(1)); while block_comparisons.len() > 1 { - block_comparisons = block_comparisons - .par_chunks(max_value) + block_comparisons + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - self.key.apply_lookup_table(&sum, &is_not_zero) }) - .collect::>(); + .collect_into_vec(&mut block_comparisons_2); + std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); } block_comparisons @@ -423,10 +417,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2); @@ -516,10 +510,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);