From 7503b30d81b25ef84100c1e889723c5a25433f68 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 2 Sep 2024 17:31:09 +0200 Subject: [PATCH] feat(hlapi): bind cuda's trailing/leading_ones/zeros, ilog2 --- .../high_level_api/integers/signed/base.rs | 82 +++++++++++++++---- .../high_level_api/integers/unsigned/base.rs | 82 +++++++++++++++---- .../integers/unsigned/tests/gpu.rs | 24 ++++++ 3 files changed, 152 insertions(+), 36 deletions(-) diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 4d735898d7..89978ebb3a 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -260,9 +260,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -296,9 +304,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -332,9 +348,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -368,9 +392,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -406,9 +438,15 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -448,9 +486,17 @@ where (crate::FheUint32::new(result), FheBool::new(is_ok)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support checked_ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, is_ok) = cuda_key + .key + .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + (crate::FheUint32::new(result), FheBool::new(is_ok)) + }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 69bc5a1b6f..1d5367a21d 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -362,9 +362,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -398,9 +406,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -434,9 +450,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -470,9 +494,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -508,9 +540,15 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -550,9 +588,17 @@ where (super::FheUint32::new(result), FheBool::new(is_ok)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support checked_ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, is_ok) = cuda_key + .key + .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + (super::FheUint32::new(result), FheBool::new(is_ok)) + }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index 536f6836a3..9c927051ab 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -121,3 +121,27 @@ fn test_is_even_is_odd_gpu_multibit() { let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); super::test_case_is_even_is_odd(&client_key); } + +#[test] +fn test_leading_trailing_zeros_ones_gpu() { + let client_key = setup_default_gpu(); + super::test_case_leading_trailing_zeros_ones(&client_key); +} + +#[test] +fn test_leading_trailing_zeros_ones_gpu_multibit() { + let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); + super::test_case_leading_trailing_zeros_ones(&client_key); +} + +#[test] +fn test_ilog2_gpu() { + let client_key = setup_default_gpu(); + super::test_case_ilog2(&client_key); +} + +#[test] +fn test_ilog2_multibit() { + let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); + super::test_case_ilog2(&client_key); +}