diff --git a/Cargo.toml b/Cargo.toml index 4fbc29fb5d..774ce1123d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,14 @@ exclude = [ "utils/cargo-tfhe-lints" ] [workspace.dependencies] -aligned-vec = { version = "0.5", default-features = false } +aligned-vec = { version = "0.6", default-features = false } bytemuck = "1.14.3" -dyn-stack = { version = "0.10", default-features = false } +dyn-stack = { version = "0.11", default-features = false } +itertools = "0.13" num-complex = "0.4" -pulp = { version = "0.18.22", default-features = false } +pulp = { version = "0.19.6", default-features = false } +rand = "0.8" +rayon = "1" serde = { version = "1.0", default-features = false } wasm-bindgen = ">=0.2.86,<0.2.94" diff --git a/apps/trivium/Cargo.toml b/apps/trivium/Cargo.toml index f309a54b5b..38aa519b85 100644 --- a/apps/trivium/Cargo.toml +++ b/apps/trivium/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rayon = { version = "1.7.0"} +rayon = { workspace = true } [target.'cfg(target_arch = "x86_64")'.dependencies.tfhe] path = "../../tfhe" diff --git a/tfhe-csprng/Cargo.toml b/tfhe-csprng/Cargo.toml index f7b2db91e1..7f0fec4388 100644 --- a/tfhe-csprng/Cargo.toml +++ b/tfhe-csprng/Cargo.toml @@ -13,13 +13,13 @@ rust-version = "1.72" [dependencies] aes = "0.8.2" -rayon = { version = "1.5.0", optional = true } +rayon = { workspace = true , optional = true } [target.'cfg(target_os = "macos")'.dependencies] libc = "0.2.133" [dev-dependencies] -rand = "0.8.3" +rand = { workspace = true } criterion = "0.5.1" clap = "=4.4.4" diff --git a/tfhe-fft/Cargo.toml b/tfhe-fft/Cargo.toml index 8e34ff174a..a94331d3f1 100644 --- a/tfhe-fft/Cargo.toml +++ b/tfhe-fft/Cargo.toml @@ -29,7 +29,7 @@ serde = ["dep:serde", "num-complex/serde"] [dev-dependencies] rustfft = "6.0" -rand = "0.8" +rand = { workspace = true } bincode = "1.3" more-asserts = "0.3.1" serde_json = "1.0.96" diff --git a/tfhe-fft/README.md b/tfhe-fft/README.md index 50a0580d63..04388660aa 100644 --- a/tfhe-fft/README.md +++ b/tfhe-fft/README.md @@ -40,7 +40,7 @@ Additionally, an optional 128-bit negacyclic FFT module is provided. ```rust use tfhe_fft::c64; use tfhe_fft::ordered::{Method, Plan}; -use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; +use dyn_stack::{GlobalPodBuffer, PodStack}; use num_complex::ComplexFloat; use std::time::Duration; @@ -48,7 +48,7 @@ fn main() { const N: usize = 4; let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap()); - let mut stack = PodStack::new(&mut scratch_memory); + let stack = PodStack::new(&mut scratch_memory); let data = [ c64::new(1.0, 0.0), @@ -58,10 +58,10 @@ fn main() { ]; let mut transformed_fwd = data; - plan.fwd(&mut transformed_fwd, stack.rb_mut()); + plan.fwd(&mut transformed_fwd, stack); let mut transformed_inv = transformed_fwd; - plan.inv(&mut transformed_inv, stack.rb_mut()); + plan.inv(&mut transformed_inv, stack); for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) { assert!((expected - actual).abs() < 1e-9); diff --git a/tfhe-fft/benches/fft.rs b/tfhe-fft/benches/fft.rs index dc573112c5..9f61565946 100644 --- a/tfhe-fft/benches/fft.rs +++ b/tfhe-fft/benches/fft.rs @@ -1,6 +1,6 @@ use core::ptr::NonNull; use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::{PodStack, ReborrowMut, StackReq}; +use dyn_stack::{PodStack, StackReq}; use serde::Serialize; use std::{fs, path::PathBuf}; use tfhe_fft::c64; @@ -129,7 +129,7 @@ pub fn bench_ffts(c: &mut Criterion) { StackReq::new_aligned::(n, 256), // src StackReq::new_aligned::(n, 256), // dst ])); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); let z = c64::new(0.0, 0.0); use rustfft::FftPlannerAvx; @@ -139,8 +139,8 @@ pub fn bench_ffts(c: &mut Criterion) { let unordered = tfhe_fft::unordered::Plan::new(n, tfhe_fft::unordered::Method::Measure(bench_duration)); - let (dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - let (src, mut stack) = stack.make_aligned_with::(n, 64, |_| z); + let (dst, stack) = stack.make_aligned_with::(n, 64, |_| z); + let (src, stack) = stack.make_aligned_with::(n, 64, |_| z); let bench_id = format!("rustfft-fwd-{n}"); c.bench_function(&bench_id, |b| { @@ -164,19 +164,19 @@ pub fn bench_ffts(c: &mut Criterion) { tfhe_fft::ordered::Plan::new(n, tfhe_fft::ordered::Method::Measure(bench_duration)); let bench_id = format!("tfhe-ordered-fwd-{n}"); - c.bench_function(&bench_id, |b| b.iter(|| ordered.fwd(dst, stack.rb_mut()))); + c.bench_function(&bench_id, |b| b.iter(|| ordered.fwd(dst, stack))); write_to_json(&bench_id, "tfhe-ordered-fwd", n); } let bench_id = format!("tfhe-unordered-fwd-{n}"); c.bench_function(&bench_id, |b| { - b.iter(|| unordered.fwd(dst, stack.rb_mut())); + b.iter(|| unordered.fwd(dst, stack)); }); write_to_json(&bench_id, "tfhe-unordered-fwd", n); let bench_id = format!("tfhe-unordered-inv-{n}"); c.bench_function(&bench_id, |b| { - b.iter(|| unordered.inv(dst, stack.rb_mut())); + b.iter(|| unordered.inv(dst, stack)); }); write_to_json(&bench_id, "tfhe-unordered-inv", n); diff --git a/tfhe-fft/src/fft128/f128_ops.rs b/tfhe-fft/src/fft128/f128_ops.rs index 84c35dd96b..75e78d90db 100644 --- a/tfhe-fft/src/fft128/f128_ops.rs +++ b/tfhe-fft/src/fft128/f128_ops.rs @@ -645,7 +645,7 @@ pub mod x86 { #[inline(always)] pub(crate) fn two_diff_f64x4(simd: V3, a: f64x4, b: f64x4) -> (f64x4, f64x4) { - two_sum_f64x4(simd, a, simd.f64s_neg(b)) + two_sum_f64x4(simd, a, simd.neg_f64s(b)) } #[inline(always)] @@ -677,7 +677,7 @@ pub mod x86 { #[inline(always)] #[cfg(feature = "nightly")] pub(crate) fn two_diff_f64x8(simd: V4, a: f64x8, b: f64x8) -> (f64x8, f64x8) { - two_sum_f64x8(simd, a, simd.f64s_neg(b)) + two_sum_f64x8(simd, a, simd.neg_f64s(b)) } #[cfg(feature = "nightly")] @@ -714,8 +714,8 @@ pub mod x86 { simd, a, f64x16 { - lo: simd.f64s_neg(b.lo), - hi: simd.f64s_neg(b.hi), + lo: simd.neg_f64s(b.lo), + hi: simd.neg_f64s(b.hi), }, ) } diff --git a/tfhe-fft/src/lib.rs b/tfhe-fft/src/lib.rs index 253b96195d..337383b0e9 100644 --- a/tfhe-fft/src/lib.rs +++ b/tfhe-fft/src/lib.rs @@ -36,14 +36,14 @@ #![cfg_attr(not(feature = "std"), doc = "```ignore")] //! use tfhe_fft::c64; //! use tfhe_fft::ordered::{Plan, Method}; -//! use dyn_stack::{PodStack, GlobalPodBuffer, ReborrowMut}; +//! use dyn_stack::{PodStack, GlobalPodBuffer}; //! use num_complex::ComplexFloat; //! use std::time::Duration; //! //! const N: usize = 4; //! let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); //! let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap()); -//! let mut stack = PodStack::new(&mut scratch_memory); +//! let stack = PodStack::new(&mut scratch_memory); //! //! let data = [ //! c64::new(1.0, 0.0), @@ -53,10 +53,10 @@ //! ]; //! //! let mut transformed_fwd = data; -//! plan.fwd(&mut transformed_fwd, stack.rb_mut()); +//! plan.fwd(&mut transformed_fwd, stack); //! //! let mut transformed_inv = transformed_fwd; -//! plan.inv(&mut transformed_inv, stack.rb_mut()); +//! plan.inv(&mut transformed_inv, stack); //! //! for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) { //! assert!((expected - actual).abs() < 1e-9); diff --git a/tfhe-fft/src/ordered.rs b/tfhe-fft/src/ordered.rs index ead6d4131d..4f66238364 100644 --- a/tfhe-fft/src/ordered.rs +++ b/tfhe-fft/src/ordered.rs @@ -16,7 +16,7 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; #[cfg(feature = "std")] use core::time::Duration; #[cfg(feature = "std")] -use dyn_stack::{GlobalPodBuffer, ReborrowMut}; +use dyn_stack::GlobalPodBuffer; use dyn_stack::{PodStack, SizeOverflow, StackReq}; /// Internal FFT algorithm. @@ -65,7 +65,7 @@ fn measure_n_runs( buf: &mut [c64], twiddles_init: &[c64], twiddles: &[c64], - stack: PodStack, + stack: &mut PodStack, ) -> Duration { let n = buf.len(); let (scratch, _) = stack.make_aligned_raw::(n, CACHELINE_ALIGN); @@ -99,7 +99,7 @@ pub(crate) fn measure_fastest_scratch(n: usize) -> StackReq { pub(crate) fn measure_fastest( min_bench_duration_per_algo: Duration, n: usize, - stack: PodStack, + stack: &mut PodStack, ) -> (FftAlgo, Duration) { const N_ALGOS: usize = 8; const MIN_DURATION: Duration = if cfg!(target_arch = "wasm32") { @@ -116,14 +116,14 @@ pub(crate) fn measure_fastest( let f = |_| c64 { re: 0.0, im: 0.0 }; - let (twiddles, stack) = stack.make_aligned_with::(2 * n, align, f); + let (twiddles, stack) = stack.make_aligned_with::(2 * n, align, f); let twiddles_init = &twiddles[..n]; let twiddles = &twiddles[n..]; - let (buf, mut stack) = stack.make_aligned_with::(n, align, f); + let (buf, stack) = stack.make_aligned_with::(n, align, f); { // initialize scratch to load it in the cpu cache - drop(stack.rb_mut().make_aligned_with::(n, align, f)); + drop(stack.make_aligned_with::(n, align, f)); } let mut avg_durations = [Duration::ZERO; N_ALGOS]; @@ -149,8 +149,7 @@ pub(crate) fn measure_fastest( let mut n_runs: u128 = 1; loop { - let duration = - measure_n_runs(n_runs, algo, buf, twiddles_init, twiddles, stack.rb_mut()); + let duration = measure_n_runs(n_runs, algo, buf, twiddles_init, twiddles, stack); if duration < MIN_DURATION { n_runs *= 2; @@ -165,8 +164,7 @@ pub(crate) fn measure_fastest( *avg = if n_runs <= init_n_runs { approx_duration } else { - let duration = - measure_n_runs(n_runs, algo, buf, twiddles_init, twiddles, stack.rb_mut()); + let duration = measure_n_runs(n_runs, algo, buf, twiddles_init, twiddles, stack); duration_div_f64(duration, n_runs as f64) }; } @@ -339,7 +337,7 @@ impl Plan { /// let mut buf = [c64::default(); 4]; /// plan.fwd(&mut buf, stack); /// ``` - pub fn fwd(&self, buf: &mut [c64], stack: PodStack) { + pub fn fwd(&self, buf: &mut [c64], stack: &mut PodStack) { let n = self.fft_size(); let (scratch, _) = stack.make_aligned_raw::(n, CACHELINE_ALIGN); let (w_init, w) = split_2(&self.twiddles); @@ -353,19 +351,19 @@ impl Plan { #[cfg_attr(not(feature = "std"), doc = " ```ignore")] /// use tfhe_fft::c64; /// use tfhe_fft::ordered::{Method, Plan}; - /// use dyn_stack::{PodStack, GlobalPodBuffer, ReborrowMut}; + /// use dyn_stack::{PodStack, GlobalPodBuffer}; /// use core::time::Duration; /// /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); /// /// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap()); - /// let mut stack = PodStack::new(&mut memory); + /// let stack = PodStack::new(&mut memory); /// /// let mut buf = [c64::default(); 4]; - /// plan.fwd(&mut buf, stack.rb_mut()); + /// plan.fwd(&mut buf, stack); /// plan.inv(&mut buf, stack); /// ``` - pub fn inv(&self, buf: &mut [c64], stack: PodStack) { + pub fn inv(&self, buf: &mut [c64], stack: &mut PodStack) { let n = self.fft_size(); let (scratch, _) = stack.make_aligned_raw::(n, CACHELINE_ALIGN); let (w_init, w) = split_2(&self.twiddles_inv); diff --git a/tfhe-fft/src/unordered.rs b/tfhe-fft/src/unordered.rs index 27c0b908db..c9336fb11d 100644 --- a/tfhe-fft/src/unordered.rs +++ b/tfhe-fft/src/unordered.rs @@ -18,7 +18,7 @@ use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; #[cfg(feature = "std")] use core::time::Duration; #[cfg(feature = "std")] -use dyn_stack::{GlobalPodBuffer, ReborrowMut}; +use dyn_stack::GlobalPodBuffer; use dyn_stack::{PodStack, SizeOverflow, StackReq}; #[inline(always)] @@ -553,7 +553,7 @@ fn measure_fastest_scratch(n: usize) -> StackReq { fn measure_fastest( mut min_bench_duration_per_algo: Duration, n: usize, - mut stack: PodStack, + stack: &mut PodStack, ) -> (FftAlgo, usize, Duration) { const MIN_DURATION: Duration = Duration::from_millis(1); min_bench_duration_per_algo = min_bench_duration_per_algo.max(MIN_DURATION); @@ -581,11 +581,8 @@ fn measure_fastest( n_algos += 1; // we'll measure the corresponding plan - let (base_algo, duration) = crate::ordered::measure_fastest( - min_bench_duration_per_algo, - base_n, - stack.rb_mut(), - ); + let (base_algo, duration) = + crate::ordered::measure_fastest(min_bench_duration_per_algo, base_n, stack); algos[i] = Some(base_algo); @@ -599,11 +596,9 @@ fn measure_fastest( let f = |_| c64 { re: 0.0, im: 0.0 }; let align = CACHELINE_ALIGN; - let (w, stack) = stack - .rb_mut() - .make_aligned_with::(n + base_n, align, f); - let (scratch, stack) = stack.make_aligned_with::(base_n, align, f); - let (z, _) = stack.make_aligned_with::(n, align, f); + let (w, stack) = stack.make_aligned_with::(n + base_n, align, f); + let (scratch, stack) = stack.make_aligned_with::(base_n, align, f); + let (z, _) = stack.make_aligned_with::(n, align, f); let n_runs = min_bench_duration_per_algo.as_secs_f64() / (duration.as_secs_f64() * (n / base_n) as f64); @@ -823,7 +818,7 @@ impl Plan { /// let mut buf = [c64::default(); 4]; /// plan.fwd(&mut buf, stack); /// ``` - pub fn fwd(&self, buf: &mut [c64], stack: PodStack) { + pub fn fwd(&self, buf: &mut [c64], stack: &mut PodStack) { assert_eq!(self.fft_size(), buf.len()); let (scratch, _) = stack.make_aligned_raw::(self.algo().1, CACHELINE_ALIGN); fwd_depth( @@ -912,19 +907,19 @@ impl Plan { #[cfg_attr(not(feature = "std"), doc = " ```ignore")] /// use tfhe_fft::c64; /// use tfhe_fft::unordered::{Method, Plan}; - /// use dyn_stack::{PodStack, GlobalPodBuffer, ReborrowMut}; + /// use dyn_stack::{PodStack, GlobalPodBuffer}; /// use core::time::Duration; /// /// let plan = Plan::new(4, Method::Measure(Duration::from_millis(10))); /// /// let mut memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap()); - /// let mut stack = PodStack::new(&mut memory); + /// let stack = PodStack::new(&mut memory); /// /// let mut buf = [c64::default(); 4]; - /// plan.fwd(&mut buf, stack.rb_mut()); + /// plan.fwd(&mut buf, stack); /// plan.inv(&mut buf, stack); /// ``` - pub fn inv(&self, buf: &mut [c64], stack: PodStack) { + pub fn inv(&self, buf: &mut [c64], stack: &mut PodStack) { assert_eq!(self.fft_size(), buf.len()); let (scratch, _) = stack.make_aligned_raw::(self.algo().1, CACHELINE_ALIGN); inv_depth( @@ -1062,7 +1057,7 @@ fn bit_rev_twice_inv(nbits: u32, base_nbits: u32, i: usize) -> usize { mod tests { use super::*; use alloc::vec; - use dyn_stack::{GlobalPodBuffer, ReborrowMut}; + use dyn_stack::GlobalPodBuffer; use num_complex::ComplexFloat; use rand::random; @@ -1157,8 +1152,8 @@ mod tests { }, ); let mut mem = GlobalPodBuffer::new(plan.fft_scratch().unwrap()); - let mut stack = PodStack::new(&mut mem); - plan.fwd(&mut z, stack.rb_mut()); + let stack = PodStack::new(&mut mem); + plan.fwd(&mut z, stack); plan.inv(&mut z, stack); for z in &mut z { @@ -9400,7 +9395,7 @@ mod tests { mod tests_serde { use super::*; use alloc::{vec, vec::Vec}; - use dyn_stack::{GlobalPodBuffer, ReborrowMut}; + use dyn_stack::GlobalPodBuffer; use num_complex::ComplexFloat; use rand::random; @@ -9440,9 +9435,9 @@ mod tests_serde { .unwrap() .or(plan2.fft_scratch().unwrap()), ); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); - plan1.fwd(&mut z, stack.rb_mut()); + plan1.fwd(&mut z, stack); let mut buf = Vec::::new(); let mut serializer = bincode::Serializer::new(&mut buf, bincode::options()); diff --git a/tfhe-ntt/Cargo.toml b/tfhe-ntt/Cargo.toml index b326cb0e3d..ba30f562e9 100644 --- a/tfhe-ntt/Cargo.toml +++ b/tfhe-ntt/Cargo.toml @@ -23,7 +23,7 @@ nightly = ["pulp/nightly"] [dev-dependencies] criterion = "0.4" -rand = "0.8" +rand = { workspace = true } serde = "1.0.163" serde_json = "1.0.96" diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml index 56c792ca17..dc50fd9564 100644 --- a/tfhe-zk-pok/Cargo.toml +++ b/tfhe-zk-pok/Cargo.toml @@ -16,8 +16,8 @@ ark-bls12-381 = "0.5.0" ark-ec = { version = "0.5.0", features = ["parallel"] } ark-ff = { version = "0.5.0", features = ["parallel"] } ark-poly = { version = "0.5.0", features = ["parallel"] } -rand = "0.8.5" -rayon = "1.8.0" +rand = { workspace = true } +rayon = { workspace = true } sha3 = "0.10.8" serde = { workspace = true, features = ["default", "derive"] } zeroize = "1.7.0" @@ -26,7 +26,7 @@ tfhe-versionable = { version = "0.3.2", path = "../utils/tfhe-versionable" } [dev-dependencies] serde_json = "~1.0" -itertools = "0.11.0" +itertools = { workspace = true } bincode = "1.3.3" criterion = "0.5.1" diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 64b60602fc..7324ad5249 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -17,12 +17,12 @@ exclude = [ "/js_on_wasm_tests/", "/web_wasm_parallel_tests/", ] -rust-version = "1.81" +rust-version = "1.82" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dev-dependencies] -rand = "0.8.5" +rand = { workspace = true } rand_distr = "0.4.3" lazy_static = { version = "1.4.0" } criterion = "0.5.1" @@ -33,7 +33,6 @@ serde_json = "1.0.94" clap = { version = "=4.4.4", features = ["derive"] } # Used in user documentation fs2 = { version = "0.4.3" } -itertools = "0.11.0" statrs = "0.16" # For erf and normality test libm = "0.2.6" @@ -60,7 +59,7 @@ tfhe-csprng = { version = "0.4.1", path = "../tfhe-csprng", features = [ ] } lazy_static = { version = "1.4.0", optional = true } serde = { workspace = true, features = ["default", "derive"] } -rayon = { version = "1.5.0" } +rayon = { workspace = true } bincode = "1.3.3" tfhe-fft = { version = "0.6.0", path = "../tfhe-fft", features = [ "serde", @@ -75,8 +74,7 @@ paste = "1.0.7" fs2 = { version = "0.4.3", optional = true } # Used for OPRF in shortint sha3 = { version = "0.10", optional = true } -# While we wait for repeat_n in rust standard library -itertools = "0.11.0" +itertools = { workspace = true } rand_core = { version = "0.6.4", features = ["std"] } tfhe-zk-pok = { version = "0.3.1", path = "../tfhe-zk-pok", optional = true } tfhe-versionable = { version = "0.3.2", path = "../utils/tfhe-versionable" } diff --git a/tfhe/src/core_crypto/algorithms/ggsw_conversion.rs b/tfhe/src/core_crypto/algorithms/ggsw_conversion.rs index 49d0b357c0..97638b39ad 100644 --- a/tfhe/src/core_crypto/algorithms/ggsw_conversion.rs +++ b/tfhe/src/core_crypto/algorithms/ggsw_conversion.rs @@ -49,7 +49,7 @@ pub fn convert_standard_ggsw_ciphertext_to_fourier_mem_optimized, output_ggsw: &mut FourierGgswCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, InputCont: Container, diff --git a/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_conversion.rs index 0767647c93..700a9a151d 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_conversion.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_conversion.rs @@ -46,7 +46,7 @@ pub fn convert_standard_lwe_bootstrap_key_to_fourier_mem_optimized, output_bsk: &mut FourierLweBootstrapKey, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, InputCont: Container, diff --git a/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs index 32d3076d9e..0cb809616b 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs @@ -8,7 +8,7 @@ use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::fft64::math::fft::{ par_convert_polynomials_list_to_fourier, Fft, FftView, }; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; /// Convert an [`LWE multi_bit bootstrap key`](`LweMultiBitBootstrapKey`) with standard @@ -50,7 +50,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized< input_bsk: &LweMultiBitBootstrapKey, output_bsk: &mut FourierLweMultiBitBootstrapKey, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, InputCont: Container, @@ -69,7 +69,7 @@ pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized< .zip(input_bsk_as_polynomial_list.iter()) { // SAFETY: forward_as_torus doesn't write any uninitialized values into its output - fft.forward_as_torus(fourier_poly, coef_poly, stack.rb_mut()); + fft.forward_as_torus(fourier_poly, coef_poly, stack); } } diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft128.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft128.rs index 8fbf7ad590..f95fad5c4a 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft128.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft128.rs @@ -233,7 +233,7 @@ pub fn programmable_bootstrap_f128_lwe_ciphertext_mem_optimized< accumulator: &GlweCiphertext, fourier_bsk: &Fourier128LweBootstrapKey, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize Scalar: UnsignedTorus + CastInto, diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft64.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft64.rs index 546cfb0ba9..9a79650ec2 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft64.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/fft64.rs @@ -233,7 +233,7 @@ pub fn blind_rotate_assign_mem_optimized< lut: &mut GlweCiphertext, fourier_bsk: &FourierLweBootstrapKey, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize InputScalar: UnsignedTorus + CastInto, @@ -455,7 +455,7 @@ pub fn add_external_product_assign_mem_optimized, glwe: &GlweCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, OutputGlweCont: ContainerMut, @@ -746,7 +746,7 @@ pub fn cmux_assign_mem_optimized( ct1: &mut GlweCiphertext, ggsw: &FourierGgswCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, Cont0: ContainerMut, @@ -1020,7 +1020,7 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized< accumulator: &GlweCiphertext, fourier_bsk: &FourierLweBootstrapKey, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize InputScalar: UnsignedTorus + CastInto, @@ -1091,7 +1091,7 @@ pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized< accumulator: &GlweCiphertextList, fourier_bsk: &FourierLweBootstrapKey, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize InputScalar: UnsignedTorus + CastInto, diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs index ae41dd2a89..83dbc1d3ff 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64.rs @@ -18,7 +18,7 @@ use crate::core_crypto::commons::traits::*; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; /// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up /// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap @@ -209,7 +209,7 @@ pub fn blind_rotate_ntt64_assign_mem_optimized( lut: &mut GlweCiphertext, bsk: &NttLweBootstrapKey, ntt: Ntt64View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where InputCont: Container, OutputCont: ContainerMut, @@ -220,7 +220,7 @@ pub fn blind_rotate_ntt64_assign_mem_optimized( mut lut: GlweCiphertextMutView<'_, u64>, lwe: &[u64], ntt: Ntt64View<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { let (lwe_body, lwe_mask) = lwe.split_last().unwrap(); let modulus = ntt.custom_modulus(); @@ -248,7 +248,7 @@ pub fn blind_rotate_ntt64_assign_mem_optimized( for (lwe_mask_element, bootstrap_key_ggsw) in izip!(lwe_mask.iter(), bsk.into_ggsw_iter()) { if *lwe_mask_element != 0u64 { - let stack = stack.rb_mut(); + let stack = &mut *stack; // We copy ct_0 to ct_1 let (ct1, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied()); @@ -479,7 +479,7 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized< accumulator: &GlweCiphertext, bsk: &NttLweBootstrapKey, ntt: Ntt64View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where InputCont: Container, OutputCont: ContainerMut, @@ -492,7 +492,7 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized< lwe_in: LweCiphertextView<'_, u64>, accumulator: GlweCiphertextView<'_, u64>, ntt: Ntt64View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert_eq!(lwe_out.ciphertext_modulus(), lwe_in.ciphertext_modulus()); debug_assert_eq!( @@ -544,7 +544,7 @@ pub(crate) fn add_external_product_ntt64_assign( ggsw: NttGgswCiphertextView<'_, u64>, glwe: &GlweCiphertext, ntt: Ntt64View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where InputGlweCont: Container, { @@ -565,7 +565,7 @@ pub(crate) fn add_external_product_ntt64_assign( out.ciphertext_modulus(), ); - let (output_fft_buffer, mut substack0) = + let (output_fft_buffer, substack0) = stack.make_aligned_raw::(poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once @@ -576,18 +576,18 @@ pub(crate) fn add_external_product_ntt64_assign( // ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER DOMAIN // In this section, we perform the external product in the ntt domain, and accumulate // the result in the output_fft_buffer variable. - let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIterNonNative::new( + let (mut decomposition, substack1) = TensorSignedDecompositionLendingIterNonNative::new( &decomposer, glwe.as_ref(), ntt.custom_modulus(), - substack0.rb_mut(), + substack0, ); // We loop through the levels (we reverse to match the order of the decomposition iterator.) ggsw.into_levels().for_each(|ggsw_decomp_matrix| { // We retrieve the decomposition of this level. - let (glwe_level, glwe_decomp_term, mut substack2) = - decomposition.collect_next_term(&mut substack1, align); + let (glwe_level, glwe_decomp_term, substack2) = + decomposition.collect_next_term(substack1, align); let glwe_decomp_term = GlweCiphertextView::from_container( &*glwe_decomp_term, ggsw.polynomial_size(), @@ -612,7 +612,7 @@ pub(crate) fn add_external_product_ntt64_assign( glwe_decomp_term.as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (ntt_poly, _) = substack2.rb_mut().make_aligned_raw::(poly_size, align); + let (ntt_poly, _) = substack2.make_aligned_raw::(poly_size, align); // We perform the forward ntt transform for the glwe polynomial ntt.forward(PolynomialMutView::from_container(ntt_poly), glwe_poly); // Now we loop through the polynomials of the output, and add the @@ -657,7 +657,7 @@ pub(crate) fn cmux_ntt64_assign( mut ct1: GlweCiphertextMutView<'_, u64>, ggsw: NttGgswCiphertextView<'_, u64>, ntt: Ntt64View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { izip!(ct1.as_mut(), ct0.as_ref(),).for_each(|(c1, c0)| { *c1 = c1.wrapping_sub_custom_mod(*c0, ntt.custom_modulus()); diff --git a/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs b/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs index ca5542c359..513baa2b74 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_wopbs.rs @@ -327,7 +327,7 @@ pub fn extract_bits_from_lwe_ciphertext_mem_optimized< delta_log: DeltaLog, number_of_bits_to_extract: ExtractedBitsCount, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize Scalar: UnsignedTorus + CastInto, @@ -661,7 +661,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimi base_log_cbs: DecompositionBaseLog, level_cbs: DecompositionLevelCount, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize Scalar: UnsignedTorus + CastInto, diff --git a/tfhe/src/core_crypto/commons/computation_buffers.rs b/tfhe/src/core_crypto/commons/computation_buffers.rs index 587e33f54b..80dbd7996a 100644 --- a/tfhe/src/core_crypto/commons/computation_buffers.rs +++ b/tfhe/src/core_crypto/commons/computation_buffers.rs @@ -23,7 +23,7 @@ impl ComputationBuffers { /// Return a `PodStack` borrowoing from the managed memory buffer for use with optimized fft /// primitives or other functions using `PodStack` to manage temporary memory. - pub fn stack(&mut self) -> PodStack<'_> { + pub fn stack(&mut self) -> &mut PodStack { PodStack::new(&mut self.memory) } } diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 865d51e794..b5f7fffd0f 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -4,7 +4,7 @@ use crate::core_crypto::commons::math::decomposition::{ }; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; -use dyn_stack::{PodStack, ReborrowMut}; +use dyn_stack::PodStack; /// An iterator that yields the terms of the signed decomposition of an integer. /// @@ -318,8 +318,8 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { decomposer: &SignedDecomposerNonNative, input: &[u64], modulus: u64, - stack: PodStack<'buffers>, - ) -> (Self, PodStack<'buffers>) { + stack: &'buffers mut PodStack, + ) -> (Self, &'buffers mut PodStack) { let shift = modulus.ceil_ilog2() as usize - decomposer.base_log * decomposer.level_count; let input_size = input.len(); let (states, stack) = @@ -409,10 +409,9 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { &mut self, substack1: &'a mut PodStack, align: usize, - ) -> (DecompositionLevel, &'a mut [u64], PodStack<'a>) { + ) -> (DecompositionLevel, &'a mut [u64], &'a mut PodStack) { let (glwe_level, _, glwe_decomp_term) = self.next_term().unwrap(); - let (glwe_decomp_term, substack2) = - substack1.rb_mut().collect_aligned(align, glwe_decomp_term); + let (glwe_decomp_term, substack2) = substack1.collect_aligned(align, glwe_decomp_term); (glwe_level, glwe_decomp_term, substack2) } } diff --git a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs index 7e27a694d9..89e83d9c36 100644 --- a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs +++ b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs @@ -15,7 +15,7 @@ type WrappingFunction<'data, Element, WrappingType> = fn( type ChunksWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map< std::iter::Zip< std::slice::Chunks<'data, Element>, - itertools::RepeatN<>::Metadata>, + core::iter::RepeatN<>::Metadata>, >, WrappingFunction<'data, Element, WrappingType>, >; @@ -23,7 +23,7 @@ type ChunksWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Ma type ChunksExactWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map< std::iter::Zip< std::slice::ChunksExact<'data, Element>, - itertools::RepeatN<>::Metadata>, + core::iter::RepeatN<>::Metadata>, >, WrappingFunction<'data, Element, WrappingType>, >; @@ -54,7 +54,7 @@ type WrappingFunctionMut<'data, Element, WrappingType> = fn( type ChunksWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map< std::iter::Zip< std::slice::ChunksMut<'data, Element>, - itertools::RepeatN<>::Metadata>, + core::iter::RepeatN<>::Metadata>, >, WrappingFunctionMut<'data, Element, WrappingType>, >; @@ -62,7 +62,7 @@ type ChunksWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter: type ChunksExactWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map< std::iter::Zip< std::slice::ChunksExactMut<'data, Element>, - itertools::RepeatN<>::Metadata>, + core::iter::RepeatN<>::Metadata>, >, WrappingFunctionMut<'data, Element, WrappingType>, >; @@ -130,7 +130,7 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> { let entity_view_pod_size = self.get_entity_view_pod_size(); self.as_ref() .chunks_exact(entity_view_pod_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::EntityView::<'_>::create_from(elt, meta)) } @@ -219,7 +219,7 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> { let meta = self.get_self_view_creation_metadata(); self.as_ref() .chunks(pod_chunk_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::SelfView::<'_>::create_from(elt, meta)) } @@ -240,7 +240,7 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> { let meta = self.get_self_view_creation_metadata(); self.as_ref() .chunks_exact(pod_chunk_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::SelfView::<'_>::create_from(elt, meta)) } @@ -341,7 +341,7 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self: let entity_view_pod_size = self.get_entity_view_pod_size(); self.as_mut() .chunks_exact_mut(entity_view_pod_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::EntityMutView::<'_>::create_from(elt, meta)) } @@ -417,7 +417,7 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self: let meta = self.get_self_view_creation_metadata(); self.as_mut() .chunks_mut(pod_chunk_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::SelfMutView::<'_>::create_from(elt, meta)) } @@ -439,7 +439,7 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self: let meta = self.get_self_view_creation_metadata(); self.as_mut() .chunks_exact_mut(pod_chunk_size) - .zip(itertools::repeat_n(meta, entity_count)) + .zip(core::iter::repeat_n(meta, entity_count)) .map(|(elt, meta)| Self::SelfMutView::<'_>::create_from(elt, meta)) } diff --git a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs index 2bac23057f..49d3006676 100644 --- a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs +++ b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs @@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::fft64::math::polynomial::{ FourierPolynomialMutView, FourierPolynomialView, }; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; /// The caller must provide a properly configured [`FftView`] object and a `PodStack` used as a @@ -156,7 +156,7 @@ pub fn glwe_fast_keyswitch( pseudo_ggsw: &PseudoFourierGgswCiphertext, glwe: &GlweCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, OutputGlweCont: ContainerMut, @@ -174,7 +174,7 @@ pub fn glwe_fast_keyswitch( ggsw: PseudoFourierGgswCiphertextView<'_>, glwe: &GlweCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, InputGlweCont: Container, @@ -193,7 +193,7 @@ pub fn glwe_fast_keyswitch( ggsw.decomposition_base_log(), ggsw.decomposition_level_count(), ); - let (output_fft_buffer, mut substack0) = + let (output_fft_buffer, substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size_out().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once @@ -204,21 +204,21 @@ pub fn glwe_fast_keyswitch( // ------------ EXTERNAL PRODUCT IN FOURIER DOMAIN // In this section, we perform the external product in the fourier // domain, and accumulate the result in the output_fft_buffer variable. - let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( + let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), - substack0.rb_mut(), + substack0, ); // We loop through the levels (we reverse to match the order of the decomposition // iterator.) ggsw.into_levels().for_each(|ggsw_decomp_matrix| { // We retrieve the decomposition of this level. - let (glwe_level, glwe_decomp_term, mut substack2) = - collect_next_term(&mut decomposition, &mut substack1, align); + let (glwe_level, glwe_decomp_term, substack2) = + collect_next_term(&mut decomposition, substack1, align); let glwe_decomp_term = GlweCiphertextView::from_container( &*glwe_decomp_term, ggsw.polynomial_size(), @@ -243,9 +243,8 @@ pub fn glwe_fast_keyswitch( glwe_decomp_term.get_mask().as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (fourier, substack3) = substack2 - .rb_mut() - .make_aligned_raw::(fourier_poly_size, align); + let (fourier, substack3) = + substack2.make_aligned_raw::(fourier_poly_size, align); // We perform the forward fft transform for the glwe polynomial let fourier = fft @@ -285,7 +284,7 @@ pub fn glwe_fast_keyswitch( .map(|slice| FourierPolynomialView { data: slice }), ) .for_each(|(out, fourier)| { - fft.add_backward_as_torus(out, fourier, substack0.rb_mut()); + fft.add_backward_as_torus(out, fourier, substack0); }); } diff --git a/tfhe/src/core_crypto/experimental/algorithms/pseudo_ggsw_conversion.rs b/tfhe/src/core_crypto/experimental/algorithms/pseudo_ggsw_conversion.rs index 05d3547c34..f0dc6f5ba4 100644 --- a/tfhe/src/core_crypto/experimental/algorithms/pseudo_ggsw_conversion.rs +++ b/tfhe/src/core_crypto/experimental/algorithms/pseudo_ggsw_conversion.rs @@ -52,7 +52,7 @@ pub fn convert_standard_pseudo_ggsw_ciphertext_to_fourier_mem_optimized< input_ggsw: &PseudoGgswCiphertext, output_ggsw: &mut PseudoFourierGgswCiphertext, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, InputCont: Container, diff --git a/tfhe/src/core_crypto/experimental/entities/fourier_pseudo_ggsw_ciphertext.rs b/tfhe/src/core_crypto/experimental/entities/fourier_pseudo_ggsw_ciphertext.rs index b6f0f74f78..39d386a405 100644 --- a/tfhe/src/core_crypto/experimental/entities/fourier_pseudo_ggsw_ciphertext.rs +++ b/tfhe/src/core_crypto/experimental/entities/fourier_pseudo_ggsw_ciphertext.rs @@ -11,7 +11,7 @@ use crate::core_crypto::fft_impl::fft64::math::decomposition::DecompositionLevel use crate::core_crypto::fft_impl::fft64::math::fft::{FftView, FourierPolynomialList}; use crate::core_crypto::fft_impl::fft64::math::polynomial::FourierPolynomialMutView; use aligned_vec::{avec, ABox}; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; /// A pseudo GGSW ciphertext in the Fourier domain. @@ -273,7 +273,7 @@ impl<'a> PseudoFourierGgswCiphertextMutView<'a> { self, coef_ggsw: &PseudoGgswCiphertext, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size()); let fourier_poly_size = coef_ggsw.polynomial_size().to_fourier_polynomial_size().0; @@ -285,7 +285,7 @@ impl<'a> PseudoFourierGgswCiphertextMutView<'a> { fft.forward_as_torus( FourierPolynomialMutView { data: fourier_poly }, coef_poly, - stack.rb_mut(), + stack, ); } } diff --git a/tfhe/src/core_crypto/fft_impl/common.rs b/tfhe/src/core_crypto/fft_impl/common.rs index cd4d123891..70f4d96d87 100644 --- a/tfhe/src/core_crypto/fft_impl/common.rs +++ b/tfhe/src/core_crypto/fft_impl/common.rs @@ -43,7 +43,7 @@ pub trait FourierBootstrapKey { &mut self, coef_bsk: &LweBootstrapKey, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContBsk: Container; @@ -59,7 +59,7 @@ pub trait FourierBootstrapKey { lwe_in: &LweCiphertext, accumulator: &GlweCiphertext, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContLweOut: ContainerMut, ContLweIn: Container, diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs index 14c0696313..ffcf7815cb 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs @@ -20,7 +20,7 @@ use crate::core_crypto::prelude::ContainerMut; use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; use core::any::TypeId; use core::mem::transmute; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_versionable::Versionize; #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize, Versionize)] @@ -250,7 +250,7 @@ where lut: &mut GlweCiphertext, lwe: &LweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus + CastInto, ContLut: ContainerMut, @@ -261,7 +261,7 @@ where mut lut: GlweCiphertext<&mut [Scalar]>, lwe: LweCiphertext<&[Scalar]>, fft: Fft128View<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { let lwe = lwe.as_ref(); let (lwe_body, lwe_mask) = lwe.split_last().unwrap(); @@ -287,7 +287,7 @@ where izip!(lwe_mask.iter(), this.into_ggsw_iter()) { if *lwe_mask_element != Scalar::ZERO { - let stack = stack.rb_mut(); + let stack = &mut *stack; // We copy ct_0 to ct_1 let (ct1, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied()); @@ -335,7 +335,7 @@ where lwe_in: &LweCiphertext, accumulator: &GlweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize Scalar: UnsignedTorus + CastInto, @@ -349,7 +349,7 @@ where lwe_in: LweCiphertext<&[Scalar]>, accumulator: GlweCiphertext<&[Scalar]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { // We type check dynamically with TypeId #[allow(clippy::transmute_undefined_repr)] @@ -417,7 +417,7 @@ where &mut self, coef_bsk: &LweBootstrapKey, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContBsk: Container, { @@ -440,7 +440,7 @@ where lwe_in: &LweCiphertext, accumulator: &GlweCiphertext, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContLweOut: ContainerMut, ContLweIn: Container, diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs index 6b0c6d0e8d..6d909343a9 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs @@ -16,7 +16,7 @@ use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertext, GlweCipherte use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecompositionLendingIter; use crate::core_crypto::prelude::ContainerMut; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::fft128::f128; use tfhe_versionable::Versionize; @@ -365,7 +365,7 @@ pub fn add_external_product_assign( ggsw: &Fourier128GgswCiphertext, glwe: &GlweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, ContOut: ContainerMut, @@ -377,7 +377,7 @@ pub fn add_external_product_assign( ggsw: Fourier128GgswCiphertext<&[f64]>, glwe: GlweCiphertext<&[Scalar]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { // we check that the polynomial sizes match debug_assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size()); @@ -404,7 +404,7 @@ pub fn add_external_product_assign( stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); let (output_fft_buffer_im0, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (output_fft_buffer_im1, mut substack0) = + let (output_fft_buffer_im1, substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid @@ -416,21 +416,21 @@ pub fn add_external_product_assign( // ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER // DOMAIN In this section, we perform the external product in the fourier // domain, and accumulate the result in the output_fft_buffer variable. - let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( + let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), - substack0.rb_mut(), + substack0, ); // We loop through the levels (we reverse to match the order of the decomposition // iterator.) for ggsw_decomp_matrix in ggsw.into_levels() { // We retrieve the decomposition of this level. - let (glwe_level, glwe_decomp_term, mut substack2) = - collect_next_term(&mut decomposition, &mut substack1, align); + let (glwe_level, glwe_decomp_term, substack2) = + collect_next_term(&mut decomposition, substack1, align); let glwe_decomp_term = GlweCiphertextView::from_container( &*glwe_decomp_term, ggsw.polynomial_size(), @@ -455,7 +455,7 @@ pub fn add_external_product_assign( glwe_decomp_term.as_polynomial_list().iter() ) { let len = fourier_poly_size; - let stack = substack2.rb_mut(); + let stack = &mut *substack2; let (fourier_re0, stack) = stack.make_aligned_raw::(len, align); let (fourier_re1, stack) = stack.make_aligned_raw::(len, align); let (fourier_im0, stack) = stack.make_aligned_raw::(len, align); @@ -509,7 +509,7 @@ pub fn add_external_product_assign( fourier_re1, fourier_im0, fourier_im1, - substack0.rb_mut(), + substack0, ); } } @@ -528,9 +528,9 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>( decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>, substack1: &'a mut PodStack, align: usize, -) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) { +) -> (DecompositionLevel, &'a mut [Scalar], &'a mut PodStack) { let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap(); - let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); + let (glwe_decomp_term, substack2) = substack1.collect_aligned(align, glwe_decomp_term); (glwe_level, glwe_decomp_term, substack2) } @@ -767,7 +767,7 @@ pub fn cmux( ct1: &mut GlweCiphertext, ggsw: &Fourier128GgswCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, ContCt0: ContainerMut, @@ -779,7 +779,7 @@ pub fn cmux( mut ct1: GlweCiphertext<&mut [Scalar]>, ggsw: Fourier128GgswCiphertext<&[f64]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { for (c1, c0) in izip!(ct1.as_mut(), ct0.as_ref()) { *c1 = c1.wrapping_sub(*c0); diff --git a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs index 7be0d1bbb2..14e9f888c3 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs @@ -437,7 +437,7 @@ impl<'a> Fft128View<'a> { fourier_re1: &[f64], fourier_im0: &[f64], fourier_im1: &[f64], - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv( standard, @@ -463,7 +463,7 @@ impl<'a> Fft128View<'a> { fourier_re1: &[f64], fourier_im0: &[f64], fourier_im1: &[f64], - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv( standard, @@ -487,7 +487,7 @@ impl<'a> Fft128View<'a> { fourier_im0: &[f64], fourier_im1: &[f64], conv_fn: F, - stack: PodStack<'_>, + stack: &mut PodStack, ) { let n = standard.len(); debug_assert_eq!(n, 2 * fourier_re0.len()); diff --git a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs index 25fd1b0c93..9319517b79 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::core_crypto::commons::test_tools::{modular_distance, new_random_generator}; use aligned_vec::avec; -use dyn_stack::{GlobalPodBuffer, ReborrowMut}; +use dyn_stack::GlobalPodBuffer; fn test_roundtrip() { let mut generator = new_random_generator(); @@ -24,7 +24,7 @@ fn test_roundtrip() { } let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap()); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); fft.forward_as_torus( &mut fourier_re0, @@ -39,7 +39,7 @@ fn test_roundtrip() { &fourier_re1, &fourier_im0, &fourier_im1, - stack.rb_mut(), + stack, ); for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) { @@ -111,7 +111,7 @@ fn test_product() { } let mut mem = GlobalPodBuffer::new(fft.backward_scratch().unwrap()); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); fft.forward_as_torus( &mut fourier0_re0, @@ -153,7 +153,7 @@ fn test_product() { &fourier0_re1, &fourier0_im0, &fourier0_im1, - stack.rb_mut(), + stack, ); convolution_naive( convolution_from_naive.as_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs index 676d15635c..51dde00e9e 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs @@ -11,7 +11,7 @@ use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::common::pbs_modulus_switch; use crate::core_crypto::prelude::{Container, ContainerMut}; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut}; +use dyn_stack::PodStack; pub fn polynomial_wrapping_monic_monomial_mul_assign_split( output_lo: Polynomial<&mut [u64]>, @@ -64,7 +64,7 @@ where lut_hi: &mut GlweCiphertext, lwe: &LweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContLutLo: ContainerMut, ContLutHi: ContainerMut, @@ -76,7 +76,7 @@ where mut lut_hi: GlweCiphertext<&mut [u64]>, lwe: LweCiphertext<&[u128]>, fft: Fft128View<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { let lwe = lwe.as_ref(); let (lwe_body, lwe_mask) = lwe.split_last().unwrap(); @@ -103,7 +103,7 @@ where izip!(lwe_mask.iter(), this.into_ggsw_iter()) { if *lwe_mask_element != 0 { - let stack = stack.rb_mut(); + let stack = &mut *stack; // We copy ct_0 to ct_1 let (ct1_lo, stack) = stack.collect_aligned(CACHELINE_ALIGN, ct0_lo.as_ref().iter().copied()); @@ -160,7 +160,7 @@ where lwe_in: &LweCiphertext, accumulator: &GlweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContLweOut: ContainerMut, ContLweIn: Container, @@ -172,14 +172,14 @@ where lwe_in: LweCiphertext<&[u128]>, accumulator: GlweCiphertext<&[u128]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { let align = CACHELINE_ALIGN; let ciphertext_modulus = accumulator.ciphertext_modulus(); let (local_accumulator_lo, stack) = stack.collect_aligned(align, accumulator.as_ref().iter().map(|i| *i as u64)); - let (local_accumulator_hi, mut stack) = stack.collect_aligned( + let (local_accumulator_hi, stack) = stack.collect_aligned( align, accumulator.as_ref().iter().map(|i| (*i >> 64) as u64), ); @@ -205,7 +205,7 @@ where &mut local_accumulator_hi, &lwe_in, fft, - stack.rb_mut(), + stack, ); let (local_accumulator, _) = stack.collect_aligned( align, diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs index d28b05b927..61bec423b1 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/ggsw.rs @@ -9,7 +9,7 @@ use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::fft128::crypto::ggsw::update_with_fmadd; use crate::core_crypto::prelude::{Container, ContainerMut, SignedDecomposer}; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut}; +use dyn_stack::PodStack; #[cfg_attr(feature = "__profiling", inline(never))] pub fn add_external_product_assign_split( @@ -19,7 +19,7 @@ pub fn add_external_product_assign_split, glwe_hi: &GlweCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContOutLo: ContainerMut, ContOutHi: ContainerMut, @@ -34,7 +34,7 @@ pub fn add_external_product_assign_split, glwe_hi: GlweCiphertext<&[u64]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { // we check that the polynomial sizes match debug_assert_eq!(ggsw.polynomial_size(), glwe_lo.polynomial_size()); @@ -69,7 +69,7 @@ pub fn add_external_product_assign_split(fourier_poly_size * ggsw.glwe_size().0, align); let (output_fft_buffer_im0, stack) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); - let (output_fft_buffer_im1, mut substack0) = + let (output_fft_buffer_im1, substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid @@ -81,10 +81,9 @@ pub fn add_external_product_assign_split(poly_size * glwe_size, align); - let (decomposition_states_hi, mut substack1) = + let (decomposition_states_lo, stack) = + substack0.make_aligned_raw::(poly_size * glwe_size, align); + let (decomposition_states_hi, substack1) = stack.make_aligned_raw::(poly_size * glwe_size, align); for (out_lo, out_hi, in_lo, in_hi) in izip!( @@ -113,10 +112,9 @@ pub fn add_external_product_assign_split(poly_size * glwe_size, align); - let (glwe_decomp_term_hi, mut substack2) = + let (glwe_decomp_term_lo, stack) = + substack1.make_aligned_raw::(poly_size * glwe_size, align); + let (glwe_decomp_term_hi, substack2) = stack.make_aligned_raw::(poly_size * glwe_size, align); let base_log = decomposer.base_log; @@ -161,7 +159,7 @@ pub fn add_external_product_assign_split(len, align); let (fourier_re1, stack) = stack.make_aligned_raw::(len, align); let (fourier_im0, stack) = stack.make_aligned_raw::(len, align); @@ -219,7 +217,7 @@ pub fn add_external_product_assign_split( ct1_hi: &mut GlweCiphertext, ggsw: &Fourier128GgswCiphertext, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContCt0Lo: ContainerMut, ContCt0Hi: ContainerMut, @@ -627,7 +625,7 @@ pub fn cmux_split( mut ct1_hi: GlweCiphertext<&mut [u64]>, ggsw: Fourier128GgswCiphertext<&[f64]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { for (c1_lo, c1_hi, c0_lo, c0_hi) in izip!( ct1_lo.as_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs index f69a527289..d42a2fc998 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/tests.rs @@ -6,7 +6,7 @@ use crate::core_crypto::fft_impl::common::tests::{ use crate::core_crypto::prelude::test::{TestResources, FFT128_U128_PARAMS}; use crate::core_crypto::prelude::*; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; +use dyn_stack::{GlobalPodBuffer, PodStack}; #[test] fn test_split_external_product() { @@ -177,7 +177,7 @@ fn test_split_pbs() { ) .unwrap(), ); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); for _ in 0..20 { for x in lwe_in.as_mut() { @@ -203,7 +203,7 @@ fn test_split_pbs() { lwe_in: LweCiphertext<&[Scalar]>, accumulator: GlweCiphertext<&[Scalar]>, fft: Fft128View<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { let (local_accumulator_data, stack) = stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied()); @@ -226,7 +226,7 @@ fn test_split_pbs() { lwe_in.as_view(), accumulator.as_view(), fft, - stack.rb_mut(), + stack, ); let mut lwe_out_split = LweCiphertext::new( @@ -236,13 +236,7 @@ fn test_split_pbs() { .to_lwe_size(), ciphertext_modulus, ); - fourier_bsk.bootstrap_u128( - &mut lwe_out_split, - &lwe_in, - &accumulator, - fft, - stack.rb_mut(), - ); + fourier_bsk.bootstrap_u128(&mut lwe_out_split, &lwe_in, &accumulator, fft, stack); assert_eq!(lwe_out_split, lwe_out_non_split); } diff --git a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs index f04d6d7a28..c98eceda28 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128_u128/math/fft/mod.rs @@ -1253,7 +1253,7 @@ impl<'a> Fft128View<'a> { fourier_re1: &[f64], fourier_im0: &[f64], fourier_im1: &[f64], - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv_split( standard_lo, @@ -1308,7 +1308,7 @@ impl<'a> Fft128View<'a> { fourier_im0: &[f64], fourier_im1: &[f64], conv_fn: impl Fn(&mut [u64], &mut [u64], &mut [u64], &mut [u64], &[f64], &[f64], &[f64], &[f64]), - stack: PodStack<'_>, + stack: &mut PodStack, ) { let n = standard_lo.len(); debug_assert_eq!(n, 2 * fourier_re0.len()); diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs index 40e4a4f8fc..ffe58a65aa 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs @@ -20,7 +20,7 @@ use crate::core_crypto::fft_impl::common::{pbs_modulus_switch, FourierBootstrapK use crate::core_crypto::fft_impl::fft64::math::fft::par_convert_polynomials_list_to_fourier; use crate::core_crypto::prelude::{CiphertextCount, CiphertextModulus, ContainerMut}; use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; use tfhe_versionable::Versionize; @@ -191,12 +191,12 @@ impl<'a> FourierLweBootstrapKeyMutView<'a> { mut self, coef_bsk: LweBootstrapKey<&'_ [Scalar]>, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { for (fourier_ggsw, standard_ggsw) in izip!(self.as_mut_view().into_ggsw_iter(), coef_bsk.iter()) { - fourier_ggsw.fill_with_forward_fourier(standard_ggsw, fft, stack.rb_mut()); + fourier_ggsw.fill_with_forward_fourier(standard_ggsw, fft, stack); } } /// Fill a bootstrapping key with the Fourier transform of a bootstrapping key in the standard @@ -288,7 +288,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { mut lut: GlweCiphertextMutView<'_, OutputScalar>, lwe: LweCiphertextView<'_, InputScalar>, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) where InputScalar: UnsignedTorus + CastInto, OutputScalar: UnsignedTorus, @@ -303,9 +303,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { lut.as_mut_polynomial_list() .iter_mut() .for_each(|mut poly| { - let (tmp_poly, _) = stack - .rb_mut() - .make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN); + let (tmp_poly, _) = stack.make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN); let mut tmp_poly = Polynomial::from_container(&mut *tmp_poly); tmp_poly.as_mut().copy_from_slice(poly.as_ref()); @@ -314,7 +312,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { // We initialize the ct_0 used for the successive cmuxes let mut ct0 = lut; - let (ct1, mut stack) = stack.make_aligned_raw(ct0.as_ref().len(), CACHELINE_ALIGN); + let (ct1, stack) = stack.make_aligned_raw(ct0.as_ref().len(), CACHELINE_ALIGN); let mut ct1 = GlweCiphertextMutView::from_container(&mut *ct1, lut_poly_size, ciphertext_modulus); @@ -349,7 +347,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { bootstrap_key_ggsw, ct1.as_view(), fft, - stack.rb_mut(), + stack, ); } } @@ -375,7 +373,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { mut lut_list: GlweCiphertextListMutView<'_, OutputScalar>, lwe_list: LweCiphertextListView<'_, InputScalar>, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) where InputScalar: UnsignedTorus + CastInto, OutputScalar: UnsignedTorus, @@ -393,9 +391,8 @@ impl<'a> FourierLweBootstrapKeyView<'a> { lut.as_mut_polynomial_list() .iter_mut() .for_each(|mut poly| { - let (tmp_poly, _) = stack - .rb_mut() - .make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN); + let (tmp_poly, _) = + stack.make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN); let mut tmp_poly = Polynomial::from_container(&mut *tmp_poly); tmp_poly.as_mut().copy_from_slice(poly.as_ref()); @@ -405,8 +402,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { // We initialize the ct_0 used for the successive cmuxes let mut ct0_list = lut_list; - let (ct1_list, mut stack) = - stack.make_aligned_raw(ct0_list.as_ref().len(), CACHELINE_ALIGN); + let (ct1_list, stack) = stack.make_aligned_raw(ct0_list.as_ref().len(), CACHELINE_ALIGN); let mut ct1_list = GlweCiphertextListMutView::from_container( &mut *ct1_list, ct0_list.glwe_size(), @@ -450,7 +446,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { bootstrap_key_ggsw, ct1.as_view(), fft, - stack.rb_mut(), + stack, ); } } @@ -478,7 +474,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { lwe_in: LweCiphertextView<'_, InputScalar>, accumulator: GlweCiphertextView<'_, OutputScalar>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize InputScalar: UnsignedTorus + CastInto, @@ -518,7 +514,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> { lwe_in: LweCiphertextListView<'_, InputScalar>, accumulator: &GlweCiphertextListView<'_, OutputScalar>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where // CastInto required for PBS modulus switch which returns a usize InputScalar: UnsignedTorus + CastInto, @@ -586,7 +582,7 @@ where &mut self, coef_bsk: &LweBootstrapKey, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContBsk: Container, { @@ -608,7 +604,7 @@ where lwe_in: &LweCiphertext, accumulator: &GlweCiphertext, fft: &Self::Fft, - stack: PodStack<'_>, + stack: &mut PodStack, ) where ContLweOut: ContainerMut, ContLweIn: Container, diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs index 1620328e6a..89cbcf4e63 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs @@ -16,7 +16,7 @@ use crate::core_crypto::entities::ggsw_ciphertext::{ }; use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertextMutView, GlweCiphertextView}; use aligned_vec::{avec, ABox, CACHELINE_ALIGN}; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; use tfhe_versionable::Versionize; @@ -257,7 +257,7 @@ impl<'a> FourierGgswCiphertextMutView<'a> { self, coef_ggsw: GgswCiphertextView<'_, Scalar>, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size()); let fourier_poly_size = coef_ggsw.polynomial_size().to_fourier_polynomial_size().0; @@ -269,7 +269,7 @@ impl<'a> FourierGgswCiphertextMutView<'a> { fft.forward_as_torus( FourierPolynomialMutView { data: fourier_poly }, coef_poly, - stack.rb_mut(), + stack, ); } } @@ -483,7 +483,7 @@ pub fn add_external_product_assign( ggsw: FourierGgswCiphertextView<'_>, glwe: GlweCiphertextView, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) where Scalar: UnsignedTorus, { @@ -503,7 +503,7 @@ pub fn add_external_product_assign( ggsw.decomposition_level_count(), ); - let (output_fft_buffer, mut substack0) = + let (output_fft_buffer, substack0) = stack.make_aligned_raw::(fourier_poly_size * ggsw.glwe_size().0, align); // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once @@ -515,20 +515,20 @@ pub fn add_external_product_assign( // ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER DOMAIN // In this section, we perform the external product in the fourier domain, and accumulate // the result in the output_fft_buffer variable. - let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( + let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), - substack0.rb_mut(), + substack0, ); // We loop through the levels (we reverse to match the order of the decomposition iterator.) ggsw.into_levels().for_each(|ggsw_decomp_matrix| { // We retrieve the decomposition of this level. - let (glwe_level, glwe_decomp_term, mut substack2) = - collect_next_term(&mut decomposition, &mut substack1, align); + let (glwe_level, glwe_decomp_term, substack2) = + collect_next_term(&mut decomposition, substack1, align); let glwe_decomp_term = GlweCiphertextView::from_container( &*glwe_decomp_term, ggsw.polynomial_size(), @@ -553,9 +553,8 @@ pub fn add_external_product_assign( glwe_decomp_term.as_polynomial_list().iter() ) .for_each(|(ggsw_row, glwe_poly)| { - let (fourier, substack3) = substack2 - .rb_mut() - .make_aligned_raw::(fourier_poly_size, align); + let (fourier, substack3) = + substack2.make_aligned_raw::(fourier_poly_size, align); // We perform the forward fft transform for the glwe polynomial let fourier = fft .forward_as_integer( @@ -596,7 +595,7 @@ pub fn add_external_product_assign( .for_each(|(out, fourier)| { // The fourier buffer is not re-used afterwards so we can use the in-place version of // the add_backward_as_torus function - fft.add_backward_in_place_as_torus(out, fourier, substack0.rb_mut()); + fft.add_backward_in_place_as_torus(out, fourier, substack0); }); } } @@ -606,9 +605,9 @@ pub(crate) fn collect_next_term<'a, Scalar: UnsignedTorus>( decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>, substack1: &'a mut PodStack, align: usize, -) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) { +) -> (DecompositionLevel, &'a mut [Scalar], &'a mut PodStack) { let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap(); - let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); + let (glwe_decomp_term, substack2) = substack1.collect_aligned(align, glwe_decomp_term); (glwe_level, glwe_decomp_term, substack2) } @@ -647,18 +646,18 @@ pub(crate) fn update_with_fmadd( is_output_uninit: bool, fourier_poly_size: usize, ) { - let rhs = S::c64s_as_simd(fourier).0; + let rhs = S::as_simd_c64s(fourier).0; if is_output_uninit { for (output_fourier, ggsw_poly) in izip!( output_fft_buffer.into_chunks(fourier_poly_size), lhs_polynomial_list.into_chunks(fourier_poly_size) ) { - let out = S::c64s_as_mut_simd(output_fourier).0; - let lhs = S::c64s_as_simd(ggsw_poly).0; + let out = S::as_mut_simd_c64s(output_fourier).0; + let lhs = S::as_simd_c64s(ggsw_poly).0; for (out, lhs, rhs) in izip!(out, lhs, rhs) { - *out = simd.c64s_mul(*lhs, *rhs); + *out = simd.mul_c64s(*lhs, *rhs); } } } else { @@ -666,11 +665,11 @@ pub(crate) fn update_with_fmadd( output_fft_buffer.into_chunks(fourier_poly_size), lhs_polynomial_list.into_chunks(fourier_poly_size) ) { - let out = S::c64s_as_mut_simd(output_fourier).0; - let lhs = S::c64s_as_simd(ggsw_poly).0; + let out = S::as_mut_simd_c64s(output_fourier).0; + let lhs = S::as_simd_c64s(ggsw_poly).0; for (out, lhs, rhs) in izip!(out, lhs, rhs) { - *out = simd.c64s_mul_add_e(*lhs, *rhs, *out); + *out = simd.mul_add_c64s(*lhs, *rhs, *out); } } } @@ -718,25 +717,25 @@ pub(crate) fn update_with_fmadd_factor( #[inline(always)] fn with_simd(self, simd: S) -> Self::Output { - let factor = simd.c64s_splat(self.factor); + let factor = simd.splat_c64s(self.factor); for (output_fourier, ggsw_poly) in izip!( self.output_fft_buffer.into_chunks(self.fourier_poly_size), self.lhs_polynomial_list.into_chunks(self.fourier_poly_size) ) { - let out = S::c64s_as_mut_simd(output_fourier).0; - let lhs = S::c64s_as_simd(ggsw_poly).0; - let rhs = S::c64s_as_simd(self.fourier).0; + let out = S::as_mut_simd_c64s(output_fourier).0; + let lhs = S::as_simd_c64s(ggsw_poly).0; + let rhs = S::as_simd_c64s(self.fourier).0; if self.is_output_uninit { for (out, &lhs, &rhs) in izip!(out, lhs, rhs) { // NOTE: factor * (lhs * rhs) is more efficient than (lhs * rhs) * factor - *out = simd.c64s_mul(factor, simd.c64s_mul(lhs, rhs)); + *out = simd.mul_c64s(factor, simd.mul_c64s(lhs, rhs)); } } else { for (out, &lhs, &rhs) in izip!(out, lhs, rhs) { // NOTE: see above - *out = simd.c64s_mul_add_e(factor, simd.c64s_mul(lhs, rhs), *out); + *out = simd.mul_add_c64s(factor, simd.mul_c64s(lhs, rhs), *out); } } } @@ -768,7 +767,7 @@ pub fn cmux( mut ct1: GlweCiphertextMutView<'_, Scalar>, ggsw: FourierGgswCiphertextView<'_>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { izip!(ct1.as_mut(), ct0.as_ref()).for_each(|(c1, c0)| { *c1 = c1.wrapping_sub(*c0); diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs index cea61c85c2..d3d851e6d6 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/mod.rs @@ -15,7 +15,7 @@ use crate::core_crypto::commons::traits::*; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; use aligned_vec::CACHELINE_ALIGN; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use tfhe_fft::c64; pub fn extract_bits_scratch( @@ -68,7 +68,7 @@ pub fn extract_bits>( delta_log: DeltaLog, number_of_bits_to_extract: ExtractedBitsCount, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!(lwe_list_out.ciphertext_modulus() == lwe_in.ciphertext_modulus()); debug_assert!(lwe_in.ciphertext_modulus() == ksk.ciphertext_modulus()); @@ -143,7 +143,7 @@ pub fn extract_bits>( let lwe_size = glwe_dimension .to_equivalent_lwe_dimension(polynomial_size) .to_lwe_size(); - let (lwe_out_pbs_buffer_data, mut stack) = + let (lwe_out_pbs_buffer_data, stack) = stack.make_aligned_with(lwe_size.0, align, |_| Scalar::ZERO); let mut lwe_out_pbs_buffer = LweCiphertext::from_container( &mut *lwe_out_pbs_buffer_data, @@ -155,7 +155,7 @@ pub fn extract_bits>( // Block to keep the lwe_bit_left_shift_buffer_data alive only as long as needed { // Shift on padding bit - let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned( + let (lwe_bit_left_shift_buffer_data, _) = stack.collect_aligned( align, lwe_in_buffer .as_ref() @@ -206,7 +206,7 @@ pub fn extract_bits>( lwe_out_ks_buffer.as_view(), pbs_accumulator.as_view(), fft, - stack.rb_mut(), + stack, ); // Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the @@ -244,7 +244,7 @@ pub fn circuit_bootstrap_boolean>( delta_log: DeltaLog, pfpksk_list: LwePrivateFunctionalPackingKeyswitchKeyList<&[Scalar]>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!(lwe_in.ciphertext_modulus() == ggsw_out.ciphertext_modulus()); debug_assert!(ggsw_out.ciphertext_modulus() == pfpksk_list.ciphertext_modulus()); @@ -306,7 +306,7 @@ pub fn circuit_bootstrap_boolean>( ); // Output for every bootstrapping - let (lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with( + let (lwe_out_bs_buffer_data, stack) = stack.make_aligned_with( fourier_bsk_output_lwe_dimension.to_lwe_size().0, CACHELINE_ALIGN, |_| Scalar::ZERO, @@ -324,7 +324,7 @@ pub fn circuit_bootstrap_boolean>( base_log_cbs, delta_log, fft, - stack.rb_mut(), + stack, ); for (pfpksk, mut glwe_out) in pfpksk_list @@ -371,7 +371,7 @@ pub fn homomorphic_shift_boolean>( base_log_cbs: DecompositionBaseLog, delta_log: DeltaLog, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!(lwe_out.ciphertext_modulus() == lwe_in.ciphertext_modulus()); debug_assert!( @@ -467,7 +467,7 @@ pub fn cmux_tree_memory_optimized>( lut_per_layer: PolynomialList<&[Scalar]>, ggsw_list: FourierGgswCiphertextListView<'_>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!(lut_per_layer.polynomial_count().0 == 1 << ggsw_list.count()); @@ -510,7 +510,7 @@ pub fn cmux_tree_memory_optimized>( ciphertext_modulus, ); - let (t_fill, mut stack) = stack.make_with(nb_layer, |_| 0_usize); + let (t_fill, stack) = stack.make_with(nb_layer, |_| 0_usize); let mut lut_polynomial_iter = lut_per_layer.iter(); loop { @@ -537,7 +537,7 @@ pub fn cmux_tree_memory_optimized>( for (j, ggsw) in ggsw_list.into_ggsw_iter().rev().enumerate() { if t_fill[j] == 2 { - let (diff_data, stack) = stack.rb_mut().collect_aligned( + let (diff_data, stack) = stack.collect_aligned( CACHELINE_ALIGN, izip!(t1_j.as_ref(), t0_j.as_ref()).map(|(&a, &b)| a.wrapping_sub(b)), ); @@ -648,7 +648,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!(stack.can_hold( circuit_bootstrap_boolean_vertical_packing_scratch::( @@ -686,7 +686,7 @@ pub fn circuit_bootstrap_boolean_vertical_packing>( mut lwe_out: LweCiphertext<&mut [Scalar]>, ggsw_list: FourierGgswCiphertextListView<'_>, fft: FftView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { debug_assert!( lwe_out.ciphertext_modulus().is_native_modulus(), @@ -815,26 +815,15 @@ pub fn vertical_packing>( // the last blind rotation. let (cmux_ggsw, br_ggsw) = ggsw_list.split_at(log_number_of_luts_for_cmux_tree); - let (cmux_tree_lut_res_data, mut stack) = + let (cmux_tree_lut_res_data, stack) = stack.make_aligned_with(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN, |_| { Scalar::ZERO }); let mut cmux_tree_lut_res = GlweCiphertext::from_container(cmux_tree_lut_res_data, polynomial_size, ciphertext_modulus); - cmux_tree_memory_optimized( - cmux_tree_lut_res.as_mut_view(), - lut, - cmux_ggsw, - fft, - stack.rb_mut(), - ); - blind_rotate_assign( - cmux_tree_lut_res.as_mut_view(), - br_ggsw, - fft, - stack.rb_mut(), - ); + cmux_tree_memory_optimized(cmux_tree_lut_res.as_mut_view(), lut, cmux_ggsw, fft, stack); + blind_rotate_assign(cmux_tree_lut_res.as_mut_view(), br_ggsw, fft, stack); // sample extract of the RLWE of the Vertical packing extract_lwe_sample_from_glwe_ciphertext(&cmux_tree_lut_res, &mut lwe_out, MonomialDegree(0)); @@ -855,15 +844,14 @@ pub fn blind_rotate_assign>( mut lut: GlweCiphertext<&mut [Scalar]>, ggsw_list: FourierGgswCiphertextListView<'_>, fft: FftView<'_>, - mut stack: PodStack<'_>, + stack: &mut PodStack, ) { let mut monomial_degree = MonomialDegree(1); for ggsw in ggsw_list.into_ggsw_iter().rev() { let ct_0 = lut.as_mut_view(); - let (ct1_data, stack) = stack - .rb_mut() - .collect_aligned(CACHELINE_ALIGN, ct_0.as_ref().iter().copied()); + let (ct1_data, stack) = + stack.collect_aligned(CACHELINE_ALIGN, ct_0.as_ref().iter().copied()); let mut ct_1 = GlweCiphertext::from_container( &mut *ct1_data, ct_0.polynomial_size(), diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/tests.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/tests.rs index 74dd3333f3..a9a01b5d30 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/wop_pbs/tests.rs @@ -172,7 +172,7 @@ pub fn test_extract_bits() { }; let req = req().unwrap(); let mut mem = GlobalPodBuffer::new(req); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); fourier_bsk .as_mut_view() @@ -225,7 +225,7 @@ pub fn test_extract_bits() { delta_log, number_values_to_extract, fft, - stack.rb_mut(), + stack, ); // Decryption of extracted bit diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs index 5eabd8f344..639619c1ab 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs @@ -29,8 +29,8 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu input: impl Iterator, base_log: DecompositionBaseLog, level: DecompositionLevelCount, - stack: PodStack<'buffers>, - ) -> (Self, PodStack<'buffers>) { + stack: &'buffers mut PodStack, + ) -> (Self, &'buffers mut PodStack) { let (states, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input); ( TensorSignedDecompositionLendingIter { diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs index e4156084fc..c315ef31b8 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs @@ -9,7 +9,7 @@ use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainer use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; use aligned_vec::{avec, ABox}; -use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use rayon::prelude::*; use std::any::TypeId; use std::collections::hash_map::Entry; @@ -383,7 +383,7 @@ impl<'a> FftView<'a> { self, fourier: FourierPolynomialMutView<'out>, standard: PolynomialView<'_, Scalar>, - stack: PodStack<'_>, + stack: &mut PodStack, ) -> FourierPolynomialMutView<'out> { self.forward_with_conv(fourier, standard, convert_forward_torus, stack) } @@ -403,7 +403,7 @@ impl<'a> FftView<'a> { self, fourier: FourierPolynomialMutView<'out>, standard: PolynomialView<'_, Scalar>, - stack: PodStack<'_>, + stack: &mut PodStack, ) -> FourierPolynomialMutView<'out> { self.forward_with_conv(fourier, standard, convert_forward_integer, stack) } @@ -462,7 +462,7 @@ impl<'a> FftView<'a> { self, standard: PolynomialMutView<'_, Scalar>, fourier: FourierPolynomialView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv(standard, fourier, convert_backward_torus, stack); } @@ -481,7 +481,7 @@ impl<'a> FftView<'a> { self, standard: PolynomialMutView<'_, Scalar>, fourier: FourierPolynomialView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv(standard, fourier, convert_add_backward_torus, stack); } @@ -492,7 +492,7 @@ impl<'a> FftView<'a> { self, standard: PolynomialMutView<'_, Scalar>, fourier: FourierPolynomialMutView<'_>, - stack: PodStack<'_>, + stack: &mut PodStack, ) { self.backward_with_conv_in_place(standard, fourier, convert_add_backward_torus, stack); } @@ -506,7 +506,7 @@ impl<'a> FftView<'a> { fourier: FourierPolynomialMutView<'out>, standard: PolynomialView<'_, Scalar>, conv_fn: F, - stack: PodStack<'_>, + stack: &mut PodStack, ) -> FourierPolynomialMutView<'out> { let fourier = fourier.data; let standard = standard.as_ref(); @@ -526,7 +526,7 @@ impl<'a> FftView<'a> { mut standard: PolynomialMutView<'_, Scalar>, fourier: FourierPolynomialView<'_>, conv_fn: F, - stack: PodStack<'_>, + stack: &mut PodStack, ) { let fourier = fourier.data; let standard = standard.as_mut(); @@ -548,7 +548,7 @@ impl<'a> FftView<'a> { mut standard: PolynomialMutView<'_, Scalar>, fourier: FourierPolynomialMutView<'_>, conv_fn: F, - stack: PodStack<'_>, + stack: &mut PodStack, ) { let fourier = fourier.data; let standard = standard.as_mut(); @@ -771,9 +771,9 @@ pub fn par_convert_polynomials_list_to_fourier( .unwrap() .try_unaligned_bytes_required() .unwrap(); - let mut stack = vec![0; stack_len]; + let mut mem = vec![0; stack_len]; - let mut stack = PodStack::new(&mut stack); + let stack = PodStack::new(&mut mem); for (fourier_poly, standard_poly) in izip!( fourier_poly_chunk.chunks_exact_mut(f_polynomial_size), @@ -782,7 +782,7 @@ pub fn par_convert_polynomials_list_to_fourier( fft.forward_as_torus( FourierPolynomialMutView { data: fourier_poly }, PolynomialView::from_container(standard_poly), - stack.rb_mut(), + stack, ); } }); diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs index 106e433c9e..68036f3089 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs @@ -28,11 +28,11 @@ fn test_roundtrip() { .unwrap() .and(fft.backward_scratch().unwrap()), ); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); // Simple roundtrip - fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut()); - fft.backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut()); + fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack); + fft.backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack); for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) { if Scalar::BITS == 32 { @@ -45,8 +45,8 @@ fn test_roundtrip() { // Simple add roundtrip // Need to zero out the buffer to have a correct result as we will be adding the result roundtrip.as_mut().fill(Scalar::ZERO); - fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut()); - fft.add_backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut()); + fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack); + fft.add_backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack); for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) { if Scalar::BITS == 32 { @@ -59,12 +59,8 @@ fn test_roundtrip() { // Forward, then add backward in place // Need to zero out the buffer to have a correct result as we will be adding the result roundtrip.as_mut().fill(Scalar::ZERO); - fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut()); - fft.add_backward_in_place_as_torus( - roundtrip.as_mut_view(), - fourier.as_mut_view(), - stack.rb_mut(), - ); + fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack); + fft.add_backward_in_place_as_torus(roundtrip.as_mut_view(), fourier.as_mut_view(), stack); for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) { if Scalar::BITS == 32 { @@ -134,10 +130,10 @@ fn test_product() { .unwrap() .and(fft.backward_scratch().unwrap()), ); - let mut stack = PodStack::new(&mut mem); + let stack = PodStack::new(&mut mem); - fft.forward_as_torus(fourier0.as_mut_view(), poly0.as_view(), stack.rb_mut()); - fft.forward_as_integer(fourier1.as_mut_view(), poly1.as_view(), stack.rb_mut()); + fft.forward_as_torus(fourier0.as_mut_view(), poly0.as_view(), stack); + fft.forward_as_integer(fourier1.as_mut_view(), poly1.as_view(), stack); for (f0, f1) in izip!(&mut *fourier0.data, &*fourier1.data) { *f0 *= *f1; @@ -153,7 +149,7 @@ fn test_product() { fft.backward_as_torus( convolution_from_fft.as_mut_view(), fourier0.as_view(), - stack.rb_mut(), + stack, ); for (expected, actual) in izip!( @@ -175,7 +171,7 @@ fn test_product() { fft.add_backward_as_torus( convolution_from_fft.as_mut_view(), fourier0.as_view(), - stack.rb_mut(), + stack, ); for (expected, actual) in izip!( @@ -199,7 +195,7 @@ fn test_product() { fft.add_backward_in_place_as_torus( convolution_from_fft.as_mut_view(), fourier0.as_mut_view(), - stack.rb_mut(), + stack, ); for (expected, actual) in izip!(