diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index d6b56ab5e5..0943b10003 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -976,28 +976,52 @@ template struct int_shift_and_rotate_buffer { (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_bits, 0, + bits_per_block * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); tmp_shift_bits = (Torus *)cuda_malloc_async( max_num_bits_that_tell_shift * num_radix_blocks * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_shift_bits, 0, + max_num_bits_that_tell_shift * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); tmp_rotated = (Torus *)cuda_malloc_async( bits_per_block * num_radix_blocks * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_rotated, 0, + bits_per_block * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); tmp_input_bits_a = (Torus *)cuda_malloc_async( bits_per_block * num_radix_blocks * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_input_bits_a, 0, + bits_per_block * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); tmp_input_bits_b = (Torus *)cuda_malloc_async( bits_per_block * num_radix_blocks * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_input_bits_b, 0, + bits_per_block * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); tmp_mux_inputs = (Torus *)cuda_malloc_async( bits_per_block * num_radix_blocks * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(tmp_mux_inputs, 0, + bits_per_block * num_radix_blocks * + (params.big_lwe_dimension + 1) * sizeof(Torus), + streams[0], gpu_indexes[0]); auto mux_lut_f = [](Torus x) -> Torus { // x is expected to be x = 0bcba @@ -1157,6 +1181,11 @@ template struct int_sc_prop_memory { num_radix_blocks * big_lwe_size_bytes, streams[0], gpu_indexes[0]); step_output = (Torus *)cuda_malloc_async( num_radix_blocks * big_lwe_size_bytes, streams[0], gpu_indexes[0]); + cuda_memset_async(generates_or_propagates, 0, + num_radix_blocks * big_lwe_size_bytes, streams[0], + gpu_indexes[0]); + cuda_memset_async(step_output, 0, num_radix_blocks * big_lwe_size_bytes, + streams[0], gpu_indexes[0]); // declare functions for lut generation auto f_lut_does_block_generate_carry = [message_modulus](Torus x) -> Torus { @@ -1273,6 +1302,11 @@ template struct int_overflowing_sub_memory { num_radix_blocks * big_lwe_size_bytes, streams[0], gpu_indexes[0]); step_output = (Torus *)cuda_malloc_async( num_radix_blocks * big_lwe_size_bytes, streams[0], gpu_indexes[0]); + cuda_memset_async(generates_or_propagates, 0, + num_radix_blocks * big_lwe_size_bytes, streams[0], + gpu_indexes[0]); + cuda_memset_async(step_output, 0, num_radix_blocks * big_lwe_size_bytes, + streams[0], gpu_indexes[0]); // declare functions for lut generation auto f_lut_does_block_generate_carry = [message_modulus](Torus x) -> Torus { @@ -1399,11 +1433,31 @@ template struct int_sum_ciphertexts_vec_memory { small_lwe_vector = (Torus *)cuda_malloc_async( max_pbs_count * (params.small_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(new_blocks, 0, + max_pbs_count * (params.big_lwe_dimension + 1) * + sizeof(Torus), + streams[0], gpu_indexes[0]); + cuda_memset_async(new_blocks_copy, 0, + max_pbs_count * (params.big_lwe_dimension + 1) * + sizeof(Torus), + streams[0], gpu_indexes[0]); + cuda_memset_async(old_blocks, 0, + max_pbs_count * (params.big_lwe_dimension + 1) * + sizeof(Torus), + streams[0], gpu_indexes[0]); + cuda_memset_async(small_lwe_vector, 0, + max_pbs_count * (params.small_lwe_dimension + 1) * + sizeof(Torus), + streams[0], gpu_indexes[0]); d_smart_copy_in = (int32_t *)cuda_malloc_async( max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]); d_smart_copy_out = (int32_t *)cuda_malloc_async( max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]); + cuda_memset_async(d_smart_copy_in, 0, max_pbs_count * sizeof(int32_t), + streams[0], gpu_indexes[0]); + cuda_memset_async(d_smart_copy_out, 0, max_pbs_count * sizeof(int32_t), + streams[0], gpu_indexes[0]); } int_sum_ciphertexts_vec_memory(cudaStream_t *streams, uint32_t *gpu_indexes, @@ -1427,11 +1481,19 @@ template struct int_sum_ciphertexts_vec_memory { new_blocks_copy = (Torus *)cuda_malloc_async( max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus), streams[0], gpu_indexes[0]); + cuda_memset_async(new_blocks_copy, 0, + max_pbs_count * (params.big_lwe_dimension + 1) * + sizeof(Torus), + streams[0], gpu_indexes[0]); d_smart_copy_in = (int32_t *)cuda_malloc_async( max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]); d_smart_copy_out = (int32_t *)cuda_malloc_async( max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]); + cuda_memset_async(d_smart_copy_in, 0, max_pbs_count * sizeof(int32_t), + streams[0], gpu_indexes[0]); + cuda_memset_async(d_smart_copy_out, 0, max_pbs_count * sizeof(int32_t), + streams[0], gpu_indexes[0]); } void release(cudaStream_t *streams, uint32_t *gpu_indexes, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh index ffd70a75e5..92f61b830f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh @@ -135,8 +135,6 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace( // host_pack bits into one block so that we have // control_bit|b|a - cuda_memset_async(mux_inputs, 0, total_nb_bits * big_lwe_size_bytes, - streams[0], gpu_indexes[0]); // Do we need this? pack_bivariate_blocks(streams, gpu_indexes, gpu_count, mux_inputs, mux_lut->lwe_indexes_out, rotated_input, input_bits_a, mux_lut->lwe_indexes_in,