Skip to content

Commit

Permalink
feat(gpu): implement signed if_then_else
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves authored and agnesLeroy committed Apr 8, 2024
1 parent f9a3984 commit 86e629a
Show file tree
Hide file tree
Showing 15 changed files with 680 additions and 343 deletions.
7 changes: 3 additions & 4 deletions tfhe/benches/integer/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,7 @@ mod cuda {
use super::*;
use criterion::criterion_group;
use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use tfhe::integer::gpu::server_key::CudaServerKey;

Expand Down Expand Up @@ -1414,8 +1415,7 @@ mod cuda {

let encrypt_tree_values = || {
let clear_cond = rng.gen::<bool>();
let ct_cond =
cks.encrypt_radix(tfhe::integer::U256::from(clear_cond), num_block);
let ct_cond = cks.encrypt_bool(clear_cond);

let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
Expand All @@ -1427,8 +1427,7 @@ mod cuda {
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
let ct_else = cks.encrypt_radix(clear_1, num_block);

let d_ct_cond =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_cond, &stream);
let d_ct_cond = CudaBooleanBlock::from_boolean_block(&ct_cond, &stream);
let d_ct_then =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_then, &stream);
let d_ct_else =
Expand Down
62 changes: 62 additions & 0 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ mod cuda {
use super::*;
use criterion::criterion_group;
use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
use tfhe::integer::gpu::server_key::CudaServerKey;

Expand Down Expand Up @@ -1555,6 +1556,66 @@ mod cuda {
}
);

fn cuda_if_then_else(c: &mut Criterion) {
let mut bench_group = c.benchmark_group("integer::cuda::signed::if_then_else");
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let mut rng = rand::thread_rng();

let gpu_index = 0;
let device = CudaDevice::new(gpu_index);
let stream = CudaStream::new_unchecked(device);

for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
if bit_size > ScalarType::BITS as usize {
break;
}

let param_name = param.name();

let bench_id = format!("if_then_else:{param_name}::{bit_size}_bits_scalar_{bit_size}");
bench_group.bench_function(&bench_id, |b| {
let (cks, _cpu_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let gpu_sks = CudaServerKey::new(&cks, &stream);

let encrypt_tree_values = || {
let clear_cond = rng.gen::<bool>();
let ct_then = cks.encrypt_signed_radix(gen_random_i256(&mut rng), num_block);
let ct_else = cks.encrypt_signed_radix(gen_random_i256(&mut rng), num_block);
let ct_cond = cks.encrypt_bool(clear_cond);

let d_ct_cond = CudaBooleanBlock::from_boolean_block(&ct_cond, &stream);
let d_ct_then =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct_then, &stream);
let d_ct_else =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct_else, &stream);

(d_ct_cond, d_ct_then, d_ct_else)
};

b.iter_batched(
encrypt_tree_values,
|(ct_cond, ct_then, ct_else)| {
let _ = gpu_sks.if_then_else(&ct_cond, &ct_then, &ct_else, &stream);
},
criterion::BatchSize::SmallInput,
)
});

write_to_json::<u64, _>(
&bench_id,
param,
param.name(),
"if_then_else",
&OperatorType::Atomic,
bit_size as u32,
vec![param.message_modulus().0.ilog2(); num_block],
);
}

bench_group.finish()
}
// Functions used to apply different way of selecting a scalar based on the context.
fn default_signed_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
let clearlow = rng.gen::<u128>();
Expand Down Expand Up @@ -1958,6 +2019,7 @@ mod cuda {
cuda_le,
cuda_min,
cuda_max,
cuda_if_then_else,
);

