Skip to content

Commit

Permalink
chore: fix some overflow related panics
Browse files Browse the repository at this point in the history
Some overflow panics were occurring when
overflow-check=true

Most of them were expected/accepted, so this commit only
makes changes so that its now explicit that overflow is accepted.
  • Loading branch information
tmontaigu committed Oct 17, 2024
1 parent 1a5dfb3 commit 400ec4e
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 28 deletions.
8 changes: 4 additions & 4 deletions tfhe/src/integer/bigint/algorithms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,20 @@ pub(crate) fn bitxor_assign(lhs: &mut [u64], rhs: &[u64]) {
}

#[inline(always)]
pub(crate) fn add_with_carry<T: UnsignedInteger>(l: T, r: T, c: bool) -> (T, bool) {
pub(crate) fn wrapping_add_with_carry<T: UnsignedInteger>(l: T, r: T, c: bool) -> (T, bool) {
let (lr, o0) = l.overflowing_add(r);
let (lrc, o1) = lr.overflowing_add(T::cast_from(c));
(lrc, o0 | o1)
}

pub(crate) fn add_assign_words<T: UnsignedInteger>(lhs: &mut [T], rhs: &[T]) {
pub(crate) fn wrapping_add_assign_words<T: UnsignedInteger>(lhs: &mut [T], rhs: &[T]) {
let iter = lhs
.iter_mut()
.zip(rhs.iter().copied().chain(std::iter::repeat(T::ZERO)));

let mut carry = false;
for (lhs_block, rhs_block) in iter {
let (result, out_carry) = add_with_carry(*lhs_block, rhs_block, carry);
let (result, out_carry) = wrapping_add_with_carry(*lhs_block, rhs_block, carry);
*lhs_block = result;
carry = out_carry;
}
Expand Down Expand Up @@ -188,7 +188,7 @@ pub(crate) fn schoolbook_mul_assign(lhs: &mut [u64], rhs: &[u64]) {

let mut result = terms.pop().unwrap();
for term in terms {
add_assign_words(&mut result, &term);
wrapping_add_assign_words(&mut result, &term);
}

for (lhs_block, result_block) in lhs.iter_mut().zip(result) {
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/bigint/static_signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl<const N: usize> std::ops::Add<Self> for StaticSignedBigInt<N> {

impl<const N: usize> std::ops::AddAssign<Self> for StaticSignedBigInt<N> {
fn add_assign(&mut self, rhs: Self) {
super::algorithms::add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice());
super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice());
}
}

Expand Down
12 changes: 11 additions & 1 deletion tfhe/src/integer/bigint/static_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ impl<const N: usize> StaticUnsignedBigInt<N> {
pub fn ceil_ilog2(self) -> u32 {
self.ilog2() + u32::from(!self.is_power_of_two())
}

pub fn wrapping_sub(mut self, other: Self) -> Self {
let mut negated = !other;
super::algorithms::wrapping_add_assign_words(
negated.0.as_mut_slice(),
Self::from(1u64).0.as_slice(),
);
super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), negated.0.as_slice());
self
}
}

#[cfg(test)]
Expand Down Expand Up @@ -107,7 +117,7 @@ impl<const N: usize> std::cmp::PartialOrd for StaticUnsignedBigInt<N> {

impl<const N: usize> std::ops::AddAssign<Self> for StaticUnsignedBigInt<N> {
fn add_assign(&mut self, rhs: Self) {
super::algorithms::add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice());
super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice());
}
}

Expand Down
13 changes: 8 additions & 5 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ fn integer_smart_add_128_bits(param: ClassicPBSParameters) {
// add the two ciphertexts
let mut ct_res = sks.smart_add(&mut ctxt_0, &mut ctxt_1);

let mut clear_result = clear_0 + clear_1;
let mut clear_result = clear_0.wrapping_add(clear_1);

// println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1);
//add multiple times to raise the degree
for _ in 0..2 {
ct_res = sks.smart_add(&mut ct_res, &mut ctxt_0);
clear_result += clear_0;
clear_result = clear_result.wrapping_add(clear_0);

let dec_res: u128 = cks.decrypt_radix(&ct_res);
// println!("clear = {}, dec_res = {}", clear, dec_res);
Expand Down Expand Up @@ -629,7 +629,7 @@ fn integer_unchecked_scalar_decomposition_overflow(param: ClassicPBSParameters)
let ct_res = sks.unchecked_scalar_add(&ct_0, scalar);
let dec_res = cks.decrypt_radix(&ct_res);

assert_eq!((clear_0 + scalar as u128), dec_res);
assert_eq!(clear_0.wrapping_add(scalar as u128), dec_res);

// Check subtraction
// -----------------
Expand All @@ -640,7 +640,7 @@ fn integer_unchecked_scalar_decomposition_overflow(param: ClassicPBSParameters)
let ct_res = sks.unchecked_scalar_sub(&ct_0, scalar);
let dec_res = cks.decrypt_radix(&ct_res);

assert_eq!((clear_0 - scalar as u128), dec_res);
assert_eq!(clear_0.wrapping_sub(scalar as u128), dec_res);
}

#[test]
Expand All @@ -666,7 +666,7 @@ fn integer_smart_scalar_mul_decomposition_overflow() {
let ct_res = sks.smart_scalar_mul(&mut ct_0, scalar);
let dec_res = cks.decrypt_radix(&ct_res);

assert_eq!((clear_0 * scalar as u128), dec_res);
assert_eq!(clear_0.wrapping_mul(scalar as u128), dec_res);
}

fn integer_default_overflowing_sub<P>(param: P)
Expand Down Expand Up @@ -696,6 +696,9 @@ fn integer_create_trivial_min_max(param: impl Into<PBSParameters>) {
// If num_bits_in_one_block is not a multiple of bit_size, then
// the actual number of bits is not the same as bit size (we end up with more)
let actual_num_bits = num_blocks * num_bits_in_one_block;
if actual_num_bits >= i128::BITS {
break;
}

// Unsigned
{
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ impl ServerKey {
} else {
// u64::MAX is -1 in two's complement
// We apply the modulus including the padding bit
u64::MAX % (1 << (block_modulus + 1))
u64::MAX % (block_modulus * 2)
}
})]
};
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ impl ServerKey {
} else {
// u64::MAX is -1 in tow's complement
// We apply the modulus including the padding bit
u64::MAX % (1 << (block_modulus + 1))
u64::MAX % (block_modulus * 2)
}
})]
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -881,17 +881,17 @@ impl ServerKey {
let block_states = {
for i in 1..grouping_size {
let state_fn = |block| {
let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2);
r << (i - 1)
let r = u64::MAX * u64::from(block != 0);
(r << (i - 1)) % (packed_modulus * 2)
};
first_grouping_luts.push(self.key.generate_lookup_table(state_fn));
}

let other_block_state_luts = (0..grouping_size)
.map(|i| {
let state_fn = |block| {
let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2);
r << i
let r = u64::MAX * u64::from(block != 0);
(r << i) % (packed_modulus * 2)
};
self.key.generate_lookup_table(state_fn)
})
Expand Down
15 changes: 13 additions & 2 deletions tfhe/src/integer/server_key/radix_parallel/scalar_div_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pub trait MiniUnsignedInteger:
fn ilog2(self) -> u32;

