From 7534b68e5c6f26d3f66c2033d215e5f6e29cf743 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 11 Oct 2023 15:04:07 +0200 Subject: [PATCH] test(core): use polynomial tests from NTT PR - initial work done in https://github.com/zama-ai/tfhe-rs/pull/394 - useful reworks of the tests have been waiting in that PR, this is to have those tests while NTT usage gets validated co-authored-by: sarah-ek --- .../algorithms/polynomial_algorithms.rs | 206 ++++++++++++++---- 1 file changed, 168 insertions(+), 38 deletions(-) diff --git a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs index 3f25e64d3b..cd6d915021 100644 --- a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs +++ b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs @@ -105,6 +105,45 @@ pub fn polynomial_wrapping_add_multisum_assign( + output: &mut Polynomial, + lhs: &Polynomial, + rhs: &Polynomial, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont1: Container, + InputCont2: Container, +{ + fn implementation( + mut output: Polynomial<&mut [Scalar]>, + lhs: Polynomial<&[Scalar]>, + rhs: Polynomial<&[Scalar]>, + ) { + let polynomial_size = output.polynomial_size(); + let degree = output.degree(); + + for (lhs_degree, &lhs_coeff) in lhs.iter().enumerate() { + for (rhs_degree, &rhs_coeff) in rhs.iter().enumerate() { + let target_degree = lhs_degree + rhs_degree; + if target_degree <= degree { + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = + (*output_coefficient).wrapping_add(lhs_coeff.wrapping_mul(rhs_coeff)); + } else { + let target_degree = target_degree % polynomial_size.0; + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = + (*output_coefficient).wrapping_sub(lhs_coeff.wrapping_mul(rhs_coeff)); + } + } + } + } + implementation(output.as_mut_view(), lhs.as_view(), rhs.as_view()); +} + /// Add the result of the product between two polynomials, reduced modulo $(X^{N}+1)$, to the /// output polynomial. /// @@ -155,24 +194,7 @@ pub fn polynomial_wrapping_add_mul_assign( + output: &mut Polynomial, + lhs: &Polynomial, + rhs: &Polynomial, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont1: Container, + InputCont2: Container, +{ + fn implementation( + mut output: Polynomial<&mut [Scalar]>, + lhs: Polynomial<&[Scalar]>, + rhs: Polynomial<&[Scalar]>, + ) { + let polynomial_size = output.polynomial_size(); + let degree = output.degree(); + + for (lhs_degree, &lhs_coeff) in lhs.iter().enumerate() { + for (rhs_degree, &rhs_coeff) in rhs.iter().enumerate() { + let target_degree = lhs_degree + rhs_degree; + if target_degree <= degree { + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = + (*output_coefficient).wrapping_sub(lhs_coeff.wrapping_mul(rhs_coeff)); + } else { + let target_degree = target_degree % polynomial_size.0; + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = + (*output_coefficient).wrapping_add(lhs_coeff.wrapping_mul(rhs_coeff)); + } + } + } + } + implementation(output.as_mut_view(), lhs.as_view(), rhs.as_view()); +} + /// Subtract the result of the product between two polynomials, reduced modulo $(X^{N}+1)$, to the /// output polynomial. /// @@ -559,24 +620,7 @@ pub fn polynomial_wrapping_sub_mul_assign() { // 50 times the test - for _i in 0..50 { + for _ in 0..50 { // random source let mut rng = rand::thread_rng(); @@ -820,7 +864,7 @@ mod test { let mut ka_mul = Polynomial::new(T::ZERO, polynomial_size); // compute the schoolbook - polynomial_wrapping_mul(&mut sb_mul, &poly_1, &poly_2); + polynomial_wrapping_add_mul_assign_schoolbook(&mut sb_mul, &poly_1, &poly_2); // compute the karatsuba polynomial_karatsuba_wrapping_mul(&mut ka_mul, &poly_1, &poly_2); @@ -830,6 +874,72 @@ mod test { } } + /// test if we have the same result when using schoolbook or ntt/karatsuba + /// for random polynomial multiplication + fn test_add_mul() { + for polynomial_log in 4..=12 { + for _ in 0..10 { + let polynomial_size = PolynomialSize(1 << polynomial_log); + let mut generator = new_random_generator(); + + // generate two random Torus polynomials + let mut poly_1 = Polynomial::new(T::ZERO, polynomial_size); + generator.fill_slice_with_random_uniform::(poly_1.as_mut()); + let poly_1 = poly_1; + + let mut poly_2 = Polynomial::new(T::ZERO, polynomial_size); + generator.fill_slice_with_random_uniform::(poly_2.as_mut()); + let poly_2 = poly_2; + + // copy this polynomial + let mut sb_mul = Polynomial::new(T::ZERO, polynomial_size); + let mut ka_mul = Polynomial::new(T::ZERO, polynomial_size); + + // compute the schoolbook + polynomial_wrapping_add_mul_assign_schoolbook(&mut sb_mul, &poly_1, &poly_2); + + // compute the ntt/karatsuba + polynomial_wrapping_add_mul_assign(&mut ka_mul, &poly_1, &poly_2); + + // test + assert_eq!(&sb_mul, &ka_mul); + } + } + } + + /// test if we have the same result when using schoolbook or ntt/karatsuba + /// for random polynomial multiplication + fn test_sub_mul() { + for polynomial_log in 4..=12 { + for _ in 0..10 { + let polynomial_size = PolynomialSize(1 << polynomial_log); + let mut generator = new_random_generator(); + + // generate two random Torus polynomials + let mut poly_1 = Polynomial::new(T::ZERO, polynomial_size); + generator.fill_slice_with_random_uniform::(poly_1.as_mut()); + let poly_1 = poly_1; + + let mut poly_2 = Polynomial::new(T::ZERO, polynomial_size); + generator.fill_slice_with_random_uniform::(poly_2.as_mut()); + let poly_2 = poly_2; + + // copy this polynomial + let mut sb_mul = Polynomial::new(T::ZERO, polynomial_size); + let mut ka_mul = Polynomial::new(T::ZERO, polynomial_size); + + // compute the schoolbook + polynomial_wrapping_sub_mul_assign_schoolbook(&mut sb_mul, &poly_1, &poly_2); + + // compute the ntt/karatsuba + polynomial_wrapping_sub_mul_assign(&mut ka_mul, &poly_1, &poly_2); + + // test + assert_eq!(&sb_mul, &ka_mul); + } + } + } + #[test] pub fn test_multiply_divide_unit_monomial_u32() { test_multiply_divide_unit_monomial::() @@ -849,4 +959,24 @@ mod test { pub fn test_multiply_karatsuba_u64() { test_multiply_karatsuba::() } + + #[test] + pub fn test_add_mul_u32() { + test_add_mul::() + } + + #[test] + pub fn test_add_mul_u64() { + test_add_mul::() + } + + #[test] + pub fn test_sub_mul_u32() { + test_sub_mul::() + } + + #[test] + pub fn test_sub_mul_u64() { + test_sub_mul::() + } }