Skip to content

Commit

Permalink
chore(gpu): fix partial sum ct with 0 or 1 inputs in the vec
Browse files Browse the repository at this point in the history
Also refactor the interface for Hillis & Steele prefix sum
  • Loading branch information
agnesLeroy committed Sep 11, 2024
1 parent 2a4026c commit 61136af
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 32 deletions.
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(

void cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
void **bsks, uint32_t num_blocks, uint32_t shift);
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift);

void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
Expand Down
6 changes: 3 additions & 3 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64(

void cuda_integer_compute_prefix_sum_hillis_steele_64(
void **streams, uint32_t *gpu_indexes, uint32_t gpu_count,
void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks,
void **bsks, uint32_t num_blocks, uint32_t shift) {
void *output_radix_lwe, void *generates_or_propagates, int8_t *mem_ptr,
void **ksks, void **bsks, uint32_t num_blocks, uint32_t shift) {

int_radix_params params = ((int_radix_lut<uint64_t> *)mem_ptr)->params;

host_compute_prefix_sum_hillis_steele<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(output_radix_lwe),
static_cast<uint64_t *>(input_radix_lwe), params,
static_cast<uint64_t *>(generates_or_propagates), params,
(int_radix_lut<uint64_t> *)mem_ptr, bsks, (uint64_t **)(ksks),
num_blocks);
}
Expand Down
18 changes: 12 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,47 +241,53 @@ void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 1024:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<1024>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 2048:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<2048>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 4096:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<4096>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 8192:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<8192>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
case 16384:
host_integer_partial_sum_ciphertexts_vec_kb<uint64_t,
AmortizedDegree<16384>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_vec), terms_degree, bsks,
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec);
(uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec,
nullptr);
break;
default:
PANIC("Cuda error (integer multiplication): unsupported polynomial size. "
Expand Down
12 changes: 10 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks,
uint64_t **ksks, int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec,
int_radix_lut<Torus> *reused_lut = nullptr) {
int_radix_lut<Torus> *reused_lut) {

auto new_blocks = mem_ptr->new_blocks;
auto old_blocks = mem_ptr->old_blocks;
Expand All @@ -205,6 +205,15 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
auto small_lwe_dimension = mem_ptr->params.small_lwe_dimension;
auto small_lwe_size = small_lwe_dimension + 1;

if (num_radix_in_vec == 0)
return;
if (num_radix_in_vec == 1) {
cuda_memcpy_async_gpu_to_gpu(radix_lwe_out, terms,
num_blocks_in_radix * big_lwe_size *
sizeof(Torus),
streams[0], gpu_indexes[0]);
return;
}
if (old_blocks != terms) {
cuda_memcpy_async_gpu_to_gpu(old_blocks, terms,
num_blocks_in_radix * num_radix_in_vec *
Expand Down Expand Up @@ -288,7 +297,6 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
terms_degree, h_lwe_idx_in, h_lwe_idx_out, h_smart_copy_in,
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
total_count, message_count, carry_count, sm_copy_count);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
luts_message_carry->set_lwe_indexes(streams[0], gpu_indexes[0],
Expand Down
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ __host__ void host_integer_scalar_mul_radix(
host_integer_partial_sum_ciphertexts_vec_kb<T, params>(
streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer,
terms_degree, bsks, ksks, mem->sum_ciphertexts_vec_mem,
num_radix_blocks, j);
num_radix_blocks, j, nullptr);

auto scp_mem_ptr = mem->sum_ciphertexts_vec_mem->scp_mem;
host_propagate_single_carry<T>(streams, gpu_indexes, gpu_count, lwe_array,
Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ extern "C" {
gpu_indexes: *const u32,
gpu_count: u32,
radix_lwe_out: *mut c_void,
radix_lwe_vec: *const c_void,
radix_lwe_vec: *mut c_void,
num_radix_in_vec: u32,
mem_ptr: *mut i8,
bsks: *const *mut c_void,
Expand Down Expand Up @@ -958,7 +958,7 @@ extern "C" {
gpu_indexes: *const u32,
gpu_count: u32,
output_radix_lwe: *mut c_void,
input_radix_lwe: *const c_void,
generates_or_propagates: *mut c_void,
mem_ptr: *mut i8,
ksks: *const *mut c_void,
bsks: *const *mut c_void,
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,7 @@ pub unsafe fn unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async<
pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_output: &mut CudaSliceMut<T>,
radix_lwe_input: &CudaSlice<T>,
generates_or_propagates: &mut CudaSliceMut<T>,
input_lut: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
Expand All @@ -2598,7 +2598,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
) {
assert_eq!(
streams.gpu_indexes[0],
radix_lwe_input.gpu_index(0),
generates_or_propagates.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
assert_eq!(
Expand Down Expand Up @@ -2643,7 +2643,7 @@ pub unsafe fn compute_prefix_sum_hillis_steele_async<T: UnsignedInteger, B: Nume
streams.gpu_indexes.as_ptr(),
streams.len() as u32,
radix_lwe_output.as_mut_c_ptr(0),
radix_lwe_input.as_c_ptr(0),
generates_or_propagates.as_mut_c_ptr(0),
mem_ptr,
keyswitch_key.ptr.as_ptr(),
bootstrapping_key.ptr.as_ptr(),
Expand Down
23 changes: 11 additions & 12 deletions tfhe/src/integer/gpu/server_key/radix/ilog2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ impl CudaServerKey {
let lwe_size = ct.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0;

// Allocate the necessary amount of memory
let mut output_radix =
CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);
let mut tmp_radix = CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]);

let lut = match direction {
Direction::Trailing => self.generate_lookup_table(|x| {
Expand Down Expand Up @@ -70,12 +69,12 @@ impl CudaServerKey {
}),
};

output_radix.copy_from_gpu_async(
tmp_radix.copy_from_gpu_async(
&ct.as_ref().d_blocks.0.d_vec,
streams,
streams.gpu_indexes[0],
);
let mut output_slice = output_radix
let mut output_slice = tmp_radix
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap();

Expand Down Expand Up @@ -167,27 +166,27 @@ impl CudaServerKey {
},
);

let mut cts = CudaLweCiphertextList::new(
let mut output_cts = CudaLweCiphertextList::new(
ct.as_ref().d_blocks.lwe_dimension(),
LweCiphertextCount(num_ct_blocks * ct.as_ref().d_blocks.lwe_ciphertext_count().0),
ct.as_ref().d_blocks.ciphertext_modulus(),
streams,
);

let input_radix_slice = output_radix
.as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
let mut generates_or_propagates = tmp_radix
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap();

match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
compute_prefix_sum_hillis_steele_async(
streams,
&mut cts
&mut output_cts
.0
.d_vec
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap(),
&input_radix_slice,
&mut generates_or_propagates,
sum_lut.acc.acc.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
Expand All @@ -211,12 +210,12 @@ impl CudaServerKey {
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
compute_prefix_sum_hillis_steele_async(
streams,
&mut cts
&mut output_cts
.0
.d_vec
.as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0])
.unwrap(),
&input_radix_slice,
&mut generates_or_propagates,
sum_lut.acc.acc.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
Expand All @@ -238,7 +237,7 @@ impl CudaServerKey {
);
}
}
cts
output_cts
}

/// Counts how many consecutive bits there are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub(crate) const fn nb_tests_for_params(params: PBSParameters) -> usize {

// >= 6 bits (3_3)
if full_modulus >= 1 << 6 {
return 15;
return 5;
}

30
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ where

let num_block =
(32f64 / (cks.parameters().message_modulus().0 as f64).log(2.0)).ceil() as usize;
println!(
"num block: {}, msg modulus: {}",
num_block,
cks.parameters().message_modulus().0
);

let cks = RadixClientKey::from((cks, num_block));
sks.set_deterministic_pbs_execution(true);
Expand All @@ -54,6 +59,7 @@ where
// 10, 7, 14 are from the paper and should trigger different branches
// 16 is a power of two and should trigger the corresponding branch
let hard_coded_divisors: [u64; 4] = [10, 7, 14, 16];
println!("Hard coded divisors");
for divisor in hard_coded_divisors {
let clear = rng.gen::<u64>() % modulus;
let ct = cks.encrypt(clear);
Expand All @@ -66,12 +72,14 @@ where
assert_eq!(r_res, clear % divisor);
}

println!("nb_tests loop");
for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;
let scalar = rng.gen_range(1u32..=u32::MAX) as u64;

let ct = cks.encrypt(clear);

println!("first case");
{
let (q, r) = executor.execute((&ct, scalar));
let (q2, r2) = executor.execute((&ct, scalar));
Expand All @@ -86,6 +94,7 @@ where
assert_eq!(r_res, clear % scalar);
}

println!("second case");
{
// Test when scalar is trivially bigger than the ct
let scalar = rng.gen_range(u32::MAX as u64 + 1..=u64::MAX);
Expand Down

0 comments on commit 61136af

Please sign in to comment.