Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(tfhe): update dependencies with breaking changes #1497

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ lazy_static = { version = "1.4.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
rayon = { version = "1.5.0" }
bincode = "1.3.3"
concrete-fft = { version = "0.4.1", features = ["serde", "fft128"] }
concrete-ntt = { version = "0.1.2" }
concrete-fft = { version = "0.5.0", features = ["serde", "fft128"] }
concrete-ntt = { version = "0.2.0" }
pulp = "0.18.22"
tfhe-cuda-backend = { version = "0.4.0-alpha.0", path = "../backends/tfhe-cuda-backend", optional = true }
aligned-vec = { version = "0.5", features = ["serde"] }
dyn-stack = { version = "0.9" }
dyn-stack = { version = "0.10" }
paste = "1.0.7"
fs2 = { version = "0.4.3", optional = true }
# Used for OPRF in shortint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,10 @@ pub fn blind_rotate_ntt64_assign_mem_optimized<InputCont, OutputCont, KeyCont>(
if *lwe_mask_element != 0u64 {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) =
let (ct1, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
let mut ct1 = GlweCiphertextMutView::from_container(
&mut *ct1,
lut_poly_size,
ciphertext_modulus,
);
let mut ct1 =
GlweCiphertextMutView::from_container(ct1, lut_poly_size, ciphertext_modulus);

// We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat}
for mut poly in ct1.as_mut_polynomial_list().iter_mut() {
Expand Down Expand Up @@ -503,10 +500,10 @@ pub fn programmable_bootstrap_ntt64_lwe_ciphertext_mem_optimized<
accumulator.ciphertext_modulus()
);

let (mut local_accumulator_data, stack) =
let (local_accumulator_data, stack) =
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
let mut local_accumulator = GlweCiphertextMutView::from_container(
&mut *local_accumulator_data,
local_accumulator_data,
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
Expand Down Expand Up @@ -568,12 +565,11 @@ pub(crate) fn add_external_product_ntt64_assign<InputGlweCont>(
out.ciphertext_modulus(),
);

let (mut output_fft_buffer, mut substack0) =
let (output_fft_buffer, mut substack0) =
stack.make_aligned_raw::<u64>(poly_size * ggsw.glwe_size().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
let output_fft_buffer = &mut *output_fft_buffer;
let mut is_output_uninit = true;

{
Expand Down Expand Up @@ -616,17 +612,16 @@ pub(crate) fn add_external_product_ntt64_assign<InputGlweCont>(
glwe_decomp_term.as_polynomial_list().iter()
)
.for_each(|(ggsw_row, glwe_poly)| {
let (mut ntt_poly, _) =
substack2.rb_mut().make_aligned_raw::<u64>(poly_size, align);
let (ntt_poly, _) = substack2.rb_mut().make_aligned_raw::<u64>(poly_size, align);
// We perform the forward ntt transform for the glwe polynomial
ntt.forward(PolynomialMutView::from_container(&mut ntt_poly), glwe_poly);
ntt.forward(PolynomialMutView::from_container(ntt_poly), glwe_poly);
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.

update_with_fmadd_ntt64(
output_fft_buffer,
ggsw_row.as_ref(),
&ntt_poly,
ntt_poly,
is_output_uninit,
poly_size,
ntt,
Expand Down
16 changes: 6 additions & 10 deletions tfhe/src/core_crypto/commons/math/decomposition/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::core_crypto::commons::math::decomposition::{
};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use dyn_stack::{DynArray, PodStack, ReborrowMut};
use dyn_stack::{PodStack, ReborrowMut};

/// An iterator that yields the terms of the signed decomposition of an integer.
///
Expand Down Expand Up @@ -288,9 +288,9 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> {
// ...0001111
mod_b_mask: u64,
// The internal states of each decomposition
states: DynArray<'buffers, u64>,
states: &'buffers mut [u64],
// Corresponding input signs
input_signs: DynArray<'buffers, u8>,
input_signs: &'buffers mut [u8],
// A flag which stores whether the iterator is a fresh one (for the recompose method).
fresh: bool,
ciphertext_modulus: u64,
Expand All @@ -306,9 +306,9 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
) -> (Self, PodStack<'buffers>) {
let shift = modulus.ceil_ilog2() as usize - decomposer.base_log * decomposer.level_count;
let input_size = input.len();
let (mut states, stack) =
let (states, stack) =
stack.make_aligned_raw::<u64>(input_size, aligned_vec::CACHELINE_ALIGN);
let (mut input_signs, stack) =
let (input_signs, stack) =
stack.make_aligned_raw::<u8>(input_size, aligned_vec::CACHELINE_ALIGN);

for ((i, state), sign) in input
Expand Down Expand Up @@ -393,11 +393,7 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
&mut self,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, u64>,
PodStack<'a>,
) {
) -> (DecompositionLevel, &'a mut [u64], PodStack<'a>) {
let (glwe_level, _, glwe_decomp_term) = self.next_term().unwrap();
let (glwe_decomp_term, substack2) =
substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
ggsw.decomposition_base_log(),
ggsw.decomposition_level_count(),
);
let (mut output_fft_buffer, mut substack0) =
let (output_fft_buffer, mut substack0) =
stack.make_aligned_raw::<c64>(fourier_poly_size * ggsw.glwe_size_out().0, align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
let output_fft_buffer = &mut *output_fft_buffer;
let mut is_output_uninit = true;

{
Expand Down Expand Up @@ -244,14 +243,14 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
glwe_decomp_term.get_mask().as_polynomial_list().iter()
)
.for_each(|(ggsw_row, glwe_poly)| {
let (mut fourier, substack3) = substack2
let (fourier, substack3) = substack2
.rb_mut()
.make_aligned_raw::<c64>(fourier_poly_size, align);

// We perform the forward fft transform for the glwe polynomial
let fourier = fft
.forward_as_integer(
FourierPolynomialMutView { data: &mut fourier },
FourierPolynomialMutView { data: fourier },
glwe_poly,
substack3,
)
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ where
if *lwe_mask_element != Scalar::ZERO {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) =
let (ct1, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0.as_ref().iter().copied());
let mut ct1 = GlweCiphertextMutView::from_container(
&mut *ct1,
ct1,
ct0.polynomial_size(),
ct0.ciphertext_modulus(),
);
Expand Down Expand Up @@ -361,10 +361,10 @@ where
return this.bootstrap_u128(&mut lwe_out, &lwe_in, &accumulator, fft, stack);
}

let (mut local_accumulator_data, stack) =
let (local_accumulator_data, stack) =
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
let mut local_accumulator = GlweCiphertextMutView::from_container(
&mut *local_accumulator_data,
local_accumulator_data,
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
Expand Down
51 changes: 21 additions & 30 deletions tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
ggsw.decomposition_level_count(),
);

let (mut output_fft_buffer_re0, stack) =
let (output_fft_buffer_re0, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_re1, stack) =
let (output_fft_buffer_re1, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_im0, stack) =
let (output_fft_buffer_im0, stack) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);
let (mut output_fft_buffer_im1, mut substack0) =
let (output_fft_buffer_im1, mut substack0) =
stack.make_aligned_raw::<f64>(fourier_poly_size * ggsw.glwe_size().0, align);

// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
Expand Down Expand Up @@ -455,30 +455,30 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
) {
let len = fourier_poly_size;
let stack = substack2.rb_mut();
let (mut fourier_re0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_re1, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_im0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (mut fourier_im1, _) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_re0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_re1, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_im0, stack) = stack.make_aligned_raw::<f64>(len, align);
let (fourier_im1, _) = stack.make_aligned_raw::<f64>(len, align);
// We perform the forward fft transform for the glwe polynomial
fft.forward_as_integer(
&mut fourier_re0,
&mut fourier_re1,
&mut fourier_im0,
&mut fourier_im1,
fourier_re0,
fourier_re1,
fourier_im0,
fourier_im1,
glwe_poly.as_ref(),
);
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.
update_with_fmadd(
&mut output_fft_buffer_re0,
&mut output_fft_buffer_re1,
&mut output_fft_buffer_im0,
&mut output_fft_buffer_im1,
output_fft_buffer_re0,
output_fft_buffer_re1,
output_fft_buffer_im0,
output_fft_buffer_im1,
ggsw_row,
&fourier_re0,
&fourier_re1,
&fourier_im0,
&fourier_im1,
fourier_re0,
fourier_re1,
fourier_im0,
fourier_im1,
is_output_uninit,
fourier_poly_size,
);
Expand All @@ -495,11 +495,6 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
//
// We iterate over the polynomials in the output.
if !is_output_uninit {
let output_fft_buffer_re0 = output_fft_buffer_re0;
let output_fft_buffer_re1 = output_fft_buffer_re1;
let output_fft_buffer_im0 = output_fft_buffer_im0;
let output_fft_buffer_im1 = output_fft_buffer_im1;

for (mut out, fourier_re0, fourier_re1, fourier_im0, fourier_im1) in izip!(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer_re0.into_chunks(fourier_poly_size),
Expand Down Expand Up @@ -532,11 +527,7 @@ fn collect_next_term<'a, Scalar: UnsignedTorus>(
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
substack1: &'a mut PodStack,
align: usize,
) -> (
DecompositionLevel,
dyn_stack::DynArray<'a, Scalar>,
PodStack<'a>,
) {
) -> (DecompositionLevel, &'a mut [Scalar], PodStack<'a>) {
let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap();
let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term);
(glwe_level, glwe_decomp_term, substack2)
Expand Down
20 changes: 6 additions & 14 deletions tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,27 +495,19 @@ impl<'a> Fft128View<'a> {
debug_assert_eq!(n, 2 * fourier_im0.len());
debug_assert_eq!(n, 2 * fourier_im1.len());

let (mut tmp_re0, stack) =
let (tmp_re0, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re0.iter().copied());
let (mut tmp_re1, stack) =
let (tmp_re1, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_re1.iter().copied());
let (mut tmp_im0, stack) =
let (tmp_im0, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im0.iter().copied());
let (mut tmp_im1, _) =
let (tmp_im1, _) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier_im1.iter().copied());

self.plan
.inv(&mut tmp_re0, &mut tmp_re1, &mut tmp_im0, &mut tmp_im1);
self.plan.inv(tmp_re0, tmp_re1, tmp_im0, tmp_im1);

let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(
standard_re,
standard_im,
&tmp_re0,
&tmp_re1,
&tmp_im0,
&tmp_im1,
);
conv_fn(standard_re, standard_im, tmp_re0, tmp_re1, tmp_im0, tmp_im1);
}
}

Expand Down
10 changes: 5 additions & 5 deletions tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ where
if *lwe_mask_element != 0 {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1_lo, stack) =
let (ct1_lo, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0_lo.as_ref().iter().copied());
let (mut ct1_hi, stack) =
let (ct1_hi, stack) =
stack.collect_aligned(CACHELINE_ALIGN, ct0_hi.as_ref().iter().copied());
let mut ct1_lo = GlweCiphertextMutView::from_container(
&mut *ct1_lo,
Expand Down Expand Up @@ -177,9 +177,9 @@ where
let align = CACHELINE_ALIGN;
let ciphertext_modulus = accumulator.ciphertext_modulus();

let (mut local_accumulator_lo, stack) =
let (local_accumulator_lo, stack) =
stack.collect_aligned(align, accumulator.as_ref().iter().map(|i| *i as u64));
let (mut local_accumulator_hi, mut stack) = stack.collect_aligned(
let (local_accumulator_hi, mut stack) = stack.collect_aligned(
align,
accumulator.as_ref().iter().map(|i| (*i >> 64) as u64),
);
Expand Down Expand Up @@ -207,7 +207,7 @@ where
fft,
stack.rb_mut(),
);
let (mut local_accumulator, _) = stack.collect_aligned(
let (local_accumulator, _) = stack.collect_aligned(
align,
izip!(local_accumulator_lo.as_ref(), local_accumulator_hi.as_ref())
.map(|(&lo, &hi)| lo as u128 | ((hi as u128) << 64)),
Expand Down
Loading
Loading