Skip to content

Commit

Permalink
fix(gpu): fix ct degree and noise level after some ops
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Dec 17, 2024
1 parent 8687b69 commit 241b737
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 11 deletions.
119 changes: 113 additions & 6 deletions tfhe/src/integer/gpu/ciphertext/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,23 @@ impl CudaRadixCiphertextInfo {
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: left.noise_level + NoiseLevel::NOMINAL,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}

pub(crate) fn after_ilog2(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|info| CudaBlockInfo {
degree: Degree::new(info.message_modulus.0 - 1),
message_modulus: info.message_modulus,
carry_modulus: info.carry_modulus,
pbs_order: info.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand Down Expand Up @@ -151,11 +167,11 @@ impl CudaRadixCiphertextInfo {
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: left.degree,
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: left.noise_level,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand All @@ -167,11 +183,87 @@ impl CudaRadixCiphertextInfo {
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: left.degree,
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: left.noise_level,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_if_then_else(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|b| CudaBlockInfo {
degree: Degree::new(b.message_modulus.0 - 1),
message_modulus: b.message_modulus,
carry_modulus: b.carry_modulus,
pbs_order: b.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_overflowing_scalar_add_sub(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|b| CudaBlockInfo {
degree: Degree::new(b.message_modulus.0 - 1),
message_modulus: b.message_modulus,
carry_modulus: b.carry_modulus,
pbs_order: b.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_rotate(&self, other: &Self) -> Self {
Self {
blocks: self
.blocks
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_scalar_rotate(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|left| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_min_max(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|left| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand Down Expand Up @@ -229,7 +321,7 @@ impl CudaRadixCiphertextInfo {
message_modulus: info.message_modulus,
carry_modulus: info.carry_modulus,
pbs_order: info.pbs_order,
noise_level: info.noise_level + NoiseLevel::NOMINAL,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand Down Expand Up @@ -310,6 +402,21 @@ impl CudaRadixCiphertextInfo {
.collect(),
}
}
pub(crate) fn after_bitnot(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|b| CudaBlockInfo {
degree: Degree::new(b.message_modulus.0 - 1),
message_modulus: b.message_modulus,
carry_modulus: b.carry_modulus,
pbs_order: b.pbs_order,
noise_level: b.noise_level,
})
.collect(),
}
}
pub(crate) fn after_scalar_bitand<T>(&self, scalar: T) -> Self
where
T: DecomposableInto<u8>,
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl CudaServerKey {
&d_decomposed_scalar,
streams,
);
ct.as_mut().info = ct.as_ref().info.after_bitnot();
}

pub fn unchecked_bitnot_assign<T: CudaIntegerRadixCiphertext>(
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/cmux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl CudaServerKey {
);
}
}
result.as_mut().info = true_ct.as_ref().info.after_if_then_else();

result
}
Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ impl CudaServerKey {
);
}
}
result.as_mut().info = ct_left.as_ref().info.after_min_max();

result
}
Expand Down Expand Up @@ -1129,7 +1130,7 @@ impl CudaServerKey {
);
}
}

result.as_mut().info = ct_left.as_ref().info.after_min_max();
result
}

Expand Down
4 changes: 3 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/ilog2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,9 @@ impl CudaServerKey {

let result = self.sum_ciphertexts_async(ciphertexts, streams).unwrap();

self.cast_to_unsigned_async(result, counter_num_blocks, streams)
let mut result_cast = self.cast_to_unsigned_async(result, counter_num_blocks, streams);
result_cast.as_mut().info = ct.as_ref().info.after_ilog2();
result_cast
}

/// Returns the number of trailing zeros in the binary representation of `ct`
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/rotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info);
}

/// # Safety
Expand Down Expand Up @@ -198,6 +199,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info);
}

/// # Safety
Expand Down
4 changes: 3 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ impl CudaServerKey {
self.propagate_single_carry_assign_async(ct_left, stream, None, OutputFlag::Carry);
}
stream.synchronize();
ct_left.as_mut().info = ct_left.as_ref().info.after_overflowing_scalar_add_sub();

let num_scalar_blocks =
BlockDecomposer::with_early_stop_at_zero(scalar, self.message_modulus.0.ilog2())
Expand Down Expand Up @@ -345,7 +346,8 @@ impl CudaServerKey {
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams);
let (mut result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams);
result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub();

let mut extra_scalar_block_iter =
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ impl CudaServerKey {
);
}
}

result.as_mut().info = ct.as_ref().info.after_min_max();
result
}

Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

pub fn unchecked_scalar_rotate_left<Scalar, T>(
Expand Down Expand Up @@ -204,6 +205,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

pub fn unchecked_scalar_rotate_right<Scalar, T>(
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

/// Computes homomorphically a left shift by a scalar.
Expand Down Expand Up @@ -300,6 +301,7 @@ impl CudaServerKey {
}
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

/// Computes homomorphically a right shift by a scalar.
Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ impl CudaServerKey {
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
streams,
);
let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams);
let (mut result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams);
result.as_mut().info = tmp_lhs.as_ref().info.after_overflowing_scalar_add_sub();

let mut extra_scalar_block_iter =
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info);
}

/// # Safety
Expand Down Expand Up @@ -196,6 +197,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info);
}

/// # Safety
Expand Down

0 comments on commit 241b737

Please sign in to comment.