From 204ea237d914c40d9146c242eeb1449fda7d5553 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 11 Dec 2024 09:53:24 +0100 Subject: [PATCH] fix(shortint): fix compression encoding change not being taken into account - this maps better to what was optimized and will dramatically diminish the pfail as we now have 2 more bits for the LUT redundancy --- tfhe/src/shortint/engine/mod.rs | 37 +++++++++++-- .../shortint/list_compression/compression.rs | 49 +++++++++++++---- tfhe/src/shortint/server_key/mod.rs | 52 ++++++++++++++++++- 3 files changed, 122 insertions(+), 16 deletions(-) diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index 0e3b34f213..14e245c755 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -85,6 +85,33 @@ pub(crate) fn fill_accumulator( carry_modulus: CarryModulus, f: F, ) -> u64 +where + C: ContainerMut, + F: Fn(u64) -> u64, +{ + fill_accumulator_with_encoding( + accumulator, + polynomial_size, + glwe_size, + message_modulus, + carry_modulus, + message_modulus, + carry_modulus, + f, + ) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn fill_accumulator_with_encoding( + accumulator: &mut GlweCiphertext, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + input_message_modulus: MessageModulus, + input_carry_modulus: CarryModulus, + output_message_modulus: MessageModulus, + output_carry_modulus: CarryModulus, + f: F, +) -> u64 where C: ContainerMut, F: Fn(u64) -> u64, @@ -97,13 +124,13 @@ where accumulator_view.get_mut_mask().as_mut().fill(0); // Modulus of the msg contained in the msg bits and operations buffer - let modulus_sup = (message_modulus.0 * carry_modulus.0) as usize; + let input_modulus_sup = (input_message_modulus.0 * input_carry_modulus.0) as usize; // N/(p/2) = size of each block - let box_size = polynomial_size.0 / modulus_sup; + let box_size = polynomial_size.0 / input_modulus_sup; // Value of the shift we multiply our messages by - let delta = (1_u64 << 63) / (message_modulus.0 * carry_modulus.0); + let output_delta = (1_u64 << 63) / (output_message_modulus.0 * output_carry_modulus.0); let mut body = accumulator_view.get_mut_body(); let accumulator_u64 = body.as_mut(); @@ -111,11 +138,11 @@ where // Tracking the max value of the function to define the degree later let mut max_value = 0; - for i in 0..modulus_sup { + for i in 0..input_modulus_sup { let index = i * box_size; let f_eval = f(i as u64); max_value = max_value.max(f_eval); - accumulator_u64[index..index + box_size].fill(f_eval * delta); + accumulator_u64[index..index + box_size].fill(f_eval * output_delta); } let half_box_size = box_size / 2; diff --git a/tfhe/src/shortint/list_compression/compression.rs b/tfhe/src/shortint/list_compression/compression.rs index 9fbdd24e4a..114bb570c9 100644 --- a/tfhe/src/shortint/list_compression/compression.rs +++ b/tfhe/src/shortint/list_compression/compression.rs @@ -7,9 +7,9 @@ use crate::core_crypto::prelude::{ }; use crate::shortint::ciphertext::CompressedCiphertextList; use crate::shortint::engine::ShortintEngine; -use crate::shortint::parameters::NoiseLevel; +use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel}; use crate::shortint::server_key::{ - apply_programmable_bootstrap, generate_lookup_table, unchecked_scalar_mul_assign, + apply_programmable_bootstrap, generate_lookup_table_with_encoding, unchecked_scalar_mul_assign, }; use crate::shortint::{Ciphertext, CiphertextModulus, MaxNoiseLevel}; use rayon::iter::ParallelIterator; @@ -126,18 +126,49 @@ impl CompressionKey { } impl DecompressionKey { - pub fn unpack(&self, packed: &CompressedCiphertextList, index: usize) -> Option { + pub fn unpack( + &self, + packed: &CompressedCiphertextList, + index: usize, + ) -> Result { + if packed.message_modulus.0 != packed.carry_modulus.0 { + return Err(crate::Error::new(format!( + "Tried to unpack values from a list where message modulus \ + ({:?}) is != carry modulus ({:?}), this is not supported.", + packed.message_modulus, packed.carry_modulus, + ))); + } + if index >= packed.count.0 { - return None; + return Err(crate::Error::new(format!( + "Tried getting index {index} for CompressedCiphertextList \ + with {} elements, out of bound access.", + packed.count.0 + ))); } - let carry_extract = generate_lookup_table( + let encryption_cleartext_modulus = packed.message_modulus.0 * packed.carry_modulus.0; + // We multiply by message_modulus during compression so the actual modulus for the + // compression is smaller + let compression_cleartext_modulus = encryption_cleartext_modulus / packed.message_modulus.0; + let effective_compression_message_modulus = MessageModulus(compression_cleartext_modulus); + let effective_compression_carry_modulus = CarryModulus(1); + + let decompression_rescale = generate_lookup_table_with_encoding( self.out_glwe_size(), self.out_polynomial_size(), packed.ciphertext_modulus, + // Input moduli are the effective compression ones + effective_compression_message_modulus, + effective_compression_carry_modulus, + // Output moduli are directly the ones stored in the list packed.message_modulus, packed.carry_modulus, - |x| x / packed.message_modulus.0, + // Here we do not divide by message_modulus + // Example: in the 2_2 case we are mapping a 2 bits message onto a 4 bits space, we + // want to keep the original 2 bits value in the 4 bits space, so we apply the identity + // and the encoding will rescale it for us. + |x| x, ); let polynomial_size = packed.modulus_switched_glwe_ciphertext_list[0].polynomial_size(); @@ -181,14 +212,14 @@ impl DecompressionKey { &self.blind_rotate_key, &intermediate_lwe, &mut output_br, - &carry_extract.acc, + &decompression_rescale.acc, buffers, ); }); - Some(Ciphertext::new( + Ok(Ciphertext::new( output_br, - carry_extract.degree, + decompression_rescale.degree, NoiseLevel::NOMINAL, packed.message_modulus, packed.carry_modulus, diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 8571586a39..49963536da 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -43,16 +43,17 @@ use crate::core_crypto::prelude::ComputationBuffers; use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel}; use crate::shortint::client_key::ClientKey; use crate::shortint::engine::{ - fill_accumulator, fill_accumulator_no_encoding, fill_many_lut_accumulator, ShortintEngine, + fill_accumulator, fill_accumulator_no_encoding, fill_accumulator_with_encoding, + fill_many_lut_accumulator, ShortintEngine, }; use crate::shortint::parameters::{ CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus, }; use crate::shortint::{EncryptionKeyChoice, PBSOrder}; -use ::tfhe_versionable::Versionize; use aligned_vec::ABox; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display, Formatter}; +use tfhe_versionable::Versionize; #[cfg(feature = "pbs-stats")] pub mod pbs_stats { @@ -1563,6 +1564,53 @@ where } } +/// Caller needs to ensure that the operation applied is coherent from an encoding perspective. +/// +/// For example: +/// +/// Input encoding has 2 bits and output encoding has 4 bits, applying the identity lut would map +/// the following: +/// +/// 0|00|xx -> 0|00|00 +/// 0|01|xx -> 0|00|01 +/// 0|10|xx -> 0|00|10 +/// 0|11|xx -> 0|00|11 +/// +/// The reason is the identity function is computed in the input space but the scaling is done in +/// the output space, as there are more bits in the output space, the delta is smaller hence the +/// apparent "division" happening. +#[allow(clippy::too_many_arguments)] +pub(crate) fn generate_lookup_table_with_encoding( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + input_message_modulus: MessageModulus, + input_carry_modulus: CarryModulus, + output_message_modulus: MessageModulus, + output_carry_modulus: CarryModulus, + f: F, +) -> LookupTableOwned +where + F: Fn(u64) -> u64, +{ + let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, ciphertext_modulus); + let max_value = fill_accumulator_with_encoding( + &mut acc, + polynomial_size, + glwe_size, + input_message_modulus, + input_carry_modulus, + output_message_modulus, + output_carry_modulus, + f, + ); + + LookupTableOwned { + acc, + degree: Degree::new(max_value), + } +} + #[derive(Copy, Clone)] pub struct PBSConformanceParameters { pub in_lwe_dimension: LweDimension,