Skip to content

Commit

Permalink
chore(gpu): add scalar div and signed scalar div to hl api
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 11, 2024
1 parent 9104781 commit 7f2d7f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
18 changes: 13 additions & 5 deletions tfhe/src/high_level_api/integers/signed/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand Down Expand Up @@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key
.pbs_key()
.$key_method(&*self.ciphertext.on_cpu(), rhs);
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices does not support div rem yet")
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
})
}
Expand All @@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
)* // Closing first repeating pattern
};
}

generic_integer_impl_scalar_div_rem!(
key_method: signed_scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheInt2, i8),
(super::FheInt4, i8),
Expand Down
17 changes: 12 additions & 5 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand All @@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
global_state::with_internal_keys(|key| {
match key {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support div_rem yet");
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
}
})
Expand All @@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
};
}
generic_integer_impl_scalar_div_rem!(
key_method: scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheUint2, u8),
(super::FheUint4, u8),
Expand Down

0 comments on commit 7f2d7f1

Please sign in to comment.