From c0d98394fa63fce8a136924342fdf4b71fcc5681 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:49:45 +0200 Subject: [PATCH] refactor(integer): add compression key types --- .../backward_compatibility/keys.rs | 63 ++++++++++++++++-- tfhe/src/high_level_api/keys/client.rs | 2 +- tfhe/src/high_level_api/keys/inner.rs | 17 +++-- tfhe/src/high_level_api/keys/server.rs | 4 +- .../list_compression.rs | 30 +++++++++ .../src/integer/backward_compatibility/mod.rs | 1 + .../ciphertext/compressed_ciphertext_list.rs | 8 ++- tfhe/src/integer/client_key/mod.rs | 16 +++-- tfhe/src/integer/client_key/radix.rs | 14 ++-- tfhe/src/integer/compression_keys.rs | 66 +++++++++++++++++++ tfhe/src/integer/gpu/client_key/radix.rs | 12 ++-- .../compressed_server_keys.rs | 18 +++-- .../gpu/list_compression/server_keys.rs | 8 +-- tfhe/src/integer/mod.rs | 1 + 14 files changed, 215 insertions(+), 45 deletions(-) create mode 100644 tfhe/src/integer/backward_compatibility/list_compression.rs create mode 100644 tfhe/src/integer/compression_keys.rs diff --git a/tfhe/src/high_level_api/backward_compatibility/keys.rs b/tfhe/src/high_level_api/backward_compatibility/keys.rs index 4cefe99bcb..ae9b2e9dee 100644 --- a/tfhe/src/high_level_api/backward_compatibility/keys.rs +++ b/tfhe/src/high_level_api/backward_compatibility/keys.rs @@ -211,14 +211,35 @@ impl Upgrade for IntegerClientKeyV0 { } } -impl Upgrade for IntegerClientKeyV1 { +impl Upgrade for IntegerClientKeyV1 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(IntegerClientKeyV2 { + key: self.key, + dedicated_compact_private_key: self.dedicated_compact_private_key, + compression_key: self.compression_key, + }) + } +} + +#[derive(Version)] +pub(crate) struct IntegerClientKeyV2 { + pub(crate) key: crate::integer::ClientKey, + pub(crate) dedicated_compact_private_key: Option, + pub(crate) compression_key: Option, +} + +impl Upgrade for IntegerClientKeyV2 { type Error = Infallible; fn upgrade(self) -> Result { Ok(IntegerClientKey { key: self.key, dedicated_compact_private_key: self.dedicated_compact_private_key, - compression_key: self.compression_key, + compression_key: self + .compression_key + .map(|key| crate::integer::compression_keys::CompressionPrivateKeys { key }), }) } } @@ -228,7 +249,8 @@ impl Upgrade for IntegerClientKeyV1 { pub(crate) enum IntegerClientKeyVersions { V0(IntegerClientKeyV0), V1(IntegerClientKeyV1), - V2(IntegerClientKey), + V2(IntegerClientKeyV2), + V3(IntegerClientKey), } #[derive(Version)] @@ -261,11 +283,11 @@ impl Upgrade for IntegerServerKeyV0 { } } -impl Upgrade for IntegerServerKeyV1 { +impl Upgrade for IntegerServerKeyV1 { type Error = Infallible; - fn upgrade(self) -> Result { - Ok(IntegerServerKey { + fn upgrade(self) -> Result { + Ok(IntegerServerKeyV2 { key: self.key, cpk_key_switching_key_material: self.cpk_key_switching_key_material, compression_key: self.compression_key, @@ -274,11 +296,38 @@ impl Upgrade for IntegerServerKeyV1 { } } +#[derive(Version)] +pub struct IntegerServerKeyV2 { + pub(crate) key: crate::integer::ServerKey, + pub(crate) cpk_key_switching_key_material: + Option, + pub(crate) compression_key: Option, + pub(crate) decompression_key: Option, +} + +impl Upgrade for IntegerServerKeyV2 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(IntegerServerKey { + key: self.key, + cpk_key_switching_key_material: self.cpk_key_switching_key_material, + compression_key: self + .compression_key + .map(|key| crate::integer::compression_keys::CompressionKey { key }), + decompression_key: self + .decompression_key + .map(|key| crate::integer::compression_keys::DecompressionKey { key }), + }) + } +} + #[derive(VersionsDispatch)] pub enum IntegerServerKeyVersions { V0(IntegerServerKeyV0), V1(IntegerServerKeyV1), - V2(IntegerServerKey), + V2(IntegerServerKeyV2), + V3(IntegerServerKey), } #[derive(Version)] diff --git a/tfhe/src/high_level_api/keys/client.rs b/tfhe/src/high_level_api/keys/client.rs index 8ed6d68695..60c97cc50a 100644 --- a/tfhe/src/high_level_api/keys/client.rs +++ b/tfhe/src/high_level_api/keys/client.rs @@ -6,8 +6,8 @@ use super::{CompressedServerKey, ServerKey}; use crate::high_level_api::backward_compatibility::keys::ClientKeyVersions; use crate::high_level_api::config::Config; use crate::high_level_api::keys::{CompactPrivateKey, IntegerClientKey}; +use crate::integer::compression_keys::CompressionPrivateKeys; use crate::prelude::Tagged; -use crate::shortint::list_compression::CompressionPrivateKeys; use crate::shortint::MessageModulus; use crate::Tag; use concrete_csprng::seeders::Seed; diff --git a/tfhe/src/high_level_api/keys/inner.rs b/tfhe/src/high_level_api/keys/inner.rs index 09f1977d4c..9763559160 100644 --- a/tfhe/src/high_level_api/keys/inner.rs +++ b/tfhe/src/high_level_api/keys/inner.rs @@ -1,12 +1,12 @@ use crate::core_crypto::commons::generators::DeterministicSeeder; use crate::core_crypto::prelude::ActivatedRandomGenerator; use crate::high_level_api::backward_compatibility::keys::*; -use crate::integer::public_key::CompactPublicKey; -use crate::integer::CompressedCompactPublicKey; -use crate::shortint::list_compression::{ +use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, CompressionPrivateKeys, DecompressionKey, }; +use crate::integer::public_key::CompactPublicKey; +use crate::integer::CompressedCompactPublicKey; use crate::shortint::parameters::list_compression::CompressionParameters; use crate::shortint::MessageModulus; use crate::Error; @@ -103,11 +103,11 @@ impl IntegerClientKey { let cks = crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder) .new_client_key(config.block_parameters.into()); + let key = crate::integer::ClientKey::from(cks); + let compression_key = config .compression_parameters - .map(|params| cks.new_compression_private_key(params)); - - let key = crate::integer::ClientKey::from(cks); + .map(|params| key.new_compression_private_key(params)); let dedicated_compact_private_key = config .dedicated_compact_public_key_parameters @@ -193,7 +193,7 @@ impl From for IntegerClientKey { let compression_key = config .compression_parameters - .map(|params| key.key.new_compression_private_key(params)); + .map(|params| key.new_compression_private_key(params)); Self { key, @@ -225,7 +225,7 @@ impl IntegerServerKey { || (None, None), |a| { let (compression_key, decompression_key) = - cks.key.new_compression_decompression_keys(a); + cks.new_compression_decompression_keys(a); (Some(compression_key), Some(decompression_key)) }, ); @@ -312,7 +312,6 @@ impl IntegerCompressedServerKey { .as_ref() .map_or((None, None), |compression_private_key| { let (compression_keys, decompression_keys) = client_key - .key .key .new_compressed_compression_decompression_keys(compression_private_key); diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index e911ca9677..340485de79 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -5,10 +5,10 @@ use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKey #[cfg(feature = "gpu")] use crate::core_crypto::gpu::{synchronize_devices, CudaStreams}; use crate::high_level_api::keys::{IntegerCompressedServerKey, IntegerServerKey}; -use crate::prelude::Tagged; -use crate::shortint::list_compression::{ +use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, DecompressionKey, }; +use crate::prelude::Tagged; use crate::shortint::MessageModulus; use crate::Tag; use std::sync::Arc; diff --git a/tfhe/src/integer/backward_compatibility/list_compression.rs b/tfhe/src/integer/backward_compatibility/list_compression.rs new file mode 100644 index 0000000000..8ddc63a5d1 --- /dev/null +++ b/tfhe/src/integer/backward_compatibility/list_compression.rs @@ -0,0 +1,30 @@ +use crate::integer::compression_keys::{ + CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, CompressionPrivateKeys, + DecompressionKey, +}; +use tfhe_versionable::VersionsDispatch; + +#[derive(VersionsDispatch)] +pub enum CompressionKeyVersions { + V0(CompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum DecompressionKeyVersions { + V0(DecompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressedCompressionKeyVersions { + V0(CompressedCompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressedDecompressionKeyVersions { + V0(CompressedDecompressionKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressionPrivateKeysVersions { + V0(CompressionPrivateKeys), +} diff --git a/tfhe/src/integer/backward_compatibility/mod.rs b/tfhe/src/integer/backward_compatibility/mod.rs index f5d0f6ba14..fba2e6ebee 100644 --- a/tfhe/src/integer/backward_compatibility/mod.rs +++ b/tfhe/src/integer/backward_compatibility/mod.rs @@ -3,6 +3,7 @@ pub mod ciphertext; pub mod client_key; pub mod key_switching_key; +pub mod list_compression; pub mod public_key; pub mod server_key; pub mod wopbs; diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 72b81e7926..482db04c4f 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -1,8 +1,8 @@ use super::{DataKind, Expandable, RadixCiphertext, SignedRadixCiphertext}; use crate::integer::backward_compatibility::ciphertext::CompressedCiphertextListVersions; +use crate::integer::compression_keys::{CompressionKey, DecompressionKey}; use crate::integer::BooleanBlock; use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList; -use crate::shortint::list_compression::{CompressionKey, DecompressionKey}; use crate::shortint::Ciphertext; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -84,7 +84,9 @@ impl CompressedCiphertextListBuilder { } pub fn build(&self, comp_key: &CompressionKey) -> CompressedCiphertextList { - let packed_list = comp_key.compress_ciphertexts_into_list(&self.ciphertexts); + let packed_list = comp_key + .key + .compress_ciphertexts_into_list(&self.ciphertexts); CompressedCiphertextList { packed_list, @@ -128,7 +130,7 @@ impl CompressedCiphertextList { Some(( (start_block_index..end_block_index) .into_par_iter() - .map(|i| decomp_key.unpack(&self.packed_list, i).unwrap()) + .map(|i| decomp_key.key.unpack(&self.packed_list, i).unwrap()) .collect(), current_info, )) diff --git a/tfhe/src/integer/client_key/mod.rs b/tfhe/src/integer/client_key/mod.rs index 9d5ba26245..9752ab3f27 100644 --- a/tfhe/src/integer/client_key/mod.rs +++ b/tfhe/src/integer/client_key/mod.rs @@ -20,9 +20,9 @@ use crate::integer::block_decomposition::BlockRecomposer; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::{CompressedCrtCiphertext, CrtCiphertext}; use crate::integer::client_key::utils::i_crt; +use crate::integer::compression_keys::{CompressionKey, CompressionPrivateKeys, DecompressionKey}; use crate::integer::encryption::{encrypt_crt, encrypt_words_radix_impl}; use crate::shortint::ciphertext::Degree; -use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys, DecompressionKey}; use crate::shortint::parameters::{CompressionParameters, MessageModulus}; use crate::shortint::{ Ciphertext, ClientKey as ShortintClientKey, ShortintParameterSet as ShortintParameters, @@ -720,14 +720,22 @@ impl ClientKey { &self, params: CompressionParameters, ) -> CompressionPrivateKeys { - self.key.new_compression_private_key(params) + CompressionPrivateKeys { + key: self.key.new_compression_private_key(params), + } } pub fn new_compression_decompression_keys( &self, private_compression_key: &CompressionPrivateKeys, ) -> (CompressionKey, DecompressionKey) { - self.key - .new_compression_decompression_keys(private_compression_key) + let (comp_key, decomp_key) = self + .key + .new_compression_decompression_keys(&private_compression_key.key); + + ( + CompressionKey { key: comp_key }, + DecompressionKey { key: decomp_key }, + ) } } diff --git a/tfhe/src/integer/client_key/radix.rs b/tfhe/src/integer/client_key/radix.rs index baa1f0a4f9..0ad5777ede 100644 --- a/tfhe/src/integer/client_key/radix.rs +++ b/tfhe/src/integer/client_key/radix.rs @@ -5,10 +5,10 @@ use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric}; use crate::integer::backward_compatibility::client_key::RadixClientKeyVersions; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; use crate::integer::ciphertext::{RadixCiphertext, SignedRadixCiphertext}; -use crate::integer::BooleanBlock; -use crate::shortint::list_compression::{ +use crate::integer::compression_keys::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionPrivateKeys, }; +use crate::integer::BooleanBlock; use crate::shortint::{Ciphertext as ShortintCiphertext, PBSParameters as ShortintParameters}; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -139,9 +139,15 @@ impl RadixClientKey { &self, private_compression_key: &CompressionPrivateKeys, ) -> (CompressedCompressionKey, CompressedDecompressionKey) { - self.key + let (comp_key, decomp_key) = self .key - .new_compressed_compression_decompression_keys(private_compression_key) + .key + .new_compressed_compression_decompression_keys(&private_compression_key.key); + + ( + CompressedCompressionKey { key: comp_key }, + CompressedDecompressionKey { key: decomp_key }, + ) } } diff --git a/tfhe/src/integer/compression_keys.rs b/tfhe/src/integer/compression_keys.rs new file mode 100644 index 0000000000..3051578b17 --- /dev/null +++ b/tfhe/src/integer/compression_keys.rs @@ -0,0 +1,66 @@ +use super::ClientKey; +use crate::integer::backward_compatibility::list_compression::*; +use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressionPrivateKeysVersions)] +pub struct CompressionPrivateKeys { + pub(crate) key: crate::shortint::list_compression::CompressionPrivateKeys, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressionKeyVersions)] +pub struct CompressionKey { + pub(crate) key: crate::shortint::list_compression::CompressionKey, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(DecompressionKeyVersions)] +pub struct DecompressionKey { + pub(crate) key: crate::shortint::list_compression::DecompressionKey, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressedCompressionKeyVersions)] +pub struct CompressedCompressionKey { + pub(crate) key: crate::shortint::list_compression::CompressedCompressionKey, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Versionize)] +#[versionize(CompressedDecompressionKeyVersions)] +pub struct CompressedDecompressionKey { + pub(crate) key: crate::shortint::list_compression::CompressedDecompressionKey, +} + +impl CompressedCompressionKey { + pub fn decompress(&self) -> CompressionKey { + CompressionKey { + key: self.key.decompress(), + } + } +} + +impl CompressedDecompressionKey { + pub fn decompress(&self) -> DecompressionKey { + DecompressionKey { + key: self.key.decompress(), + } + } +} + +impl ClientKey { + pub fn new_compressed_compression_decompression_keys( + &self, + private_compression_key: &CompressionPrivateKeys, + ) -> (CompressedCompressionKey, CompressedDecompressionKey) { + let (comp_key, decomp_key) = self + .key + .new_compressed_compression_decompression_keys(&private_compression_key.key); + + ( + CompressedCompressionKey { key: comp_key }, + CompressedDecompressionKey { key: decomp_key }, + ) + } +} diff --git a/tfhe/src/integer/gpu/client_key/radix.rs b/tfhe/src/integer/gpu/client_key/radix.rs index 96a5ef6310..fea28f0063 100644 --- a/tfhe/src/integer/gpu/client_key/radix.rs +++ b/tfhe/src/integer/gpu/client_key/radix.rs @@ -4,13 +4,13 @@ use crate::core_crypto::prelude::{ allocate_and_generate_new_lwe_packing_keyswitch_key, par_generate_lwe_bootstrap_key, LweBootstrapKey, }; +use crate::integer::compression_keys::{CompressionKey, CompressionPrivateKeys}; use crate::integer::gpu::list_compression::server_keys::{ CudaCompressionKey, CudaDecompressionKey, }; use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::RadixClientKey; use crate::shortint::engine::ShortintEngine; -use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys}; use crate::shortint::{ClassicPBSParameters, EncryptionKeyChoice, PBSParameters}; impl RadixClientKey { @@ -19,6 +19,8 @@ impl RadixClientKey { private_compression_key: &CompressionPrivateKeys, streams: &CudaStreams, ) -> (CudaCompressionKey, CudaDecompressionKey) { + let private_compression_key = &private_compression_key.key; + let cks_params: ClassicPBSParameters = match self.parameters() { PBSParameters::PBS(a) => a, PBSParameters::MultiBitPBS(_) => { @@ -47,9 +49,11 @@ impl RadixClientKey { }); let glwe_compression_key = CompressionKey { - packing_key_switching_key, - lwe_per_glwe: params.lwe_per_glwe, - storage_log_modulus: private_compression_key.params.storage_log_modulus, + key: crate::shortint::list_compression::CompressionKey { + packing_key_switching_key, + lwe_per_glwe: params.lwe_per_glwe, + storage_log_modulus: private_compression_key.params.storage_log_modulus, + }, }; let cuda_compression_key = diff --git a/tfhe/src/integer/gpu/list_compression/compressed_server_keys.rs b/tfhe/src/integer/gpu/list_compression/compressed_server_keys.rs index 0a89c6efba..450ffb9dbf 100644 --- a/tfhe/src/integer/gpu/list_compression/compressed_server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/compressed_server_keys.rs @@ -1,12 +1,12 @@ use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey; use crate::core_crypto::gpu::CudaStreams; +use crate::integer::compression_keys::{ + CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, +}; use crate::integer::gpu::list_compression::server_keys::{ CudaCompressionKey, CudaDecompressionKey, }; use crate::integer::gpu::server_key::CudaBootstrappingKey; -use crate::shortint::list_compression::{ - CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, -}; use crate::shortint::PBSParameters; impl CompressedDecompressionKey { @@ -16,6 +16,7 @@ impl CompressedDecompressionKey { streams: &CudaStreams, ) -> CudaDecompressionKey { let h_bootstrap_key = self + .key .blind_rotate_key .as_view() .par_decompress_into_lwe_bootstrap_key(); @@ -27,7 +28,7 @@ impl CompressedDecompressionKey { CudaDecompressionKey { blind_rotate_key, - lwe_per_glwe: self.lwe_per_glwe, + lwe_per_glwe: self.key.lwe_per_glwe, parameters, } } @@ -36,14 +37,17 @@ impl CompressedDecompressionKey { impl CompressedCompressionKey { pub fn decompress_to_cuda(&self, streams: &CudaStreams) -> CudaCompressionKey { let packing_key_switching_key = self + .key .packing_key_switching_key .as_view() .decompress_into_lwe_packing_keyswitch_key(); let glwe_compression_key = CompressionKey { - packing_key_switching_key, - lwe_per_glwe: self.lwe_per_glwe, - storage_log_modulus: self.storage_log_modulus, + key: crate::shortint::list_compression::CompressionKey { + packing_key_switching_key, + lwe_per_glwe: self.key.lwe_per_glwe, + storage_log_modulus: self.key.storage_log_modulus, + }, }; CudaCompressionKey::from_compression_key(&glwe_compression_key, streams) diff --git a/tfhe/src/integer/gpu/list_compression/server_keys.rs b/tfhe/src/integer/gpu/list_compression/server_keys.rs index bd7d2e12d5..e0623564ad 100644 --- a/tfhe/src/integer/gpu/list_compression/server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/server_keys.rs @@ -4,13 +4,13 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::{CiphertextModulusLog, GlweCiphertextCount, LweCiphertextCount}; +use crate::integer::compression_keys::CompressionKey; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; use crate::integer::gpu::ciphertext::CudaRadixCiphertext; use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::{ compress_integer_radix_async, cuda_memcpy_async_gpu_to_gpu, decompress_integer_radix_async, }; -use crate::shortint::list_compression::CompressionKey; use crate::shortint::PBSParameters; use itertools::Itertools; @@ -38,11 +38,11 @@ impl CudaCompressionKey { pub fn from_compression_key(compression_key: &CompressionKey, streams: &CudaStreams) -> Self { Self { packing_key_switching_key: CudaLwePackingKeyswitchKey::from_lwe_packing_keyswitch_key( - &compression_key.packing_key_switching_key, + &compression_key.key.packing_key_switching_key, streams, ), - lwe_per_glwe: compression_key.lwe_per_glwe, - storage_log_modulus: compression_key.storage_log_modulus, + lwe_per_glwe: compression_key.key.lwe_per_glwe, + storage_log_modulus: compression_key.key.storage_log_modulus, } } diff --git a/tfhe/src/integer/mod.rs b/tfhe/src/integer/mod.rs index dc8b6ffb1a..edad81f108 100755 --- a/tfhe/src/integer/mod.rs +++ b/tfhe/src/integer/mod.rs @@ -55,6 +55,7 @@ pub mod backward_compatibility; pub mod bigint; pub mod ciphertext; pub mod client_key; +pub mod compression_keys; pub mod key_switching_key; #[cfg(any(test, feature = "internal-keycache"))] pub mod keycache;