Skip to content

Commit

Permalink
chore(gpu): refactor lwe_chunk_size
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 27, 2024
1 parent 45effa4 commit 25e55af
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64(

void scratch_cuda_multi_bit_programmable_bootstrap_64(
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64(
Expand Down Expand Up @@ -130,7 +129,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
int8_t *d_mem_acc_step_two = NULL;
int8_t *d_mem_acc_cg = NULL;
int8_t *d_mem_acc_tbc = NULL;

uint32_t lwe_chunk_size;
double2 *keybundle_fft;
Torus *global_accumulator;
double2 *global_accumulator_fft;
Expand All @@ -142,6 +141,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size,
PBS_VARIANT pbs_variant, bool allocate_gpu_memory) {
this->pbs_variant = pbs_variant;
this->lwe_chunk_size = lwe_chunk_size;
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);

// default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
if (grouping_factor == 0)
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
scratch_cuda_multi_bit_programmable_bootstrap_64(
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
polynomial_size, level_count, grouping_factor,
input_lwe_ciphertext_count, allocate_gpu_memory);
stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size,
level_count, input_lwe_ciphertext_count, allocate_gpu_memory);
break;
case CLASSICAL:
scratch_cuda_programmable_bootstrap_64(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ __host__ void execute_cg_external_product_loop(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count,
uint32_t lut_stride) {
uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint64_t full_dm =
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
polynomial_size);
Expand Down Expand Up @@ -314,8 +314,7 @@ __host__ void host_cg_multi_bit_programmable_bootstrap(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);
auto lwe_chunk_size = buffer->lwe_chunk_size;

for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {
Expand All @@ -324,15 +323,15 @@ __host__ void host_cg_multi_bit_programmable_bootstrap(
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
grouping_factor, level_count, lwe_offset);

// Accumulate
execute_cg_external_product_loop<Torus, params>(
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset,
lut_count, lut_stride);
grouping_factor, base_log, level_count, lwe_offset, lut_count,
lut_stride);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,52 +331,51 @@ void scratch_cuda_cg_multi_bit_programmable_bootstrap(
template <typename Torus>
void scratch_cuda_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

switch (polynomial_size) {
case 256:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 512:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 1024:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 2048:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 4096:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 8192:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<8192>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
case 16384:
scratch_multi_bit_programmable_bootstrap<Torus, AmortizedDegree<16384>>(
static_cast<cudaStream_t>(stream), gpu_index, buffer, lwe_dimension,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory);
static_cast<cudaStream_t>(stream), gpu_index, buffer, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
break;
default:
PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported "
Expand All @@ -386,10 +385,9 @@ void scratch_cuda_multi_bit_programmable_bootstrap(
}

void scratch_cuda_multi_bit_programmable_bootstrap_64(
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory) {
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

#if (CUDA_ARCH >= 900)
if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
Expand All @@ -411,8 +409,8 @@ void scratch_cuda_multi_bit_programmable_bootstrap_64(
else
scratch_cuda_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
lwe_dimension, glwe_dimension, polynomial_size, level_count,
grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory);
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory);
}

void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream,
Expand Down Expand Up @@ -490,10 +488,9 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,

template void scratch_cuda_multi_bit_programmable_bootstrap<uint64_t>(
void *stream, uint32_t gpu_index,
pbs_buffer<uint64_t, MULTI_BIT> **pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory);
pbs_buffer<uint64_t, MULTI_BIT> **pbs_buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template void
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector<uint64_t>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,9 @@ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two(
template <typename Torus, typename params>
__host__ void scratch_multi_bit_programmable_bootstrap(
cudaStream_t stream, uint32_t gpu_index,
pbs_buffer<Torus, MULTI_BIT> **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t grouping_factor,
bool allocate_gpu_memory) {
pbs_buffer<Torus, MULTI_BIT> **buffer, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, input_lwe_ciphertext_count, polynomial_size);
Expand All @@ -404,9 +403,9 @@ __host__ void execute_compute_keybundle(
Torus *lwe_input_indexes, Torus *bootstrapping_key,
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, uint32_t lwe_offset) {
uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint32_t chunk_size =
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
if (chunk_size == 0)
Expand Down Expand Up @@ -507,9 +506,9 @@ __host__ void execute_step_two(
Torus *lwe_output_indexes, pbs_buffer<Torus, MULTI_BIT> *buffer,
uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count,
uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size,
uint32_t lut_count, uint32_t lut_stride) {
uint32_t j, uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint64_t full_sm_accumulate_step_two =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
polynomial_size);
Expand Down Expand Up @@ -555,8 +554,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);
auto lwe_chunk_size = buffer->lwe_chunk_size;

for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {
Expand All @@ -565,7 +563,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
grouping_factor, level_count, lwe_offset);
// Accumulate
uint32_t chunk_size = std::min(
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
Expand All @@ -578,8 +576,7 @@ __host__ void host_multi_bit_programmable_bootstrap(
execute_step_two<Torus, params>(
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, level_count, j, lwe_offset, lwe_chunk_size,
lut_count, lut_stride);
grouping_factor, level_count, j, lwe_offset, lut_count, lut_stride);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ __host__ void execute_tbc_external_product_loop(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t base_log, uint32_t level_count,
uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count,
uint32_t lut_stride) {
uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) {

auto lwe_chunk_size = buffer->lwe_chunk_size;
auto supports_dsm =
supports_distributed_shared_memory_on_multibit_programmable_bootstrap<
Torus>(polynomial_size);
Expand Down Expand Up @@ -325,25 +325,23 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap(
uint32_t lut_count, uint32_t lut_stride) {
cudaSetDevice(gpu_index);

auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
gpu_index, num_samples, polynomial_size);

auto lwe_chunk_size = buffer->lwe_chunk_size;
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
lwe_offset += lwe_chunk_size) {

// Compute a keybundle
execute_compute_keybundle<Torus, params>(
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset);
grouping_factor, level_count, lwe_offset);

// Accumulate
execute_tbc_external_product_loop<Torus, params>(
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset,
lut_count, lut_stride);
grouping_factor, base_log, level_count, lwe_offset, lut_count,
lut_stride);
}
}

Expand Down
2 changes: 0 additions & 2 deletions backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1192,11 +1192,9 @@ extern "C" {
stream: *mut c_void,
gpu_index: u32,
pbs_buffer: *mut *mut i8,
lwe_dimension: u32,
glwe_dimension: u32,
polynomial_size: u32,
level_count: u32,
grouping_factor: u32,
input_lwe_ciphertext_count: u32,
allocate_gpu_memory: bool,
);
Expand Down
2 changes: 0 additions & 2 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,9 @@ pub unsafe fn programmable_bootstrap_multi_bit_async<T: UnsignedInteger>(
streams.ptr[0],
streams.gpu_indexes[0],
std::ptr::addr_of_mut!(pbs_buffer),
lwe_dimension.0 as u32,
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
level.0 as u32,
grouping_factor.0 as u32,
num_samples,
true,
);
Expand Down

0 comments on commit 25e55af

Please sign in to comment.