Skip to content

Commit

Permalink
fix(gpu) update degree and noise after more ops
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Dec 11, 2024
1 parent caaa97a commit c9b4a58
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 20 deletions.
56 changes: 51 additions & 5 deletions tfhe/src/integer/gpu/ciphertext/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl CudaRadixCiphertextInfo {
.blocks
.iter()
.map(|info| CudaBlockInfo {
degree: info.degree,
degree: Degree::new(info.message_modulus.0 - 1),
message_modulus: info.message_modulus,
carry_modulus: info.carry_modulus,
pbs_order: info.pbs_order,
Expand Down Expand Up @@ -167,11 +167,11 @@ impl CudaRadixCiphertextInfo {
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: left.degree,
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: left.noise_level,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand All @@ -183,11 +183,57 @@ impl CudaRadixCiphertextInfo {
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: left.degree,
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: left.noise_level,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_rotate(&self, other: &Self) -> Self {
Self {
blocks: self
.blocks
.iter()
.zip(&other.blocks)
.map(|(left, _)| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_scalar_rotate(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|left| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
}
pub(crate) fn after_min_max(&self) -> Self {
Self {
blocks: self
.blocks
.iter()
.map(|left| CudaBlockInfo {
degree: Degree::new(left.message_modulus.0 - 1),
message_modulus: left.message_modulus,
carry_modulus: left.carry_modulus,
pbs_order: left.pbs_order,
noise_level: NoiseLevel::NOMINAL,
})
.collect(),
}
Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/gpu/server_key/radix/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ impl CudaServerKey {
);
}
}
result.as_mut().info = ct_left.as_ref().info.after_min_max();

result
}
Expand Down Expand Up @@ -1129,7 +1130,7 @@ impl CudaServerKey {
);
}
}

result.as_mut().info = ct_left.as_ref().info.after_min_max();
result
}

Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/rotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info);
}

/// # Safety
Expand Down Expand Up @@ -198,6 +199,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&rotate.as_ref().info);
}

/// # Safety
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ impl CudaServerKey {
);
}
}

result.as_mut().info = ct.as_ref().info.after_min_max();
result
}

Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

pub fn unchecked_scalar_rotate_left<Scalar, T>(
Expand Down Expand Up @@ -204,6 +205,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

pub fn unchecked_scalar_rotate_right<Scalar, T>(
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

/// Computes homomorphically a left shift by a scalar.
Expand Down Expand Up @@ -300,6 +301,7 @@ impl CudaServerKey {
}
}
}
ct.as_mut().info = ct.as_ref().info.after_scalar_rotate();
}

/// Computes homomorphically a right shift by a scalar.
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info);
}

/// # Safety
Expand Down Expand Up @@ -196,6 +197,7 @@ impl CudaServerKey {
);
}
}
ct.as_mut().info = ct.as_ref().info.after_rotate(&shift.as_ref().info);
}

/// # Safety
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index}.",
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);
} else if i < binary_ops.len() + unary_ops.len() {
let index = i - binary_ops.len();
Expand Down Expand Up @@ -629,7 +629,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on unary op {fn_name} with clear input {clear_input} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on unary op {fn_name} with clear input {clear_input} at iteration {fn_index}.",
);
} else if i < binary_ops.len() + unary_ops.len() + scalar_binary_ops.len() {
let index = i - binary_ops.len() - unary_ops.len();
Expand Down Expand Up @@ -671,7 +671,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);
} else if i < binary_ops.len()
+ unary_ops.len()
Expand Down Expand Up @@ -727,7 +727,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);
assert_eq!(
decrypted_overflow, expected_overflow,
Expand Down Expand Up @@ -794,7 +794,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);
assert_eq!(
decrypted_overflow, expected_overflow,
Expand Down Expand Up @@ -836,7 +836,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);

let res_ct: RadixCiphertext = res.into_radix(1, &sks);
Expand Down Expand Up @@ -886,7 +886,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on binary op {fn_name} with clear inputs {clear_left} and {clear_right} at iteration {fn_index}.",
);
let res_ct: RadixCiphertext = res.into_radix(1, &sks);
if i % 2 == 0 {
Expand Down Expand Up @@ -945,7 +945,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} and {clear_bool} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} and {clear_bool} at iteration {fn_index}.",
);
if i % 2 == 0 {
left_vec[j] = res.clone();
Expand Down Expand Up @@ -1020,11 +1020,11 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res_q, expected_res_q,
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.",
);
assert_eq!(
decrypted_res_r, expected_res_r,
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.",
);
if i % 2 == 0 {
left_vec[j] = res_q.clone();
Expand Down Expand Up @@ -1103,11 +1103,11 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res_q, expected_res_q,
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.",
);
assert_eq!(
decrypted_res_r, expected_res_r,
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear inputs {clear_left}, {clear_right} at iteration {fn_index}.",
);
if i % 2 == 0 {
left_vec[j] = res_r.clone();
Expand Down Expand Up @@ -1181,7 +1181,7 @@ pub(crate) fn random_op_sequence_test<P>(
// Correctness check
assert_eq!(
decrypted_res, expected_res,
"Invalid result on op {fn_name} with clear input {clear_input} at iteration{fn_index} at iteration{fn_index}.",
"Invalid result on op {fn_name} with clear input {clear_input} at iteration {fn_index}.",
);
if i % 2 == 0 {
left_vec[j] = cast_res.clone();
Expand Down

0 comments on commit c9b4a58

Please sign in to comment.