Skip to content

Commit

Permalink
chore(tfhe): update dependencies with breaking changes
Browse files Browse the repository at this point in the history
- concrete-fft to 0.5 and concrete-ntt 0.2.0 due to rust AVX512 breaking
change (fix for bad args in function)
- dyn-stack to 0.10 due to concrete-fft update
  • Loading branch information
IceTDrinker committed Aug 29, 2024
1 parent 6e2908a commit 0e0d860
Show file tree
Hide file tree
Showing 17 changed files with 146 additions and 189 deletions.
6 changes: 3 additions & 3 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ lazy_static = { version = "1.4.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
rayon = { version = "1.5.0" }
bincode = "1.3.3"
concrete-fft = { version = "0.4.1", features = ["serde", "fft128"] }
concrete-ntt = { version = "0.1.2" }
concrete-fft = { version = "0.5.0", features = ["serde", "fft128"] }
concrete-ntt = { version = "0.2.0" }
pulp = "0.18.22"
tfhe-cuda-backend = { version = "0.4.0-alpha.0", path = "../backends/tfhe-cuda-backend", optional = true }
aligned-vec = { version = "0.5", features = ["serde"] }
dyn-stack = { version = "0.9" }
dyn-stack = { version = "0.10" }
paste = "1.0.7"
fs2 = { version = "0.4.3", optional = true }
# Used for OPRF in shortint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,10 @@ pub fn blind_rotate_ntt64_assign_mem_optimized<InputCont, OutputCont, KeyCont>(
if *lwe_mask_element != 0u64 {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) =
let (ct1, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
let mut ct1 = GlweCiphertextMutView::from_container(
&mut *ct1,
lut_poly_size,
ciphertext_modulus,
);
let mut ct1 =
GlweCiphertextMutView::from_container(ct1, lut_poly_size, ciphertext_modulus);

// We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat}
for mut poly in ct1.as_mut_polynomial_list().iter_mut() {
Expand Down Expand Up @@ -503,10 +500,10 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized<
accumulator.ciphertext_modulus()
);

let (mut local_accumulator_data, stack) =
let (local_accumulator_data, stack) =
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
let mut local_accumulator = GlweCiphertextMutView::from_container(
&mut *local_accumulator_data,
local_accumulator_data,
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
Expand Down Expand Up @@ -568,12 +565,11 @@ pub(crate) fn add_external_product_ntt64_assign<InputGlweCont>(
out.ciphertext_modulus(),
);

let (mut output_fft_buffer, mut substack0) =
let (output_fft_buffer, mut substack0) =
stack.make_aligned_raw::<u64>(poly_size * ggsw.glwe_size().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
let output_fft_buffer = &mut *output_fft_buffer;
let mut is_output_uninit = true;

{
Expand Down Expand Up @@ -616,17 +612,16 @@ pub(crate) fn add_external_product_ntt64_assign<InputGlweCont>(
glwe_decomp_term.as_polynomial_list().iter()
)
.for_each(|(ggsw_row, glwe_poly)| {
let (mut ntt_poly, _) =
substack2.rb_mut().make_aligned_raw::<u64>(poly_size, align);
let (ntt_poly, _) = substack2.rb_mut().make_aligned_raw::<u64>(poly_size, align);
// We perform the forward ntt transform for the glwe polynomial
ntt.forward(PolynomialMutView::from_container(&mut ntt_poly), glwe_poly);
ntt.forward(PolynomialMutView::from_container(ntt_poly), glwe_poly);
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.

update_with_fmadd_ntt64(
output_fft_buffer,
ggsw_row.as_ref(),
&ntt_poly,
ntt_poly,
is_output_uninit,
poly_size,
ntt,
Expand Down
16 changes: 6 additions & 10 deletions tfhe/src/core_crypto/commons/math/decomposition/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::core_crypto::commons::math::decomposition::{
};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use dyn_stack::{DynArray, PodStack, ReborrowMut};
use dyn_stack::{PodStack, ReborrowMut};

/// An iterator that yields the terms of the signed decomposition of an integer.
///
Expand Down Expand Up @@ -288,9 +288,9 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> {
// ...0001111
mod_b_mask: u64,
// The internal states of each decomposition
states: DynArray<'buffers, u64>,
states: &'buffers mut [u64],
// Corresponding input signs
input_signs: DynArray<'buffers, u8>,
input_signs: &'buffers mut [u8],
// A flag which stores whether the iterator is a fresh one (for the recompose method).
fresh: bool,
ciphertext_modulus: u64,
Expand All @@ -306,9 +306,9 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
) -> (Self, PodStack<'buffers>) {
let shift = modulus.ceil_ilog2() as usize - decomposer.base_log * decomposer.level_count;
let input_size = input.len();
let (mut states, stack) =
let (states, stack) =
stack.make_aligned_raw::<u64>(input_size, aligned_vec::CACHELINE_ALIGN);
let (mut input_signs, stack) =
let (input_signs, stack) =
stack.make_aligned_raw::<u8>(input_size, aligned_vec::CACHELINE_ALIGN);

for ((i, state), sign) in input
Expand Down Expand Up @@ -393,11 +393,7 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
&mut self,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, u64>,
PodStack<'a>,
) {
) -> (DecompositionLevel, &'a mut [u64], PodStack<'a>) {
let (glwe_level, _, glwe_decomp_term) = self.next_term().unwrap();
let (glwe_decomp_term, substack2) =
substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
ggsw.decomposition_base_log(),
ggsw.decomposition_level_count(),
);
let (mut output_fft_buffer, mut substack0) =
let (output_fft_buffer, mut substack0) =
stack.make_aligned_raw::<c64>(fourier_poly_size * ggsw.glwe_size_out().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
let output_fft_buffer = &mut *output_fft_buffer;
let mut is_output_uninit = true;

{
Expand Down Expand Up @@ -244,14 +243,14 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
glwe_decomp_term.get_mask().as_polynomial_list().iter()
)
.for_each(|(ggsw_row, glwe_poly)| {
let (mut fourier, substack3) = substack2
let (fourier, substack3) = substack2
.rb_mut()
.make_aligned_raw::<c64>(fourier_poly_size, align);

// We perform the forward fft transform for the glwe polynomial
let fourier = fft
.forward_as_integer(
FourierPolynomialMutView { data: &mut fourier },
FourierPolynomialMutView { data: fourier },
glwe_poly,
substack3,
)
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ where
if *lwe_mask_element != Scalar::ZERO {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) =
let (ct1, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
let mut ct1 = GlweCiphertextMutView::from_container(
&mut *ct1,
ct1,
ct0.polynomial_size(),
ct0.ciphertext_modulus(),
);
Expand Down Expand Up @@ -361,10 +361,10 @@ where
return this.bootstrap_u128(&mut lwe_out, &lwe_in, &accumulator, fft, stack);
}

let (mut local_accumulator_data, stack) =
let (local_accumulator_data, stack) =
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
let mut local_accumulator = GlweCiphertextMutView::from_container(
&mut *local_accumulator_data,
local_accumulator_data,
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
Expand Down
51 changes: 21 additions & 30 deletions tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
ggsw.decomposition_level_count(),
);

let (mut output_fft_buffer_re0, stack) =
let (output_fft_buffer_re0, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_re1, stack) =
let (output_fft_buffer_re1, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_im0, stack) =
let (output_fft_buffer_im0, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_im1, mut substack0) =
let (output_fft_buffer_im1, mut substack0) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);

// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
Expand Down Expand Up @@ -455,30 +455,30 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
) {
let len = fourier_poly_size;
let stack = substack2.rb_mut();
let (mut fourier_re0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_re1, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_im0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_im1, _) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_re0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_re1, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_im0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_im1, _) = stack.make_aligned_raw::<f64>(len, align);
// We perform the forward fft transform for the glwe polynomial
fft.forward_as_integer(
&mut fourier_re0,
&mut fourier_re1,
&mut fourier_im0,
&mut fourier_im1,
fourier_re0,
fourier_re1,
fourier_im0,
fourier_im1,
glwe_poly.as_ref(),
);
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.
update_with_fmadd(
&mut output_fft_buffer_re0,
&mut output_fft_buffer_re1,
&mut output_fft_buffer_im0,
&mut output_fft_buffer_im1,
output_fft_buffer_re0,
output_fft_buffer_re1,
output_fft_buffer_im0,
output_fft_buffer_im1,
ggsw_row,
&fourier_re0,
&fourier_re1,
&fourier_im0,
&fourier_im1,
fourier_re0,
fourier_re1,
fourier_im0,
fourier_im1,
is_output_uninit,
fourier_poly_size,
);
Expand All @@ -495,11 +495,6 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
//
// We iterate over the polynomials in the output.
if !is_output_uninit {
let output_fft_buffer_re0 = output_fft_buffer_re0;
let output_fft_buffer_re1 = output_fft_buffer_re1;
let output_fft_buffer_im0 = output_fft_buffer_im0;
let output_fft_buffer_im1 = output_fft_buffer_im1;

for (mut out, fourier_re0, fourier_re1, fourier_im0, fourier_im1) in izip!(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer_re0.into_chunks(fourier_poly_size),
Expand Down Expand Up @@ -532,11 +527,7 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>(
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, Scalar>,
PodStack<'a>,
) {
) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) {
let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap();
let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
(glwe_level, glwe_decomp_term, substack2)
Expand Down
20 changes: 6 additions & 14 deletions tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,27 +495,19 @@ impl<'a> Fft128View<'a> {
debug_assert_eq!(n, 2 * fourier_im0.len());
debug_assert_eq!(n, 2 * fourier_im1.len());

let (mut tmp_re0, stack) =
let (tmp_re0, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re0.iter().copied());
let (mut tmp_re1, stack) =
let (tmp_re1, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re1.iter().copied());
let (mut tmp_im0, stack) =
let (tmp_im0, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im0.iter().copied());
let (mut tmp_im1, _) =
let (tmp_im1, _) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im1.iter().copied());

self.plan
.inv(&mut tmp_re0, &mut tmp_re1, &mut tmp_im0, &mut tmp_im1);
self.plan.inv(tmp_re0, tmp_re1, tmp_im0, tmp_im1);

let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(
standard_re,
standard_im,
&tmp_re0,
&tmp_re1,
&tmp_im0,
&tmp_im1,
);
conv_fn(standard_re, standard_im, tmp_re0, tmp_re1, tmp_im0, tmp_im1);
}
}

Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/core_crypto/fft_impl/fft128/math/fft/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ fn test_product<Scalar: UnsignedTorus>() {
);

for (f0_re0, f0_re1, f0_im0, f0_im1, f1_re0, f1_re1, f1_im0, f1_im1) in izip!(
&mut *fourier0_re0,
&mut *fourier0_re1,
&mut *fourier0_im0,
&mut *fourier0_im1,
fourier0_re0,
fourier0_re1,
fourier0_im0,
fourier0_im1,
&*fourier1_re0,
&*fourier1_re1,
&*fourier1_im0,
Expand Down
10 changes: 5 additions & 5 deletions tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ where
if *lwe_mask_element != 0 {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1_lo, stack) =
let (ct1_lo, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0_lo.as_ref().iter().copied());
let (mut ct1_hi, stack) =
let (ct1_hi, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0_hi.as_ref().iter().copied());
let mut ct1_lo = GlweCiphertextMutView::from_container(
&mut *ct1_lo,
Expand Down Expand Up @@ -177,9 +177,9 @@ where
let align = CACHELINE_ALIGN;
let ciphertext_modulus = accumulator.ciphertext_modulus();

let (mut local_accumulator_lo, stack) =
let (local_accumulator_lo, stack) =
stack.collect_aligned(align, accumulator.as_ref().iter().map(|i| *i as u64));
let (mut local_accumulator_hi, mut stack) = stack.collect_aligned(
let (local_accumulator_hi, mut stack) = stack.collect_aligned(
align,
accumulator.as_ref().iter().map(|i| (*i >> 64) as u64),
);
Expand Down Expand Up @@ -207,7 +207,7 @@ where
fft,
stack.rb_mut(),
);
let (mut local_accumulator, _) = stack.collect_aligned(
let (local_accumulator, _) = stack.collect_aligned(
align,
izip!(local_accumulator_lo.as_ref(), local_accumulator_hi.as_ref())
.map(|(&lo, &hi)| lo as u128 | ((hi as u128) << 64)),
Expand Down
Loading

0 comments on commit 0e0d860

Please sign in to comment.