Skip to content

Commit

Permalink
feat(gpu): signed scalar div
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 11, 2024
1 parent 2173bb6 commit 9104781
Show file tree
Hide file tree
Showing 8 changed files with 605 additions and 148 deletions.
14 changes: 14 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,12 @@ mod cuda {
rng_func: default_signed_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_signed_scalar_div_rem,
display_name: div_rem,
rng_func: div_scalar
);

//===========================================
// Default
//===========================================
Expand Down Expand Up @@ -2035,6 +2041,12 @@ mod cuda {
rng_func: default_signed_scalar
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: signed_scalar_div_rem,
display_name: div_rem,
rng_func: div_scalar
);

criterion_group!(
unchecked_cuda_ops,
cuda_unchecked_add,
Expand Down Expand Up @@ -2081,6 +2093,7 @@ mod cuda {
cuda_unchecked_scalar_le,
cuda_unchecked_scalar_min,
cuda_unchecked_scalar_max,
cuda_unchecked_signed_scalar_div_rem,
);

criterion_group!(
Expand Down Expand Up @@ -2146,6 +2159,7 @@ mod cuda {
cuda_scalar_max,
cuda_signed_overflowing_scalar_add,
cuda_signed_overflowing_scalar_sub,
cuda_signed_scalar_div_rem,
);

fn cuda_bench_server_key_signed_cast_function<F>(
Expand Down
281 changes: 279 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/scalar_div_mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::scalar_div_mod::choose_multiplier;
use crate::integer::server_key::radix_parallel::scalar_div_mod::{
choose_multiplier, SignedReciprocable,
};
use crate::integer::server_key::{MiniUnsignedInteger, Reciprocable, ScalarMultiplier};
use crate::prelude::{CastFrom, CastInto};

Expand Down Expand Up @@ -32,6 +36,21 @@ impl CudaServerKey {
result
}

fn signed_scalar_mul_high<Scalar>(
&self,
lhs: &CudaSignedRadixCiphertext,
rhs: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let num_blocks = lhs.as_ref().d_blocks.lwe_ciphertext_count().0;
let mut result = self.extend_radix_with_sign_msb(lhs, num_blocks, streams);
self.scalar_mul_assign(&mut result, rhs, streams);
self.trim_radix_blocks_lsb(&result, num_blocks, streams)
}

/// Computes homomorphically a division between a ciphertext and a scalar.
///
/// This function computes the operation without checking if it exceeds the capacity of the
Expand Down Expand Up @@ -403,4 +422,262 @@ impl CudaServerKey {

self.unchecked_scalar_rem(numerator, divisor, streams)
}

pub fn unchecked_signed_scalar_div<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
assert_ne!(divisor, Scalar::ZERO, "attempt to divide by 0");

let numerator_bits = self.message_modulus.0.ilog2()
* numerator.ciphertext.d_blocks.lwe_ciphertext_count().0 as u32;
assert!(
Scalar::BITS >= numerator_bits as usize,
"The scalar divisor type must have a number of bits that is\
>= to the number of bits encrypted in the ciphertext"
);

// wrappings_abs returns Scalar::MIN when its input is Scalar::MIN (since in signed numbers
// Scalar::MIN's absolute value cannot be represented.
// However, casting Scalar::MIN to signed value will give the correct abs value
// If Scalar and Scalar::Unsigned have the same number of bits
let absolute_divisor = Scalar::Unsigned::cast_from(divisor.wrapping_abs());

if absolute_divisor == Scalar::Unsigned::ONE {
// Strangely, the paper says: Issue q = d;
return if divisor < Scalar::ZERO {
// quotient = -quotient;
self.neg(numerator, streams)
} else {
numerator.duplicate(streams)
};
}

let chosen_multiplier =
choose_multiplier(absolute_divisor, numerator_bits - 1, numerator_bits);

if chosen_multiplier.l >= numerator_bits {
return self.create_trivial_zero_radix(
numerator.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
}

let quotient;
if absolute_divisor == (Scalar::Unsigned::ONE << chosen_multiplier.l as usize) {
// Issue q = SRA(n + SRL(SRA(n, l − 1), N − l), l);
let l = chosen_multiplier.l;

// SRA(n, l − 1)
let mut tmp = self.unchecked_scalar_right_shift(numerator, l - 1, streams);

// SRL(SRA(n, l − 1), N − l)
unsafe {
self.unchecked_scalar_right_shift_logical_assign_async(
&mut tmp,
(numerator_bits - l) as usize,
streams,
);
}
streams.synchronize();
// n + SRL(SRA(n, l − 1), N − l)
self.add_assign(&mut tmp, numerator, streams);
// SRA(n + SRL(SRA(n, l − 1), N − l), l);
quotient = self.unchecked_scalar_right_shift(&tmp, l, streams);
} else if chosen_multiplier.multiplier
< (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE << (numerator_bits - 1))
{
// in the condition above works (it makes more values take this branch,
// but results still seemed correct)

// multiplier is less than the max possible value of Scalar
// Issue q = SRA(MULSH(m, n), shpost) − XSIGN(n);

let (mut tmp, xsign) = rayon::join(
move || {
// MULSH(m, n)
let mut tmp = self.signed_scalar_mul_high(
numerator,
chosen_multiplier.multiplier,
streams,
);

// SRA(MULSH(m, n), shpost)
unsafe {
self.unchecked_scalar_right_shift_assign_async(
&mut tmp,
chosen_multiplier.shift_post,
streams,
);
}
streams.synchronize();
tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N − 1)
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
},
);

self.sub_assign(&mut tmp, &xsign, streams);
quotient = tmp;
} else {
// Issue q = SRA(n + MULSH(m − 2^N , n), shpost) − XSIGN(n);
// Note from the paper: m - 2^N is negative

let (mut tmp, xsign) = rayon::join(
move || {
// The subtraction may overflow.
// We then cast the result to a signed type.
// Overall, this will work fine due to two's complement representation
let cst = chosen_multiplier.multiplier
- (<Scalar::Unsigned as Reciprocable>::DoublePrecision::ONE
<< numerator_bits);
let cst = Scalar::DoublePrecision::cast_from(cst);

// MULSH(m - 2^N, n)
let mut tmp = self.signed_scalar_mul_high(numerator, cst, streams);

// n + MULSH(m − 2^N , n)
self.add_assign(&mut tmp, numerator, streams);

// SRA(n + MULSH(m - 2^N, n), shpost)
tmp = self.unchecked_scalar_right_shift(
&tmp,
chosen_multiplier.shift_post,
streams,
);

tmp
},
|| {
// XSIGN is: -1 if x < 0 { -1 } else { 0 }
// It is equivalent to SRA(x, N − 1)
self.unchecked_scalar_right_shift(numerator, numerator_bits - 1, streams)
},
);

self.sub_assign(&mut tmp, &xsign, streams);
quotient = tmp;
}

if divisor < Scalar::ZERO {
self.neg(&quotient, streams)
} else {
quotient
}
}

pub fn signed_scalar_div<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};

self.unchecked_signed_scalar_div(numerator, divisor, streams)
}

pub fn unchecked_signed_scalar_div_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let quotient = self.unchecked_signed_scalar_div(numerator, divisor, streams);

// remainder = numerator - (quotient * divisor)
let tmp = self.unchecked_scalar_mul(&quotient, divisor, streams);
let remainder = self.sub(numerator, &tmp, streams);

(quotient, remainder)
}

pub fn signed_scalar_div_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};

self.unchecked_signed_scalar_div_rem(numerator, divisor, streams)
}

pub fn unchecked_signed_scalar_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let (_, remainder) = self.unchecked_signed_scalar_div_rem(numerator, divisor, streams);

remainder
}

pub fn signed_scalar_rem<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let mut tmp_numerator;
let numerator = if numerator.block_carries_are_empty() {
numerator
} else {
unsafe {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
}
&tmp_numerator
};

self.unchecked_signed_scalar_rem(numerator, divisor, streams)
}
}
Loading

0 comments on commit 9104781

Please sign in to comment.