From 4e8bdc4380e72aaa6d5c4d5977d42a2d970d33b3 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Mon, 17 Jun 2024 11:17:51 +0200 Subject: [PATCH] chore(gpu): add scalar div and signed scalar div to hl api --- .../integers/signed/scalar_ops.rs | 18 +++++++++++++----- .../integers/unsigned/scalar_ops.rs | 17 ++++++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index 43779dda08..97eb5ac28e 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -365,7 +365,6 @@ where // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( - key_method: $key_method:ident, // A 'list' of tuple, where the first element is the concrete Fhe type // e.g (FheUint8 and the rest is scalar types (u8, u16, etc) fhe_and_scalar_type: $( @@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem { InternalServerKey::Cpu(cpu_key) => { let (q, r) = cpu_key .pbs_key() - .$key_method(&*self.ciphertext.on_cpu(), rhs); + .signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs); ( <$concrete_type>::new(q, cpu_key.tag.clone()), <$concrete_type>::new(r, cpu_key.tag.clone()) ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices does not support div rem yet") + InternalServerKey::Cuda(cuda_key) => { + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + cuda_key.key.signed_scalar_div_rem( + &*self.ciphertext.on_gpu(), rhs, streams + ) + }); + let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); + ( + <$concrete_type>::new(q, cuda_key.tag.clone()), + <$concrete_type>::new(r, cuda_key.tag.clone()) + ) } }) } @@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem { )* // Closing first repeating pattern }; } + generic_integer_impl_scalar_div_rem!( - key_method: signed_scalar_div_rem_parallelized, fhe_and_scalar_type: (super::FheInt2, i8), (super::FheInt4, i8), diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 451093555a..2bb3d56d4a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -446,7 +446,6 @@ where // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( - key_method: $key_method:ident, // A 'list' of tuple, where the first element is the concrete Fhe type // e.g (FheUint8 and the rest is scalar types (u8, u16, etc) fhe_and_scalar_type: $( @@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem { global_state::with_internal_keys(|key| { match key { InternalServerKey::Cpu(cpu_key) => { - let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs); + let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs); ( <$concrete_type>::new(q, cpu_key.tag.clone()), <$concrete_type>::new(r, cpu_key.tag.clone()) ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support div_rem yet"); + InternalServerKey::Cuda(cuda_key) => { + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + cuda_key.key.scalar_div_rem( + &*self.ciphertext.on_gpu(), rhs, streams + ) + }); + let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); + ( + <$concrete_type>::new(q, cuda_key.tag.clone()), + <$concrete_type>::new(r, cuda_key.tag.clone()) + ) } } }) @@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem { }; } generic_integer_impl_scalar_div_rem!( - key_method: scalar_div_rem_parallelized, fhe_and_scalar_type: (super::FheUint2, u8), (super::FheUint4, u8),