criterion_group!(
Expand Down
8 changes: 5 additions & 3 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{FheEq, IfThenElse};
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::BooleanBlock;
use crate::named::Named;
use crate::shortint::ciphertext::NotTrivialCiphertextError;
Expand Down Expand Up @@ -178,9 +180,9 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner = cuda_key.key.if_then_else(
&self.ciphertext.on_gpu(),
&ct_then.ciphertext.on_gpu(),
&ct_else.ciphertext.on_gpu(),
&CudaBooleanBlock(self.ciphertext.on_gpu().duplicate(stream)),
&*ct_then.ciphertext.on_gpu(),
&*ct_else.ciphertext.on_gpu(),
stream,
);

Expand Down
48 changes: 20 additions & 28 deletions tfhe/src/integer/gpu/server_key/radix/cmux.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::core_crypto::gpu::CudaStream;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::CudaServerKey;

Expand All @@ -8,22 +9,22 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_if_then_else_async(
pub unsafe fn unchecked_if_then_else_async<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaUnsignedRadixCiphertext,
true_ct: &CudaUnsignedRadixCiphertext,
false_ct: &CudaUnsignedRadixCiphertext,
condition: &CudaBooleanBlock,
true_ct: &T,
false_ct: &T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext {
) -> T {
let lwe_ciphertext_count = true_ct.as_ref().d_blocks.lwe_ciphertext_count();
let mut result = self
let mut result: T = self
.create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream);

match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
stream.unchecked_cmux_integer_radix_classic_kb_async(
&mut result.as_mut().d_blocks.0.d_vec,
&condition.as_ref().d_blocks.0.d_vec,
&condition.as_ref().ciphertext.d_blocks.0.d_vec,
&true_ct.as_ref().d_blocks.0.d_vec,
&false_ct.as_ref().d_blocks.0.d_vec,
&d_bsk.d_vec,
Expand All @@ -48,7 +49,7 @@ impl CudaServerKey {
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
stream.unchecked_cmux_integer_radix_multibit_kb_async(
&mut result.as_mut().d_blocks.0.d_vec,
&condition.as_ref().d_blocks.0.d_vec,
&condition.as_ref().ciphertext.d_blocks.0.d_vec,
&true_ct.as_ref().d_blocks.0.d_vec,
&false_ct.as_ref().d_blocks.0.d_vec,
&d_multibit_bsk.d_vec,
Expand All @@ -75,39 +76,30 @@ impl CudaServerKey {

result
}
pub fn unchecked_if_then_else(
pub fn unchecked_if_then_else<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaUnsignedRadixCiphertext,
true_ct: &CudaUnsignedRadixCiphertext,
false_ct: &CudaUnsignedRadixCiphertext,
condition: &CudaBooleanBlock,
true_ct: &T,
false_ct: &T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext {
) -> T {
let result =
unsafe { self.unchecked_if_then_else_async(condition, true_ct, false_ct, stream) };
stream.synchronize();
result
}

pub fn if_then_else(
pub fn if_then_else<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaUnsignedRadixCiphertext,
true_ct: &CudaUnsignedRadixCiphertext,
false_ct: &CudaUnsignedRadixCiphertext,
condition: &CudaBooleanBlock,
true_ct: &T,
false_ct: &T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext {
let mut tmp_condition;
) -> T {
let mut tmp_true_ct;
let mut tmp_false_ct;

let result = unsafe {
let condition = if condition.block_carries_are_empty() {
condition
} else {
tmp_condition = condition.duplicate_async(stream);
self.full_propagate_assign_async(&mut tmp_condition, stream);
&tmp_condition
};

let true_ct = if true_ct.block_carries_are_empty() {
true_ct
} else {
Expand Down
10 changes: 6 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl CudaServerKey {
///
/// ```rust
/// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream};
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::integer::{gen_keys_radix, RadixCiphertext};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
Expand All @@ -57,19 +58,20 @@ impl CudaServerKey {
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream);
///
/// let d_ctxt = sks.create_trivial_zero_radix(num_blocks, &mut stream);
/// let d_ctxt: CudaUnsignedRadixCiphertext =
/// sks.create_trivial_zero_radix(num_blocks, &mut stream);
/// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream);
///
/// // Decrypt:
/// let dec: u64 = cks.decrypt(&ctxt);
/// assert_eq!(0, dec);
/// ```
pub fn create_trivial_zero_radix(
pub fn create_trivial_zero_radix<T: CudaIntegerRadixCiphertext>(
&self,
num_blocks: usize,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext {
self.create_trivial_radix(0, num_blocks, stream)
) -> T {
T::from(self.create_trivial_radix(0, num_blocks, stream).ciphertext)
}

/// Create a trivial ciphertext on the GPU
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/gpu/server_key/radix/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl CudaServerKey {
stream: &CudaStream,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let num_blocks = lhs.as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
let mut tmp = self.create_trivial_zero_radix(1, stream);
let mut tmp: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream);
if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
&& rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
{
Expand Down
55 changes: 55 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub(crate) mod test_add;
pub(crate) mod test_bitwise_op;
pub(crate) mod test_cmux;
pub(crate) mod test_comparison;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
Expand Down Expand Up @@ -293,3 +294,57 @@ where
d_res.to_signed_radix_ciphertext(&context.stream)
}
}

impl<'a, F>
FunctionExecutor<
(
&'a BooleanBlock,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
),
SignedRadixCiphertext,
> for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaBooleanBlock,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStream,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(
&mut self,
input: (
&'a BooleanBlock,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
),
) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");

let d_ctxt_1: CudaBooleanBlock =
CudaBooleanBlock::from_boolean_block(input.0, &context.stream);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.stream);
let d_ctxt_3: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.2, &context.stream);

let d_res = (self.func)(
&context.sks,
&d_ctxt_1,
&d_ctxt_2,
&d_ctxt_3,
&context.stream,
);

d_res.to_signed_radix_ciphertext(&context.stream)
}
}
27 changes: 27 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/test_cmux.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_cmux::{
signed_default_if_then_else_test, signed_unchecked_if_then_else_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_unchecked_if_then_else);
create_gpu_parametrized_test!(integer_if_then_else);

fn integer_unchecked_if_then_else<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_if_then_else);
signed_unchecked_if_then_else_test(param, executor);
}

fn integer_if_then_else<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::if_then_else);
signed_default_if_then_else_test(param, executor);
}
Loading

0 comments on commit 86e629a

Please sign in to comment.