diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 872ec7810b..d6b56ab5e5 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -80,6 +80,11 @@ void cleanup_cuda_apply_bivariate_lut_kb_64(void **streams, uint32_t gpu_count, int8_t **mem_ptr_void); +void cuda_apply_many_univariate_lut_kb_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks, + void **bsks, uint32_t num_blocks, uint32_t num_luts, uint32_t lut_stride); + void scratch_cuda_full_propagation_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, diff --git a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h index be006cfb77..fa7ddffcbc 100644 --- a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h +++ b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h @@ -69,7 +69,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_samples); + uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride); void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( void *stream, uint32_t gpu_index, void *lwe_array_out, @@ -77,7 +77,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_samples); + uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride); void cleanup_cuda_programmable_bootstrap(void *stream, uint32_t gpu_index, int8_t **pbs_buffer); @@ -331,7 +331,8 @@ void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void cuda_programmable_bootstrap_lwe_ciphertext_vector( @@ -340,7 +341,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); #if (CUDA_ARCH >= 900) template @@ -350,7 +352,8 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void scratch_cuda_programmable_bootstrap_tbc( diff --git a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h index 2d5c150364..c72ee20b15 100644 --- a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h +++ b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h @@ -27,7 +27,8 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream, uint32_t gpu_index, @@ -58,7 +59,8 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); #endif template @@ -74,7 +76,8 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); template void scratch_cuda_multi_bit_programmable_bootstrap( @@ -90,7 +93,8 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); template uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle( diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh index 164d1e908c..d8e9b73515 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh @@ -194,12 +194,16 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes, compression_params.glwe_dimension, compression_params.polynomial_size); + // In the case of extracting a single LWE this parameters are dummy + uint32_t lut_count = 1; + uint32_t lut_stride = 0; /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension auto encryption_params = mem_ptr->encryption_params; auto lut = mem_ptr->carry_extract_lut; auto active_gpu_count = get_active_gpu_count(num_lwes, gpu_count); if (active_gpu_count == 1) { + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_array_out, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, extracted_lwe, @@ -208,7 +212,7 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes, compression_params.small_lwe_dimension, encryption_params.polynomial_size, encryption_params.pbs_base_log, encryption_params.pbs_level, encryption_params.grouping_factor, - num_lwes, encryption_params.pbs_type); + num_lwes, encryption_params.pbs_type, lut_count, lut_stride); } else { /// For multi GPU execution we create vectors of pointers for inputs and /// outputs @@ -235,7 +239,7 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes, compression_params.small_lwe_dimension, encryption_params.polynomial_size, encryption_params.pbs_base_log, encryption_params.pbs_level, encryption_params.grouping_factor, - num_lwes, encryption_params.pbs_type); + num_lwes, encryption_params.pbs_type, lut_count, lut_stride); /// Copy data back to GPU 0 and release vecs multi_gpu_gather_lwe_async(streams, gpu_indexes, active_gpu_count, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu index 05b466f5f1..d141f6cf1c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu @@ -131,6 +131,19 @@ void cleanup_cuda_apply_univariate_lut_kb_64(void **streams, mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); } +void cuda_apply_many_univariate_lut_kb_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks, + void **bsks, uint32_t num_blocks, uint32_t lut_count, uint32_t lut_stride) { + + host_apply_many_univariate_lut_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + static_cast(output_radix_lwe), + static_cast(input_radix_lwe), + (int_radix_lut *)mem_ptr, (uint64_t **)(ksks), bsks, num_blocks, + lut_count, lut_stride); +} + void scratch_cuda_apply_bivariate_lut_kb_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, void *input_lut, uint32_t lwe_dimension, uint32_t glwe_dimension, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index b23f6e5ab2..fd0288ac2d 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -189,6 +189,9 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb( auto polynomial_size = params.polynomial_size; auto grouping_factor = params.grouping_factor; + // In the case of extracting a single LWE this parameters are dummy + uint32_t lut_count = 1; + uint32_t lut_stride = 0; /// For multi GPU execution we create vectors of pointers for inputs and /// outputs std::vector lwe_array_in_vec = lut->lwe_array_in_vec; @@ -211,7 +214,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb( lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level, - grouping_factor, num_radix_blocks, pbs_type); + grouping_factor, num_radix_blocks, pbs_type, lut_count, lut_stride); } else { /// Make sure all data that should be on GPU 0 is indeed there cuda_synchronize_stream(streams[0], gpu_indexes[0]); @@ -237,7 +240,92 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb( lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log, - pbs_level, grouping_factor, num_radix_blocks, pbs_type); + pbs_level, grouping_factor, num_radix_blocks, pbs_type, lut_count, + lut_stride); + + /// Copy data back to GPU 0 and release vecs + multi_gpu_gather_lwe_async(streams, gpu_indexes, active_gpu_count, + lwe_array_out, lwe_after_pbs_vec, + lut->h_lwe_indexes_out, + lut->using_trivial_lwe_indexes, + num_radix_blocks, big_lwe_dimension + 1); + + /// Synchronize all GPUs + for (uint i = 0; i < active_gpu_count; i++) { + cuda_synchronize_stream(streams[i], gpu_indexes[i]); + } + } +} + +template +__host__ void integer_radix_apply_many_univariate_lookup_table_kb( + cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, + Torus *lwe_array_out, Torus *lwe_array_in, void **bsks, Torus **ksks, + uint32_t num_radix_blocks, int_radix_lut *lut, uint32_t lut_count, + uint32_t lut_stride) { + // apply_lookup_table + auto params = lut->params; + auto pbs_type = params.pbs_type; + auto big_lwe_dimension = params.big_lwe_dimension; + auto small_lwe_dimension = params.small_lwe_dimension; + auto ks_level = params.ks_level; + auto ks_base_log = params.ks_base_log; + auto pbs_level = params.pbs_level; + auto pbs_base_log = params.pbs_base_log; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto grouping_factor = params.grouping_factor; + + /// For multi GPU execution we create vectors of pointers for inputs and + /// outputs + std::vector lwe_array_in_vec = lut->lwe_array_in_vec; + std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; + std::vector lwe_after_pbs_vec = lut->lwe_after_pbs_vec; + std::vector lwe_trivial_indexes_vec = lut->lwe_trivial_indexes_vec; + + auto active_gpu_count = get_active_gpu_count(num_radix_blocks, gpu_count); + if (active_gpu_count == 1) { + execute_keyswitch_async(streams, gpu_indexes, 1, lwe_after_ks_vec[0], + lwe_trivial_indexes_vec[0], lwe_array_in, + lut->lwe_indexes_in, ksks, big_lwe_dimension, + small_lwe_dimension, ks_base_log, ks_level, + num_radix_blocks); + + /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE + /// dimension to a big LWE dimension + execute_pbs_async( + streams, gpu_indexes, 1, lwe_array_out, lut->lwe_indexes_out, + lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], + lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, + small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level, + grouping_factor, num_radix_blocks, pbs_type, lut_count, lut_stride); + } else { + /// Make sure all data that should be on GPU 0 is indeed there + cuda_synchronize_stream(streams[0], gpu_indexes[0]); + + /// With multiple GPUs we push to the vectors on each GPU then when we + /// gather data to GPU 0 we can copy back to the original indexing + multi_gpu_scatter_lwe_async( + streams, gpu_indexes, active_gpu_count, lwe_array_in_vec, lwe_array_in, + lut->h_lwe_indexes_in, lut->using_trivial_lwe_indexes, num_radix_blocks, + big_lwe_dimension + 1); + + /// Apply KS to go from a big LWE dimension to a small LWE dimension + execute_keyswitch_async(streams, gpu_indexes, active_gpu_count, + lwe_after_ks_vec, lwe_trivial_indexes_vec, + lwe_array_in_vec, lwe_trivial_indexes_vec, + ksks, big_lwe_dimension, small_lwe_dimension, + ks_base_log, ks_level, num_radix_blocks); + + /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE + /// dimension to a big LWE dimension + execute_pbs_async( + streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, + lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, + lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, + glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log, + pbs_level, grouping_factor, num_radix_blocks, pbs_type, lut_count, + lut_stride); /// Copy data back to GPU 0 and release vecs multi_gpu_gather_lwe_async(streams, gpu_indexes, active_gpu_count, @@ -272,6 +360,10 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb( auto polynomial_size = params.polynomial_size; auto grouping_factor = params.grouping_factor; + // In the case of extracting a single LWE this parameters are dummy + uint32_t lut_count = 1; + uint32_t lut_stride = 0; + // Left message is shifted auto lwe_array_pbs_in = lut->tmp_lwe_before_ks; pack_bivariate_blocks(streams, gpu_indexes, gpu_count, @@ -302,7 +394,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb( lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level, - grouping_factor, num_radix_blocks, pbs_type); + grouping_factor, num_radix_blocks, pbs_type, lut_count, lut_stride); } else { cuda_synchronize_stream(streams[0], gpu_indexes[0]); multi_gpu_scatter_lwe_async( @@ -324,7 +416,8 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb( lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log, - pbs_level, grouping_factor, num_radix_blocks, pbs_type); + pbs_level, grouping_factor, num_radix_blocks, pbs_type, lut_count, + lut_stride); /// Copy data back to GPU 0 and release vecs multi_gpu_gather_lwe_async(streams, gpu_indexes, active_gpu_count, @@ -700,6 +793,9 @@ void host_full_propagate_inplace(cudaStream_t *streams, uint32_t *gpu_indexes, int big_lwe_size = (params.glwe_dimension * params.polynomial_size + 1); int small_lwe_size = (params.small_lwe_dimension + 1); + // In the case of extracting a single LWE this parameters are dummy + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (int i = 0; i < num_blocks; i++) { auto cur_input_block = &input_blocks[i * big_lwe_size]; @@ -722,7 +818,7 @@ void host_full_propagate_inplace(cudaStream_t *streams, uint32_t *gpu_indexes, mem_ptr->lut->lwe_trivial_indexes, bsks, mem_ptr->lut->buffer, params.glwe_dimension, params.small_lwe_dimension, params.polynomial_size, params.pbs_base_log, params.pbs_level, - params.grouping_factor, 2, params.pbs_type); + params.grouping_factor, 2, params.pbs_type, lut_count, lut_stride); cuda_memcpy_async_gpu_to_gpu(cur_input_block, mem_ptr->tmp_big_lwe_vector, big_lwe_size * sizeof(Torus), streams[0], @@ -994,6 +1090,18 @@ void host_apply_univariate_lut_kb(cudaStream_t *streams, uint32_t *gpu_indexes, num_blocks, mem); } +template +void host_apply_many_univariate_lut_kb( + cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, + Torus *radix_lwe_out, Torus *radix_lwe_in, int_radix_lut *mem, + Torus **ksks, void **bsks, uint32_t num_blocks, uint32_t lut_count, + uint32_t lut_stride) { + + integer_radix_apply_many_univariate_lookup_table_kb( + streams, gpu_indexes, gpu_count, radix_lwe_out, radix_lwe_in, bsks, ksks, + num_blocks, mem, lut_count, lut_stride); +} + template void scratch_cuda_apply_bivariate_lut_kb( cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index e0116124a9..7bd838217f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -206,6 +206,10 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( auto small_lwe_dimension = mem_ptr->params.small_lwe_dimension; auto small_lwe_size = small_lwe_dimension + 1; + // In the case of extracting a single LWE this parameters are dummy + uint32_t lut_count = 1; + uint32_t lut_stride = 0; + if (num_radix_in_vec == 0) return; if (num_radix_in_vec == 1) { @@ -364,7 +368,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( glwe_dimension, small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor, total_count, - mem_ptr->params.pbs_type); + mem_ptr->params.pbs_type, lut_count, lut_stride); } else { cuda_synchronize_stream(streams[0], gpu_indexes[0]); @@ -412,7 +416,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( glwe_dimension, small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor, total_count, - mem_ptr->params.pbs_type); + mem_ptr->params.pbs_type, lut_count, lut_stride); multi_gpu_gather_lwe_async( streams, gpu_indexes, active_gpu_count, new_blocks, lwe_after_pbs_vec, diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index f8f12d4403..e3bf1471b7 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -127,7 +127,8 @@ void execute_pbs_async( std::vector pbs_buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t grouping_factor, - uint32_t input_lwe_ciphertext_count, PBS_TYPE pbs_type) { + uint32_t input_lwe_ciphertext_count, PBS_TYPE pbs_type, uint32_t lut_count, + uint32_t lut_stride) { switch (sizeof(Torus)) { case sizeof(uint32_t): // 32 bits @@ -159,7 +160,8 @@ void execute_pbs_async( current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes, current_lwe_array_in, current_lwe_input_indexes, bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension, - polynomial_size, base_log, level_count, num_inputs_on_gpu); + polynomial_size, base_log, level_count, num_inputs_on_gpu, + lut_count, lut_stride); } break; default: @@ -198,7 +200,7 @@ void execute_pbs_async( current_lwe_array_in, current_lwe_input_indexes, bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_inputs_on_gpu); + num_inputs_on_gpu, lut_count, lut_stride); } break; case CLASSICAL: @@ -226,7 +228,8 @@ void execute_pbs_async( current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes, current_lwe_array_in, current_lwe_input_indexes, bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension, - polynomial_size, base_log, level_count, num_inputs_on_gpu); + polynomial_size, base_log, level_count, num_inputs_on_gpu, + lut_count, lut_stride); } break; default: diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh index 94d383a2f3..04ff5348c5 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_classic.cuh @@ -44,7 +44,8 @@ __global__ void device_programmable_bootstrap_cg( const double2 *__restrict__ bootstrapping_key, double2 *join_buffer, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, int8_t *device_mem, - uint64_t device_memory_size_per_block) { + uint64_t device_memory_size_per_block, uint32_t lut_count, + uint32_t lut_stride) { grid_group grid = this_grid(); @@ -151,8 +152,38 @@ __global__ void device_programmable_bootstrap_cg( // we do the computation at block 0 to avoid waiting for extra blocks, in // case they're not synchronized sample_extract_mask(block_lwe_array_out, accumulator); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + accumulator, glwe_dimension, + i * lut_stride); + } + } } else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) { sample_extract_body(block_lwe_array_out, accumulator, 0); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + accumulator, 0, i * lut_stride); + } + } } } @@ -202,7 +233,8 @@ __host__ void host_programmable_bootstrap_cg( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t input_lwe_ciphertext_count) { + uint32_t level_count, uint32_t input_lwe_ciphertext_count, + uint32_t lut_count, uint32_t lut_stride) { // With SM each block corresponds to either the mask or body, no need to // duplicate data for each @@ -226,7 +258,7 @@ __host__ void host_programmable_bootstrap_cg( int thds = polynomial_size / params::opt; dim3 grid(level_count, glwe_dimension + 1, input_lwe_ciphertext_count); - void *kernel_args[14]; + void *kernel_args[16]; kernel_args[0] = &lwe_array_out; kernel_args[1] = &lwe_output_indexes; kernel_args[2] = &lut_vector; @@ -240,6 +272,8 @@ __host__ void host_programmable_bootstrap_cg( kernel_args[10] = &base_log; kernel_args[11] = &level_count; kernel_args[12] = &d_mem; + kernel_args[14] = &lut_count; + kernel_args[15] = &lut_stride; if (max_shared_memory < partial_sm) { kernel_args[13] = &full_dm; diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index 847f9e03d9..9ad863708e 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -30,7 +30,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset, uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input, - int8_t *device_mem, uint64_t device_memory_size_per_block) { + int8_t *device_mem, uint64_t device_memory_size_per_block, + uint32_t lut_count, uint32_t lut_stride) { grid_group grid = this_grid(); @@ -129,9 +130,44 @@ __global__ void __launch_bounds__(params::degree / params::opt) // Perform a sample extract. At this point, all blocks have the result, // but we do the computation at block 0 to avoid waiting for extra blocks, // in case they're not synchronized + // Always extract one by default sample_extract_mask(block_lwe_array_out, accumulator); + + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + accumulator, glwe_dimension, + i * lut_stride); + } + } + } else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) { + sample_extract_body(block_lwe_array_out, accumulator, 0); + + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + accumulator, 0, i * lut_stride); + } + } } } else { // Load the accumulator calculated in previous iterations @@ -256,7 +292,8 @@ __host__ void execute_cg_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset) { + uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, + uint32_t lut_stride) { uint64_t full_dm = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( @@ -283,7 +320,7 @@ __host__ void execute_cg_external_product_loop( auto global_accumulator = buffer->global_accumulator; auto buffer_fft = buffer->global_accumulator_fft; - void *kernel_args[20]; + void *kernel_args[22]; kernel_args[0] = &lwe_array_out; kernel_args[1] = &lwe_output_indexes; kernel_args[2] = &lut_vector; @@ -303,6 +340,8 @@ __host__ void execute_cg_external_product_loop( kernel_args[16] = &chunk_size; kernel_args[17] = &keybundle_size_per_input; kernel_args[18] = &d_mem; + kernel_args[20] = &lut_count; + kernel_args[21] = &lut_stride; dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples); dim3 thds(polynomial_size / params::opt, 1, 1); @@ -335,7 +374,8 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( Torus *lwe_array_in, Torus *lwe_input_indexes, uint64_t *bootstrapping_key, pbs_buffer *buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { auto lwe_chunk_size = get_lwe_chunk_size( gpu_index, num_samples, polynomial_size); @@ -354,7 +394,8 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, + lut_count, lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu index 1ee92f9c90..66a90fca69 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cu @@ -122,7 +122,8 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples) { + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride) { switch (polynomial_size) { case 256: @@ -130,49 +131,56 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 512: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 1024: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 2048: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 4096: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 8192: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 16384: host_programmable_bootstrap_tbc>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; default: PANIC("Cuda error (classical PBS): unsupported polynomial size. " @@ -370,7 +378,8 @@ void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples) { + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride) { switch (polynomial_size) { case 256: @@ -378,49 +387,56 @@ void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 512: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 1024: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 2048: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 4096: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 8192: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 16384: host_programmable_bootstrap_cg>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; default: PANIC("Cuda error (classical PBS): unsupported polynomial size. " @@ -436,7 +452,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples) { + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride) { switch (polynomial_size) { case 256: @@ -444,49 +461,56 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 512: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 1024: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 2048: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 4096: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 8192: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case 16384: host_programmable_bootstrap>( static_cast(stream), gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, glwe_dimension, - lwe_dimension, polynomial_size, base_log, level_count, num_samples); + lwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; default: PANIC("Cuda error (classical PBS): unsupported polynomial size. " @@ -503,7 +527,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_samples) { + uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { if (base_log > 32) PANIC("Cuda error (classical PBS): base log should be > number of bits " @@ -523,7 +547,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; #else PANIC("Cuda error (PBS): TBC pbs is not supported.") @@ -537,7 +562,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case DEFAULT: cuda_programmable_bootstrap_lwe_ciphertext_vector( @@ -548,7 +574,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; default: PANIC("Cuda error (PBS): unknown pbs variant.") @@ -622,7 +649,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_samples) { + uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { if (base_log > 64) PANIC("Cuda error (classical PBS): base log should be > number of bits " "in the ciphertext representation (64)"); @@ -641,7 +668,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; #else PANIC("Cuda error (PBS): TBC pbs is not supported.") @@ -655,7 +683,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; case PBS_VARIANT::DEFAULT: cuda_programmable_bootstrap_lwe_ciphertext_vector( @@ -666,7 +695,8 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_array_in), static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, num_samples); + glwe_dimension, polynomial_size, base_log, level_count, num_samples, + lut_count, lut_stride); break; default: PANIC("Cuda error (PBS): unknown pbs variant.") @@ -694,7 +724,8 @@ template void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void cuda_programmable_bootstrap_lwe_ciphertext_vector( void *stream, uint32_t gpu_index, uint64_t *lwe_array_out, @@ -703,7 +734,8 @@ template void cuda_programmable_bootstrap_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void scratch_cuda_programmable_bootstrap_cg( void *stream, uint32_t gpu_index, @@ -723,7 +755,8 @@ template void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( uint32_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void cuda_programmable_bootstrap_lwe_ciphertext_vector( void *stream, uint32_t gpu_index, uint32_t *lwe_array_out, @@ -732,7 +765,8 @@ template void cuda_programmable_bootstrap_lwe_ciphertext_vector( uint32_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void scratch_cuda_programmable_bootstrap_cg( void *stream, uint32_t gpu_index, @@ -760,7 +794,8 @@ template void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( uint32_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( void *stream, uint32_t gpu_index, uint64_t *lwe_array_out, uint64_t *lwe_output_indexes, uint64_t *lut_vector, @@ -768,7 +803,8 @@ template void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples); + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride); template void scratch_cuda_programmable_bootstrap_tbc( void *stream, uint32_t gpu_index, pbs_buffer **pbs_buffer, uint32_t glwe_dimension, diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh index 47d6955cb1..b9dfdf415c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh @@ -141,7 +141,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) Torus *global_accumulator, double2 *global_accumulator_fft, uint32_t lwe_iteration, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - int8_t *device_mem, uint64_t device_memory_size_per_block) { + int8_t *device_mem, uint64_t device_memory_size_per_block, + uint32_t lut_count, uint32_t lut_stride) { // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is @@ -216,8 +217,38 @@ __global__ void __launch_bounds__(params::degree / params::opt) // but we do the computation at block 0 to avoid waiting for extra blocks, // in case they're not synchronized sample_extract_mask(block_lwe_array_out, accumulator); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + accumulator, glwe_dimension, + i * lut_stride); + } + } } else if (blockIdx.y == glwe_dimension) { sample_extract_body(block_lwe_array_out, accumulator, 0); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + accumulator, 0, i * lut_stride); + } + } } } else { // Persist the updated accumulator @@ -375,16 +406,15 @@ execute_step_one(cudaStream_t stream, uint32_t gpu_index, Torus *lut_vector, } template -__host__ void -execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out, - Torus *lwe_output_indexes, Torus *lut_vector, - Torus *lut_vector_indexes, double2 *bootstrapping_key, - Torus *global_accumulator, double2 *global_accumulator_fft, - uint32_t input_lwe_ciphertext_count, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, int8_t *d_mem, - int lwe_iteration, uint64_t partial_sm, uint64_t partial_dm, - uint64_t full_sm, uint64_t full_dm) { +__host__ void execute_step_two( + cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out, + Torus *lwe_output_indexes, Torus *lut_vector, Torus *lut_vector_indexes, + double2 *bootstrapping_key, Torus *global_accumulator, + double2 *global_accumulator_fft, uint32_t input_lwe_ciphertext_count, + uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t base_log, uint32_t level_count, int8_t *d_mem, int lwe_iteration, + uint64_t partial_sm, uint64_t partial_dm, uint64_t full_sm, + uint64_t full_dm, uint32_t lut_count, uint32_t lut_stride) { int max_shared_memory = cuda_get_max_shared_memory(0); cudaSetDevice(gpu_index); @@ -397,21 +427,21 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out, lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, bootstrapping_key, global_accumulator, global_accumulator_fft, lwe_iteration, lwe_dimension, polynomial_size, base_log, - level_count, d_mem, full_dm); + level_count, d_mem, full_dm, lut_count, lut_stride); } else if (max_shared_memory < full_sm) { device_programmable_bootstrap_step_two <<>>( lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, bootstrapping_key, global_accumulator, global_accumulator_fft, lwe_iteration, lwe_dimension, polynomial_size, base_log, - level_count, d_mem, partial_dm); + level_count, d_mem, partial_dm, lut_count, lut_stride); } else { device_programmable_bootstrap_step_two <<>>( lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, bootstrapping_key, global_accumulator, global_accumulator_fft, lwe_iteration, lwe_dimension, polynomial_size, base_log, - level_count, d_mem, 0); + level_count, d_mem, 0, lut_count, lut_stride); } check_cuda_error(cudaGetLastError()); } @@ -425,7 +455,8 @@ __host__ void host_programmable_bootstrap( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t input_lwe_ciphertext_count) { + uint32_t level_count, uint32_t input_lwe_ciphertext_count, + uint32_t lut_count, uint32_t lut_stride) { cudaSetDevice(gpu_index); // With SM each block corresponds to either the mask or body, no need to @@ -461,7 +492,8 @@ __host__ void host_programmable_bootstrap( lut_vector_indexes, bootstrapping_key, global_accumulator, global_accumulator_fft, input_lwe_ciphertext_count, lwe_dimension, glwe_dimension, polynomial_size, base_log, level_count, d_mem, i, - partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two); + partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two, + lut_count, lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu index ec5e39eb41..e466d77d0b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu @@ -65,7 +65,8 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { if (base_log > 64) PANIC("Cuda error (multi-bit PBS): base log should be > number of bits in " @@ -78,7 +79,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 512: host_cg_multi_bit_programmable_bootstrap>( @@ -86,7 +87,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 1024: host_cg_multi_bit_programmable_bootstrap>( @@ -94,7 +95,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 2048: host_cg_multi_bit_programmable_bootstrap>( @@ -102,7 +103,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 4096: host_cg_multi_bit_programmable_bootstrap>( @@ -110,7 +111,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 8192: host_cg_multi_bit_programmable_bootstrap>( @@ -118,7 +119,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 16384: host_cg_multi_bit_programmable_bootstrap>( @@ -126,7 +127,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -142,7 +143,8 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { if (base_log > 64) PANIC("Cuda error (multi-bit PBS): base log should be > number of bits in " @@ -155,7 +157,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 512: host_multi_bit_programmable_bootstrap>( @@ -163,7 +165,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 1024: host_multi_bit_programmable_bootstrap>( @@ -171,7 +173,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 2048: host_multi_bit_programmable_bootstrap>( @@ -179,7 +181,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 4096: host_multi_bit_programmable_bootstrap>( @@ -187,7 +189,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 8192: host_multi_bit_programmable_bootstrap>( @@ -195,7 +197,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 16384: host_multi_bit_programmable_bootstrap>( @@ -203,7 +205,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -218,7 +220,8 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key, int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, - uint32_t level_count, uint32_t num_samples) { + uint32_t level_count, uint32_t num_samples, uint32_t lut_count, + uint32_t lut_stride) { pbs_buffer *buffer = (pbs_buffer *)mem_ptr; @@ -235,7 +238,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; #else PANIC("Cuda error (multi-bit PBS): TBC pbs is not supported.") @@ -250,7 +253,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case PBS_VARIANT::DEFAULT: cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( @@ -262,7 +265,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( static_cast(lwe_input_indexes), static_cast(bootstrapping_key), buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; default: PANIC("Cuda error (multi-bit PBS): unsupported implementation variant.") @@ -499,7 +502,8 @@ cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, uint64_t *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); template void scratch_cuda_cg_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, @@ -515,7 +519,8 @@ cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, uint64_t *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); template bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( @@ -586,7 +591,8 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { if (base_log > 64) PANIC("Cuda error (multi-bit PBS): base log should be > number of bits in " @@ -599,7 +605,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 512: host_tbc_multi_bit_programmable_bootstrap>( @@ -607,7 +613,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 1024: host_tbc_multi_bit_programmable_bootstrap>( @@ -615,7 +621,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 2048: host_tbc_multi_bit_programmable_bootstrap>( @@ -623,7 +629,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 4096: host_tbc_multi_bit_programmable_bootstrap>( @@ -631,7 +637,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 8192: host_tbc_multi_bit_programmable_bootstrap>( @@ -639,7 +645,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; case 16384: host_tbc_multi_bit_programmable_bootstrap>( @@ -647,7 +653,7 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, grouping_factor, base_log, level_count, - num_samples); + num_samples, lut_count, lut_stride); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -670,5 +676,6 @@ cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( uint64_t *lwe_input_indexes, uint64_t *bootstrapping_key, pbs_buffer *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples); + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride); #endif diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index e55c559f4e..c39816e3c1 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -252,7 +252,8 @@ __global__ void __launch_bounds__(params::degree / params::opt) uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t grouping_factor, uint32_t iteration, uint32_t lwe_offset, uint32_t lwe_chunk_size, int8_t *device_mem, - uint64_t device_memory_size_per_block) { + uint64_t device_memory_size_per_block, uint32_t lut_count, + uint32_t lut_stride) { // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is // much faster than global memory @@ -325,8 +326,38 @@ __global__ void __launch_bounds__(params::degree / params::opt) // but we do the computation at block 0 to avoid waiting for extra blocks, // in case they're not synchronized sample_extract_mask(block_lwe_array_out, global_slice); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + global_slice, glwe_dimension, + i * lut_stride); + } + } } else if (blockIdx.y == glwe_dimension) { sample_extract_body(block_lwe_array_out, global_slice, 0); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + global_slice, 0, i * lut_stride); + } + } } } } @@ -567,7 +598,8 @@ __host__ void execute_step_two( Torus *lwe_output_indexes, pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count, - uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size) { + uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size, + uint32_t lut_count, uint32_t lut_stride) { uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( @@ -590,7 +622,8 @@ __host__ void execute_step_two( lwe_array_out, lwe_output_indexes, keybundle_fft, global_accumulator, global_accumulator_fft, lwe_dimension, glwe_dimension, polynomial_size, level_count, grouping_factor, j, - lwe_offset, lwe_chunk_size, d_mem, full_sm_accumulate_step_two); + lwe_offset, lwe_chunk_size, d_mem, full_sm_accumulate_step_two, + lut_count, lut_stride); else device_multi_bit_programmable_bootstrap_accumulate_step_two @@ -598,7 +631,8 @@ __host__ void execute_step_two( stream>>>(lwe_array_out, lwe_output_indexes, keybundle_fft, global_accumulator, global_accumulator_fft, lwe_dimension, glwe_dimension, polynomial_size, level_count, - grouping_factor, j, lwe_offset, lwe_chunk_size, d_mem, 0); + grouping_factor, j, lwe_offset, lwe_chunk_size, d_mem, 0, + lut_count, lut_stride); check_cuda_error(cudaGetLastError()); } @@ -609,7 +643,8 @@ __host__ void host_multi_bit_programmable_bootstrap( Torus *lwe_array_in, Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { auto lwe_chunk_size = get_lwe_chunk_size( gpu_index, num_samples, polynomial_size); @@ -634,7 +669,8 @@ __host__ void host_multi_bit_programmable_bootstrap( execute_step_two( stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, level_count, j, lwe_offset, lwe_chunk_size); + grouping_factor, level_count, j, lwe_offset, lwe_chunk_size, + lut_count, lut_stride); } } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh index 8bb01bc4c3..5dccab3606 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_classic.cuh @@ -44,7 +44,8 @@ __global__ void device_programmable_bootstrap_tbc( const double2 *__restrict__ bootstrapping_key, double2 *join_buffer, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, int8_t *device_mem, - uint64_t device_memory_size_per_block, bool support_dsm) { + uint64_t device_memory_size_per_block, bool support_dsm, uint32_t lut_count, + uint32_t lut_stride) { cluster_group cluster = this_cluster(); @@ -155,8 +156,40 @@ __global__ void device_programmable_bootstrap_tbc( // we do the computation at block 0 to avoid waiting for extra blocks, in // case they're not synchronized sample_extract_mask(block_lwe_array_out, accumulator); + + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + accumulator, glwe_dimension, + i * lut_stride); + } + } } else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) { sample_extract_body(block_lwe_array_out, accumulator, 0); + + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + accumulator, 0, i * lut_stride); + } + } } } @@ -225,7 +258,8 @@ __host__ void host_programmable_bootstrap_tbc( Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key, pbs_buffer *buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t input_lwe_ciphertext_count) { + uint32_t level_count, uint32_t input_lwe_ciphertext_count, + uint32_t lut_count, uint32_t lut_stride) { auto supports_dsm = supports_distributed_shared_memory_on_classic_programmable_bootstrap< @@ -281,7 +315,7 @@ __host__ void host_programmable_bootstrap_tbc( lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft, lwe_dimension, polynomial_size, base_log, level_count, d_mem, full_dm, - supports_dsm)); + supports_dsm, lut_count, lut_stride)); } else if (max_shared_memory < full_sm + minimum_sm_tbc) { config.dynamicSmemBytes = partial_sm + minimum_sm_tbc; @@ -290,7 +324,7 @@ __host__ void host_programmable_bootstrap_tbc( lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft, lwe_dimension, polynomial_size, base_log, level_count, d_mem, - partial_dm, supports_dsm)); + partial_dm, supports_dsm, lut_count, lut_stride)); } else { config.dynamicSmemBytes = full_sm + minimum_sm_tbc; @@ -299,7 +333,7 @@ __host__ void host_programmable_bootstrap_tbc( lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer_fft, lwe_dimension, polynomial_size, base_log, level_count, d_mem, 0, - supports_dsm)); + supports_dsm, lut_count, lut_stride)); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh index 9a28690c25..d2cc68d6c7 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh @@ -31,7 +31,7 @@ __global__ void __launch_bounds__(params::degree / params::opt) uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset, uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input, int8_t *device_mem, uint64_t device_memory_size_per_block, - bool support_dsm) { + bool support_dsm, uint32_t lut_count, uint32_t lut_stride) { cluster_group cluster = this_cluster(); @@ -138,8 +138,39 @@ __global__ void __launch_bounds__(params::degree / params::opt) // but we do the computation at block 0 to avoid waiting for extra blocks, // in case they're not synchronized sample_extract_mask(block_lwe_array_out, accumulator); + + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_mask(next_block_lwe_array_out, + accumulator, glwe_dimension, + i * lut_stride); + } + } } else if (blockIdx.x == 0 && blockIdx.y == glwe_dimension) { sample_extract_body(block_lwe_array_out, accumulator, 0); + if (lut_count > 1) { + for (int i = 1; i < lut_count; i++) { + + auto next_lwe_array_out = + lwe_array_out + + (i * gridDim.z * (glwe_dimension * polynomial_size + 1)); + auto next_block_lwe_array_out = + &next_lwe_array_out[lwe_output_indexes[blockIdx.z] * + (glwe_dimension * polynomial_size + 1) + + blockIdx.y * polynomial_size]; + + sample_extract_body(next_block_lwe_array_out, + accumulator, 0, i * lut_stride); + } + } } } else { // Load the accumulator calculated in previous iterations @@ -267,7 +298,8 @@ __host__ void execute_tbc_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset) { + uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, + uint32_t lut_stride) { auto supports_dsm = supports_distributed_shared_memory_on_multibit_programmable_bootstrap< @@ -331,7 +363,8 @@ __host__ void execute_tbc_external_product_loop( lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft, global_accumulator, lwe_dimension, glwe_dimension, polynomial_size, base_log, level_count, grouping_factor, lwe_offset, chunk_size, - keybundle_size_per_input, d_mem, full_dm, supports_dsm)); + keybundle_size_per_input, d_mem, full_dm, supports_dsm, lut_count, + lut_stride)); } else if (max_shared_memory < full_dm + minimum_dm) { config.dynamicSmemBytes = partial_dm + minimum_dm; check_cuda_error(cudaLaunchKernelEx( @@ -342,7 +375,8 @@ __host__ void execute_tbc_external_product_loop( lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft, global_accumulator, lwe_dimension, glwe_dimension, polynomial_size, base_log, level_count, grouping_factor, lwe_offset, chunk_size, - keybundle_size_per_input, d_mem, partial_dm, supports_dsm)); + keybundle_size_per_input, d_mem, partial_dm, supports_dsm, lut_count, + lut_stride)); } else { config.dynamicSmemBytes = full_dm + minimum_dm; check_cuda_error(cudaLaunchKernelEx( @@ -353,7 +387,8 @@ __host__ void execute_tbc_external_product_loop( lwe_array_in, lwe_input_indexes, keybundle_fft, buffer_fft, global_accumulator, lwe_dimension, glwe_dimension, polynomial_size, base_log, level_count, grouping_factor, lwe_offset, chunk_size, - keybundle_size_per_input, d_mem, 0, supports_dsm)); + keybundle_size_per_input, d_mem, 0, supports_dsm, lut_count, + lut_stride)); } } @@ -364,7 +399,8 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( Torus *lwe_array_in, Torus *lwe_input_indexes, uint64_t *bootstrapping_key, pbs_buffer *buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, - uint32_t base_log, uint32_t level_count, uint32_t num_samples) { + uint32_t base_log, uint32_t level_count, uint32_t num_samples, + uint32_t lut_count, uint32_t lut_stride) { cudaSetDevice(gpu_index); auto lwe_chunk_size = get_lwe_chunk_size( @@ -384,7 +420,8 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, + lut_count, lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp index 7a76e667cc..8ad5831eeb 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/benchmarks/benchmark_pbs.cpp @@ -179,7 +179,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, TbcMultiBit) stream, gpu_index, (pbs_buffer **)&buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_level, grouping_factor, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( @@ -187,7 +188,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, TbcMultiBit) d_lut_pbs_identity, d_lut_pbs_indexes, d_lwe_ct_in_array, d_lwe_input_indexes, d_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, - pbs_base_log, pbs_level, input_lwe_ciphertext_count); + pbs_base_log, pbs_level, input_lwe_ciphertext_count, lut_count, + lut_stride); cuda_synchronize_stream(stream, gpu_index); } @@ -208,7 +210,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, CgMultiBit) stream, gpu_index, (pbs_buffer **)&buffer, glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( @@ -216,7 +219,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, CgMultiBit) d_lut_pbs_identity, d_lut_pbs_indexes, d_lwe_ct_in_array, d_lwe_input_indexes, d_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, - pbs_base_log, pbs_level, input_lwe_ciphertext_count); + pbs_base_log, pbs_level, input_lwe_ciphertext_count, lut_count, + lut_stride); cuda_synchronize_stream(stream, gpu_index); } @@ -229,7 +233,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, DefaultMultiBit) stream, gpu_index, (pbs_buffer **)&buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_level, grouping_factor, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( @@ -237,7 +242,8 @@ BENCHMARK_DEFINE_F(MultiBitBootstrap_u64, DefaultMultiBit) d_lut_pbs_identity, d_lut_pbs_indexes, d_lwe_ct_in_array, d_lwe_input_indexes, d_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, - pbs_base_log, pbs_level, input_lwe_ciphertext_count); + pbs_base_log, pbs_level, input_lwe_ciphertext_count, lut_count, + lut_stride); cuda_synchronize_stream(stream, gpu_index); } @@ -258,7 +264,8 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC) stream, gpu_index, (pbs_buffer **)&buffer, glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector( @@ -268,7 +275,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, TbcPBC) (uint64_t *)d_lwe_input_indexes, (double2 *)d_fourier_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log, pbs_level, - input_lwe_ciphertext_count); + input_lwe_ciphertext_count, lut_count, lut_stride); cuda_synchronize_stream(stream, gpu_index); } @@ -289,7 +296,8 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, CgPBS) stream, gpu_index, (pbs_buffer **)&buffer, glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_programmable_bootstrap_cg_lwe_ciphertext_vector( @@ -299,7 +307,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, CgPBS) (uint64_t *)d_lwe_input_indexes, (double2 *)d_fourier_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log, pbs_level, - input_lwe_ciphertext_count); + input_lwe_ciphertext_count, lut_count, lut_stride); cuda_synchronize_stream(stream, gpu_index); } @@ -313,7 +321,8 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, DefaultPBS) stream, gpu_index, (pbs_buffer **)&buffer, glwe_dimension, polynomial_size, pbs_level, input_lwe_ciphertext_count, true); - + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (auto _ : st) { // Execute PBS cuda_programmable_bootstrap_lwe_ciphertext_vector( @@ -323,7 +332,7 @@ BENCHMARK_DEFINE_F(ClassicalBootstrap_u64, DefaultPBS) (uint64_t *)d_lwe_input_indexes, (double2 *)d_fourier_bsk, (pbs_buffer *)buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log, pbs_level, - input_lwe_ciphertext_count); + input_lwe_ciphertext_count, lut_count, lut_stride); cuda_synchronize_stream(stream, gpu_index); } diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_classical_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_classical_pbs.cpp index cc6b11ba37..5ddb1430f2 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_classical_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_classical_pbs.cpp @@ -173,6 +173,8 @@ TEST_P(ClassicalProgrammableBootstrapTestPrimitives_u64, bootstrap) { cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0); int bsk_size = (glwe_dimension + 1) * (glwe_dimension + 1) * pbs_level * polynomial_size * (lwe_dimension + 1); + uint32_t lut_count = 1; + uint32_t lut_stride = 0; // Here execute the PBS for (int r = 0; r < repetitions; r++) { double *d_fourier_bsk = d_fourier_bsk_array + (ptrdiff_t)(bsk_size * r); @@ -190,7 +192,7 @@ TEST_P(ClassicalProgrammableBootstrapTestPrimitives_u64, bootstrap) { (void *)d_lut_pbs_indexes, (void *)d_lwe_ct_in, (void *)d_lwe_input_indexes, (void *)d_fourier_bsk, pbs_buffer, lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log, - pbs_level, number_of_inputs); + pbs_level, number_of_inputs, lut_count, lut_stride); // Copy result back cuda_memcpy_async_to_cpu(lwe_ct_out_array, d_lwe_ct_out_array, (glwe_dimension * polynomial_size + 1) * diff --git a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp index 0a6e00390f..4d05d87412 100644 --- a/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp +++ b/backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/tests/test_multibit_pbs.cpp @@ -119,6 +119,8 @@ TEST_P(MultiBitProgrammableBootstrapTestPrimitives_u64, (glwe_dimension + 1) * (glwe_dimension + 1) * polynomial_size * (1 << grouping_factor); + uint32_t lut_count = 1; + uint32_t lut_stride = 0; for (int r = 0; r < repetitions; r++) { uint64_t *d_bsk = d_bsk_array + (ptrdiff_t)(bsk_size * r); uint64_t *lwe_sk_out = @@ -135,7 +137,7 @@ TEST_P(MultiBitProgrammableBootstrapTestPrimitives_u64, (void *)d_lut_pbs_indexes, (void *)d_lwe_ct_in, (void *)d_lwe_input_indexes, (void *)d_bsk, pbs_buffer, lwe_dimension, glwe_dimension, polynomial_size, grouping_factor, pbs_base_log, - pbs_level, number_of_inputs); + pbs_level, number_of_inputs, lut_count, lut_stride); // Copy result to the host memory cuda_memcpy_async_to_cpu(lwe_ct_out_array, d_lwe_ct_out_array, diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index e707151b42..98fce652c9 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -196,6 +196,20 @@ extern "C" { num_blocks: u32, ); + pub fn cuda_apply_many_univariate_lut_kb_64( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + output_radix_lwe: *mut c_void, + input_radix_lwe: *const c_void, + mem_ptr: *mut i8, + ksks: *const *mut c_void, + bsks: *const *mut c_void, + num_blocks: u32, + num_luts: u32, + lut_stride: u32, + ); + pub fn cleanup_cuda_apply_univariate_lut_kb_64( streams: *const *mut c_void, gpu_indexes: *const u32, @@ -1152,6 +1166,8 @@ extern "C" { base_log: u32, level_count: u32, num_samples: u32, + lut_count: u32, + lut_stride: u32, ); pub fn cleanup_cuda_programmable_bootstrap( @@ -1203,6 +1219,8 @@ extern "C" { base_log: u32, level_count: u32, num_samples: u32, + lut_count: u32, + lut_stride: u32, ); pub fn cleanup_cuda_multi_bit_programmable_bootstrap( diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index adc6717971..2fecf70271 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -109,6 +109,8 @@ pub unsafe fn programmable_bootstrap_async( level: DecompositionLevelCount, num_samples: u32, ) { + let lut_count = 1u32; + let lut_stride = 0u32; let mut pbs_buffer: *mut i8 = std::ptr::null_mut(); scratch_cuda_programmable_bootstrap_64( streams.ptr[0], @@ -137,6 +139,8 @@ pub unsafe fn programmable_bootstrap_async( base_log.0 as u32, level.0 as u32, num_samples, + lut_count, + lut_stride, ); cleanup_cuda_programmable_bootstrap( streams.ptr[0], @@ -169,6 +173,8 @@ pub unsafe fn programmable_bootstrap_multi_bit_async( grouping_factor: LweBskGroupingFactor, num_samples: u32, ) { + let lut_count = 1u32; + let lut_stride = 0u32; let mut pbs_buffer: *mut i8 = std::ptr::null_mut(); scratch_cuda_multi_bit_programmable_bootstrap_64( streams.ptr[0], @@ -200,6 +206,8 @@ pub unsafe fn programmable_bootstrap_multi_bit_async( base_log.0 as u32, level.0 as u32, num_samples, + lut_count, + lut_stride, ); cleanup_cuda_multi_bit_programmable_bootstrap( streams.ptr[0], diff --git a/tfhe/src/integer/gpu/ciphertext/mod.rs b/tfhe/src/integer/gpu/ciphertext/mod.rs index 51e702d0c8..c7760a39b0 100644 --- a/tfhe/src/integer/gpu/ciphertext/mod.rs +++ b/tfhe/src/integer/gpu/ciphertext/mod.rs @@ -424,7 +424,7 @@ impl CudaRadixCiphertext { /// /// assert_eq!(msg, msg_copied); /// ``` - pub(crate) fn duplicate(&self, streams: &CudaStreams) -> Self { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { let ct = unsafe { self.duplicate_async(streams) }; streams.synchronize(); ct @@ -433,7 +433,7 @@ impl CudaRadixCiphertext { /// /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until streams is synchronised - pub(crate) unsafe fn duplicate_async(&self, streams: &CudaStreams) -> Self { + pub unsafe fn duplicate_async(&self, streams: &CudaStreams) -> Self { let lwe_ciphertext_count = self.d_blocks.lwe_ciphertext_count(); let ciphertext_modulus = self.d_blocks.ciphertext_modulus(); diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index ec9dd5bc5a..ba0bb62986 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -2331,6 +2331,95 @@ pub unsafe fn apply_univariate_lut_kb_async( ); } +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn apply_many_univariate_lut_kb_async( + streams: &CudaStreams, + radix_lwe_output: &mut CudaSliceMut, + radix_lwe_input: &CudaSlice, + input_lut: &[T], + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + num_blocks: u32, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, + lut_count: u32, + lut_stride: u32, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_input.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_output.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + bootstrapping_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + keyswitch_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_apply_univariate_lut_kb_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + input_lut.as_ptr().cast(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + ); + cuda_apply_many_univariate_lut_kb_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + radix_lwe_output.as_mut_c_ptr(0), + radix_lwe_input.as_c_ptr(0), + mem_ptr, + keyswitch_key.ptr.as_ptr(), + bootstrapping_key.ptr.as_ptr(), + num_blocks, + lut_count, + lut_stride, + ); + cleanup_cuda_apply_univariate_lut_kb_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + ); +} + #[allow(clippy::too_many_arguments)] /// # Safety /// diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index e5b82caa9c..3d90c9acd4 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -13,13 +13,15 @@ use crate::integer::gpu::ciphertext::{ }; use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::{ - apply_univariate_lut_kb_async, full_propagate_assign_async, + apply_many_univariate_lut_kb_async, apply_univariate_lut_kb_async, full_propagate_assign_async, propagate_single_carry_assign_async, propagate_single_carry_get_input_carries_assign_async, CudaServerKey, PBSType, }; use crate::shortint::ciphertext::{Degree, NoiseLevel}; -use crate::shortint::engine::fill_accumulator; -use crate::shortint::server_key::{BivariateLookupTableOwned, LookupTableOwned}; +use crate::shortint::engine::{fill_accumulator, fill_many_lut_accumulator}; +use crate::shortint::server_key::{ + BivariateLookupTableOwned, LookupTableOwned, ManyLookupTableOwned, +}; use crate::shortint::PBSOrder; mod add; @@ -676,6 +678,38 @@ impl CudaServerKey { } } + pub fn generate_many_lookup_table( + &self, + functions: &[&dyn Fn(u64) -> u64], + ) -> ManyLookupTableOwned { + let (glwe_size, polynomial_size) = match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + (d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size) + } + CudaBootstrappingKey::MultiBit(d_bsk) => { + (d_bsk.glwe_dimension.to_glwe_size(), d_bsk.polynomial_size) + } + }; + let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, self.ciphertext_modulus); + + let (input_max_degree, sample_extraction_stride, per_function_output_degree) = + fill_many_lut_accumulator( + &mut acc, + polynomial_size, + glwe_size, + self.message_modulus, + self.carry_modulus, + functions, + ); + + ManyLookupTableOwned { + acc, + input_max_degree, + sample_extraction_stride, + per_function_output_degree, + } + } + /// Generates a bivariate accumulator pub(crate) fn generate_lookup_table_bivariate(&self, f: F) -> BivariateLookupTableOwned where @@ -801,6 +835,162 @@ impl CudaServerKey { info.noise_level = NoiseLevel::NOMINAL; } } + /// Applies many lookup tables on the range of ciphertexts + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// let num_blocks = 2; + /// let msg = 3; + /// let ct = cks.encrypt_radix(msg, num_blocks); + /// let d_ct = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &mut stream); + /// // Generate the lookup table for the functions + /// // f1: x -> x*x mod 4 + /// // f2: x -> count_ones(x as binary) mod 4 + /// let f1 = |x: u64| x.pow(2) % 4; + /// let f2 = |x: u64| x.count_ones() as u64 % 4; + /// // Easy to use for generation + /// let luts = sks.generate_many_lookup_table(&[&f1, &f2]); + /// let vec_res = unsafe { sks.apply_many_lookup_table_async(&d_ct.as_ref(), &luts, &stream) }; + /// stream.synchronize(); + /// // Need to manually help Rust to iterate over them easily + /// let functions: &[&dyn Fn(u64) -> u64] = &[&f1, &f2]; + /// for (d_res, function) in vec_res.iter().zip(functions) { + /// let d_res_unsigned = CudaUnsignedRadixCiphertext { + /// ciphertext: d_res.duplicate(&stream), + /// }; + /// let res = d_res_unsigned.to_radix_ciphertext(&mut stream); + /// let dec: u64 = cks.decrypt_radix(&res); + /// println!(" compare {} vs {}", dec, function(msg)); + /// assert_eq!(dec, function(msg)); + /// } + /// ``` + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn apply_many_lookup_table_async( + &self, + input: &CudaRadixCiphertext, + lut: &ManyLookupTableOwned, + streams: &CudaStreams, + ) -> Vec { + let lwe_dimension = input.d_blocks.lwe_dimension(); + let lwe_size = lwe_dimension.to_lwe_size().0; + + let input_slice = input + .d_blocks + .0 + .d_vec + .as_slice(.., streams.gpu_indexes[0]) + .unwrap(); + + // The accumulator has been rotated, we can now proceed with the various sample extractions + let function_count = lut.function_count(); + let num_ct_blocks = input.d_blocks.lwe_ciphertext_count().0; + let total_radixes_size = num_ct_blocks * lwe_size * function_count; + let mut output_radixes = CudaVec::new(total_radixes_size, streams, streams.gpu_indexes[0]); + + let mut output_slice = output_radixes + .as_mut_slice(0..total_radixes_size, streams.gpu_indexes[0]) + .unwrap(); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_many_univariate_lut_kb_async( + streams, + &mut output_slice, + &input_slice, + lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + function_count as u32, + lut.sample_extraction_stride as u32, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_many_univariate_lut_kb_async( + streams, + &mut output_slice, + &input_slice, + lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + function_count as u32, + lut.sample_extraction_stride as u32, + ); + } + }; + + let mut ciphertexts = Vec::::with_capacity(function_count); + + for i in 0..function_count { + let slice_size = num_ct_blocks * lwe_size; + let mut ct = input.duplicate(streams); + let mut ct_slice = ct + .d_blocks + .0 + .d_vec + .as_mut_slice(0..slice_size, streams.gpu_indexes[0]) + .unwrap(); + + let slice_size = num_ct_blocks * lwe_size; + let output_slice = output_radixes + .as_mut_slice(slice_size * i..slice_size * (i + 1), streams.gpu_indexes[0]) + .unwrap(); + + ct_slice.copy_from_gpu_async(&output_slice, streams, 0); + + for info in ct.info.blocks.iter_mut() { + info.degree = lut.per_function_output_degree[i]; + info.noise_level = NoiseLevel::NOMINAL; + } + + ciphertexts.push(ct); + } + + ciphertexts + } /// # Safety /// diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index bc4a84ba01..c87ebe585e 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -162,30 +162,27 @@ pub(crate) fn fill_accumulator_no_encoding( /// Fills a GlweCiphertext for use in a ManyLookupTable setting pub(crate) fn fill_many_lut_accumulator( accumulator: &mut GlweCiphertext, - server_key: &ServerKey, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, functions: &[&dyn Fn(u64) -> u64], ) -> (MaxDegree, usize, Vec) where C: ContainerMut, { - assert_eq!( - accumulator.polynomial_size(), - server_key.bootstrapping_key.polynomial_size() - ); - assert_eq!( - accumulator.glwe_size(), - server_key.bootstrapping_key.glwe_size() - ); + assert_eq!(accumulator.polynomial_size(), polynomial_size); + assert_eq!(accumulator.glwe_size(), glwe_size); let mut accumulator_view = accumulator.as_mut_view(); accumulator_view.get_mut_mask().as_mut().fill(0); // Modulus of the msg contained in the msg bits and operations buffer - let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0; + let modulus_sup = message_modulus.0 * carry_modulus.0; // N/(p/2) = size of each block - let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup; + let box_size = polynomial_size.0 / modulus_sup; // Value of the delta we multiply our messages by let delta = (1_u64 << 63) / (modulus_sup as u64); diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 351b0b2cc7..bfcb87bc5d 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -883,7 +883,14 @@ impl ServerKey { self.ciphertext_modulus, ); let (input_max_degree, sample_extraction_stride, per_function_output_degree) = - fill_many_lut_accumulator(&mut acc, self, functions); + fill_many_lut_accumulator( + &mut acc, + self.bootstrapping_key.polynomial_size(), + self.bootstrapping_key.glwe_size(), + self.message_modulus, + self.carry_modulus, + functions, + ); ManyLookupTableOwned { acc,