From 33bec2aaa4d342bedd8ecbfaf0ec79ed724f2a84 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Mon, 16 Dec 2024 15:14:52 +0100 Subject: [PATCH] fix(gpu): fix ct degree and noise level after some ops --- tfhe/src/integer/gpu/ciphertext/info.rs | 119 +++++++++++++++++- .../gpu/server_key/radix/bitwise_op.rs | 1 + tfhe/src/integer/gpu/server_key/radix/cmux.rs | 1 + .../gpu/server_key/radix/comparison.rs | 3 +- .../src/integer/gpu/server_key/radix/ilog2.rs | 4 +- .../integer/gpu/server_key/radix/rotate.rs | 2 + .../gpu/server_key/radix/scalar_add.rs | 4 +- .../gpu/server_key/radix/scalar_comparison.rs | 2 +- .../gpu/server_key/radix/scalar_rotate.rs | 2 + .../gpu/server_key/radix/scalar_shift.rs | 2 + .../gpu/server_key/radix/scalar_sub.rs | 3 +- .../src/integer/gpu/server_key/radix/shift.rs | 2 + 12 files changed, 134 insertions(+), 11 deletions(-) diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index 132f55c877..260dfcffd5 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -122,7 +122,23 @@ impl CudaRadixCiphertextInfo { message_modulus: left.message_modulus, carry_modulus: left.carry_modulus, pbs_order: left.pbs_order, - noise_level: left.noise_level + NoiseLevel::NOMINAL, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + + pub(crate) fn after_ilog2(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|info| CudaBlockInfo { + degree: Degree::new(info.message_modulus.0 - 1), + message_modulus: info.message_modulus, + carry_modulus: info.carry_modulus, + pbs_order: info.pbs_order, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } @@ -151,11 +167,11 @@ impl CudaRadixCiphertextInfo { .iter() .zip(&other.blocks) .map(|(left, _)| CudaBlockInfo { - degree: left.degree, + degree: Degree::new(left.message_modulus.0 - 1), message_modulus: left.message_modulus, carry_modulus: left.carry_modulus, pbs_order: left.pbs_order, - noise_level: left.noise_level, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } @@ -167,11 +183,87 @@ impl CudaRadixCiphertextInfo { .iter() .zip(&other.blocks) .map(|(left, _)| CudaBlockInfo { - degree: left.degree, + degree: Degree::new(left.message_modulus.0 - 1), message_modulus: left.message_modulus, carry_modulus: left.carry_modulus, pbs_order: left.pbs_order, - noise_level: left.noise_level, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_if_then_else(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|b| CudaBlockInfo { + degree: Degree::new(b.message_modulus.0 - 1), + message_modulus: b.message_modulus, + carry_modulus: b.carry_modulus, + pbs_order: b.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_overflowing_scalar_add_sub(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|b| CudaBlockInfo { + degree: Degree::new(b.message_modulus.0 - 1), + message_modulus: b.message_modulus, + carry_modulus: b.carry_modulus, + pbs_order: b.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_rotate(&self, other: &Self) -> Self { + Self { + blocks: self + .blocks + .iter() + .zip(&other.blocks) + .map(|(left, _)| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_scalar_rotate(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|left| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_min_max(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|left| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } @@ -229,7 +321,7 @@ impl CudaRadixCiphertextInfo { message_modulus: info.message_modulus, carry_modulus: info.carry_modulus, pbs_order: info.pbs_order, - noise_level: info.noise_level + NoiseLevel::NOMINAL, + noise_level: NoiseLevel::NOMINAL, }) .collect(), } @@ -310,6 +402,21 @@ impl CudaRadixCiphertextInfo { .collect(), } } + pub(crate) fn after_bitnot(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|b| CudaBlockInfo { + degree: Degree::new(b.message_modulus.0 - 1), + message_modulus: b.message_modulus, + carry_modulus: b.carry_modulus, + pbs_order: b.pbs_order, + noise_level: b.noise_level, + }) + .collect(), + } + } pub(crate) fn after_scalar_bitand(&self, scalar: T) -> Self where T: DecomposableInto, diff --git a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs index 5dfc2b3238..0e77fda595 100644 --- a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs @@ -91,6 +91,7 @@ impl CudaServerKey { &d_decomposed_scalar, streams, ); + ct.as_mut().info = ct.as_ref().info.after_bitnot(); } pub fn unchecked_bitnot_assign( diff --git a/tfhe/src/integer/gpu/server_key/radix/cmux.rs b/tfhe/src/integer/gpu/server_key/radix/cmux.rs index f02a1e1996..864a7914cd 100644 --- a/tfhe/src/integer/gpu/server_key/radix/cmux.rs +++ b/tfhe/src/integer/gpu/server_key/radix/cmux.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + result.as_mut().info = true_ct.as_ref().info.after_if_then_else(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/comparison.rs b/tfhe/src/integer/gpu/server_key/radix/comparison.rs index a06a8dbd88..26d728747d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/comparison.rs @@ -1030,6 +1030,7 @@ impl CudaServerKey { ); } } + result.as_mut().info = ct_left.as_ref().info.after_min_max(); result } @@ -1129,7 +1130,7 @@ impl CudaServerKey { ); } } - + result.as_mut().info = ct_left.as_ref().info.after_min_max(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/ilog2.rs b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs index 9c0fb78920..e9265612b0 100644 --- a/tfhe/src/integer/gpu/server_key/radix/ilog2.rs +++ b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs @@ -786,7 +786,9 @@ impl CudaServerKey { let result = self.sum_ciphertexts_async(ciphertexts, streams).unwrap(); - self.cast_to_unsigned_async(result, counter_num_blocks, streams) + let mut result_cast = self.cast_to_unsigned_async(result, counter_num_blocks, streams); + result_cast.as_mut().info = ct.as_ref().info.after_ilog2(); + result_cast } /// Returns the number of trailing zeros in the binary representation of `ct` diff --git a/tfhe/src/integer/gpu/server_key/radix/rotate.rs b/tfhe/src/integer/gpu/server_key/radix/rotate.rs index 4f73f14060..b7f6348454 100644 --- a/tfhe/src/integer/gpu/server_key/radix/rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/rotate.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info); } /// # Safety @@ -198,6 +199,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info); } /// # Safety diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs index 40ad82add2..597d0e2e55 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs @@ -271,6 +271,7 @@ impl CudaServerKey { self.propagate_single_carry_assign_async(ct_left, stream, None, OutputFlag::Carry); } stream.synchronize(); + ct_left.as_mut().info = ct_left.as_ref().info.after_overflowing_scalar_add_sub(); let num_scalar_blocks = BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2()) @@ -345,7 +346,8 @@ impl CudaServerKey { ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, streams, ); - let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams); + let (mut result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams); + result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub(); let mut extra_scalar_block_iter = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index b9b44b4353..ff6049cbad 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -395,7 +395,7 @@ impl CudaServerKey { ); } } - + result.as_mut().info = ct.as_ref().info.after_min_max(); result } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs index 6449123f43..240e2144f4 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs @@ -97,6 +97,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } pub fn unchecked_scalar_rotate_left( @@ -204,6 +205,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } pub fn unchecked_scalar_rotate_right( diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs index f651dad71f..a50a04e063 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs @@ -99,6 +99,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } /// Computes homomorphically a left shift by a scalar. @@ -300,6 +301,7 @@ impl CudaServerKey { } } } + ct.as_mut().info = ct.as_ref().info.after_scalar_rotate(); } /// Computes homomorphically a right shift by a scalar. diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs index ae0fb950e6..cefdf2dfde 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs @@ -230,7 +230,8 @@ impl CudaServerKey { ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, streams, ); - let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams); + let (mut result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams); + result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub(); let mut extra_scalar_block_iter = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) diff --git a/tfhe/src/integer/gpu/server_key/radix/shift.rs b/tfhe/src/integer/gpu/server_key/radix/shift.rs index e775aa914c..1a7f58f69d 100644 --- a/tfhe/src/integer/gpu/server_key/radix/shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/shift.rs @@ -79,6 +79,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info); } /// # Safety @@ -196,6 +197,7 @@ impl CudaServerKey { ); } } + ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info); } /// # Safety