Skip to content

Commit

Permalink
refactor(tfhe)!: update key level order for better performance
Browse files Browse the repository at this point in the history
- use natural order for decomposition levels in bsk

co-authored-by: Agnes Leroy <[email protected]>
  • Loading branch information
IceTDrinker and agnesLeroy committed Nov 5, 2024
1 parent dda9388 commit 615ed3d
Show file tree
Hide file tree
Showing 45 changed files with 361 additions and 269 deletions.
6 changes: 2 additions & 4 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);

for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
for (int j = 0; j < level_count; j++) {
auto ksk_block =
get_ith_block(ksk, i, j, lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
Expand Down Expand Up @@ -209,8 +208,7 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(

// block of key for current lwe coefficient (cur_input_lwe[i])
auto ksk_block = &fp_ksk[i * ksk_block_size];
for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
for (int j = 0; j < level_count; j++) {
auto ksk_glwe = &ksk_block[j * glwe_size * polynomial_size];
// Iterate through each level and multiply by the ksk piece
auto ksk_glwe_chunk = &ksk_glwe[poly_id * coef_per_block];
Expand Down
12 changes: 6 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ __device__ const T *get_ith_mask_kth_block(const T *ptr, int i, int k,
uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}

Expand All @@ -35,8 +35,8 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}
template <typename T>
Expand All @@ -45,8 +45,8 @@ __device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1) +
glwe_dimension * polynomial_size / 2];
}
Expand Down
22 changes: 13 additions & 9 deletions tfhe/src/core_crypto/algorithms/ggsw_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont, Outp
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

