diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh index 25701aca92..92cf753d14 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh @@ -480,20 +480,30 @@ __host__ void host_programmable_bootstrap( double2 *global_join_buffer = pbs_buffer->global_join_buffer; int8_t *d_mem = pbs_buffer->d_mem; + bool graphCreated = false; + cudaGraph_t graph; + cudaGraphExec_t instance; for (int i = 0; i < lwe_dimension; i++) { - execute_step_one( - stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, - lwe_input_indexes, bootstrapping_key, global_accumulator, - global_join_buffer, input_lwe_ciphertext_count, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, d_mem, i, - partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one); - execute_step_two( - stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, - lut_vector_indexes, bootstrapping_key, global_accumulator, - global_join_buffer, input_lwe_ciphertext_count, lwe_dimension, - glwe_dimension, polynomial_size, base_log, level_count, d_mem, i, - partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two, - num_many_lut, lut_stride); + if (!graphCreated) { + cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); + execute_step_one( + stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, + lwe_input_indexes, bootstrapping_key, global_accumulator, + global_join_buffer, input_lwe_ciphertext_count, lwe_dimension, + glwe_dimension, polynomial_size, base_log, level_count, d_mem, i, + partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one); + execute_step_two( + stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector, + lut_vector_indexes, bootstrapping_key, global_accumulator, + global_join_buffer, input_lwe_ciphertext_count, lwe_dimension, + glwe_dimension, polynomial_size, base_log, level_count, d_mem, i, + partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two, + num_many_lut, lut_stride); + cudaStreamEndCapture(stream, &graph); + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + graphCreated = true; + } + cudaGraphLaunch(instance, stream); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh index ba73d29bf7..2dc966a9ef 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh @@ -649,29 +649,41 @@ __host__ void host_multi_bit_programmable_bootstrap( auto lwe_chunk_size = buffer->lwe_chunk_size; + bool graphCreated = false; + cudaGraph_t graph; + cudaGraphExec_t instance; + for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor); lwe_offset += lwe_chunk_size) { - // Compute a keybundle - execute_compute_keybundle( - stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, - buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, level_count, lwe_offset); - // Accumulate - uint32_t chunk_size = std::min( - lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); - for (uint32_t j = 0; j < chunk_size; j++) { - execute_step_one( - stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, - lwe_input_indexes, buffer, num_samples, lwe_dimension, glwe_dimension, - polynomial_size, base_log, level_count, j, lwe_offset); - - execute_step_two( - stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, - num_samples, lwe_dimension, glwe_dimension, polynomial_size, - grouping_factor, level_count, j, lwe_offset, num_many_lut, - lut_stride); + if (!graphCreated) { + cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); + // Compute a keybundle + execute_compute_keybundle( + stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, + buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, + grouping_factor, level_count, lwe_offset); + // Accumulate + uint32_t chunk_size = std::min( + lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset); + for (uint32_t j = 0; j < chunk_size; j++) { + execute_step_one( + stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, + lwe_input_indexes, buffer, num_samples, lwe_dimension, + glwe_dimension, polynomial_size, base_log, level_count, j, + lwe_offset); + + execute_step_two( + stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, + num_samples, lwe_dimension, glwe_dimension, polynomial_size, + grouping_factor, level_count, j, lwe_offset, num_many_lut, + lut_stride); + } + cudaStreamEndCapture(stream, &graph); + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + graphCreated = true; } + cudaGraphLaunch(instance, stream); } } #endif // MULTIBIT_PBS_H