diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 4d735898d7..89978ebb3a 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -260,9 +260,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -296,9 +304,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -332,9 +348,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -368,9 +392,17 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -406,9 +438,15 @@ where crate::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + crate::FheUint32::new(result) + }), }) } @@ -448,9 +486,17 @@ where (crate::FheUint32::new(result), FheBool::new(is_ok)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support checked_ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, is_ok) = cuda_key + .key + .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + crate::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + (crate::FheUint32::new(result), FheBool::new(is_ok)) + }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 69bc5a1b6f..1d5367a21d 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -362,9 +362,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -398,9 +406,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support leading_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .leading_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -434,9 +450,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_zeros yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -470,9 +494,17 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support trailing_ones yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key + .key + .trailing_ones(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -508,9 +540,15 @@ where super::FheUint32::new(result) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let result = cuda_key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + super::FheUint32::new(result) + }), }) } @@ -550,9 +588,17 @@ where (super::FheUint32::new(result), FheBool::new(is_ok)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support checked_ilog2 yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, is_ok) = cuda_key + .key + .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key.key.cast_to_unsigned( + result, + super::FheUint32Id::num_blocks(cuda_key.key.message_modulus), + streams, + ); + (super::FheUint32::new(result), FheBool::new(is_ok)) + }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index 536f6836a3..9c927051ab 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -121,3 +121,27 @@ fn test_is_even_is_odd_gpu_multibit() { let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); super::test_case_is_even_is_odd(&client_key); } + +#[test] +fn test_leading_trailing_zeros_ones_gpu() { + let client_key = setup_default_gpu(); + super::test_case_leading_trailing_zeros_ones(&client_key); +} + +#[test] +fn test_leading_trailing_zeros_ones_gpu_multibit() { + let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); + super::test_case_leading_trailing_zeros_ones(&client_key); +} + +#[test] +fn test_ilog2_gpu() { + let client_key = setup_default_gpu(); + super::test_case_ilog2(&client_key); +} + +#[test] +fn test_ilog2_multibit() { + let client_key = setup_gpu(Some(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS)); + super::test_case_ilog2(&client_key); +}