fn is_power_of_two(self) -> bool;

fn wrapping_sub(self, other: Self) -> Self;
}

impl<T> MiniUnsignedInteger for T
Expand All @@ -59,6 +61,10 @@ where
fn is_power_of_two(self) -> bool {
<T as UnsignedInteger>::is_power_of_two(self)
}

fn wrapping_sub(self, other: Self) -> Self {
<T as UnsignedInteger>::wrapping_sub(self, other)
}
}

impl<const N: usize> MiniUnsignedInteger for StaticUnsignedBigInt<N> {
Expand All @@ -73,6 +79,10 @@ impl<const N: usize> MiniUnsignedInteger for StaticUnsignedBigInt<N> {
fn is_power_of_two(self) -> bool {
self.is_power_of_two()
}

fn wrapping_sub(self, other: Self) -> Self {
self.wrapping_sub(other)
}
}

pub trait Reciprocable: MiniUnsignedInteger {
Expand Down Expand Up @@ -496,8 +506,9 @@ impl ServerKey {
// The subtraction may overflow.
// We then cast the result to a signed type.
// Overall, this will work fine due to two's complement representation
let cst = chosen_multiplier.multiplier
- (<T::Unsigned as Reciprocable>::DoublePrecision::ONE << numerator_bits);
let cst = chosen_multiplier.multiplier.wrapping_sub(
<T::Unsigned as Reciprocable>::DoublePrecision::ONE << numerator_bits,
);
let cst = T::DoublePrecision::cast_from(cst);

// MULSH(m - 2^N, n)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1215,12 +1215,12 @@ where
let mut ctxt_0 = cks.encrypt(clear_0);

let mut ct_res = executor.execute((&mut ctxt_0, clear_1));
clear = (clear_0 - clear_1) % modulus;
clear = clear_0.wrapping_sub(clear_1) % modulus;

// Sub multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&mut ct_res, clear_1));
clear = (clear - clear_1) % modulus;
clear = clear.wrapping_sub(clear_1) % modulus;

let dec_res: u64 = cks.decrypt(&ct_res);

Expand Down
14 changes: 7 additions & 7 deletions tfhe/src/shortint/server_key/tests/parametrized_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ where

let dec_res = cks.decrypt(&ct_res);

assert_eq!((clear - scalar) % message_modulus, dec_res as u8);
assert_eq!(clear.wrapping_sub(scalar) % message_modulus, dec_res as u8);
}
}

Expand All @@ -1174,11 +1174,11 @@ where

let mut ct_res = sks.smart_scalar_sub(&mut ctxt_0, clear_1);

let mut clear = (clear_0 - clear_1) % modulus;
let mut clear = clear_0.wrapping_sub(clear_1) % modulus;

for _ in 0..NB_SUB_TEST_SMART {
ct_res = sks.smart_scalar_sub(&mut ct_res, clear_1);
clear = (clear - clear_1) % modulus;
clear = clear.wrapping_sub(clear_1) % modulus;

let dec_res = cks.decrypt(&ct_res);

Expand Down Expand Up @@ -1427,7 +1427,7 @@ where

let dec = cks.decrypt(&ct_tmp);

let clear_result = (clear1 - clear2) % modulus;
let clear_result = clear1.wrapping_sub(clear2) % modulus;
assert_eq!(clear_result, dec % modulus);
}
}
Expand All @@ -1452,10 +1452,10 @@ where

let mut ct_res = sks.smart_sub(&mut ct1, &mut ct2);

let mut clear_res = (clear1 - clear2) % modulus;
let mut clear_res = clear1.wrapping_sub(clear2) % modulus;
for _ in 0..NB_SUB_TEST_SMART {
ct_res = sks.smart_sub(&mut ct_res, &mut ct2);
clear_res = (clear_res - clear2) % modulus;
clear_res = clear_res.wrapping_sub(clear2) % modulus;
}

let dec_res = cks.decrypt(&ct_res);
Expand Down Expand Up @@ -1625,7 +1625,7 @@ where

let dec_res = cks.decrypt(&res);

let clear_mux = (msg_true - msg_false) * control_bit + msg_false;
let clear_mux = (msg_true.wrapping_sub(msg_false) * control_bit).wrapping_add(msg_false);
println!("(msg_true - msg_false) * control_bit + msg_false = {clear_mux}, res = {dec_res}");
assert_eq!(clear_mux, dec_res);
}
Expand Down

0 comments on commit 400ec4e

Please sign in to comment.