diff --git a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h index c92dad2311..857ab80180 100644 --- a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h +++ b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h @@ -89,7 +89,37 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory, uint32_t lwe_chunk_size = 0); +template +__host__ __device__ uint64_t +get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( + uint32_t polynomial_size); +template +__host__ __device__ uint64_t +get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one( + uint32_t polynomial_size); +template +__host__ __device__ uint64_t +get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( + uint32_t polynomial_size); +template +__host__ __device__ uint64_t +get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one( + uint32_t polynomial_size); +template +__host__ __device__ uint64_t +get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( + uint32_t polynomial_size); +template +__host__ __device__ uint64_t +get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + uint32_t polynomial_size); + template struct pbs_buffer { + int8_t *d_mem_keybundle = NULL; + int8_t *d_mem_acc_step_one = NULL; + int8_t *d_mem_acc_step_two = NULL; + int8_t *d_mem_acc_cg = NULL; + double2 *keybundle_fft; Torus *global_accumulator; double2 *global_accumulator_fft; @@ -103,31 +133,99 @@ template struct pbs_buffer { this->pbs_variant = pbs_variant; auto max_shared_memory = cuda_get_max_shared_memory(stream->gpu_index); + uint64_t full_sm_keybundle = + get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle< + Torus>(polynomial_size); + uint64_t full_sm_accumulate_step_one = + get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one( + polynomial_size); + uint64_t partial_sm_accumulate_step_one = + get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one< + Torus>(polynomial_size); + uint64_t full_sm_accumulate_step_two = + get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( + polynomial_size); + uint64_t full_sm_cg_accumulate = + get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( + polynomial_size); + uint64_t partial_sm_cg_accumulate = + get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + polynomial_size); + + auto num_blocks_keybundle = input_lwe_ciphertext_count * lwe_chunk_size * + (glwe_dimension + 1) * (glwe_dimension + 1) * + level_count; + auto num_blocks_acc_step_one = + level_count * (glwe_dimension + 1) * input_lwe_ciphertext_count; + auto num_blocks_acc_step_two = + input_lwe_ciphertext_count * (glwe_dimension + 1); + auto num_blocks_acc_cg = + level_count * (glwe_dimension + 1) * input_lwe_ciphertext_count; + if (allocate_gpu_memory) { + // Keybundle + if (max_shared_memory < full_sm_keybundle) + d_mem_keybundle = (int8_t *)cuda_malloc_async( + num_blocks_keybundle * full_sm_keybundle, stream); + switch (pbs_variant) { case DEFAULT: + // Accumulator step one + if (max_shared_memory < partial_sm_accumulate_step_one) + d_mem_acc_step_one = (int8_t *)cuda_malloc_async( + num_blocks_acc_step_one * full_sm_accumulate_step_one, stream); + else if (max_shared_memory < full_sm_accumulate_step_one) + d_mem_acc_step_one = (int8_t *)cuda_malloc_async( + num_blocks_acc_step_one * partial_sm_accumulate_step_one, stream); + + // Accumulator step two + if (max_shared_memory < full_sm_accumulate_step_two) + d_mem_acc_step_two = (int8_t *)cuda_malloc_async( + num_blocks_acc_step_two * full_sm_accumulate_step_two, stream); + break; case CG: - keybundle_fft = (double2 *)cuda_malloc_async( - input_lwe_ciphertext_count * lwe_chunk_size * level_count * - (glwe_dimension + 1) * (glwe_dimension + 1) * - (polynomial_size / 2) * sizeof(double2), - stream); - global_accumulator = (Torus *)cuda_malloc_async( - input_lwe_ciphertext_count * (glwe_dimension + 1) * - polynomial_size * sizeof(Torus), - stream); - global_accumulator_fft = (double2 *)cuda_malloc_async( - input_lwe_ciphertext_count * (glwe_dimension + 1) * level_count * - (polynomial_size / 2) * sizeof(double2), - stream); + // Accumulator CG + if (max_shared_memory < partial_sm_cg_accumulate) + d_mem_acc_cg = (int8_t *)cuda_malloc_async( + num_blocks_acc_cg * full_sm_cg_accumulate, stream); + else if (max_shared_memory < full_sm_cg_accumulate) + d_mem_acc_cg = (int8_t *)cuda_malloc_async( + num_blocks_acc_cg * partial_sm_cg_accumulate, stream); break; default: PANIC("Cuda error (PBS): unsupported implementation variant.") } + + keybundle_fft = (double2 *)cuda_malloc_async( + num_blocks_keybundle * (polynomial_size / 2) * sizeof(double2), + stream); + global_accumulator = (Torus *)cuda_malloc_async( + num_blocks_acc_step_two * polynomial_size * sizeof(Torus), stream); + global_accumulator_fft = (double2 *)cuda_malloc_async( + num_blocks_acc_step_one * (polynomial_size / 2) * sizeof(double2), + stream); } } void release(cuda_stream_t *stream) { + + if (d_mem_keybundle) + cuda_drop_async(d_mem_keybundle, stream); + switch (pbs_variant) { + case DEFAULT: + if (d_mem_acc_step_one) + cuda_drop_async(d_mem_acc_step_one, stream); + if (d_mem_acc_step_two) + cuda_drop_async(d_mem_acc_step_two, stream); + break; + case CG: + if (d_mem_acc_cg) + cuda_drop_async(d_mem_acc_cg, stream); + break; + default: + PANIC("Cuda error (PBS): unsupported implementation variant.") + } + cuda_drop_async(keybundle_fft, stream); cuda_drop_async(global_accumulator, stream); cuda_drop_async(global_accumulator_fft, stream); diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index 8fbec54c68..30025ac627 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -16,7 +16,7 @@ #include "types/complex/operations.cuh" #include -template +template __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate( Torus *lwe_array_out, Torus *lwe_output_indexes, Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, Torus *lwe_input_indexes, @@ -24,7 +24,8 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate( uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset, uint32_t lwe_chunk_size, - uint32_t keybundle_size_per_input) { + uint32_t keybundle_size_per_input, int8_t *device_mem, + uint64_t device_memory_size_per_block) { grid_group grid = this_grid(); @@ -34,14 +35,21 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate( extern __shared__ int8_t sharedmem[]; int8_t *selected_memory; - selected_memory = sharedmem; + if constexpr (SMD == FULLSM) { + selected_memory = sharedmem; + } else { + int block_index = blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; + selected_memory = &device_mem[block_index * device_memory_size_per_block]; + } - // We always compute the pointer with most restrictive alignment to avoid - // alignment issues - double2 *accumulator_fft = (double2 *)selected_memory; - Torus *accumulator = - (Torus *)accumulator_fft + - (ptrdiff_t)(sizeof(double2) * polynomial_size / 2 / sizeof(Torus)); + Torus *accumulator = (Torus *)selected_memory; + double2 *accumulator_fft = + (double2 *)accumulator + + (ptrdiff_t)(sizeof(Torus) * polynomial_size / sizeof(double2)); + + if constexpr (SMD == PARTIALSM) + accumulator_fft = (double2 *)sharedmem; // The third dimension of the block is used to determine on which ciphertext // this block is operating, in the case of batch bootstraps @@ -126,6 +134,12 @@ __global__ void device_multi_bit_programmable_bootstrap_cg_accumulate( } } +template +__host__ __device__ uint64_t +get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + uint32_t polynomial_size) { + return sizeof(Torus) * polynomial_size; // accumulator +} template __host__ __device__ uint64_t get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( @@ -166,25 +180,64 @@ __host__ void scratch_cg_multi_bit_programmable_bootstrap( uint64_t full_sm_keybundle = get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( polynomial_size); - uint64_t full_sm_accumulate = + uint64_t full_sm_cg_accumulate = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( polynomial_size); + uint64_t partial_sm_cg_accumulate = + get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + polynomial_size); - check_cuda_error(cudaFuncSetAttribute( - device_multi_bit_programmable_bootstrap_keybundle, - cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_keybundle)); - cudaFuncSetCacheConfig( - device_multi_bit_programmable_bootstrap_keybundle, - cudaFuncCachePreferShared); - check_cuda_error(cudaGetLastError()); - - check_cuda_error(cudaFuncSetAttribute( - device_multi_bit_programmable_bootstrap_cg_accumulate, - cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_accumulate)); - cudaFuncSetCacheConfig( - device_multi_bit_programmable_bootstrap_cg_accumulate, - cudaFuncCachePreferShared); - check_cuda_error(cudaGetLastError()); + if (max_shared_memory < full_sm_keybundle) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncAttributeMaxDynamicSharedMemorySize, 0)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_keybundle)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } + + if (max_shared_memory < partial_sm_cg_accumulate) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncAttributeMaxDynamicSharedMemorySize, 0)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else if (max_shared_memory < full_sm_cg_accumulate) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm_cg_accumulate)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_cg_accumulate)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_cg_accumulate, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } if (!lwe_chunk_size) lwe_chunk_size = get_average_lwe_chunk_size( @@ -195,41 +248,37 @@ __host__ void scratch_cg_multi_bit_programmable_bootstrap( allocate_gpu_memory); } -template -__host__ void host_cg_multi_bit_programmable_bootstrap( - cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes, - Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, - Torus *lwe_input_indexes, uint64_t *bootstrapping_key, - pbs_buffer *pbs_buffer, uint32_t glwe_dimension, - uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory, - uint32_t lwe_chunk_size = 0) { - cudaSetDevice(stream->gpu_index); - - if (!lwe_chunk_size) - lwe_chunk_size = get_average_lwe_chunk_size(lwe_dimension, level_count, - glwe_dimension, num_samples); - - // - double2 *keybundle_fft = pbs_buffer->keybundle_fft; - Torus *global_accumulator = pbs_buffer->global_accumulator; - double2 *buffer_fft = pbs_buffer->global_accumulator_fft; - - // - uint64_t full_sm_keybundle = - get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( - polynomial_size); - uint64_t full_sm_accumulate = +template +__host__ void execute_external_product_loop( + cuda_stream_t *stream, Torus *lut_vector, Torus *lut_vector_indexes, + Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *lwe_array_out, + Torus *lwe_output_indexes, pbs_buffer *buffer, + uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, + uint32_t level_count, uint32_t lwe_chunk_size, uint32_t max_shared_memory, + int lwe_offset) { + + uint64_t full_dm = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( polynomial_size); + uint64_t partial_dm = + get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + polynomial_size); + uint64_t no_dm = 0; uint32_t keybundle_size_per_input = lwe_chunk_size * level_count * (glwe_dimension + 1) * (glwe_dimension + 1) * (polynomial_size / 2); - // - void *kernel_args[18]; + uint32_t chunk_size = + std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); + + auto d_mem = buffer->d_mem_acc_cg; + auto keybundle_fft = buffer->keybundle_fft; + auto global_accumulator = buffer->global_accumulator; + auto buffer_fft = buffer->global_accumulator_fft; + + void *kernel_args[20]; kernel_args[0] = &lwe_array_out; kernel_args[1] = &lwe_output_indexes; kernel_args[2] = &lut_vector; @@ -245,38 +294,68 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( kernel_args[12] = &base_log; kernel_args[13] = &level_count; kernel_args[14] = &grouping_factor; + kernel_args[15] = &lwe_offset; + kernel_args[16] = &chunk_size; kernel_args[17] = &keybundle_size_per_input; + kernel_args[18] = &d_mem; - // dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples); dim3 thds(polynomial_size / params::opt, 1, 1); - for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); - lwe_offset += lwe_chunk_size) { + if (max_shared_memory < partial_dm) { + kernel_args[19] = &full_dm; + check_cuda_error(cudaLaunchCooperativeKernel( + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, NOSM>, + grid_accumulate, thds, (void **)kernel_args, 0, stream->stream)); + } else if (max_shared_memory < full_dm) { + kernel_args[19] = &partial_dm; + check_cuda_error(cudaLaunchCooperativeKernel( + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, PARTIALSM>, + grid_accumulate, thds, (void **)kernel_args, partial_dm, + stream->stream)); + } else { + kernel_args[19] = &no_dm; + check_cuda_error(cudaLaunchCooperativeKernel( + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, FULLSM>, + grid_accumulate, thds, (void **)kernel_args, full_dm, stream->stream)); + } +} - uint32_t chunk_size = std::min( - lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); +template +__host__ void host_cg_multi_bit_programmable_bootstrap( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes, + Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, + Torus *lwe_input_indexes, uint64_t *bootstrapping_key, + pbs_buffer *buffer, uint32_t glwe_dimension, + uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory, + uint32_t lwe_chunk_size = 0) { + cudaSetDevice(stream->gpu_index); - // Compute a keybundle - dim3 grid_keybundle(num_samples * chunk_size, - (glwe_dimension + 1) * (glwe_dimension + 1), - level_count); - device_multi_bit_programmable_bootstrap_keybundle - <<stream>>>( - lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key, - lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, - base_log, level_count, lwe_offset, chunk_size, - keybundle_size_per_input); - check_cuda_error(cudaGetLastError()); + if (!lwe_chunk_size) + lwe_chunk_size = get_average_lwe_chunk_size(lwe_dimension, level_count, + glwe_dimension, num_samples); - kernel_args[15] = &lwe_offset; - kernel_args[16] = &chunk_size; + for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); + lwe_offset += lwe_chunk_size) { - check_cuda_error(cudaLaunchCooperativeKernel( - (void *)device_multi_bit_programmable_bootstrap_cg_accumulate, - grid_accumulate, thds, (void **)kernel_args, full_sm_accumulate, - stream->stream)); + // Compute a keybundle + execute_compute_keybundle( + stream, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, + num_samples, lwe_dimension, glwe_dimension, polynomial_size, + grouping_factor, base_log, level_count, max_shared_memory, + lwe_chunk_size, lwe_offset); + + // Accumulate + execute_external_product_loop( + stream, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, + lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, + glwe_dimension, polynomial_size, grouping_factor, base_log, level_count, + lwe_chunk_size, max_shared_memory, lwe_offset); } } @@ -291,9 +370,12 @@ __host__ bool verify_cuda_programmable_bootstrap_cg_multi_bit_grid_size( return false; // Calculate the dimension of the kernel - uint64_t full_sm = + uint64_t full_sm_cg_accumulate = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( params::degree); + uint64_t partial_sm_cg_accumulate = + get_buffer_size_partial_sm_cg_multibit_programmable_bootstrap( + params::degree); int thds = params::degree / params::opt; @@ -301,11 +383,25 @@ __host__ bool verify_cuda_programmable_bootstrap_cg_multi_bit_grid_size( int number_of_blocks = level_count * (glwe_dimension + 1) * num_samples; int max_active_blocks_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks_per_sm, - (void *) - device_multi_bit_programmable_bootstrap_cg_accumulate, - thds, full_sm); + if (max_shared_memory < partial_sm_cg_accumulate) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, NOSM>, + thds, 0); + } else if (max_shared_memory < full_sm_cg_accumulate) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, PARTIALSM>, + thds, partial_sm_cg_accumulate); + } else { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + (void *)device_multi_bit_programmable_bootstrap_cg_accumulate< + Torus, params, FULLSM>, + thds, full_sm_cg_accumulate); + } // Get the number of streaming multiprocessors int number_of_sm = 0; diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index 7fb81fbdde..93666d21b5 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -32,17 +32,26 @@ __device__ Torus calculates_monomial_degree(Torus *lwe_array_group, x, 2 * params::degree); // 2 * params::log2_degree + 1); } -template +template __global__ void device_multi_bit_programmable_bootstrap_keybundle( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *keybundle_array, Torus *bootstrapping_key, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size, - uint32_t keybundle_size_per_input) { + uint32_t keybundle_size_per_input, int8_t *device_mem, + uint64_t device_memory_size_per_block) { extern __shared__ int8_t sharedmem[]; int8_t *selected_memory = sharedmem; + if constexpr (SMD == FULLSM) { + selected_memory = sharedmem; + } else { + int block_index = blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; + selected_memory = &device_mem[block_index * device_memory_size_per_block]; + } + // Ids uint32_t level_id = blockIdx.z; uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1); @@ -99,7 +108,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle( synchronize_threads_in_block(); - double2 *fft = (double2 *)sharedmem; + double2 *fft = (double2 *)selected_memory; // Move accumulator to local memory double2 temp[params::opt / 2]; @@ -136,13 +145,14 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle( } } -template +template __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_one( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *lut_vector, Torus *lut_vector_indexes, Torus *global_accumulator, double2 *global_accumulator_fft, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t lwe_iteration) { + uint32_t level_count, uint32_t lwe_iteration, int8_t *device_mem, + uint64_t device_memory_size_per_block) { // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is @@ -152,11 +162,22 @@ __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_one( selected_memory = sharedmem; + if constexpr (SMD == FULLSM) { + selected_memory = sharedmem; + } else { + int block_index = blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; + selected_memory = &device_mem[block_index * device_memory_size_per_block]; + } + Torus *accumulator = (Torus *)selected_memory; double2 *accumulator_fft = (double2 *)accumulator + (ptrdiff_t)(sizeof(Torus) * polynomial_size / sizeof(double2)); + if constexpr (SMD == PARTIALSM) + accumulator_fft = (double2 *)sharedmem; + Torus *block_lwe_array_in = &lwe_array_in[lwe_input_indexes[blockIdx.z] * (lwe_dimension + 1)]; @@ -219,13 +240,14 @@ __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_one( accumulator_fft, global_fft_slice); } -template +template __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_two( Torus *lwe_array_out, Torus *lwe_output_indexes, double2 *keybundle_array, Torus *global_accumulator, double2 *global_accumulator_fft, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t grouping_factor, uint32_t iteration, - uint32_t lwe_offset, uint32_t lwe_chunk_size) { + uint32_t lwe_offset, uint32_t lwe_chunk_size, int8_t *device_mem, + uint64_t device_memory_size_per_block) { // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is // much faster than global memory @@ -233,8 +255,18 @@ __global__ void device_multi_bit_programmable_bootstrap_accumulate_step_two( int8_t *selected_memory; selected_memory = sharedmem; + + if constexpr (SMD == FULLSM) { + selected_memory = sharedmem; + } else { + int block_index = blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; + selected_memory = &device_mem[block_index * device_memory_size_per_block]; + } + double2 *accumulator_fft = (double2 *)selected_memory; + // double2 *keybundle = keybundle_array + // select the input blockIdx.x * lwe_chunk_size * level_count * @@ -299,7 +331,6 @@ get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( uint32_t polynomial_size) { return sizeof(Torus) * polynomial_size; // accumulator } - template __host__ __device__ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one( @@ -308,6 +339,12 @@ get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one( } template __host__ __device__ uint64_t +get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one( + uint32_t polynomial_size) { + return sizeof(Torus) * polynomial_size; // accumulator +} +template +__host__ __device__ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( uint32_t polynomial_size) { return sizeof(Torus) * polynomial_size; // accumulator @@ -350,36 +387,86 @@ __host__ void scratch_multi_bit_programmable_bootstrap( uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( polynomial_size); + uint64_t partial_sm_accumulate_step_one = + get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one< + Torus>(polynomial_size); + + if (max_shared_memory < full_sm_keybundle) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncAttributeMaxDynamicSharedMemorySize, 0)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_keybundle)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_keybundle, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } - check_cuda_error(cudaFuncSetAttribute( - device_multi_bit_programmable_bootstrap_keybundle, - cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_keybundle)); - cudaFuncSetCacheConfig( - device_multi_bit_programmable_bootstrap_keybundle, - cudaFuncCachePreferShared); - check_cuda_error(cudaGetLastError()); - - check_cuda_error(cudaFuncSetAttribute( - device_multi_bit_programmable_bootstrap_accumulate_step_one, - cudaFuncAttributeMaxDynamicSharedMemorySize, - full_sm_accumulate_step_one)); - cudaFuncSetCacheConfig( - device_multi_bit_programmable_bootstrap_accumulate_step_one, - cudaFuncCachePreferShared); - check_cuda_error(cudaGetLastError()); + if (max_shared_memory < partial_sm_accumulate_step_one) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, NOSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, 0)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, NOSM>, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else if (max_shared_memory < full_sm_accumulate_step_one) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, PARTIALSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + partial_sm_accumulate_step_one)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, PARTIALSM>, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + full_sm_accumulate_step_one)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_accumulate_step_one< + Torus, params, FULLSM>, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } - check_cuda_error(cudaFuncSetAttribute( - device_multi_bit_programmable_bootstrap_accumulate_step_two, - cudaFuncAttributeMaxDynamicSharedMemorySize, - full_sm_accumulate_step_two)); - cudaFuncSetCacheConfig( - device_multi_bit_programmable_bootstrap_accumulate_step_two, - cudaFuncCachePreferShared); - check_cuda_error(cudaGetLastError()); + if (max_shared_memory < full_sm_accumulate_step_two) { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_accumulate_step_two< + Torus, params, NOSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, 0)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_accumulate_step_two< + Torus, params, NOSM>, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } else { + check_cuda_error(cudaFuncSetAttribute( + device_multi_bit_programmable_bootstrap_accumulate_step_two< + Torus, params, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + full_sm_accumulate_step_two)); + cudaFuncSetCacheConfig( + device_multi_bit_programmable_bootstrap_accumulate_step_two< + Torus, params, FULLSM>, + cudaFuncCachePreferShared); + check_cuda_error(cudaGetLastError()); + } if (!lwe_chunk_size) lwe_chunk_size = get_average_lwe_chunk_size( @@ -390,84 +477,186 @@ __host__ void scratch_multi_bit_programmable_bootstrap( allocate_gpu_memory); } -template -__host__ void host_multi_bit_programmable_bootstrap( - cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes, - Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, - Torus *lwe_input_indexes, uint64_t *bootstrapping_key, - pbs_buffer *buffer, uint32_t glwe_dimension, - uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory, - uint32_t lwe_chunk_size = 0) { - cudaSetDevice(stream->gpu_index); +template +__host__ void execute_compute_keybundle( + cuda_stream_t *stream, Torus *lwe_array_in, Torus *lwe_input_indexes, + Torus *bootstrapping_key, pbs_buffer *buffer, + uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, + uint32_t level_count, uint32_t max_shared_memory, uint32_t lwe_chunk_size, + int lwe_offset) { - // If a chunk size is not passed to this function, select one. - if (!lwe_chunk_size) - lwe_chunk_size = get_average_lwe_chunk_size(lwe_dimension, level_count, - glwe_dimension, num_samples); - // - double2 *keybundle_fft = buffer->keybundle_fft; - Torus *global_accumulator = buffer->global_accumulator; - double2 *global_accumulator_fft = buffer->global_accumulator_fft; + uint32_t chunk_size = + std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); + + uint32_t keybundle_size_per_input = + lwe_chunk_size * level_count * (glwe_dimension + 1) * + (glwe_dimension + 1) * (polynomial_size / 2); - // uint64_t full_sm_keybundle = get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( polynomial_size); + + auto d_mem = buffer->d_mem_keybundle; + auto keybundle_fft = buffer->keybundle_fft; + + // Compute a keybundle + dim3 grid_keybundle(num_samples * chunk_size, + (glwe_dimension + 1) * (glwe_dimension + 1), level_count); + dim3 thds(polynomial_size / params::opt, 1, 1); + + if (max_shared_memory < full_sm_keybundle) + device_multi_bit_programmable_bootstrap_keybundle + <<stream>>>( + lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key, + lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, + base_log, level_count, lwe_offset, chunk_size, + keybundle_size_per_input, d_mem, full_sm_keybundle); + else + device_multi_bit_programmable_bootstrap_keybundle + <<stream>>>( + lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key, + lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, + base_log, level_count, lwe_offset, chunk_size, + keybundle_size_per_input, d_mem, 0); + check_cuda_error(cudaGetLastError()); +} + +template +__host__ void +execute_step_one(cuda_stream_t *stream, Torus *lut_vector, + Torus *lut_vector_indexes, Torus *lwe_array_in, + Torus *lwe_input_indexes, pbs_buffer *buffer, + uint32_t num_samples, uint32_t lwe_dimension, + uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t base_log, uint32_t level_count, + uint32_t max_shared_memory, int j, int lwe_offset) { + uint64_t full_sm_accumulate_step_one = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one( polynomial_size); + uint64_t partial_sm_accumulate_step_one = + get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one< + Torus>(polynomial_size); + + // + auto d_mem = buffer->d_mem_acc_step_one; + auto global_accumulator = buffer->global_accumulator; + auto global_accumulator_fft = buffer->global_accumulator_fft; + + dim3 grid_accumulate_step_one(level_count, glwe_dimension + 1, num_samples); + dim3 thds(polynomial_size / params::opt, 1, 1); + + if (max_shared_memory < partial_sm_accumulate_step_one) + device_multi_bit_programmable_bootstrap_accumulate_step_one + <<stream>>>( + lwe_array_in, lwe_input_indexes, lut_vector, lut_vector_indexes, + global_accumulator, global_accumulator_fft, lwe_dimension, + glwe_dimension, polynomial_size, base_log, level_count, + j + lwe_offset, d_mem, full_sm_accumulate_step_one); + else if (max_shared_memory < full_sm_accumulate_step_one) + device_multi_bit_programmable_bootstrap_accumulate_step_one + <<stream>>>( + lwe_array_in, lwe_input_indexes, lut_vector, lut_vector_indexes, + global_accumulator, global_accumulator_fft, lwe_dimension, + glwe_dimension, polynomial_size, base_log, level_count, + j + lwe_offset, d_mem, partial_sm_accumulate_step_one); + else + device_multi_bit_programmable_bootstrap_accumulate_step_one + <<stream>>>(lwe_array_in, lwe_input_indexes, lut_vector, + lut_vector_indexes, global_accumulator, + global_accumulator_fft, lwe_dimension, + glwe_dimension, polynomial_size, base_log, + level_count, j + lwe_offset, d_mem, 0); + check_cuda_error(cudaGetLastError()); +} + +template +__host__ void execute_step_two( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes, + pbs_buffer *buffer, uint32_t num_samples, + uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, + int32_t grouping_factor, uint32_t level_count, uint32_t max_shared_memory, + int j, int lwe_offset, uint32_t lwe_chunk_size) { + uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( polynomial_size); - uint32_t keybundle_size_per_input = - lwe_chunk_size * level_count * (glwe_dimension + 1) * - (glwe_dimension + 1) * (polynomial_size / 2); - // - dim3 grid_accumulate_step_one(level_count, glwe_dimension + 1, num_samples); + auto d_mem = buffer->d_mem_acc_step_two; + auto keybundle_fft = buffer->keybundle_fft; + auto global_accumulator = buffer->global_accumulator; + auto global_accumulator_fft = buffer->global_accumulator_fft; + dim3 grid_accumulate_step_two(num_samples, glwe_dimension + 1); dim3 thds(polynomial_size / params::opt, 1, 1); + if (max_shared_memory < full_sm_accumulate_step_two) + device_multi_bit_programmable_bootstrap_accumulate_step_two + <<stream>>>( + lwe_array_out, lwe_output_indexes, keybundle_fft, + global_accumulator, global_accumulator_fft, lwe_dimension, + glwe_dimension, polynomial_size, level_count, grouping_factor, j, + lwe_offset, lwe_chunk_size, d_mem, full_sm_accumulate_step_two); + else + device_multi_bit_programmable_bootstrap_accumulate_step_two + <<stream>>>(lwe_array_out, lwe_output_indexes, keybundle_fft, + global_accumulator, global_accumulator_fft, + lwe_dimension, glwe_dimension, polynomial_size, + level_count, grouping_factor, j, lwe_offset, + lwe_chunk_size, d_mem, 0); + check_cuda_error(cudaGetLastError()); +} + +template +__host__ void host_multi_bit_programmable_bootstrap( + cuda_stream_t *stream, Torus *lwe_array_out, Torus *lwe_output_indexes, + Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, + Torus *lwe_input_indexes, Torus *bootstrapping_key, + pbs_buffer *buffer, uint32_t glwe_dimension, + uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory, + uint32_t lwe_chunk_size = 0) { + cudaSetDevice(stream->gpu_index); + + // If a chunk size is not passed to this function, select one. + if (!lwe_chunk_size) + lwe_chunk_size = get_average_lwe_chunk_size(lwe_dimension, level_count, + glwe_dimension, num_samples); + for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { - uint32_t chunk_size = std::min( - lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); - // Compute a keybundle - dim3 grid_keybundle(num_samples * chunk_size, - (glwe_dimension + 1) * (glwe_dimension + 1), - level_count); - device_multi_bit_programmable_bootstrap_keybundle - <<stream>>>( - lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key, - lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, - base_log, level_count, lwe_offset, chunk_size, - keybundle_size_per_input); - check_cuda_error(cudaGetLastError()); - + execute_compute_keybundle( + stream, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, + num_samples, lwe_dimension, glwe_dimension, polynomial_size, + grouping_factor, base_log, level_count, max_shared_memory, + lwe_chunk_size, lwe_offset); // Accumulate + uint32_t chunk_size = std::min( + lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); for (int j = 0; j < chunk_size; j++) { - device_multi_bit_programmable_bootstrap_accumulate_step_one - <<stream>>>(lwe_array_in, lwe_input_indexes, lut_vector, - lut_vector_indexes, global_accumulator, - global_accumulator_fft, lwe_dimension, - glwe_dimension, polynomial_size, base_log, - level_count, j + lwe_offset); - check_cuda_error(cudaGetLastError()); - - device_multi_bit_programmable_bootstrap_accumulate_step_two - <<stream>>>(lwe_array_out, lwe_output_indexes, keybundle_fft, - global_accumulator, global_accumulator_fft, - lwe_dimension, glwe_dimension, polynomial_size, - level_count, grouping_factor, j, lwe_offset, - lwe_chunk_size); - check_cuda_error(cudaGetLastError()); + execute_step_one( + stream, lut_vector, lut_vector_indexes, lwe_array_in, + lwe_input_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, + polynomial_size, base_log, level_count, max_shared_memory, j, + lwe_offset); + + execute_step_two( + stream, lwe_array_out, lwe_output_indexes, buffer, num_samples, + lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, + level_count, max_shared_memory, j, lwe_offset, lwe_chunk_size); } } } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp index 0e3c219c52..1d9b9746de 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp @@ -193,10 +193,36 @@ ::testing::internal::ParamGenerator new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, 1, 3, 1, 10}, + (MultiBitProgrammableBootstrapTestParams){ + 888, 1, 16384, + new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), + new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, + 1, 3, 1, 10}, + + (MultiBitProgrammableBootstrapTestParams){ + 888, 1, 1024, + new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), + new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, + 128, 3, 1, 10}, (MultiBitProgrammableBootstrapTestParams){ 888, 1, 2048, new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, + 128, 3, 1, 10}, + (MultiBitProgrammableBootstrapTestParams){ + 888, 1, 4096, + new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), + new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, + 128, 3, 1, 10}, + (MultiBitProgrammableBootstrapTestParams){ + 888, 1, 8192, + new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), + new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, + 128, 3, 1, 10}, + (MultiBitProgrammableBootstrapTestParams){ + 888, 1, 16384, + new_gaussian_from_std_dev(sqrt(4.9571231961752025e-12)), + new_gaussian_from_std_dev(sqrt(9.9409770026944e-32)), 21, 1, 2, 2, 128, 3, 1, 10}); std::string printParamName( ::testing::TestParamInfo p) {