From 0e0d860f7e267bfd4f0b5961e1476d2e3eb4fa2a Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 26 Aug 2024 15:20:20 +0200 Subject: [PATCH] chore(tfhe): update dependencies with breaking changes - concrete-fft to 0.5 and concrete-ntt 0.2.0 due to rust AVX512 breaking change (fix for bad args in function) - dyn-stack to 0.10 due to concrete-fft update --- tfhe/Cargo.toml | 6 +- .../lwe_programmable_bootstrapping/ntt64.rs | 23 ++--- .../commons/math/decomposition/iter.rs | 16 ++-- .../algorithms/glwe_fast_keyswitch.rs | 7 +- .../fft_impl/fft128/crypto/bootstrap.rs | 8 +- .../fft_impl/fft128/crypto/ggsw.rs | 51 +++++------ .../fft_impl/fft128/math/fft/mod.rs | 20 ++--- .../fft_impl/fft128/math/fft/tests.rs | 8 +- .../fft_impl/fft128_u128/crypto/bootstrap.rs | 10 +-- .../fft_impl/fft128_u128/crypto/ggsw.rs | 52 +++++------- .../fft_impl/fft128_u128/crypto/tests.rs | 2 +- .../fft_impl/fft128_u128/math/fft/mod.rs | 19 ++--- .../fft_impl/fft64/crypto/bootstrap.rs | 6 +- .../core_crypto/fft_impl/fft64/crypto/ggsw.rs | 12 +-- .../fft_impl/fft64/crypto/wop_pbs/mod.rs | 85 +++++++++---------- .../fft_impl/fft64/math/decomposition.rs | 4 +- .../fft_impl/fft64/math/fft/mod.rs | 6 +- 17 files changed, 146 insertions(+), 189 deletions(-) diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index d1a764cb2d..60a9ee26eb 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -62,12 +62,12 @@ lazy_static = { version = "1.4.0", optional = true } serde = { version = "1.0", features = ["derive"] } rayon = { version = "1.5.0" } bincode = "1.3.3" -concrete-fft = { version = "0.4.1", features = ["serde", "fft128"] } -concrete-ntt = { version = "0.1.2" } +concrete-fft = { version = "0.5.0", features = ["serde", "fft128"] } +concrete-ntt = { version = "0.2.0" } pulp = "0.18.22" tfhe-cuda-backend = { version = "0.4.0-alpha.0", path = "../backends/tfhe-cuda-backend", optional = true } aligned-vec = { version = "0.5", features = ["serde"] } -dyn-stack = { version = "0.9" } +dyn-stack = { version = "0.10" } paste = "1.0.7" fs2 = { version = "0.4.3", optional = true } # Used for OPRF in shortint diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs index 2d07365ea2..3901282d14 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs @@ -250,13 +250,10 @@ pub fn blind_rotate_ntt64_assign_mem_optimized( if *lwe_mask_element != 0u64 { let stack = stack.rb_mut(); // We copy ct_0 to ct_1 - let (mut ct1, stack) = + let (ct1, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied()); - let mut ct1 = GlweCiphertextMutView::from_container( - &mut *ct1, - lut_poly_size, - ciphertext_modulus, - ); + let mut ct1 = + GlweCiphertextMutView::from_container(ct1, lut_poly_size, ciphertext_modulus); // We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat} for mut poly in ct1.as_mut_polynomial_list().iter_mut() { @@ -503,10 +500,10 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized< accumulator.ciphertext_modulus() ); - let (mut local_accumulator_data, stack) = + let (local_accumulator_data, stack) = stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied()); let mut local_accumulator = GlweCiphertextMutView::from_container( - &mut *local_accumulator_data, + local_accumulator_data, accumulator.polynomial_size(), accumulator.ciphertext_modulus(), ); @@ -568,12 +565,11 @@ pub(crate) fn add_external_product_ntt64_assign( out.ciphertext_modulus(), ); - let (mut output_fft_buffer, mut substack0) = + let (output_fft_buffer, mut substack0) = stack.make_aligned_raw::(poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once // it has been fully initialized for the first time. - let output_fft_buffer = &mut *output_fft_buffer; let mut is_output_uninit = true; { @@ -616,17 +612,16 @@ pub(crate) fn add_external_product_ntt64_assign( glwe_decomp_term.as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (mut ntt_poly, _) = - substack2.rb_mut().make_aligned_raw::(poly_size, align); + let (ntt_poly, _) = substack2.rb_mut().make_aligned_raw::(poly_size, align); // We perform the forward ntt transform for the glwe polynomial - ntt.forward(PolynomialMutView::from_container(&mut ntt_poly), glwe_poly); + ntt.forward(PolynomialMutView::from_container(ntt_poly), glwe_poly); // Now we loop through the polynomials of the output, and add the // corresponding product of polynomials. update_with_fmadd_ntt64( output_fft_buffer, ggsw_row.as_ref(), - &ntt_poly, + ntt_poly, is_output_uninit, poly_size, ntt, diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 8122aac8f4..3d91888976 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -4,7 +4,7 @@ use crate::core_crypto::commons::math::decomposition::{ }; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; -use dyn_stack::{DynArray, PodStack, ReborrowMut}; +use dyn_stack::{PodStack, ReborrowMut}; /// An iterator that yields the terms of the signed decomposition of an integer. /// @@ -288,9 +288,9 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> { // ...0001111 mod_b_mask: u64, // The internal states of each decomposition - states: DynArray<'buffers, u64>, + states: &'buffers mut [u64], // Corresponding input signs - input_signs: DynArray<'buffers, u8>, + input_signs: &'buffers mut [u8], // A flag which stores whether the iterator is a fresh one (for the recompose method). fresh: bool, ciphertext_modulus: u64, @@ -306,9 +306,9 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { ) -> (Self, PodStack<'buffers>) { let shift = modulus.ceil_ilog2() as usize - decomposer.base_log * decomposer.level_count; let input_size = input.len(); - let (mut states, stack) = + let (states, stack) = stack.make_aligned_raw::(input_size, aligned_vec::CACHELINE_ALIGN); - let (mut input_signs, stack) = + let (input_signs, stack) = stack.make_aligned_raw::(input_size, aligned_vec::CACHELINE_ALIGN); for ((i, state), sign) in input @@ -393,11 +393,7 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { &mut self, substack1: &'a mut PodStack, align: usize, - ) -> ( - DecompositionLevel, - dyn_stack::DynArray<'a, u64>, - PodStack<'a>, - ) { + ) -> (DecompositionLevel, &'a mut [u64], PodStack<'a>) { let (glwe_level, _, glwe_decomp_term) = self.next_term().unwrap(); let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); diff --git a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs index 597c488f5f..10d9ea52a9 100644 --- a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs +++ b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs @@ -193,12 +193,11 @@ pub fn glwe_fast_keyswitch( ggsw.decomposition_base_log(), ggsw.decomposition_level_count(), ); - let (mut output_fft_buffer, mut substack0) = + let (output_fft_buffer, mut substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size_out().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once // it has been fully initialized for the first time. - let output_fft_buffer = &mut *output_fft_buffer; let mut is_output_uninit = true; { @@ -244,14 +243,14 @@ pub fn glwe_fast_keyswitch( glwe_decomp_term.get_mask().as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (mut fourier, substack3) = substack2 + let (fourier, substack3) = substack2 .rb_mut() .make_aligned_raw::(fourier_poly_size, align); // We perform the forward fft transform for the glwe polynomial let fourier = fft .forward_as_integer( - FourierPolynomialMutView { data: &mut fourier }, + FourierPolynomialMutView { data: fourier }, glwe_poly, substack3, ) diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs index 99939d5c12..14c0696313 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs @@ -289,10 +289,10 @@ where if *lwe_mask_element != Scalar::ZERO { let stack = stack.rb_mut(); // We copy ct_0 to ct_1 - let (mut ct1, stack) = + let (ct1, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied()); let mut ct1 = GlweCiphertextMutView::from_container( - &mut *ct1, + ct1, ct0.polynomial_size(), ct0.ciphertext_modulus(), ); @@ -361,10 +361,10 @@ where return this.bootstrap_u128(&mut lwe_out, &lwe_in, &accumulator, fft, stack); } - let (mut local_accumulator_data, stack) = + let (local_accumulator_data, stack) = stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied()); let mut local_accumulator = GlweCiphertextMutView::from_container( - &mut *local_accumulator_data, + local_accumulator_data, accumulator.polynomial_size(), accumulator.ciphertext_modulus(), ); diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs index 14aa239746..1e8dc852ea 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs @@ -397,13 +397,13 @@ pub fn add_external_product_assign( ggsw.decomposition_level_count(), ); - let (mut output_fft_buffer_re0, stack) = + let (output_fft_buffer_re0, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_re1, stack) = + let (output_fft_buffer_re1, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_im0, stack) = + let (output_fft_buffer_im0, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_im1, mut substack0) = + let (output_fft_buffer_im1, mut substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid @@ -455,30 +455,30 @@ pub fn add_external_product_assign( ) { let len = fourier_poly_size; let stack = substack2.rb_mut(); - let (mut fourier_re0, stack) = stack.make_aligned_raw::(len, align); - let (mut fourier_re1, stack) = stack.make_aligned_raw::(len, align); - let (mut fourier_im0, stack) = stack.make_aligned_raw::(len, align); - let (mut fourier_im1, _) = stack.make_aligned_raw::(len, align); + let (fourier_re0, stack) = stack.make_aligned_raw::(len, align); + let (fourier_re1, stack) = stack.make_aligned_raw::(len, align); + let (fourier_im0, stack) = stack.make_aligned_raw::(len, align); + let (fourier_im1, _) = stack.make_aligned_raw::(len, align); // We perform the forward fft transform for the glwe polynomial fft.forward_as_integer( - &mut fourier_re0, - &mut fourier_re1, - &mut fourier_im0, - &mut fourier_im1, + fourier_re0, + fourier_re1, + fourier_im0, + fourier_im1, glwe_poly.as_ref(), ); // Now we loop through the polynomials of the output, and add the // corresponding product of polynomials. update_with_fmadd( - &mut output_fft_buffer_re0, - &mut output_fft_buffer_re1, - &mut output_fft_buffer_im0, - &mut output_fft_buffer_im1, + output_fft_buffer_re0, + output_fft_buffer_re1, + output_fft_buffer_im0, + output_fft_buffer_im1, ggsw_row, - &fourier_re0, - &fourier_re1, - &fourier_im0, - &fourier_im1, + fourier_re0, + fourier_re1, + fourier_im0, + fourier_im1, is_output_uninit, fourier_poly_size, ); @@ -495,11 +495,6 @@ pub fn add_external_product_assign( // // We iterate over the polynomials in the output. if !is_output_uninit { - let output_fft_buffer_re0 = output_fft_buffer_re0; - let output_fft_buffer_re1 = output_fft_buffer_re1; - let output_fft_buffer_im0 = output_fft_buffer_im0; - let output_fft_buffer_im1 = output_fft_buffer_im1; - for (mut out, fourier_re0, fourier_re1, fourier_im0, fourier_im1) in izip!( out.as_mut_polynomial_list().iter_mut(), output_fft_buffer_re0.into_chunks(fourier_poly_size), @@ -532,11 +527,7 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>( decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>, substack1: &'a mut PodStack, align: usize, -) -> ( - DecompositionLevel, - dyn_stack::DynArray<'a, Scalar>, - PodStack<'a>, -) { +) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) { let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap(); let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); (glwe_level, glwe_decomp_term, substack2) diff --git a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs index c694f3fdbe..e04203fcad 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs @@ -495,27 +495,19 @@ impl<'a> Fft128View<'a> { debug_assert_eq!(n, 2 * fourier_im0.len()); debug_assert_eq!(n, 2 * fourier_im1.len()); - let (mut tmp_re0, stack) = + let (tmp_re0, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re0.iter().copied()); - let (mut tmp_re1, stack) = + let (tmp_re1, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re1.iter().copied()); - let (mut tmp_im0, stack) = + let (tmp_im0, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im0.iter().copied()); - let (mut tmp_im1, _) = + let (tmp_im1, _) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im1.iter().copied()); - self.plan - .inv(&mut tmp_re0, &mut tmp_re1, &mut tmp_im0, &mut tmp_im1); + self.plan.inv(tmp_re0, tmp_re1, tmp_im0, tmp_im1); let (standard_re, standard_im) = standard.split_at_mut(n / 2); - conv_fn( - standard_re, - standard_im, - &tmp_re0, - &tmp_re1, - &tmp_im0, - &tmp_im1, - ); + conv_fn(standard_re, standard_im, tmp_re0, tmp_re1, tmp_im0, tmp_im1); } } diff --git a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs index 25fd1b0c93..95ed7002e3 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs @@ -129,10 +129,10 @@ fn test_product() { ); for (f0_re0, f0_re1, f0_im0, f0_im1, f1_re0, f1_re1, f1_im0, f1_im1) in izip!( - &mut *fourier0_re0, - &mut *fourier0_re1, - &mut *fourier0_im0, - &mut *fourier0_im1, + fourier0_re0, + fourier0_re1, + fourier0_im0, + fourier0_im1, &*fourier1_re0, &*fourier1_re1, &*fourier1_im0, diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs index d6c0095dfc..676d15635c 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs @@ -105,9 +105,9 @@ where if *lwe_mask_element != 0 { let stack = stack.rb_mut(); // We copy ct_0 to ct_1 - let (mut ct1_lo, stack) = + let (ct1_lo, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0_lo.as_ref().iter().copied()); - let (mut ct1_hi, stack) = + let (ct1_hi, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0_hi.as_ref().iter().copied()); let mut ct1_lo = GlweCiphertextMutView::from_container( &mut *ct1_lo, @@ -177,9 +177,9 @@ where let align = CACHELINE_ALIGN; let ciphertext_modulus = accumulator.ciphertext_modulus(); - let (mut local_accumulator_lo, stack) = + let (local_accumulator_lo, stack) = stack.collect_aligned(align, accumulator.as_ref().iter().map(|i| *i as u64)); - let (mut local_accumulator_hi, mut stack) = stack.collect_aligned( + let (local_accumulator_hi, mut stack) = stack.collect_aligned( align, accumulator.as_ref().iter().map(|i| (*i >> 64) as u64), ); @@ -207,7 +207,7 @@ where fft, stack.rb_mut(), ); - let (mut local_accumulator, _) = stack.collect_aligned( + let (local_accumulator, _) = stack.collect_aligned( align, izip!(local_accumulator_lo.as_ref(), local_accumulator_hi.as_ref()) .map(|(&lo, &hi)| lo as u128 | ((hi as u128) << 64)), diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs index c961c69b65..29a053727d 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs @@ -63,32 +63,28 @@ pub fn add_external_product_assign_split(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_re1, stack) = + let (output_fft_buffer_re1, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_im0, stack) = + let (output_fft_buffer_im0, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (mut output_fft_buffer_im1, mut substack0) = + let (output_fft_buffer_im1, mut substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once // it has been fully initialized for the first time. - let output_fft_buffer_re0 = &mut *output_fft_buffer_re0; - let output_fft_buffer_re1 = &mut *output_fft_buffer_re1; - let output_fft_buffer_im0 = &mut *output_fft_buffer_im0; - let output_fft_buffer_im1 = &mut *output_fft_buffer_im1; let mut is_output_uninit = true; { // ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER // DOMAIN In this section, we perform the external product in the fourier // domain, and accumulate the result in the output_fft_buffer variable. - let (mut decomposition_states_lo, stack) = substack0 + let (decomposition_states_lo, stack) = substack0 .rb_mut() .make_aligned_raw::(poly_size * glwe_size, align); - let (mut decomposition_states_hi, mut substack1) = + let (decomposition_states_hi, mut substack1) = stack.make_aligned_raw::(poly_size * glwe_size, align); let shift = 128 - decomposer.base_log * decomposer.level_count; @@ -104,6 +100,7 @@ pub fn add_external_product_assign_split> 64) as u64; } + // Reborrow to avoid mut slices to be moved let decomposition_states_lo = &mut *decomposition_states_lo; let decomposition_states_hi = &mut *decomposition_states_hi; let mut current_level = decomposer.level_count; @@ -118,17 +115,17 @@ pub fn add_external_product_assign_split(poly_size * glwe_size, align); - let (mut glwe_decomp_term_hi, mut substack2) = + let (glwe_decomp_term_hi, mut substack2) = stack.make_aligned_raw::(poly_size * glwe_size, align); let base_log = decomposer.base_log; collect_next_term_split( - &mut glwe_decomp_term_lo, - &mut glwe_decomp_term_hi, + glwe_decomp_term_lo, + glwe_decomp_term_hi, decomposition_states_lo, decomposition_states_hi, mod_b_mask_lo, @@ -136,9 +133,6 @@ pub fn add_external_product_assign_split(len, align); - let (mut fourier_re1, stack) = stack.make_aligned_raw::(len, align); - let (mut fourier_im0, stack) = stack.make_aligned_raw::(len, align); - let (mut fourier_im1, _) = stack.make_aligned_raw::(len, align); + let (fourier_re0, stack) = stack.make_aligned_raw::(len, align); + let (fourier_re1, stack) = stack.make_aligned_raw::(len, align); + let (fourier_im0, stack) = stack.make_aligned_raw::(len, align); + let (fourier_im1, _) = stack.make_aligned_raw::(len, align); // We perform the forward fft transform for the glwe polynomial fft.forward_as_integer_split( - &mut fourier_re0, - &mut fourier_re1, - &mut fourier_im0, - &mut fourier_im1, + fourier_re0, + fourier_re1, + fourier_im0, + fourier_im1, glwe_poly_lo.as_ref(), glwe_poly_hi.as_ref(), ); @@ -192,10 +186,10 @@ pub fn add_external_product_assign_split, stack: PodStack<'_>, ) { - let (mut local_accumulator_data, stack) = + let (local_accumulator_data, stack) = stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied()); let mut local_accumulator = GlweCiphertextMutView::from_container( &mut *local_accumulator_data, diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs index b6ed146396..3ec0add2f6 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs @@ -1316,17 +1316,16 @@ impl<'a> Fft128View<'a> { debug_assert_eq!(n, 2 * fourier_im0.len()); debug_assert_eq!(n, 2 * fourier_im1.len()); - let (mut tmp_re0, stack) = + let (tmp_re0, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re0.iter().copied()); - let (mut tmp_re1, stack) = + let (tmp_re1, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re1.iter().copied()); - let (mut tmp_im0, stack) = + let (tmp_im0, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im0.iter().copied()); - let (mut tmp_im1, _) = + let (tmp_im1, _) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im1.iter().copied()); - self.plan - .inv(&mut tmp_re0, &mut tmp_re1, &mut tmp_im0, &mut tmp_im1); + self.plan.inv(tmp_re0, tmp_re1, tmp_im0, tmp_im1); let (standard_re_lo, standard_im_lo) = standard_lo.split_at_mut(n / 2); let (standard_re_hi, standard_im_hi) = standard_hi.split_at_mut(n / 2); @@ -1335,10 +1334,10 @@ impl<'a> Fft128View<'a> { standard_re_hi, standard_im_lo, standard_im_hi, - &tmp_re0, - &tmp_re1, - &tmp_im0, - &tmp_im1, + tmp_re0, + tmp_re1, + tmp_im0, + tmp_im1, ); } } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs index c512a1f253..2884b54bc1 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs @@ -353,7 +353,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { lut.as_mut_polynomial_list() .iter_mut() .for_each(|mut poly| { - let (mut tmp_poly, _) = stack + let (tmp_poly, _) = stack .rb_mut() .make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN); @@ -364,7 +364,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { // We initialize the ct_0 used for the successive cmuxes let mut ct0 = lut; - let (mut ct1, mut stack) = stack.make_aligned_raw(ct0.as_ref().len(), CACHELINE_ALIGN); + let (ct1, mut stack) = stack.make_aligned_raw(ct0.as_ref().len(), CACHELINE_ALIGN); let mut ct1 = GlweCiphertextMutView::from_container(&mut *ct1, lut_poly_size, ciphertext_modulus); @@ -437,7 +437,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { accumulator.ciphertext_modulus() ); - let (mut local_accumulator_data, stack) = + let (local_accumulator_data, stack) = stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied()); let mut local_accumulator = GlweCiphertextMutView::from_container( &mut *local_accumulator_data, diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs index cf589a9706..ba6b75ae68 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs @@ -588,7 +588,7 @@ pub fn add_external_product_assign( ggsw.decomposition_level_count(), ); - let (mut output_fft_buffer, mut substack0) = + let (output_fft_buffer, mut substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once @@ -638,13 +638,13 @@ pub fn add_external_product_assign( glwe_decomp_term.as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (mut fourier, substack3) = substack2 + let (fourier, substack3) = substack2 .rb_mut() .make_aligned_raw::(fourier_poly_size, align); // We perform the forward fft transform for the glwe polynomial let fourier = fft .forward_as_integer( - FourierPolynomialMutView { data: &mut fourier }, + FourierPolynomialMutView { data: fourier }, glwe_poly, substack3, ) @@ -691,11 +691,7 @@ pub(crate) fn collect_next_term<'a, Scalar: UnsignedTorus>( decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>, substack1: &'a mut PodStack, align: usize, -) -> ( - DecompositionLevel, - dyn_stack::DynArray<'a, Scalar>, - PodStack<'a>, -) { +) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) { let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap(); let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); (glwe_level, glwe_decomp_term, substack2) diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs index c07a2c29a4..8a34d46193 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs @@ -123,17 +123,16 @@ pub fn extract_bits>( let align = CACHELINE_ALIGN; - let (mut lwe_in_buffer_data, stack) = - stack.collect_aligned(align, lwe_in.as_ref().iter().copied()); + let (lwe_in_buffer_data, stack) = stack.collect_aligned(align, lwe_in.as_ref().iter().copied()); let mut lwe_in_buffer = LweCiphertext::from_container(&mut *lwe_in_buffer_data, lwe_in.ciphertext_modulus()); - let (mut lwe_out_ks_buffer_data, stack) = + let (lwe_out_ks_buffer_data, stack) = stack.make_aligned_with(ksk.output_lwe_size().0, align, |_| Scalar::ZERO); let mut lwe_out_ks_buffer = LweCiphertext::from_container(&mut *lwe_out_ks_buffer_data, ksk.ciphertext_modulus()); - let (mut pbs_accumulator_data, stack) = + let (pbs_accumulator_data, stack) = stack.make_aligned_with(glwe_size.0 * polynomial_size.0, align, |_| Scalar::ZERO); let mut pbs_accumulator = GlweCiphertextMutView::from_container( &mut *pbs_accumulator_data, @@ -144,7 +143,7 @@ pub fn extract_bits>( let lwe_size = glwe_dimension .to_equivalent_lwe_dimension(polynomial_size) .to_lwe_size(); - let (mut lwe_out_pbs_buffer_data, mut stack) = + let (lwe_out_pbs_buffer_data, mut stack) = stack.make_aligned_with(lwe_size.0, align, |_| Scalar::ZERO); let mut lwe_out_pbs_buffer = LweCiphertext::from_container( &mut *lwe_out_pbs_buffer_data, @@ -153,26 +152,27 @@ pub fn extract_bits>( // We iterate on the list in reverse as we want to store the extracted MSB at index 0 for (bit_idx, mut output_ct) in lwe_list_out.iter_mut().rev().enumerate() { - // Shift on padding bit - let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned( - align, - lwe_in_buffer - .as_ref() - .iter() - .map(|s| *s << (ciphertext_n_bits - delta_log.0 - bit_idx - 1)), - ); - - // Key switch to input PBS key - keyswitch_lwe_ciphertext( - &ksk, - &LweCiphertext::from_container( - &*lwe_bit_left_shift_buffer_data, - lwe_in.ciphertext_modulus(), - ), - &mut lwe_out_ks_buffer, - ); + // Block to keep the lwe_bit_left_shift_buffer_data alive only as long as needed + { + // Shift on padding bit + let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned( + align, + lwe_in_buffer + .as_ref() + .iter() + .map(|s| *s << (ciphertext_n_bits - delta_log.0 - bit_idx - 1)), + ); - drop(lwe_bit_left_shift_buffer_data); + // Key switch to input PBS key + keyswitch_lwe_ciphertext( + &ksk, + &LweCiphertext::from_container( + lwe_bit_left_shift_buffer_data, + lwe_in.ciphertext_modulus(), + ), + &mut lwe_out_ks_buffer, + ); + } // Store the keyswitch output unmodified to the output list (as we need to to do other // computations on the output of the keyswitch) @@ -306,7 +306,7 @@ pub fn circuit_bootstrap_boolean>( ); // Output for every bootstrapping - let (mut lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with( + let (lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with( fourier_bsk_output_lwe_dimension.to_lwe_size().0, CACHELINE_ALIGN, |_| Scalar::ZERO, @@ -384,7 +384,7 @@ pub fn homomorphic_shift_boolean>( let polynomial_size = fourier_bsk.polynomial_size(); let ciphertext_moudulus = lwe_out.ciphertext_modulus(); - let (mut lwe_left_shift_buffer_data, stack) = + let (lwe_left_shift_buffer_data, stack) = stack.make_aligned_with(lwe_in_size.0, CACHELINE_ALIGN, |_| Scalar::ZERO); let mut lwe_left_shift_buffer = LweCiphertext::from_container( &mut *lwe_left_shift_buffer_data, @@ -403,7 +403,7 @@ pub fn homomorphic_shift_boolean>( *shift_buffer_body.data = (*shift_buffer_body.data).wrapping_add(Scalar::ONE << (ciphertext_n_bits - 2)); - let (mut pbs_accumulator_data, stack) = stack.make_aligned_with( + let (pbs_accumulator_data, stack) = stack.make_aligned_with( polynomial_size.0 * fourier_bsk.glwe_size().0, CACHELINE_ALIGN, |_| Scalar::ZERO, @@ -486,31 +486,31 @@ pub fn cmux_tree_memory_optimized>( // At index 0 you have the lut that will be loaded, and then the result for each layer gets // computed at the next index, last layer result gets stored in `result`. // This allow to use memory space in C * nb_layer instead of C' * 2 ^ nb_layer - let (mut t_0_data, stack) = stack.make_aligned_with( + let (t_0_data, stack) = stack.make_aligned_with( polynomial_size.0 * glwe_size.0 * nb_layer, CACHELINE_ALIGN, |_| Scalar::ZERO, ); - let (mut t_1_data, stack) = stack.make_aligned_with( + let (t_1_data, stack) = stack.make_aligned_with( polynomial_size.0 * glwe_size.0 * nb_layer, CACHELINE_ALIGN, |_| Scalar::ZERO, ); let mut t_0 = GlweCiphertextList::from_container( - t_0_data.as_mut(), + t_0_data, glwe_size, polynomial_size, ciphertext_modulus, ); let mut t_1 = GlweCiphertextList::from_container( - t_1_data.as_mut(), + t_1_data, glwe_size, polynomial_size, ciphertext_modulus, ); - let (mut t_fill, mut stack) = stack.make_with(nb_layer, |_| 0_usize); + let (t_fill, mut stack) = stack.make_with(nb_layer, |_| 0_usize); let mut lut_polynomial_iter = lut_per_layer.iter(); loop { @@ -565,8 +565,6 @@ pub fn cmux_tree_memory_optimized>( t_fill[j + 1] += 1; t_fill[j] = 0; - drop(diff_data); - (j_counter, t0_j, t1_j) = (j_counter_plus_1, t_0_j_plus_1, t_1_j_plus_1); } else { assert_eq!(j, nb_layer - 1); @@ -680,7 +678,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing>( // the last blind rotation. let (cmux_ggsw, br_ggsw) = ggsw_list.split_at(log_number_of_luts_for_cmux_tree); - let (mut cmux_tree_lut_res_data, mut stack) = + let (cmux_tree_lut_res_data, mut stack) = stack.make_aligned_with(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN, |_| { Scalar::ZERO }); - let mut cmux_tree_lut_res = GlweCiphertext::from_container( - &mut *cmux_tree_lut_res_data, - polynomial_size, - ciphertext_modulus, - ); + let mut cmux_tree_lut_res = + GlweCiphertext::from_container(cmux_tree_lut_res_data, polynomial_size, ciphertext_modulus); cmux_tree_memory_optimized( cmux_tree_lut_res.as_mut_view(), @@ -866,7 +861,7 @@ pub fn blind_rotate_assign>( for ggsw in ggsw_list.into_ggsw_iter().rev() { let ct_0 = lut.as_mut_view(); - let (mut ct1_data, stack) = stack + let (ct1_data, stack) = stack .rb_mut() .collect_aligned(CACHELINE_ALIGN, ct_0.as_ref().iter().copied()); let mut ct_1 = GlweCiphertext::from_container( diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs index 8e3f7cd8fa..efc5212792 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs @@ -2,7 +2,7 @@ use crate::core_crypto::commons::math::decomposition::decompose_one_level; pub use crate::core_crypto::commons::math::decomposition::DecompositionLevel; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; -use dyn_stack::{DynArray, PodStack}; +use dyn_stack::PodStack; use std::iter::Map; use std::slice::IterMut; @@ -18,7 +18,7 @@ pub struct TensorSignedDecompositionLendingIter<'buffers, Scalar: UnsignedIntege // ...0001111 mod_b_mask: Scalar, // The internal states of each decomposition - states: DynArray<'buffers, Scalar>, + states: &'buffers mut [Scalar], // A flag which stores whether the iterator is a fresh one (for the recompose method). fresh: bool, } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs index 17d070a041..47738e1cba 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs @@ -532,12 +532,12 @@ impl<'a> FftView<'a> { let standard = standard.as_mut(); let n = standard.len(); debug_assert_eq!(n, 2 * fourier.len()); - let (mut tmp, stack) = + let (tmp, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier.iter().copied()); - self.plan.inv(&mut tmp, stack); + self.plan.inv(tmp, stack); let (standard_re, standard_im) = standard.split_at_mut(n / 2); - conv_fn(standard_re, standard_im, &tmp, self.twisties); + conv_fn(standard_re, standard_im, tmp, self.twisties); } fn backward_with_conv_in_place<