Skip to content

Commit

Permalink
feat(gpu): signed scalar rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Apr 4, 2024
1 parent 4c8528d commit 971b0cf
Show file tree
Hide file tree
Showing 12 changed files with 777 additions and 598 deletions.
28 changes: 28 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,18 @@ mod cuda {
rng_func: shift_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_rotate_right,
display_name: rotate_right,
rng_func: shift_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_rotate_left,
display_name: rotate_left,
rng_func: shift_scalar
);

//===========================================
// Default
//===========================================
Expand Down Expand Up @@ -1874,6 +1886,18 @@ mod cuda {
rng_func: shift_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_rotate_left,
display_name: rotate_left,
rng_func: shift_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_rotate_right,
display_name: rotate_right,
rng_func: shift_scalar
);

criterion_group!(
unchecked_cuda_ops,
cuda_unchecked_add,
Expand Down Expand Up @@ -1908,6 +1932,8 @@ mod cuda {
cuda_unchecked_scalar_bitxor,
cuda_unchecked_scalar_left_shift,
cuda_unchecked_scalar_right_shift,
cuda_unchecked_scalar_rotate_left,
cuda_unchecked_scalar_rotate_right,
);

criterion_group!(
Expand Down Expand Up @@ -1944,6 +1970,8 @@ mod cuda {
cuda_scalar_bitxor,
cuda_scalar_left_shift,
cuda_scalar_right_shift,
cuda_scalar_rotate_left,
cuda_scalar_rotate_right,
);
}

Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ generic_integer_impl_scalar_operation!(
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_stream(|stream| {
cuda_key.key.scalar_rotate_left(
&lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream
&*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream
)
});
RadixCiphertext::Cuda(inner_result)
Expand Down Expand Up @@ -808,7 +808,7 @@ generic_integer_impl_scalar_operation!(
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_stream(|stream| {
cuda_key.key.scalar_rotate_right(
&lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream
&*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream
)
});
RadixCiphertext::Cuda(inner_result)
Expand Down
126 changes: 59 additions & 67 deletions tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::core_crypto::gpu::CudaStream;
use crate::core_crypto::prelude::CastFrom;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::CudaServerKey;

Expand All @@ -9,15 +9,16 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_left_async<T>(
pub unsafe fn unchecked_scalar_rotate_left_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
n: T,
ct: &T,
n: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = ct.duplicate_async(stream);
self.unchecked_scalar_rotate_left_assign_async(&mut result, n, stream);
Expand All @@ -28,14 +29,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_left_assign_async<T>(
pub unsafe fn unchecked_scalar_rotate_left_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
n: T,
ct: &mut T,
n: Scalar,
stream: &CudaStream,
) where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
Expand Down Expand Up @@ -89,15 +91,16 @@ impl CudaServerKey {
}
}

pub fn unchecked_scalar_rotate_left<T>(
pub fn unchecked_scalar_rotate_left<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
n: T,
ct: &T,
n: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let result = unsafe { self.unchecked_scalar_rotate_left_async(ct, n, stream) };
stream.synchronize();
Expand All @@ -108,15 +111,16 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_right_async<T>(
pub unsafe fn unchecked_scalar_rotate_right_async<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
n: T,
ct: &T,
n: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = ct.duplicate_async(stream);
self.unchecked_scalar_rotate_right_assign_async(&mut result, n, stream);
Expand All @@ -127,14 +131,15 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_right_assign_async<T>(
pub unsafe fn unchecked_scalar_rotate_right_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
n: T,
ct: &mut T,
n: Scalar,
stream: &CudaStream,
) where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
Expand Down Expand Up @@ -188,29 +193,27 @@ impl CudaServerKey {
}
}

pub fn unchecked_scalar_rotate_right<T>(
pub fn unchecked_scalar_rotate_right<Scalar, T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
n: T,
ct: &T,
n: Scalar,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let result = unsafe { self.unchecked_scalar_rotate_right_async(ct, n, stream) };
stream.synchronize();
result
}

pub fn scalar_rotate_left_assign<T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
n: T,
stream: &CudaStream,
) where
T: CastFrom<u32>,
u32: CastFrom<T>,
pub fn scalar_rotate_left_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStream)
where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
unsafe {
Expand All @@ -222,14 +225,11 @@ impl CudaServerKey {
stream.synchronize();
}

pub fn scalar_rotate_right_assign<T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
n: T,
stream: &CudaStream,
) where
T: CastFrom<u32>,
u32: CastFrom<T>,
pub fn scalar_rotate_right_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStream)
where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
unsafe {
Expand All @@ -241,30 +241,22 @@ impl CudaServerKey {
stream.synchronize();
}

pub fn scalar_rotate_left<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
shift: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_rotate_left<Scalar, T>(&self, ct: &T, shift: Scalar, stream: &CudaStream) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.scalar_rotate_left_assign(&mut result, shift, stream);
result
}

pub fn scalar_rotate_right<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
shift: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_rotate_right<Scalar, T>(&self, ct: &T, shift: Scalar, stream: &CudaStream) -> T
where
T: CastFrom<u32>,
u32: CastFrom<T>,
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.scalar_rotate_right_assign(&mut result, shift, stream);
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub(crate) mod test_rotate;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_bitwise_op;
pub(crate) mod test_scalar_mul;
pub(crate) mod test_scalar_rotate;
pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_rotate::{
signed_default_scalar_rotate_left_test, signed_default_scalar_rotate_right_test,
signed_unchecked_scalar_rotate_left_test, signed_unchecked_scalar_rotate_right_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_signed_unchecked_scalar_rotate_left);
create_gpu_parametrized_test!(integer_signed_scalar_rotate_left);
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_rotate_right);
create_gpu_parametrized_test!(integer_signed_scalar_rotate_right);

fn integer_signed_unchecked_scalar_rotate_left<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_left);
signed_unchecked_scalar_rotate_left_test(param, executor);
}

fn integer_signed_scalar_rotate_left<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
signed_default_scalar_rotate_left_test(param, executor);
}

fn integer_signed_unchecked_scalar_rotate_right<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_right);
signed_unchecked_scalar_rotate_right_test(param, executor);
}

fn integer_signed_scalar_rotate_right<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
signed_default_scalar_rotate_right_test(param, executor);
}
19 changes: 1 addition & 18 deletions tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_bitwise_op;
pub(crate) mod test_scalar_comparison;
pub(crate) mod test_scalar_mul;
pub(crate) mod test_scalar_rotate;
pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
Expand Down Expand Up @@ -85,8 +86,6 @@ impl<F> GpuFunctionExecutor<F> {

// Unchecked operations
create_gpu_parametrized_test!(integer_unchecked_if_then_else);
create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left);
create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right);

// Default operations
create_gpu_parametrized_test!(integer_if_then_else);
Expand Down Expand Up @@ -477,22 +476,6 @@ where
}
}

fn integer_unchecked_scalar_rotate_left<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_left);
unchecked_scalar_rotate_left_test(param, executor);
}

fn integer_unchecked_scalar_rotate_right<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_right);
unchecked_scalar_rotate_right_test(param, executor);
}

fn integer_if_then_else<P>(param: P)
where
P: Into<PBSParameters> + Copy,
Expand Down
Loading

0 comments on commit 971b0cf

Please sign in to comment.