diff --git a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h index c72ee20b15..4b74189d5e 100644 --- a/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h +++ b/backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap_multibit.h @@ -17,8 +17,7 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64( void scratch_cuda_multi_bit_programmable_bootstrap_64( void *stream, uint32_t gpu_index, int8_t **pbs_buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64( @@ -130,7 +129,7 @@ template struct pbs_buffer { int8_t *d_mem_acc_step_two = NULL; int8_t *d_mem_acc_cg = NULL; int8_t *d_mem_acc_tbc = NULL; - + uint32_t lwe_chunk_size; double2 *keybundle_fft; Torus *global_accumulator; double2 *global_accumulator_fft; @@ -142,6 +141,7 @@ template struct pbs_buffer { uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size, PBS_VARIANT pbs_variant, bool allocate_gpu_memory) { this->pbs_variant = pbs_variant; + this->lwe_chunk_size = lwe_chunk_size; auto max_shared_memory = cuda_get_max_shared_memory(gpu_index); // default diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index e3bf1471b7..459a496d11 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -271,9 +271,8 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index, if (grouping_factor == 0) PANIC("Multi-bit PBS error: grouping factor should be > 0.") scratch_cuda_multi_bit_programmable_bootstrap_64( - stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension, - polynomial_size, level_count, grouping_factor, - input_lwe_ciphertext_count, allocate_gpu_memory); + stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + level_count, input_lwe_ciphertext_count, allocate_gpu_memory); break; case CLASSICAL: scratch_cuda_programmable_bootstrap_64( diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh index d17a953151..f26f45b810 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh @@ -229,9 +229,9 @@ __host__ void execute_cg_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, - uint32_t lut_stride) { + uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint64_t full_dm = get_buffer_size_full_sm_cg_multibit_programmable_bootstrap( polynomial_size); @@ -314,8 +314,7 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -324,15 +323,15 @@ __host__ void host_cg_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate execute_cg_external_product_loop( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, - lut_count, lut_stride); + grouping_factor, base_log, level_count, lwe_offset, lut_count, + lut_stride); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu index 98f8074a09..49a1c75350 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu @@ -331,52 +331,51 @@ void scratch_cuda_cg_multi_bit_programmable_bootstrap( template void scratch_cuda_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, pbs_buffer **buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t grouping_factor, + uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { switch (polynomial_size) { case 256: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 512: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 1024: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 2048: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 4096: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 8192: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; case 16384: scratch_multi_bit_programmable_bootstrap>( - static_cast(stream), gpu_index, buffer, lwe_dimension, - glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, grouping_factor, allocate_gpu_memory); + static_cast(stream), gpu_index, buffer, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory); break; default: PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported " @@ -386,10 +385,9 @@ void scratch_cuda_multi_bit_programmable_bootstrap( } void scratch_cuda_multi_bit_programmable_bootstrap_64( - void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count, - bool allocate_gpu_memory) { + void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { #if (CUDA_ARCH >= 900) if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( @@ -411,8 +409,8 @@ void scratch_cuda_multi_bit_programmable_bootstrap_64( else scratch_cuda_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)buffer, - lwe_dimension, glwe_dimension, polynomial_size, level_count, - grouping_factor, input_lwe_ciphertext_count, allocate_gpu_memory); + glwe_dimension, polynomial_size, level_count, + input_lwe_ciphertext_count, allocate_gpu_memory); } void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream, @@ -490,10 +488,9 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs, template void scratch_cuda_multi_bit_programmable_bootstrap( void *stream, uint32_t gpu_index, - pbs_buffer **pbs_buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t grouping_factor, uint32_t input_lwe_ciphertext_count, - bool allocate_gpu_memory); + pbs_buffer **pbs_buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory); template void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector( diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index 74b3669479..455233f057 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -385,10 +385,9 @@ uint64_t get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( template __host__ void scratch_multi_bit_programmable_bootstrap( cudaStream_t stream, uint32_t gpu_index, - pbs_buffer **buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count, - uint32_t input_lwe_ciphertext_count, uint32_t grouping_factor, - bool allocate_gpu_memory) { + pbs_buffer **buffer, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t level_count, + uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { auto lwe_chunk_size = get_lwe_chunk_size( gpu_index, input_lwe_ciphertext_count, polynomial_size); @@ -404,9 +403,9 @@ __host__ void execute_compute_keybundle( Torus *lwe_input_indexes, Torus *bootstrapping_key, pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset) { + uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint32_t chunk_size = std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); if (chunk_size == 0) @@ -507,9 +506,9 @@ __host__ void execute_step_two( Torus *lwe_output_indexes, pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, int32_t grouping_factor, uint32_t level_count, - uint32_t j, uint32_t lwe_offset, uint32_t lwe_chunk_size, - uint32_t lut_count, uint32_t lut_stride) { + uint32_t j, uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; uint64_t full_sm_accumulate_step_two = get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two( polynomial_size); @@ -555,8 +554,7 @@ __host__ void host_multi_bit_programmable_bootstrap( uint32_t base_log, uint32_t level_count, uint32_t num_samples, uint32_t lut_count, uint32_t lut_stride) { - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -565,7 +563,7 @@ __host__ void host_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate uint32_t chunk_size = std::min( lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); @@ -578,8 +576,7 @@ __host__ void host_multi_bit_programmable_bootstrap( execute_step_two( stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, level_count, j, lwe_offset, lwe_chunk_size, - lut_count, lut_stride); + grouping_factor, level_count, j, lwe_offset, lut_count, lut_stride); } } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh index b1fac308ac..21d007a1cc 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh @@ -220,9 +220,9 @@ __host__ void execute_tbc_external_product_loop( pbs_buffer *buffer, uint32_t num_samples, uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log, uint32_t level_count, - uint32_t lwe_chunk_size, uint32_t lwe_offset, uint32_t lut_count, - uint32_t lut_stride) { + uint32_t lwe_offset, uint32_t lut_count, uint32_t lut_stride) { + auto lwe_chunk_size = buffer->lwe_chunk_size; auto supports_dsm = supports_distributed_shared_memory_on_multibit_programmable_bootstrap< Torus>(polynomial_size); @@ -325,9 +325,7 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( uint32_t lut_count, uint32_t lut_stride) { cudaSetDevice(gpu_index); - auto lwe_chunk_size = get_lwe_chunk_size( - gpu_index, num_samples, polynomial_size); - + auto lwe_chunk_size = buffer->lwe_chunk_size; for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { @@ -335,15 +333,15 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap( execute_compute_keybundle( stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset); + grouping_factor, level_count, lwe_offset); // Accumulate execute_tbc_external_product_loop( stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, base_log, level_count, lwe_chunk_size, lwe_offset, - lut_count, lut_stride); + grouping_factor, base_log, level_count, lwe_offset, lut_count, + lut_stride); } } diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 1ddeb12d2f..add83c92e6 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -1192,11 +1192,9 @@ extern "C" { stream: *mut c_void, gpu_index: u32, pbs_buffer: *mut *mut i8, - lwe_dimension: u32, glwe_dimension: u32, polynomial_size: u32, level_count: u32, - grouping_factor: u32, input_lwe_ciphertext_count: u32, allocate_gpu_memory: bool, ); diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 2fecf70271..d1fde41f56 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -180,11 +180,9 @@ pub unsafe fn programmable_bootstrap_multi_bit_async( streams.ptr[0], streams.gpu_indexes[0], std::ptr::addr_of_mut!(pbs_buffer), - lwe_dimension.0 as u32, glwe_dimension.0 as u32, polynomial_size.0 as u32, level.0 as u32, - grouping_factor.0 as u32, num_samples, true, );