Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(integer): do sum by safe chunk sizes #1512

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2088,7 +2088,7 @@ template <typename Torus> struct int_are_all_block_true_buffer {

if (allocate_gpu_memory) {
Torus total_modulus = params.message_modulus * params.carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (params.message_modulus - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agnesLeroy this is fine as is for our set of parameters dedicated to the blockchain but ideally we need to take the min with the noise level see the tfhe/src/integer/server_key/mod.rs file for the final min with noise level


int max_chunks = (num_radix_blocks + max_value - 1) / max_value;
tmp_block_accumulated = (Torus *)cuda_malloc_async(
Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __host__ void are_all_comparisons_block_true(
auto tmp_out = are_all_block_true_buffer->tmp_out;

uint32_t total_modulus = message_modulus * carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (message_modulus - 1);

cuda_memcpy_async_gpu_to_gpu(tmp_out, lwe_array_in,
num_radix_blocks * (big_lwe_dimension + 1) *
Expand Down Expand Up @@ -173,7 +173,7 @@ __host__ void is_at_least_one_comparisons_block_true(
auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer;

uint32_t total_modulus = message_modulus * carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (message_modulus - 1);

cuda_memcpy_async_gpu_to_gpu(mem_ptr->tmp_lwe_array_out, lwe_array_in,
num_radix_blocks * (big_lwe_dimension + 1) *
Expand Down
18 changes: 17 additions & 1 deletion tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub(crate) mod radix;
pub(crate) mod radix_parallel;

use crate::integer::client_key::ClientKey;
use crate::shortint::ciphertext::MaxDegree;
use crate::shortint::ciphertext::{Degree, MaxDegree};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

Expand Down Expand Up @@ -231,6 +231,22 @@ impl ServerKey {

num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}

/// Returns how many ciphertext can be summed at once
///
/// The number of ciphertext that can be added together depends on the degree
/// (in order not to go beyond the carry space and keep results correct) but also
/// on the noise level (in order to have the correct error probability and so correctness and
/// security)
///
/// - `degree` is expected degree of all elements to be summed
pub(crate) fn max_sum_size(&self, degree: Degree) -> usize {
let max_degree =
MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus());
let max_sum_to_full_carry = max_degree.get() / degree.get();

max_sum_to_full_carry.min(self.key.max_noise_level.get())
}
}

impl AsRef<crate::shortint::ServerKey> for ServerKey {
Expand Down
25 changes: 10 additions & 15 deletions tfhe/src/integer/server_key/radix/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::ServerKey;
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::comparator::Comparator;
use crate::shortint::ciphertext::Degree;

impl ServerKey {
/// Compares for equality 2 ciphertexts
Expand Down Expand Up @@ -53,30 +54,27 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

let is_max_value = self
.key
.generate_lookup_table(|x| u64::from((x & max_value as u64) == max_value as u64));
.generate_lookup_table(|x| u64::from(x == max_sum_size as u64));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.chunks(max_value)
.chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

if blocks.len() == max_value {
if blocks.len() == max_sum_size {
self.key.apply_lookup_table(&sum, &is_max_value)
} else {
let is_equal_to_num_blocks = self.key.generate_lookup_table(|x| {
u64::from((x & max_value as u64) == blocks.len() as u64)
});
let is_equal_to_num_blocks = self
.key
.generate_lookup_table(|x| u64::from(x == blocks.len() as u64));
self.key.apply_lookup_table(&sum, &is_equal_to_num_blocks)
}
})
Expand Down Expand Up @@ -112,15 +110,12 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));
let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.chunks(max_value)
.chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
Expand Down
34 changes: 4 additions & 30 deletions tfhe/src/integer/server_key/radix_parallel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl ServerKey {
{
// Even though the corresponding function
// may already exist in self.key
// we generate our own lut to do less allocations
// we generate our own lut to do fewer allocations
// one for all the threads as opposed to one per thread
let lut = self
.key
Expand All @@ -36,7 +36,7 @@ impl ServerKey {
{
// Even though the corresponding function
// may already exist in self.key
// we generate our own lut to do less allocations
// we generate our own lut to do fewer allocations
// one for all the threads as opposed to one per thread
let lut = self
.key
Expand All @@ -50,34 +50,8 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2);
let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));

while block_comparisons.len() > 1 {
block_comparisons
.par_chunks(max_value)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}
self.key.apply_lookup_table(&sum, &is_non_zero)
})
.collect_into_vec(&mut block_comparisons_2);
std::mem::swap(&mut block_comparisons_2, &mut block_comparisons);
}

BooleanBlock::new_unchecked(
block_comparisons
.into_iter()
.next()
.unwrap_or_else(|| self.key.create_trivial(0)),
)
IceTDrinker marked this conversation as resolved.
Show resolved Hide resolved
let result = self.is_at_least_one_comparisons_block_true(block_comparisons);
BooleanBlock::new_unchecked(result)
}

pub fn unchecked_gt_parallelized<T>(&self, lhs: &T, rhs: &T) -> BooleanBlock
Expand Down
36 changes: 15 additions & 21 deletions tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType};
use crate::shortint::ciphertext::Degree;
use crate::shortint::server_key::LookupTableOwned;
use crate::shortint::Ciphertext;
use rayon::prelude::*;
Expand Down Expand Up @@ -160,27 +161,23 @@ impl ServerKey {
return self.key.create_trivial(1);
}

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let max_sum_size = self.max_sum_size(Degree::new(1));
let is_max_value = self
.key
.generate_lookup_table(|x| u64::from(x == max_value as u64));
.generate_lookup_table(|x| u64::from(x == max_sum_size as u64));

while block_comparisons.len() > 1 {
// Since all blocks encrypt either 0 or 1, we can sum max_value of them
// as in the worst case we will be adding `max_value` ones
block_comparisons = block_comparisons
.par_chunks(max_value)
.par_chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

if blocks.len() == max_value {
if blocks.len() == max_sum_size {
self.key.apply_lookup_table(&sum, &is_max_value)
} else {
let is_equal_to_num_blocks = self
Expand Down Expand Up @@ -213,25 +210,22 @@ impl ServerKey {
return self.key.create_trivial(1);
}

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let is_not_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));
let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2);
let max_sum_size = self.max_sum_size(Degree::new(1));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.par_chunks(max_value)
block_comparisons
.par_chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

self.key.apply_lookup_table(&sum, &is_not_zero)
})
.collect::<Vec<_>>();
.collect_into_vec(&mut block_comparisons_2);
std::mem::swap(&mut block_comparisons_2, &mut block_comparisons);
}

block_comparisons
Expand Down Expand Up @@ -423,10 +417,10 @@ impl ServerKey {
let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

assert!(carry_modulus >= message_modulus);
u8::try_from(max_value).unwrap();
u8::try_from(max_sum_size).unwrap();

let num_blocks = lhs.blocks().len();
let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);
Expand Down Expand Up @@ -516,10 +510,10 @@ impl ServerKey {
let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

assert!(carry_modulus >= message_modulus);
u8::try_from(max_value).unwrap();
u8::try_from(max_sum_size).unwrap();

let num_blocks = lhs.blocks().len();
let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);
Expand Down
Loading