From 799c222d7ce7c8f149a27975dbee8b16fa89c12f Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Tue, 10 Dec 2024 10:55:31 +0100 Subject: [PATCH] fix(gpu) update degree and noise after more ops --- tfhe/src/integer/gpu/ciphertext/info.rs | 86 +++++++++++++++++-- tfhe/src/integer/gpu/server_key/radix/cmux.rs | 1 + .../gpu/server_key/radix/comparison.rs | 3 +- .../integer/gpu/server_key/radix/rotate.rs | 2 + .../gpu/server_key/radix/scalar_add.rs | 4 +- .../gpu/server_key/radix/scalar_comparison.rs | 2 +- .../gpu/server_key/radix/scalar_rotate.rs | 2 + .../gpu/server_key/radix/scalar_shift.rs | 2 + .../gpu/server_key/radix/scalar_sub.rs | 3 +- .../src/integer/gpu/server_key/radix/shift.rs | 2 + .../tests_long_run/test_random_op_sequence.rs | 26 +++--- 11 files changed, 111 insertions(+), 22 deletions(-) diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index 92e90d8030..586d05467e 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -134,7 +134,7 @@ impl CudaRadixCiphertextInfo { .blocks .iter() .map(|info| CudaBlockInfo { - degree: info.degree, + degree: Degree::new(info.message_modulus.0 - 1), message_modulus: info.message_modulus, carry_modulus: info.carry_modulus, pbs_order: info.pbs_order, @@ -167,11 +167,11 @@ impl CudaRadixCiphertextInfo { .iter() .zip(&other.blocks) .map(|(left, _)| CudaBlockInfo { - degree: left.degree, + degree: Degree::new(left.message_modulus.0 - 1), message_modulus: left.message_modulus, carry_modulus: left.carry_modulus, pbs_order: left.pbs_order, - noise_level: left.noise_level, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } @@ -183,11 +183,87 @@ impl CudaRadixCiphertextInfo { .iter() .zip(&other.blocks) .map(|(left, _)| CudaBlockInfo { - degree: left.degree, + degree: Degree::new(left.message_modulus.0 - 1), message_modulus: left.message_modulus, carry_modulus: left.carry_modulus, pbs_order: left.pbs_order, - noise_level: left.noise_level, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_if_then_else(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|b| CudaBlockInfo { + degree: Degree::new(b.message_modulus.0 - 1), + message_modulus: b.message_modulus, + carry_modulus: b.carry_modulus, + pbs_order: b.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_overflowing_scalar_add_sub(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|b| CudaBlockInfo { + degree: Degree::new(b.message_modulus.0 - 1), + message_modulus: b.message_modulus, + carry_modulus: b.carry_modulus, + pbs_order: b.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_rotate(&self, other: &Self) -> Self { + Self { + blocks: self + .blocks + .iter() + .zip(&other.blocks) + .map(|(left, _)| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_scalar_rotate(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|left| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_min_max(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|left| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } diff --git a/tfhe/src/integer/gpu/server_key/radix/cmux.rs b/tfhe/src/integer/gpu/server_key/radix/cmux.rs index f02a1e1996..864a7914cd 100644 --- a/tfhe/src/integer/gpu/server_key/radix/cmux.rs +++ b/tfhe/src/integer/gpu/server_key/radix/cmux.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + result.as_mut().info = true_ct.as_ref().info.after_if_then_else(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/comparison.rs b/tfhe/src/integer/gpu/server_key/radix/comparison.rs index a06a8dbd88..26d728747d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/comparison.rs @@ -1030,6 +1030,7 @@ impl CudaServerKey { ); } } + result.as_mut().info = ct_left.as_ref().info.after_min_max(); result } @@ -1129,7 +1130,7 @@ impl CudaServerKey { ); } } - + result.as_mut().info = ct_left.as_ref().info.after_min_max(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/rotate.rs b/tfhe/src/integer/gpu/server_key/radix/rotate.rs index 4f73f14060..b7f6348454 100644 --- a/tfhe/src/integer/gpu/server_key/radix/rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/rotate.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info); } /// # Safety @@ -198,6 +199,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info); } /// # Safety diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs index 40ad82add2..597d0e2e55 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs @@ -271,6 +271,7 @@ impl CudaServerKey { self.propagate_single_carry_assign_async(ct_left, stream, None, OutputFlag::Carry); } stream.synchronize(); + ct_left.as_mut().info = ct_left.as_ref().info.after_overflowing_scalar_add_sub(); let num_scalar_blocks = BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2()) @@ -345,7 +346,8 @@ impl CudaServerKey { ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, streams, ); - let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams); + let (mut result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams); + result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub(); let mut extra_scalar_block_iter = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index 93652f4b32..b98726d185 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -395,7 +395,7 @@ impl CudaServerKey { ); } } - + result.as_mut().info = ct.as_ref().info.after_min_max(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs index 6449123f43..240e2144f4 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs @@ -97,6 +97,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } pub fn unchecked_scalar_rotate_left( @@ -204,6 +205,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } pub fn unchecked_scalar_rotate_right( diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs index f651dad71f..a50a04e063 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs @@ -99,6 +99,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } /// Computes homomorphically a left shift by a scalar. @@ -300,6 +301,7 @@ impl CudaServerKey { } } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } /// Computes homomorphically a right shift by a scalar. diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs index ae0fb950e6..cefdf2dfde 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs @@ -230,7 +230,8 @@ impl CudaServerKey { ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, streams, ); - let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams); + let (mut result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams); + result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub(); let mut extra_scalar_block_iter = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) diff --git a/tfhe/src/integer/gpu/server_key/radix/shift.rs b/tfhe/src/integer/gpu/server_key/radix/shift.rs index e775aa914c..1a7f58f69d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/shift.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info); } /// # Safety @@ -196,6 +197,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info); } /// # Safety diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs index 903776f5c5..a4970ff6af 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs @@ -580,7 +580,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index}.", + "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); } else if i < binary_ops.len() + unary_ops.len() { let index = i - binary_ops.len(); @@ -629,7 +629,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on unary op {fn_name} with clear input {clear_input} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on unary op {fn_name} with clear input {clear_input} at iteration {fn_index}.", ); } else if i < binary_ops.len() + unary_ops.len() + scalar_binary_ops.len() { let index = i - binary_ops.len() - unary_ops.len(); @@ -671,7 +671,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); } else if i < binary_ops.len() + unary_ops.len() @@ -727,7 +727,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); assert_eq!( decrypted_overflow, expected_overflow, @@ -794,7 +794,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); assert_eq!( decrypted_overflow, expected_overflow, @@ -836,7 +836,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); let res_ct: RadixCiphertext = res.into_radix(1, &sks); @@ -886,7 +886,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.", ); let res_ct: RadixCiphertext = res.into_radix(1, &sks); if i % 2 == 0 { @@ -945,7 +945,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} and {clear_bool} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} and {clear_bool} at iteration {fn_index}.", ); if i % 2 == 0 { left_vec[j] = res.clone(); @@ -1020,11 +1020,11 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res_q, expected_res_q, - "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.", ); assert_eq!( decrypted_res_r, expected_res_r, - "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.", ); if i % 2 == 0 { left_vec[j] = res_q.clone(); @@ -1103,11 +1103,11 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res_q, expected_res_q, - "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.", ); assert_eq!( decrypted_res_r, expected_res_r, - "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.", ); if i % 2 == 0 { left_vec[j] = res_r.clone(); @@ -1181,7 +1181,7 @@ pub(crate) fn random_op_sequence_test

( // Correctness check assert_eq!( decrypted_res, expected_res, - "Invalid result on op {fn_name} with clear input {clear_input} at iteration{fn_index} at iteration{fn_index}.", + "Invalid result on op {fn_name} with clear input {clear_input} at iteration {fn_index}.", ); if i % 2 == 0 { left_vec[j] = cast_res.clone();