Skip to content

Commit

Permalink
feat(gpu): signed scalar sub
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Mar 12, 2024
1 parent d380144 commit 6f954bb
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 294 deletions.
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ generic_integer_impl_scalar_operation!(
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_stream(|stream| {
cuda_key.key.scalar_sub(
&lhs.ciphertext.on_gpu(), rhs, stream
&*lhs.ciphertext.on_gpu(), rhs, stream
)
});
RadixCiphertext::Cuda(inner_result)
Expand Down
62 changes: 27 additions & 35 deletions tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStream;
use crate::core_crypto::prelude::UnsignedNumeric;
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaServerKey;
use crate::integer::server_key::TwosComplementNegation;

Expand Down Expand Up @@ -43,14 +43,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg - scalar, dec);
/// ```
pub fn unchecked_scalar_sub<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn unchecked_scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.unchecked_scalar_sub_assign(&mut result, scalar, stream);
Expand All @@ -61,26 +57,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_sub_assign_async<T>(
pub unsafe fn unchecked_scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let negated_scalar = scalar.twos_complement_negation();
self.unchecked_scalar_add_assign_async(ct, negated_scalar, stream);
ct.as_mut().info = ct.as_ref().info.after_scalar_sub(scalar);
}

pub fn unchecked_scalar_sub_assign<T>(
pub fn unchecked_scalar_sub_assign<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_sub_assign_async(ct, scalar, stream);
Expand Down Expand Up @@ -125,14 +123,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg - scalar, dec);
/// ```
pub fn scalar_sub<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.scalar_sub_assign(&mut result, scalar, stream);
Expand All @@ -143,13 +137,14 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_sub_assign_async<T>(
pub unsafe fn scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign_async(ct, stream);
Expand All @@ -159,13 +154,10 @@ impl CudaServerKey {
self.full_propagate_assign_async(ct, stream);
}

pub fn scalar_sub_assign<T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream)
where
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_sub_assign_async(ct, scalar, stream);
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;

use crate::core_crypto::gpu::CudaStream;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_signed::signed_unchecked_scalar_sub_test;
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub);

fn integer_signed_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
signed_unchecked_scalar_sub_test(param, executor);
}
19 changes: 1 addition & 18 deletions tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;

use crate::core_crypto::gpu::{CudaDevice, CudaStream};
Expand Down Expand Up @@ -75,7 +76,6 @@ impl<F> GpuFunctionExecutor<F> {
}

// Unchecked operations
create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
create_gpu_parametrized_test!(integer_unchecked_small_scalar_mul);
create_gpu_parametrized_test!(integer_unchecked_bitnot);
create_gpu_parametrized_test!(integer_unchecked_bitand);
Expand Down Expand Up @@ -107,7 +107,6 @@ create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left);
create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right);

// Default operations
create_gpu_parametrized_test!(integer_scalar_sub);
create_gpu_parametrized_test!(integer_small_scalar_mul);
create_gpu_parametrized_test!(integer_scalar_right_shift);
create_gpu_parametrized_test!(integer_scalar_left_shift);
Expand Down Expand Up @@ -318,14 +317,6 @@ where
unchecked_small_scalar_mul_test(param, executor);
}

fn integer_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
unchecked_scalar_sub_test(param, executor);
}

fn integer_unchecked_bitnot<P>(param: P)
where
P: Into<PBSParameters>,
Expand Down Expand Up @@ -1496,14 +1487,6 @@ where
default_small_scalar_mul_test(param, executor);
}

fn integer_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
default_scalar_sub_test(param, executor);
}

fn integer_bitnot<P>(param: P)
where
P: Into<PBSParameters> + Copy,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
default_scalar_sub_test, unchecked_scalar_sub_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
create_gpu_parametrized_test!(integer_scalar_sub);

fn integer_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
unchecked_scalar_sub_test(param, executor);
}

fn integer_scalar_sub<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
default_scalar_sub_test(param, executor);
}
Loading

0 comments on commit 6f954bb

Please sign in to comment.