for (level_index, (mut level_matrix, mut generator)) in
for (output_index, (mut level_matrix, mut generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -269,11 +270,12 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont,
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
|(output_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -401,12 +403,13 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

for (level_index, (mut level_matrix, mut loop_generator)) in
for (output_index, (mut level_matrix, mut loop_generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -581,11 +584,12 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
|(output_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -881,7 +885,7 @@ where
glwe_secret_key.glwe_dimension()
);

let level_matrix = ggsw_ciphertext.last().unwrap();
let level_matrix = ggsw_ciphertext.first().unwrap();
let level_matrix_as_glwe_list = level_matrix.as_glwe_list();
let last_row = level_matrix_as_glwe_list.last().unwrap();
let decomp_level = ggsw_ciphertext.decomposition_level_count();
Expand Down
13 changes: 5 additions & 8 deletions tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -305,8 +304,7 @@ pub fn keyswitch_lwe_ciphertext_other_mod<Scalar, KSKCont, InputCont, OutputCont
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -438,8 +436,7 @@ pub fn keyswitch_lwe_ciphertext_with_scalar_change<
{
let decomposition_iter = input_decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -802,7 +799,7 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible<
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
buffer.as_mut(),
Expand Down Expand Up @@ -949,7 +946,7 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count_other_mod<
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
buffer.as_mut(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
// We fill the buffer with the powers of the key elements
for (level, message) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(decomposition_plaintexts_buffer.iter_mut())
{
// Here we take the decomposition term from the native torus, bring it to the torus we
Expand Down Expand Up @@ -234,6 +235,7 @@ pub fn generate_lwe_keyswitch_key_other_mod<
// We fill the buffer with the powers of the key elements
for (level, message) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(decomposition_plaintexts_buffer.iter_mut())
{
// Here we take the decomposition term from the native torus, bring it to the torus we
Expand Down Expand Up @@ -415,6 +417,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(decomposition_plaintexts_buffer.iter_mut())
{
// Here we take the decomposition term from the native torus, bring it to the torus we
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub fn keyswitch_lwe_ciphertext_into_glwe_ciphertext<Scalar, KeyCont, InputCont,
// Loop over the number of levels:
// We compute the multiplication of a ciphertext from the private functional
// keyswitching key with a piece of the decomposition and subtract it to the buffer
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().rev().zip(decomp) {
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().zip(decomp) {
slice_wrapping_sub_scalar_mul_assign(
output_glwe_ciphertext.as_mut(),
level_key_cipher.as_ref(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub fn generate_lwe_packing_keyswitch_key<
// We fill the buffer with the powers of the key elements
for (level, mut messages) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(decomposition_plaintexts_buffer.chunks_exact_mut(polynomial_size.0))
{
// Here we take the decomposition term from the native torus, bring it to the torus we
Expand Down Expand Up @@ -330,6 +331,7 @@ pub fn generate_seeded_lwe_packing_keyswitch_key<
// We fill the buffer with the powers of the key elements
for (level, mut messages) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(decomposition_plaintexts_buffer.chunks_exact_mut(polynomial_size.0))
{
// Here we take the decomposition term from the native torus, bring it to the torus we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub fn private_functional_keyswitch_lwe_ciphertext_into_glwe_ciphertext<
// Loop over the number of levels:
// We compute the multiplication of a ciphertext from the private functional
// keyswitching key with a piece of the decomposition and subtract it to the buffer
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().rev().zip(decomp) {
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().zip(decomp) {
slice_wrapping_sub_scalar_mul_assign(
output_glwe_ciphertext.as_mut(),
level_key_cipher.as_ref(),
Expand Down Expand Up @@ -208,7 +208,7 @@ pub fn par_private_functional_keyswitch_lwe_ciphertext_into_glwe_ciphertext_with
// Loop over the number of levels:
// We compute the multiplication of a ciphertext from the private functional
// keyswitching key with a piece of the decomposition and subtract it to the buffer
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().rev().zip(decomp) {
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().zip(decomp) {
slice_wrapping_sub_scalar_mul_assign(
glwe_buffer.as_mut(),
level_key_cipher.as_ref(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub fn generate_lwe_private_functional_packing_keyswitch_key<
// We fill the buffer with the powers of the key bits
for (level, mut message) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(messages.chunks_exact_mut(polynomial_size.0))
{
slice_wrapping_add_scalar_mul_assign(
Expand Down Expand Up @@ -219,6 +220,7 @@ pub fn par_generate_lwe_private_functional_packing_keyswitch_key<
// We fill the buffer with the powers of the key bits
for (level, mut message) in (1..=decomp_level_count.0)
.map(DecompositionLevel)
.rev()
.zip(messages.chunks_exact_mut(polynomial_size.0))
{
slice_wrapping_add_scalar_mul_assign(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ pub(crate) fn add_external_product_ntt64_assign<InputGlweCont>(
);

// We loop through the levels (we reverse to match the order of the decomposition iterator.)
ggsw.into_levels().rev().for_each(|ggsw_decomp_matrix| {
ggsw.into_levels().for_each(|ggsw_decomp_matrix| {
// We retrieve the decomposition of this level.
let (glwe_level, glwe_decomp_term, mut substack2) =
decomposition.collect_next_term(&mut substack1, align);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
use tfhe_versionable::VersionsDispatch;

use crate::core_crypto::prelude::{Container, GgswCiphertext, UnsignedInteger};

impl<C: Container> Deprecable for GgswCiphertext<C>
where
C::Element: UnsignedInteger,
{
const TYPE_NAME: &'static str = "GgswCiphertext";
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}

#[derive(VersionsDispatch)]
pub enum GgswCiphertextVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(GgswCiphertext<C>),
V0(Deprecated<GgswCiphertext<C>>),
V1(GgswCiphertext<C>),
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
use tfhe_versionable::VersionsDispatch;

use crate::core_crypto::prelude::{Container, GgswCiphertextList, UnsignedInteger};

impl<C: Container> Deprecable for GgswCiphertextList<C>
where
C::Element: UnsignedInteger,
{
const TYPE_NAME: &'static str = "GgswCiphertextList";
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}

#[derive(VersionsDispatch)]
pub enum GgswCiphertextListVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(GgswCiphertextList<C>),
V0(Deprecated<GgswCiphertextList<C>>),
V1(GgswCiphertextList<C>),
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
use tfhe_versionable::VersionsDispatch;

use crate::core_crypto::prelude::{Container, LweBootstrapKey, UnsignedInteger};

impl<C: Container> Deprecable for LweBootstrapKey<C>
where
C::Element: UnsignedInteger,
{
const TYPE_NAME: &'static str = "LweBootstrapKey";
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}

#[derive(VersionsDispatch)]
pub enum LweBootstrapKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(LweBootstrapKey<C>),
V0(Deprecated<LweBootstrapKey<C>>),
V1(LweBootstrapKey<C>),
}
Original file line number Diff line number Diff line change
@@ -1,57 +1,22 @@
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
use tfhe_versionable::VersionsDispatch;

use crate::core_crypto::prelude::{
CiphertextModulus, Container, ContainerMut, ContiguousEntityContainerMut, DecompositionBaseLog,
DecompositionLevelCount, LweKeyswitchKey, LweSize, UnsignedInteger,
};
use crate::core_crypto::prelude::{Container, LweKeyswitchKey, UnsignedInteger};

#[derive(Version)]
pub struct LweKeyswitchKeyV0<C: Container>
impl<C: Container> Deprecable for LweKeyswitchKey<C>
where
C::Element: UnsignedInteger,
{
data: C,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
output_lwe_size: LweSize,
ciphertext_modulus: CiphertextModulus<C::Element>,
}

impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> Upgrade<LweKeyswitchKey<C>>
for LweKeyswitchKeyV0<C>
{
type Error = std::convert::Infallible;

fn upgrade(self) -> Result<LweKeyswitchKey<C>, Self::Error> {
let Self {
data,
decomp_base_log,
decomp_level_count,
output_lwe_size,
ciphertext_modulus,
} = self;
let mut new_ksk = LweKeyswitchKey::from_container(
data,
decomp_base_log,
decomp_level_count,
output_lwe_size,
ciphertext_modulus,
);

// Invert levels
for mut ksk_block in new_ksk.iter_mut() {
ksk_block.reverse();
}

Ok(new_ksk)
}
const TYPE_NAME: &'static str = "LweKeyswitchKey";
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}

#[derive(VersionsDispatch)]
pub enum LweKeyswitchKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(LweKeyswitchKeyV0<C>),
V1(LweKeyswitchKey<C>),
V0(Deprecated<LweKeyswitchKey<C>>),
V1(Deprecated<LweKeyswitchKey<C>>),
V2(LweKeyswitchKey<C>),
}
Loading

0 comments on commit 615ed3d

Please sign in to comment.