From da069ad2d492797af7bf559670d710edc3164e05 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 2 Sep 2024 11:06:34 +0200 Subject: [PATCH] feat(hlapi): add tag system Tag The `Tag` allows to store bytes alongside of entities (keys, and ciphertext) the main purpose of this system is to `tag` / identify ciphertext with their keys. * When encrypted, a ciphertext gets the tag of the key used to encrypt it. * Ciphertexts resulting from operations (add, sub, etc.) get the tag from the ServerKey used * PublicKey gets its tag from the ClientKey that was used to create it * ServerKey gets its tag from the ClientKey that was used to create it User can change the tag of any entities at any point. BREAKING CHANGE: Many of the into_raw_parts and from_raw_parts changed to accommodate the addition of the `tag`` --- tfhe/src/high_level_api/array.rs | 4 +- .../backward_compatibility/booleans.rs | 72 ++- .../backward_compatibility/compact_list.rs | 22 +- .../compressed_ciphertext_list.rs | 22 +- .../backward_compatibility/integers.rs | 133 +++++- .../backward_compatibility/keys.rs | 139 +++++- .../backward_compatibility/mod.rs | 1 + .../backward_compatibility/tag.rs | 7 + tfhe/src/high_level_api/booleans/base.rs | 227 +++++---- .../src/high_level_api/booleans/compressed.rs | 73 ++- tfhe/src/high_level_api/booleans/encrypt.rs | 20 +- tfhe/src/high_level_api/booleans/mod.rs | 3 +- tfhe/src/high_level_api/compact_list.rs | 131 +++-- .../compressed_ciphertext_list.rs | 43 +- tfhe/src/high_level_api/global_state.rs | 17 +- tfhe/src/high_level_api/integers/oprf.rs | 43 +- .../high_level_api/integers/signed/base.rs | 87 ++-- .../integers/signed/compressed.rs | 70 ++- .../high_level_api/integers/signed/encrypt.rs | 17 +- .../src/high_level_api/integers/signed/mod.rs | 2 +- .../src/high_level_api/integers/signed/ops.rs | 95 ++-- .../integers/signed/overflowing_ops.rs | 55 ++- .../integers/signed/scalar_ops.rs | 37 +- .../high_level_api/integers/unsigned/base.rs | 87 ++-- .../integers/unsigned/compressed.rs | 68 ++- .../integers/unsigned/encrypt.rs | 10 +- .../high_level_api/integers/unsigned/mod.rs | 2 +- .../high_level_api/integers/unsigned/ops.rs | 131 ++--- .../integers/unsigned/overflowing_ops.rs | 43 +- .../integers/unsigned/scalar_ops.rs | 73 +-- tfhe/src/high_level_api/keys/client.rs | 21 +- .../high_level_api/keys/key_switching_key.rs | 12 +- tfhe/src/high_level_api/keys/public.rs | 88 +++- tfhe/src/high_level_api/keys/server.rs | 76 ++- tfhe/src/high_level_api/mod.rs | 3 + tfhe/src/high_level_api/prelude.rs | 2 +- tfhe/src/high_level_api/tag.rs | 449 ++++++++++++++++++ tfhe/src/high_level_api/tests/mod.rs | 2 + .../high_level_api/tests/tags_on_entities.rs | 324 +++++++++++++ tfhe/src/high_level_api/traits.rs | 8 +- tfhe/src/high_level_api/utils.rs | 25 +- .../backward_compatibility/high_level_api.rs | 2 +- 42 files changed, 2190 insertions(+), 556 deletions(-) create mode 100644 tfhe/src/high_level_api/backward_compatibility/tag.rs create mode 100644 tfhe/src/high_level_api/tag.rs create mode 100644 tfhe/src/high_level_api/tests/tags_on_entities.rs diff --git a/tfhe/src/high_level_api/array.rs b/tfhe/src/high_level_api/array.rs index 6bfc125ed9..9045c552be 100644 --- a/tfhe/src/high_level_api/array.rs +++ b/tfhe/src/high_level_api/array.rs @@ -16,7 +16,7 @@ pub fn fhe_uint_array_eq(lhs: &[FheUint], rhs: &[FheUint] let result = cpu_keys .pbs_key() .all_eq_slices_parallelized(&tmp_lhs, &tmp_rhs); - FheBool::new(result) + FheBool::new(result, cpu_keys.tag.clone()) }) } @@ -37,6 +37,6 @@ pub fn fhe_uint_array_contains_sub_slice( let result = cpu_keys .pbs_key() .contains_sub_slice_parallelized(&tmp_lhs, &tmp_pattern); - FheBool::new(result) + FheBool::new(result, cpu_keys.tag.clone()) }) } diff --git a/tfhe/src/high_level_api/backward_compatibility/booleans.rs b/tfhe/src/high_level_api/backward_compatibility/booleans.rs index 1b961b835d..2605b89659 100644 --- a/tfhe/src/high_level_api/backward_compatibility/booleans.rs +++ b/tfhe/src/high_level_api/backward_compatibility/booleans.rs @@ -1,11 +1,16 @@ #![allow(deprecated)] use serde::{Deserialize, Serialize}; -use tfhe_versionable::{Versionize, VersionsDispatch}; +use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch}; -use crate::high_level_api::booleans::InnerBooleanVersionOwned; +use crate::high_level_api::booleans::{ + InnerBoolean, InnerBooleanVersionOwned, InnerCompressedFheBool, +}; use crate::integer::ciphertext::{CompactCiphertextList, DataKind}; -use crate::{CompactCiphertextList as HlCompactCiphertextList, CompressedFheBool, Error, FheBool}; +use crate::{ + CompactCiphertextList as HlCompactCiphertextList, CompressedFheBool, Error, FheBool, Tag, +}; +use std::convert::Infallible; // Manual impl #[derive(Serialize, Deserialize)] @@ -14,9 +19,26 @@ pub(crate) enum InnerBooleanVersionedOwned { V0(InnerBooleanVersionOwned), } +#[derive(Version)] +pub struct FheBoolV0 { + pub(in crate::high_level_api) ciphertext: InnerBoolean, +} + +impl Upgrade for FheBoolV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(FheBool { + ciphertext: self.ciphertext, + tag: Tag::default(), + }) + } +} + #[derive(VersionsDispatch)] pub enum FheBoolVersions { - V0(FheBool), + V0(FheBoolV0), + V1(FheBool), } #[derive(VersionsDispatch)] @@ -24,9 +46,33 @@ pub enum CompactFheBoolVersions { V0(CompactFheBool), } +#[derive(VersionsDispatch)] +pub enum InnerCompressedFheBoolVersions { + V0(InnerCompressedFheBool), +} + +// Before V1 where we added the Tag, the CompressedFheBool +// was simply the inner enum +type CompressedFheBoolV0 = InnerCompressedFheBool; + +impl Upgrade for CompressedFheBoolV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompressedFheBool { + inner: match self { + Self::Seeded(s) => Self::Seeded(s), + Self::ModulusSwitched(m) => Self::ModulusSwitched(m), + }, + tag: Tag::default(), + }) + } +} + #[derive(VersionsDispatch)] pub enum CompressedFheBoolVersions { - V0(CompressedFheBool), + V0(CompressedFheBoolV0), + V1(CompressedFheBool), } #[derive(VersionsDispatch)] @@ -56,14 +102,18 @@ impl CompactFheBool { .iter_mut() .for_each(|info| *info = DataKind::Boolean); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let block = list + .inner .get::(0) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - let mut ciphertext = FheBool::new(block); + let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) } @@ -86,17 +136,21 @@ impl CompactFheBoolList { .iter_mut() .for_each(|info| *info = DataKind::Boolean); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let len = list.len(); (0..len) .map(|idx| { let block = list + .inner .get::(idx) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - let mut ciphertext = FheBool::new(block); + let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) }) diff --git a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs index d531b094e3..b6f3bdf4f4 100644 --- a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs +++ b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs @@ -1,8 +1,24 @@ -use tfhe_versionable::VersionsDispatch; +use std::convert::Infallible; +use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; -use crate::CompactCiphertextList; +use crate::{CompactCiphertextList, Tag}; + +#[derive(Version)] +pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList); + +impl Upgrade for CompactCiphertextListV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompactCiphertextList { + inner: self.0, + tag: Tag::default(), + }) + } +} #[derive(VersionsDispatch)] pub enum CompactCiphertextListVersions { - V0(CompactCiphertextList), + V0(CompactCiphertextListV0), + V1(CompactCiphertextList), } diff --git a/tfhe/src/high_level_api/backward_compatibility/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/backward_compatibility/compressed_ciphertext_list.rs index f1087ffba0..0e81efc270 100644 --- a/tfhe/src/high_level_api/backward_compatibility/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/backward_compatibility/compressed_ciphertext_list.rs @@ -1,8 +1,24 @@ -use tfhe_versionable::VersionsDispatch; +use std::convert::Infallible; +use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; -use crate::CompressedCiphertextList; +use crate::{CompressedCiphertextList, Tag}; + +#[derive(Version)] +pub struct CompressedCiphertextListV0(crate::integer::ciphertext::CompressedCiphertextList); + +impl Upgrade for CompressedCiphertextListV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompressedCiphertextList { + inner: self.0, + tag: Tag::default(), + }) + } +} #[derive(VersionsDispatch)] pub enum CompressedCiphertextListVersions { - V0(CompressedCiphertextList), + V0(CompressedCiphertextListV0), + V1(CompressedCiphertextList), } diff --git a/tfhe/src/high_level_api/backward_compatibility/integers.rs b/tfhe/src/high_level_api/backward_compatibility/integers.rs index 6d8f92f064..de05f2b0a7 100644 --- a/tfhe/src/high_level_api/backward_compatibility/integers.rs +++ b/tfhe/src/high_level_api/backward_compatibility/integers.rs @@ -18,9 +18,12 @@ use crate::integer::ciphertext::{ }; use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext; use crate::shortint::{Ciphertext, ServerKey}; -use crate::{CompactCiphertextList as HlCompactCiphertextList, Error}; +use crate::{CompactCiphertextList as HlCompactCiphertextList, Error, Tag}; use serde::{Deserialize, Serialize}; +use self::signed::RadixCiphertext as SignedRadixCiphertext; +use self::unsigned::RadixCiphertext as UnsignedRadixCiphertext; + // Manual impl #[derive(Serialize, Deserialize)] #[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] @@ -67,11 +70,11 @@ impl Upgrade for CompressedSignedRadixCiphertex let blocks = ct .blocks .par_iter() - .map(|a| old_sk_decompress(&sk.key.key, a)) + .map(|a| old_sk_decompress(&sk.key.key.key, a)) .collect(); let radix = BaseSignedRadixCiphertext { blocks }; - sk.key + sk.pbs_key() .switch_modulus_and_compress_signed_parallelized(&radix) }); Ok(CompressedSignedRadixCiphertext::ModulusSwitched(upgraded)) @@ -105,11 +108,12 @@ impl Upgrade for CompressedRadixCiphertextV0 { let blocks = ct .blocks .par_iter() - .map(|a| old_sk_decompress(&sk.key.key, a)) + .map(|a| old_sk_decompress(&sk.key.key.key, a)) .collect(); let radix = BaseRadixCiphertext { blocks }; - sk.key.switch_modulus_and_compress_parallelized(&radix) + sk.pbs_key() + .switch_modulus_and_compress_parallelized(&radix) }); Ok(CompressedRadixCiphertext::ModulusSwitched(upgraded)) } @@ -123,9 +127,28 @@ pub enum CompressedRadixCiphertextVersions { V1(CompressedRadixCiphertext), } +#[derive(Version)] +pub struct FheIntV0 { + pub(in crate::high_level_api) ciphertext: SignedRadixCiphertext, + pub(in crate::high_level_api) id: Id, +} + +impl Upgrade> for FheIntV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(FheInt { + ciphertext: self.ciphertext, + id: self.id, + tag: Tag::default(), + }) + } +} + #[derive(VersionsDispatch)] pub enum FheIntVersions { - V0(FheInt), + V0(FheIntV0), + V1(FheInt), } #[derive(VersionsDispatch)] @@ -133,9 +156,31 @@ pub enum CompactFheIntVersions { V0(CompactFheInt), } +#[derive(Version)] +pub struct CompressedFheIntV0 +where + Id: FheIntId, +{ + pub(in crate::high_level_api) ciphertext: CompressedSignedRadixCiphertext, + pub(in crate::high_level_api) id: Id, +} + +impl Upgrade> for CompressedFheIntV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(CompressedFheInt { + ciphertext: self.ciphertext, + id: self.id, + tag: Tag::default(), + }) + } +} + #[derive(VersionsDispatch)] pub enum CompressedFheIntVersions { - V0(CompressedFheInt), + V0(CompressedFheIntV0), + V1(CompressedFheInt), } #[derive(VersionsDispatch)] @@ -143,9 +188,28 @@ pub enum CompactFheIntListVersions { V0(CompactFheIntList), } +#[derive(Version)] +pub struct FheUintV0 { + pub(in crate::high_level_api) ciphertext: UnsignedRadixCiphertext, + pub(in crate::high_level_api) id: Id, +} + +impl Upgrade> for FheUintV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(FheUint { + ciphertext: self.ciphertext, + id: self.id, + tag: Tag::default(), + }) + } +} + #[derive(VersionsDispatch)] pub enum FheUintVersions { - V0(FheUint), + V0(FheUintV0), + V1(FheUint), } #[derive(VersionsDispatch)] @@ -153,9 +217,27 @@ pub enum CompactFheUintVersions { V0(CompactFheUint), } +#[derive(Version)] +pub struct CompressedFheUintV0 +where + Id: FheUintId, +{ + pub(in crate::high_level_api) ciphertext: CompressedRadixCiphertext, + pub(in crate::high_level_api) id: Id, +} + +impl Upgrade> for CompressedFheUintV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(CompressedFheUint::new(self.ciphertext, Tag::default())) + } +} + #[derive(VersionsDispatch)] pub enum CompressedFheUintVersions { - V0(CompressedFheUint), + V0(CompressedFheUintV0), + V1(CompressedFheUint), } #[derive(VersionsDispatch)] @@ -186,13 +268,17 @@ where .info .iter_mut() .for_each(|info| *info = DataKind::Signed(info.num_blocks())); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let ct = list + .inner .get::(0) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - Ok(FheInt::new(ct)) + Ok(FheInt::new(ct, Tag::default())) } } @@ -217,7 +303,10 @@ where .iter_mut() .for_each(|info| *info = DataKind::Signed(info.num_blocks())); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let len = list.len(); @@ -225,9 +314,10 @@ where (0..len) .map(|idx| { let ct = list + .inner .get::(idx) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - Ok(FheInt::new(ct)) + Ok(FheInt::new(ct, Tag::default())) }) .collect::, _>>() } @@ -254,16 +344,19 @@ where .iter_mut() .for_each(|info| *info = DataKind::Unsigned(info.num_blocks())); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let ct = list + .inner .get::(0) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - Ok(FheUint::new(ct)) + Ok(FheUint::new(ct, Tag::default())) } } - #[derive(Clone, Versionize)] #[versionize(CompactFheUintListVersions)] #[deprecated(since = "0.7.0", note = "Use CompactCiphertextList instead")] @@ -285,7 +378,10 @@ where .iter_mut() .for_each(|info| *info = DataKind::Unsigned(info.num_blocks())); - let hl_list = HlCompactCiphertextList(self.list); + let hl_list = HlCompactCiphertextList { + inner: self.list, + tag: Tag::default(), + }; let list = hl_list.expand()?; let len = list.len(); @@ -293,9 +389,10 @@ where (0..len) .map(|idx| { let ct = list + .inner .get::(idx) .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; - Ok(FheUint::new(ct)) + Ok(FheUint::new(ct, Tag::default())) }) .collect::, _>>() } diff --git a/tfhe/src/high_level_api/backward_compatibility/keys.rs b/tfhe/src/high_level_api/backward_compatibility/keys.rs index dd358fd63d..4cefe99bcb 100644 --- a/tfhe/src/high_level_api/backward_compatibility/keys.rs +++ b/tfhe/src/high_level_api/backward_compatibility/keys.rs @@ -1,12 +1,31 @@ use crate::high_level_api::keys::*; use crate::shortint::list_compression::{CompressionKey, CompressionPrivateKeys, DecompressionKey}; +use crate::Tag; use std::convert::Infallible; use std::sync::Arc; use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; #[derive(VersionsDispatch)] pub enum ClientKeyVersions { - V0(ClientKey), + V0(ClientKeyV0), + V1(ClientKey), +} + +#[derive(Version)] +pub struct ClientKeyV0 { + pub(crate) key: IntegerClientKey, +} + +impl Upgrade for ClientKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let Self { key } = self; + Ok(ClientKey { + key, + tag: Tag::default(), + }) + } } // This type was previously versioned using a manual implementation with a conversion @@ -16,12 +35,28 @@ pub struct ServerKeyV0 { pub(crate) integer_key: Arc, } -impl Upgrade for ServerKeyV0 { +impl Upgrade for ServerKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(ServerKeyV1 { + key: self.integer_key, + }) + } +} + +#[derive(Version)] +pub struct ServerKeyV1 { + pub(crate) key: Arc, +} + +impl Upgrade for ServerKeyV1 { type Error = Infallible; fn upgrade(self) -> Result { Ok(ServerKey { - key: self.integer_key, + key: self.key, + tag: Tag::default(), }) } } @@ -29,32 +64,118 @@ impl Upgrade for ServerKeyV0 { #[derive(VersionsDispatch)] pub enum ServerKeyVersions { V0(ServerKeyV0), - V1(ServerKey), + V1(ServerKeyV1), + V2(ServerKey), +} + +#[derive(Version)] +pub struct CompressedServerKeyV0 { + pub(crate) integer_key: IntegerCompressedServerKey, +} + +impl Upgrade for CompressedServerKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompressedServerKey { + integer_key: self.integer_key, + tag: Tag::default(), + }) + } } #[derive(VersionsDispatch)] pub enum CompressedServerKeyVersions { - V0(CompressedServerKey), + V0(CompressedServerKeyV0), + V1(CompressedServerKey), +} + +#[derive(Version)] +pub struct PublicKeyV0 { + pub(in crate::high_level_api) key: crate::integer::PublicKey, +} + +impl Upgrade for PublicKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(PublicKey { + key: self.key, + tag: Tag::default(), + }) + } } #[derive(VersionsDispatch)] pub enum PublicKeyVersions { - V0(PublicKey), + V0(PublicKeyV0), + V1(PublicKey), +} + +#[derive(Version)] +pub struct CompactPublicKeyV0 { + pub(in crate::high_level_api) key: IntegerCompactPublicKey, +} + +impl Upgrade for CompactPublicKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompactPublicKey { + key: self.key, + tag: Tag::default(), + }) + } } #[derive(VersionsDispatch)] pub enum CompactPublicKeyVersions { - V0(CompactPublicKey), + V0(CompactPublicKeyV0), + V1(CompactPublicKey), +} + +#[derive(Version)] +pub struct CompressedPublicKeyV0 { + pub(in crate::high_level_api) key: crate::integer::CompressedPublicKey, +} + +impl Upgrade for CompressedPublicKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompressedPublicKey { + key: self.key, + tag: Tag::default(), + }) + } } #[derive(VersionsDispatch)] pub enum CompressedPublicKeyVersions { - V0(CompressedPublicKey), + V0(CompressedPublicKeyV0), + V1(CompressedPublicKey), +} + +#[derive(Version)] +pub struct CompressedCompactPublicKeyV0 { + pub(in crate::high_level_api) key: IntegerCompressedCompactPublicKey, +} + +impl Upgrade for CompressedCompactPublicKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + Ok(CompressedCompactPublicKey { + key: self.key, + tag: Tag::default(), + }) + } } #[derive(VersionsDispatch)] pub enum CompressedCompactPublicKeyVersions { - V0(CompressedCompactPublicKey), + V0(CompressedCompactPublicKeyV0), + V1(CompressedCompactPublicKey), } #[derive(VersionsDispatch)] diff --git a/tfhe/src/high_level_api/backward_compatibility/mod.rs b/tfhe/src/high_level_api/backward_compatibility/mod.rs index 97c49f4839..1907dfabf7 100644 --- a/tfhe/src/high_level_api/backward_compatibility/mod.rs +++ b/tfhe/src/high_level_api/backward_compatibility/mod.rs @@ -6,3 +6,4 @@ pub mod compressed_ciphertext_list; pub mod config; pub mod integers; pub mod keys; +pub mod tag; diff --git a/tfhe/src/high_level_api/backward_compatibility/tag.rs b/tfhe/src/high_level_api/backward_compatibility/tag.rs new file mode 100644 index 0000000000..728f830d4e --- /dev/null +++ b/tfhe/src/high_level_api/backward_compatibility/tag.rs @@ -0,0 +1,7 @@ +use crate::high_level_api::tag::Tag; +use tfhe_versionable::VersionsDispatch; + +#[derive(VersionsDispatch)] +pub enum TagVersions { + V0(Tag), +} diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 0723a0efe2..f66c653ceb 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -6,7 +6,7 @@ use crate::high_level_api::global_state; use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId}; use crate::high_level_api::keys::InternalServerKey; -use crate::high_level_api::traits::{FheEq, IfThenElse}; +use crate::high_level_api::traits::{FheEq, IfThenElse, Tagged}; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; #[cfg(feature = "gpu")] @@ -17,7 +17,7 @@ use crate::named::Named; use crate::shortint::ciphertext::NotTrivialCiphertextError; use crate::shortint::parameters::CiphertextConformanceParams; use crate::shortint::PBSParameters; -use crate::{Device, ServerKey}; +use crate::{Device, ServerKey, Tag}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign}; @@ -50,6 +50,7 @@ use tfhe_versionable::Versionize; #[versionize(FheBoolVersions)] pub struct FheBool { pub(in crate::high_level_api) ciphertext: InnerBoolean, + pub(crate) tag: Tag, } impl Named for FheBool { @@ -81,7 +82,7 @@ impl ParameterSetConformant for FheBool { type ParameterSet = FheBoolConformanceParams; fn is_conformant(&self, params: &FheBoolConformanceParams) -> bool { - let Self { ciphertext } = self; + let Self { ciphertext, tag: _ } = self; let BooleanBlock(block) = &*ciphertext.on_cpu(); @@ -90,9 +91,10 @@ impl ParameterSetConformant for FheBool { } impl FheBool { - pub(in crate::high_level_api) fn new>(ciphertext: T) -> Self { + pub(in crate::high_level_api) fn new>(ciphertext: T, tag: Tag) -> Self { Self { ciphertext: ciphertext.into(), + tag, } } @@ -196,7 +198,7 @@ where &*ct_then.ciphertext.on_cpu(), &*ct_else.ciphertext.on_cpu(), ); - FheUint::new(inner) + FheUint::new(inner, cpu_sks.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -207,12 +209,22 @@ where streams, ); - FheUint::new(inner) + FheUint::new(inner, cuda_key.tag.clone()) }), }) } } +impl Tagged for FheBool { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl IfThenElse> for FheBool { /// Conditional selection. /// @@ -222,38 +234,40 @@ impl IfThenElse> for FheBool { /// - if `self` is false, the output will have the value of `ct_else` fn if_then_else(&self, ct_then: &FheInt, ct_else: &FheInt) -> FheInt { let ct_condition = self; - let new_ct = global_state::with_internal_keys(|key| match key { - InternalServerKey::Cpu(key) => key.pbs_key().if_then_else_parallelized( - &ct_condition.ciphertext.on_cpu(), - &*ct_then.ciphertext.on_cpu(), - &*ct_else.ciphertext.on_cpu(), - ), + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(key) => { + let new_ct = key.pbs_key().if_then_else_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*ct_then.ciphertext.on_cpu(), + &*ct_else.ciphertext.on_cpu(), + ); + FheInt::new(new_ct, key.tag.clone()) + } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { panic!("Cuda devices do not support signed integers") } - }); - - FheInt::new(new_ct) + }) } } impl IfThenElse for FheBool { fn if_then_else(&self, ct_then: &Self, ct_else: &Self) -> Self { let ct_condition = self; - let new_ct = global_state::with_internal_keys(|key| match key { - InternalServerKey::Cpu(key) => key.pbs_key().if_then_else_parallelized( - &ct_condition.ciphertext.on_cpu(), - &*ct_then.ciphertext.on_cpu(), - &*ct_else.ciphertext.on_cpu(), - ), + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(key) => { + let new_ct = key.pbs_key().if_then_else_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*ct_then.ciphertext.on_cpu(), + &*ct_else.ciphertext.on_cpu(), + ); + Self::new(new_ct, key.tag.clone()) + } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { panic!("Cuda devices do not support signed integers") } - }); - - Self::new(new_ct) + }) } } @@ -281,13 +295,14 @@ where /// assert!(!decrypted); /// ``` fn eq(&self, other: B) -> Self { - let ciphertext = global_state::with_internal_keys(|key| match key { + global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner = key.pbs_key().key.equal( self.ciphertext.on_cpu().as_ref(), other.borrow().ciphertext.on_cpu().as_ref(), ); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)) + let ciphertext = InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)); + Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -296,10 +311,10 @@ where &other.borrow().ciphertext.on_gpu(), streams, ); - InnerBoolean::Cuda(inner) + let ciphertext = InnerBoolean::Cuda(inner); + Self::new(ciphertext, cuda_key.tag.clone()) }), - }); - Self::new(ciphertext) + }) } /// Test for difference between two [FheBool] @@ -322,13 +337,14 @@ where /// assert_eq!(decrypted, true != false); /// ``` fn ne(&self, other: B) -> Self { - let ciphertext = global_state::with_internal_keys(|key| match key { + global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner = key.pbs_key().key.not_equal( self.ciphertext.on_cpu().as_ref(), other.borrow().ciphertext.on_cpu().as_ref(), ); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)) + let ciphertext = InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)); + Self::new(ciphertext, key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -337,10 +353,10 @@ where &other.borrow().ciphertext.on_gpu(), streams, ); - InnerBoolean::Cuda(inner) + let ciphertext = InnerBoolean::Cuda(inner); + Self::new(ciphertext, cuda_key.tag.clone()) }), - }); - Self::new(ciphertext) + }) } } @@ -364,13 +380,16 @@ impl FheEq for FheBool { /// assert!(!decrypted); /// ``` fn eq(&self, other: bool) -> FheBool { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner = key .pbs_key() .key .scalar_equal(self.ciphertext.on_cpu().as_ref(), u8::from(other)); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -378,10 +397,10 @@ impl FheEq for FheBool { cuda_key .key .scalar_eq(&*self.ciphertext.on_gpu(), u8::from(other), streams); - InnerBoolean::Cuda(inner) + (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) }), }); - Self::new(ciphertext) + Self::new(ciphertext, tag) } /// Test for equality between a [FheBool] and a [bool] @@ -403,13 +422,16 @@ impl FheEq for FheBool { /// assert_eq!(decrypted, true != false); /// ``` fn ne(&self, other: bool) -> FheBool { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner = key .pbs_key() .key .scalar_not_equal(self.ciphertext.on_cpu().as_ref(), u8::from(other)); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -417,10 +439,10 @@ impl FheEq for FheBool { cuda_key .key .scalar_ne(&*self.ciphertext.on_gpu(), u8::from(other), streams); - InnerBoolean::Cuda(inner) + (InnerBoolean::Cuda(inner), cuda_key.tag.clone()) }), }); - Self::new(ciphertext) + Self::new(ciphertext, tag) } } @@ -477,12 +499,12 @@ where /// assert!(result); /// ``` fn bitand(self, rhs: B) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key .pbs_key() .boolean_bitand(&self.ciphertext.on_cpu(), &rhs.borrow().ciphertext.on_cpu()); - InnerBoolean::Cpu(inner_ct) + (InnerBoolean::Cpu(inner_ct), key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -491,12 +513,16 @@ where &rhs.borrow().ciphertext.on_gpu(), streams, ); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -554,13 +580,16 @@ where /// assert_eq!(result, true | false); /// ``` fn bitor(self, rhs: B) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key.pbs_key().key.bitor( self.ciphertext.on_cpu().as_ref(), rhs.borrow().ciphertext.on_cpu().as_ref(), ); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -569,12 +598,15 @@ where &rhs.borrow().ciphertext.on_gpu(), streams, ); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -632,13 +664,16 @@ where /// assert!(!result); /// ``` fn bitxor(self, rhs: B) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key.pbs_key().key.bitxor( self.ciphertext.on_cpu().as_ref(), rhs.borrow().ciphertext.on_cpu().as_ref(), ); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -647,12 +682,15 @@ where &rhs.borrow().ciphertext.on_gpu(), streams, ); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -702,13 +740,16 @@ impl BitAnd for &FheBool { /// assert_eq!(decrypted, true & false); /// ``` fn bitand(self, rhs: bool) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key .pbs_key() .key .scalar_bitand(self.ciphertext.on_cpu().as_ref(), u8::from(rhs)); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -716,12 +757,15 @@ impl BitAnd for &FheBool { cuda_key .key .scalar_bitand(&*self.ciphertext.on_gpu(), u8::from(rhs), streams); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -771,13 +815,16 @@ impl BitOr for &FheBool { /// assert_eq!(decrypted, true | false); /// ``` fn bitor(self, rhs: bool) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key .pbs_key() .key .scalar_bitor(self.ciphertext.on_cpu().as_ref(), u8::from(rhs)); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -785,12 +832,15 @@ impl BitOr for &FheBool { cuda_key .key .scalar_bitor(&*self.ciphertext.on_gpu(), u8::from(rhs), streams); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -840,13 +890,16 @@ impl BitXor for &FheBool { /// assert_eq!(decrypted, true ^ false); /// ``` fn bitxor(self, rhs: bool) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner_ct = key .pbs_key() .key .scalar_bitxor(self.ciphertext.on_cpu().as_ref(), u8::from(rhs)); - InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)) + ( + InnerBoolean::Cpu(BooleanBlock::new_unchecked(inner_ct)), + key.tag.clone(), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -854,12 +907,15 @@ impl BitXor for &FheBool { cuda_key .key .scalar_bitxor(&*self.ciphertext.on_gpu(), u8::from(rhs), streams); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner_ct.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner_ct.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } @@ -1303,21 +1359,24 @@ impl std::ops::Not for &FheBool { /// assert!(!result); /// ``` fn not(self) -> Self::Output { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { let inner = key.pbs_key().boolean_bitnot(&self.ciphertext.on_cpu()); - InnerBoolean::Cpu(inner) + (InnerBoolean::Cpu(inner), key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key .key .scalar_bitxor(&*self.ciphertext.on_gpu(), 1, streams); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( - inner.ciphertext, - )) + ( + InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + inner.ciphertext, + )), + cuda_key.tag.clone(), + ) }), }); - FheBool::new(ciphertext) + FheBool::new(ciphertext, tag) } } diff --git a/tfhe/src/high_level_api/booleans/compressed.rs b/tfhe/src/high_level_api/booleans/compressed.rs index 95d9793b33..8f51b3a62f 100644 --- a/tfhe/src/high_level_api/booleans/compressed.rs +++ b/tfhe/src/high_level_api/booleans/compressed.rs @@ -1,15 +1,25 @@ -use crate::backward_compatibility::booleans::CompressedFheBoolVersions; +use crate::backward_compatibility::booleans::{ + CompressedFheBoolVersions, InnerCompressedFheBoolVersions, +}; use crate::conformance::ParameterSetConformant; use crate::high_level_api::global_state::with_cpu_internal_keys; +use crate::high_level_api::traits::Tagged; use crate::integer::BooleanBlock; use crate::named::Named; use crate::prelude::FheTryEncrypt; use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree}; use crate::shortint::CompressedCiphertext; -use crate::{ClientKey, FheBool, FheBoolConformanceParams}; +use crate::{ClientKey, FheBool, FheBoolConformanceParams, Tag}; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; +#[derive(Clone, Serialize, Deserialize, Versionize)] +#[versionize(InnerCompressedFheBoolVersions)] +pub enum InnerCompressedFheBool { + Seeded(CompressedCiphertext), + ModulusSwitched(CompressedModulusSwitchedCiphertext), +} + /// Compressed [FheBool] /// /// Meant to save in storage space / transfer. @@ -33,26 +43,40 @@ use tfhe_versionable::Versionize; /// ``` #[derive(Clone, Serialize, Deserialize, Versionize)] #[versionize(CompressedFheBoolVersions)] -pub enum CompressedFheBool { - Seeded(CompressedCiphertext), - ModulusSwitched(CompressedModulusSwitchedCiphertext), +pub struct CompressedFheBool { + pub(in crate::high_level_api) inner: InnerCompressedFheBool, + pub(crate) tag: Tag, +} + +impl Tagged for CompressedFheBool { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } } impl CompressedFheBool { - pub(in crate::high_level_api) fn new(ciphertext: CompressedCiphertext) -> Self { - Self::Seeded(ciphertext) + pub(in crate::high_level_api) fn new(ciphertext: CompressedCiphertext, tag: Tag) -> Self { + Self { + inner: InnerCompressedFheBool::Seeded(ciphertext), + tag, + } } /// Decompresses itself into a [FheBool] /// /// See [CompressedFheBool] example. pub fn decompress(&self) -> FheBool { - let mut ciphertext = FheBool::new(BooleanBlock::new_unchecked(match self { - Self::Seeded(seeded) => seeded.decompress(), - Self::ModulusSwitched(modulus_switched) => { - with_cpu_internal_keys(|sk| sk.key.key.decompress(modulus_switched)) + let ciphertext = BooleanBlock::new_unchecked(match &self.inner { + InnerCompressedFheBool::Seeded(seeded) => seeded.decompress(), + InnerCompressedFheBool::ModulusSwitched(modulus_switched) => { + with_cpu_internal_keys(|sk| sk.pbs_key().key.decompress(modulus_switched)) } - })); + }); + let mut ciphertext = FheBool::new(ciphertext, self.tag.clone()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); @@ -67,7 +91,7 @@ impl FheTryEncrypt for CompressedFheBool { fn try_encrypt(value: bool, key: &ClientKey) -> Result { let mut ciphertext = key.key.key.key.encrypt_compressed(u64::from(value)); ciphertext.degree = Degree::new(1); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, key.tag.clone())) } } @@ -75,9 +99,9 @@ impl ParameterSetConformant for CompressedFheBool { type ParameterSet = FheBoolConformanceParams; fn is_conformant(&self, params: &FheBoolConformanceParams) -> bool { - match self { - Self::Seeded(seeded) => seeded.is_conformant(¶ms.0), - Self::ModulusSwitched(ct) => ct.is_conformant(¶ms.0), + match &self.inner { + InnerCompressedFheBool::Seeded(seeded) => seeded.is_conformant(¶ms.0), + InnerCompressedFheBool::ModulusSwitched(ct) => ct.is_conformant(¶ms.0), } } } @@ -88,10 +112,17 @@ impl Named for CompressedFheBool { impl FheBool { pub fn compress(&self) -> CompressedFheBool { - CompressedFheBool::ModulusSwitched(with_cpu_internal_keys(|sk| { - sk.key - .key - .switch_modulus_and_compress(&self.ciphertext.on_cpu().0) - })) + with_cpu_internal_keys(|sk| { + let inner = InnerCompressedFheBool::ModulusSwitched( + sk.pbs_key() + .key + .switch_modulus_and_compress(&self.ciphertext.on_cpu().0), + ); + + CompressedFheBool { + inner, + tag: sk.tag.clone(), + } + }) } } diff --git a/tfhe/src/high_level_api/booleans/encrypt.rs b/tfhe/src/high_level_api/booleans/encrypt.rs index e3ba6fb326..2cffe180aa 100644 --- a/tfhe/src/high_level_api/booleans/encrypt.rs +++ b/tfhe/src/high_level_api/booleans/encrypt.rs @@ -16,7 +16,7 @@ impl FheTryEncrypt for FheBool { fn try_encrypt(value: bool, key: &ClientKey) -> Result { let integer_client_key = &key.key.key; - let mut ciphertext = Self::new(integer_client_key.encrypt_bool(value)); + let mut ciphertext = Self::new(integer_client_key.encrypt_bool(value), key.tag.clone()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) } @@ -57,8 +57,7 @@ impl FheTryEncrypt for FheBool { type Error = crate::Error; fn try_encrypt(value: bool, key: &CompressedPublicKey) -> Result { - let key = &key.key; - let mut ciphertext = Self::new(key.encrypt_bool(value)); + let mut ciphertext = Self::new(key.key.encrypt_bool(value), key.tag.clone()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) } @@ -68,8 +67,7 @@ impl FheTryEncrypt for FheBool { type Error = crate::Error; fn try_encrypt(value: bool, key: &PublicKey) -> Result { - let key = &key.key; - let mut ciphertext = Self::new(key.encrypt_bool(value)); + let mut ciphertext = Self::new(key.key.encrypt_bool(value), key.tag.clone()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) } @@ -86,9 +84,10 @@ impl FheTryTrivialEncrypt for FheBool { type Error = crate::Error; fn try_encrypt_trivial(value: bool) -> Result { - let ciphertext = global_state::with_internal_keys(|key| match key { + let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { - InnerBoolean::Cpu(key.pbs_key().create_trivial_boolean_block(value)) + let ct = InnerBoolean::Cpu(key.pbs_key().create_trivial_boolean_block(value)); + (ct, key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -96,11 +95,12 @@ impl FheTryTrivialEncrypt for FheBool { cuda_key .key .create_trivial_radix(u64::from(value), 1, streams); - InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( + let ct = InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( inner.into_inner(), - )) + )); + (ct, cuda_key.tag.clone()) }), }); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, tag)) } } diff --git a/tfhe/src/high_level_api/booleans/mod.rs b/tfhe/src/high_level_api/booleans/mod.rs index 37b301ec38..56a18eaefa 100644 --- a/tfhe/src/high_level_api/booleans/mod.rs +++ b/tfhe/src/high_level_api/booleans/mod.rs @@ -1,7 +1,8 @@ pub use base::{FheBool, FheBoolConformanceParams}; pub use compressed::CompressedFheBool; -pub(in crate::high_level_api) use inner::InnerBooleanVersionOwned; +pub(in crate::high_level_api) use compressed::InnerCompressedFheBool; +pub(in crate::high_level_api) use inner::{InnerBoolean, InnerBooleanVersionOwned}; mod base; mod compressed; diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index 7c13aaa204..7a91b425a1 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -6,6 +6,7 @@ use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; use crate::core_crypto::prelude::Numeric; use crate::high_level_api::global_state; use crate::high_level_api::keys::InternalServerKey; +use crate::high_level_api::traits::Tagged; use crate::integer::ciphertext::{Compactable, DataKind, Expandable}; use crate::integer::encryption::KnowsMessageModulus; use crate::integer::parameters::{ @@ -14,9 +15,12 @@ use crate::integer::parameters::{ }; use crate::named::Named; use crate::shortint::MessageModulus; +#[cfg(feature = "zk-pok")] +pub use zk::ProvenCompactCiphertextList; + #[cfg(feature = "zk-pok")] use crate::zk::{CompactPkePublicParams, ZkComputeLoad}; -use crate::CompactPublicKey; +use crate::{CompactPublicKey, Tag}; impl crate::FheTypes { fn from_data_kind(data_kind: DataKind, message_modulus: MessageModulus) -> Option { @@ -71,7 +75,10 @@ impl crate::FheTypes { #[derive(Clone, Serialize, Deserialize, Versionize)] #[versionize(CompactCiphertextListVersions)] -pub struct CompactCiphertextList(pub(crate) crate::integer::ciphertext::CompactCiphertextList); +pub struct CompactCiphertextList { + pub(crate) inner: crate::integer::ciphertext::CompactCiphertextList, + pub(crate) tag: Tag, +} impl Named for CompactCiphertextList { const NAME: &'static str = "high_level_api::CompactCiphertextList"; @@ -83,7 +90,7 @@ impl CompactCiphertextList { } pub fn len(&self) -> usize { - self.0.len() + self.inner.len() } pub fn is_empty(&self) -> bool { @@ -91,8 +98,8 @@ impl CompactCiphertextList { } pub fn get_kind_of(&self, index: usize) -> Option { - self.0.get_kind_of(index).and_then(|data_kind| { - crate::FheTypes::from_data_kind(data_kind, self.0.ct_list.message_modulus) + self.inner.get_kind_of(index).and_then(|data_kind| { + crate::FheTypes::from_data_kind(data_kind, self.inner.ct_list.message_modulus) }) } @@ -100,36 +107,40 @@ impl CompactCiphertextList { &self, sks: &crate::ServerKey, ) -> crate::Result { - self.0 + self.inner .expand( IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(sks.key.pbs_key()), IntegerCompactCiphertextListCastingMode::NoCasting, ) - .map(|inner| CompactCiphertextListExpander { inner }) + .map(|inner| CompactCiphertextListExpander { + inner, + tag: self.tag.clone(), + }) } pub fn expand(&self) -> crate::Result { // For WASM - if !self.0.is_packed() && !self.0.needs_casting() { - // No ServerKey required, shortcircuit to avoid the global state call + if !self.inner.is_packed() && !self.inner.needs_casting() { + // No ServerKey required, short-circuit to avoid the global state call return Ok(CompactCiphertextListExpander { - inner: self.0.expand( + inner: self.inner.expand( IntegerCompactCiphertextListUnpackingMode::NoUnpacking, IntegerCompactCiphertextListCastingMode::NoCasting, )?, + tag: self.tag.clone(), }); } global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), Some(InternalServerKey::Cpu(cpu_key)) => { - let unpacking_mode = if self.0.is_packed() { + let unpacking_mode = if self.inner.is_packed() { IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(cpu_key.pbs_key()) } else { IntegerCompactCiphertextListUnpackingMode::NoUnpacking }; - let casting_mode = if self.0.needs_casting() { + let casting_mode = if self.inner.needs_casting() { IntegerCompactCiphertextListCastingMode::CastIfNecessary( cpu_key.cpk_casting_key().ok_or_else(|| { crate::Error::new( @@ -143,42 +154,70 @@ impl CompactCiphertextList { IntegerCompactCiphertextListCastingMode::NoCasting }; - self.0 + self.inner .expand(unpacking_mode, casting_mode) - .map(|inner| CompactCiphertextListExpander { inner }) + .map(|inner| CompactCiphertextListExpander { + inner, + tag: self.tag.clone(), + }) } #[cfg(feature = "gpu")] Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), }) } } + +impl Tagged for CompactCiphertextList { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl ParameterSetConformant for CompactCiphertextList { type ParameterSet = CompactCiphertextListConformanceParams; fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { - let Self(list) = self; + let Self { inner, tag: _ } = self; - list.is_conformant(parameter_set) + inner.is_conformant(parameter_set) } } #[cfg(feature = "zk-pok")] -#[derive(Clone, Serialize, Deserialize)] -pub struct ProvenCompactCiphertextList(crate::integer::ciphertext::ProvenCompactCiphertextList); +mod zk { + use super::*; -#[cfg(feature = "zk-pok")] + #[derive(Clone, Serialize, Deserialize)] + pub struct ProvenCompactCiphertextList { + pub(crate) inner: crate::integer::ciphertext::ProvenCompactCiphertextList, + pub(crate) tag: Tag, + } +} + +impl Tagged for ProvenCompactCiphertextList { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} impl Named for ProvenCompactCiphertextList { const NAME: &'static str = "high_level_api::ProvenCompactCiphertextList"; } -#[cfg(feature = "zk-pok")] impl ProvenCompactCiphertextList { pub fn builder(pk: &CompactPublicKey) -> CompactCiphertextListBuilder { CompactCiphertextListBuilder::new(pk) } pub fn len(&self) -> usize { - self.0.len() + self.inner.len() } pub fn is_empty(&self) -> bool { @@ -186,8 +225,8 @@ impl ProvenCompactCiphertextList { } pub fn get_kind_of(&self, index: usize) -> Option { - self.0.get_kind_of(index).and_then(|data_kind| { - crate::FheTypes::from_data_kind(data_kind, self.0.ct_list.message_modulus()) + self.inner.get_kind_of(index).and_then(|data_kind| { + crate::FheTypes::from_data_kind(data_kind, self.inner.ct_list.message_modulus()) }) } @@ -197,28 +236,29 @@ impl ProvenCompactCiphertextList { pk: &CompactPublicKey, ) -> crate::Result { // For WASM - if !self.0.is_packed() && !self.0.needs_casting() { - // No ServerKey required, short-circuit to avoid the global state call + if !self.inner.is_packed() && !self.inner.needs_casting() { + // No ServerKey required, short circuit to avoid the global state call return Ok(CompactCiphertextListExpander { - inner: self.0.verify_and_expand( + inner: self.inner.verify_and_expand( public_params, &pk.key.key, IntegerCompactCiphertextListUnpackingMode::NoUnpacking, IntegerCompactCiphertextListCastingMode::NoCasting, )?, + tag: self.tag.clone(), }); } global_state::try_with_internal_keys(|maybe_keys| match maybe_keys { None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), Some(InternalServerKey::Cpu(cpu_key)) => { - let unpacking_mode = if self.0.is_packed() { + let unpacking_mode = if self.inner.is_packed() { IntegerCompactCiphertextListUnpackingMode::UnpackIfNecessary(cpu_key.pbs_key()) } else { IntegerCompactCiphertextListUnpackingMode::NoUnpacking }; - let casting_mode = if self.0.needs_casting() { + let casting_mode = if self.inner.needs_casting() { IntegerCompactCiphertextListCastingMode::CastIfNecessary( cpu_key.cpk_casting_key().ok_or_else(|| { crate::Error::new( @@ -232,9 +272,12 @@ impl ProvenCompactCiphertextList { IntegerCompactCiphertextListCastingMode::NoCasting }; - self.0 + self.inner .verify_and_expand(public_params, &pk.key.key, unpacking_mode, casting_mode) - .map(|expander| CompactCiphertextListExpander { inner: expander }) + .map(|expander| CompactCiphertextListExpander { + inner: expander, + tag: self.tag.clone(), + }) } #[cfg(feature = "gpu")] Some(_) => Err(crate::Error::new("Expected a CPU server key".to_string())), @@ -243,7 +286,8 @@ impl ProvenCompactCiphertextList { } pub struct CompactCiphertextListExpander { - inner: crate::integer::ciphertext::CompactCiphertextListExpander, + pub(in crate::high_level_api) inner: crate::integer::ciphertext::CompactCiphertextListExpander, + tag: Tag, } impl CompactCiphertextListExpander { @@ -263,9 +307,13 @@ impl CompactCiphertextListExpander { pub fn get(&self, index: usize) -> Option> where - T: Expandable, + T: Expandable + Tagged, { - self.inner.get(index) + let mut expanded = self.inner.get::(index); + if let Some(Ok(inner)) = &mut expanded { + inner.tag_mut().set_data(self.tag.data()); + } + expanded } } @@ -283,12 +331,14 @@ fn num_bits_to_strict_num_blocks( pub struct CompactCiphertextListBuilder { inner: crate::integer::ciphertext::CompactCiphertextListBuilder, + tag: Tag, } impl CompactCiphertextListBuilder { pub fn new(pk: &CompactPublicKey) -> Self { Self { inner: crate::integer::ciphertext::CompactCiphertextListBuilder::new(&pk.key.key), + tag: pk.tag.clone(), } } @@ -333,13 +383,19 @@ impl CompactCiphertextListBuilder { } pub fn build(&self) -> CompactCiphertextList { - CompactCiphertextList(self.inner.build()) + CompactCiphertextList { + inner: self.inner.build(), + tag: self.tag.clone(), + } } pub fn build_packed(&self) -> CompactCiphertextList { self.inner .build_packed() - .map(CompactCiphertextList) + .map(|list| CompactCiphertextList { + inner: list, + tag: self.tag.clone(), + }) .expect("Internal error, invalid parameters should not have been allowed") } @@ -351,7 +407,10 @@ impl CompactCiphertextListBuilder { ) -> crate::Result { self.inner .build_with_proof_packed(public_params, compute_load) - .map(ProvenCompactCiphertextList) + .map(|proved_list| ProvenCompactCiphertextList { + inner: proved_list, + tag: self.tag.clone(), + }) } } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 56a43381b1..5cd5fac628 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -6,8 +6,9 @@ use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; use crate::high_level_api::integers::{FheIntId, FheUintId}; use crate::integer::ciphertext::{Compressible, DataKind, Expandable}; use crate::named::Named; +use crate::prelude::Tagged; use crate::shortint::Ciphertext; -use crate::{FheBool, FheInt, FheUint}; +use crate::{FheBool, FheInt, FheUint, Tag}; impl Compressible for FheUint { fn compress_into(self, messages: &mut Vec) -> DataKind { @@ -58,12 +59,16 @@ impl CompressedCiphertextListBuilder { pub fn build(&self) -> crate::Result { crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { Some(InternalServerKey::Cpu(cpu_key)) => cpu_key + .key .compression_key .as_ref() .ok_or_else(|| { crate::Error::new("Compression key not set in server key".to_owned()) }) - .map(|compression_key| CompressedCiphertextList(self.inner.build(compression_key))), + .map(|compression_key| CompressedCiphertextList { + inner: self.inner.build(compression_key), + tag: cpu_key.tag.clone(), + }), _ => Err(crate::Error::new( "A Cpu server key is needed to be set to use compression".to_owned(), )), @@ -73,15 +78,28 @@ impl CompressedCiphertextListBuilder { #[derive(Clone, Serialize, Deserialize, Versionize)] #[versionize(CompressedCiphertextListVersions)] -pub struct CompressedCiphertextList(crate::integer::ciphertext::CompressedCiphertextList); +pub struct CompressedCiphertextList { + pub(in crate::high_level_api) inner: crate::integer::ciphertext::CompressedCiphertextList, + pub(in crate::high_level_api) tag: Tag, +} impl Named for CompressedCiphertextList { const NAME: &'static str = "high_level_api::CompressedCiphertextList"; } +impl Tagged for CompressedCiphertextList { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl CompressedCiphertextList { pub fn len(&self) -> usize { - self.0.len() + self.inner.len() } pub fn is_empty(&self) -> bool { @@ -89,9 +107,9 @@ impl CompressedCiphertextList { } pub fn get_kind_of(&self, index: usize) -> Option { - Some(match self.0.get_kind_of(index)? { + Some(match self.inner.get_kind_of(index)? { DataKind::Unsigned(n) => { - let num_bits_per_block = self.0.packed_list.message_modulus.0.ilog2() as usize; + let num_bits_per_block = self.inner.packed_list.message_modulus.0.ilog2() as usize; let num_bits = n * num_bits_per_block; match num_bits { 2 => crate::FheTypes::Uint2, @@ -111,7 +129,7 @@ impl CompressedCiphertextList { } } DataKind::Signed(n) => { - let num_bits_per_block = self.0.packed_list.message_modulus.0.ilog2() as usize; + let num_bits_per_block = self.inner.packed_list.message_modulus.0.ilog2() as usize; let num_bits = n * num_bits_per_block; match num_bits { 2 => crate::FheTypes::Int2, @@ -136,16 +154,23 @@ impl CompressedCiphertextList { pub fn get(&self, index: usize) -> crate::Result> where - T: Expandable, + T: Expandable + Tagged, { crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { Some(InternalServerKey::Cpu(cpu_key)) => cpu_key + .key .decompression_key .as_ref() .ok_or_else(|| { crate::Error::new("Compression key not set in server key".to_owned()) }) - .and_then(|decompression_key| self.0.get(index, decompression_key)), + .and_then(|decompression_key| { + let mut ct = self.inner.get::(index, decompression_key); + if let Ok(Some(ct_ref)) = &mut ct { + ct_ref.tag_mut().set_data(cpu_key.tag.data()) + } + ct + }), _ => Err(crate::Error::new( "A Cpu server key is needed to be set".to_string(), )), diff --git a/tfhe/src/high_level_api/global_state.rs b/tfhe/src/high_level_api/global_state.rs index 60d28d79d6..0ff11a9bc5 100644 --- a/tfhe/src/high_level_api/global_state.rs +++ b/tfhe/src/high_level_api/global_state.rs @@ -3,7 +3,7 @@ #[cfg(feature = "gpu")] use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::errors::{UninitializedServerKey, UnwrapResultExt}; -use crate::high_level_api::keys::{IntegerServerKey, InternalServerKey, ServerKey}; +use crate::high_level_api::keys::{InternalServerKey, ServerKey}; use std::cell::RefCell; /// We store the internal keys as thread local, meaning each thread has its own set of keys. /// @@ -122,10 +122,23 @@ pub(in crate::high_level_api) fn device_of_internal_keys() -> Option crate::Result { + INTERNAL_KEYS.with(|keys| { + let cell = keys.borrow(); + Ok(match cell.as_ref().ok_or(UninitializedServerKey)? { + InternalServerKey::Cpu(cpu_key) => cpu_key.tag.clone(), + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(cuda_key) => cuda_key.tag.clone(), + }) + }) +} + #[inline] pub(crate) fn with_cpu_internal_keys(func: F) -> T where - F: FnOnce(&IntegerServerKey) -> T, + F: FnOnce(&ServerKey) -> T, { // Should use `with_borrow` when its stabilized INTERNAL_KEYS.with(|keys| { diff --git a/tfhe/src/high_level_api/integers/oprf.rs b/tfhe/src/high_level_api/integers/oprf.rs index 6dbe1c7745..360b390dfa 100644 --- a/tfhe/src/high_level_api/integers/oprf.rs +++ b/tfhe/src/high_level_api/integers/oprf.rs @@ -27,21 +27,23 @@ impl FheUint { /// assert!(dec_result < (1 << random_bits_count)); /// ``` pub fn generate_oblivious_pseudo_random(seed: Seed, random_bits_count: u64) -> Self { - let ct = global_state::with_internal_keys(|key| match key { - InternalServerKey::Cpu(key) => key - .key - .par_generate_oblivious_pseudo_random_unsigned_integer( - seed, - random_bits_count, - Id::num_blocks(key.message_modulus()) as u64, - ), + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(key) => { + let ct = key + .pbs_key() + .par_generate_oblivious_pseudo_random_unsigned_integer( + seed, + random_bits_count, + Id::num_blocks(key.message_modulus()) as u64, + ); + + Self::new(ct, key.tag.clone()) + } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { todo!("Cuda devices do not yet support oblivious pseudo random generation") } - }); - - Self::new(ct) + }) } } @@ -73,20 +75,21 @@ impl FheInt { seed: Seed, randomizer: SignedRandomizationSpec, ) -> Self { - let ct = global_state::with_internal_keys(|key| match key { + global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { - key.key.par_generate_oblivious_pseudo_random_signed_integer( - seed, - randomizer, - Id::num_blocks(key.message_modulus()) as u64, - ) + let ct = key + .pbs_key() + .par_generate_oblivious_pseudo_random_signed_integer( + seed, + randomizer, + Id::num_blocks(key.message_modulus()) as u64, + ); + Self::new(ct, key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { todo!("Cuda devices do not yet support oblivious pseudo random generation") } - }); - - Self::new(ct) + }) } } diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 4d735898d7..12c0c96445 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -6,13 +6,14 @@ use crate::conformance::ParameterSetConformant; use crate::high_level_api::global_state; use crate::high_level_api::integers::{FheUint, FheUintId, IntegerId}; use crate::high_level_api::keys::InternalServerKey; +use crate::high_level_api::traits::Tagged; use crate::integer::client_key::RecomposableSignedInteger; use crate::integer::parameters::RadixCiphertextConformanceParams; use crate::named::Named; use crate::prelude::CastFrom; use crate::shortint::ciphertext::NotTrivialCiphertextError; use crate::shortint::PBSParameters; -use crate::{Device, FheBool, ServerKey}; +use crate::{Device, FheBool, ServerKey, Tag}; use std::marker::PhantomData; #[cfg(feature = "gpu")] @@ -36,9 +37,9 @@ pub trait FheIntId: IntegerId {} #[derive(Clone, serde::Deserialize, serde::Serialize, Versionize)] #[versionize(FheIntVersions)] pub struct FheInt { - // pub(in crate::high_level_api) ciphertext: RadixCiphertext, - pub(in crate::high_level_api::integers) id: Id, + pub(in crate::high_level_api) id: Id, + pub(crate) tag: Tag, } pub struct FheIntConformanceParams { @@ -75,7 +76,11 @@ impl ParameterSetConformant for FheInt { type ParameterSet = FheIntConformanceParams; fn is_conformant(&self, params: &FheIntConformanceParams) -> bool { - let Self { ciphertext, id: _ } = self; + let Self { + ciphertext, + id: _, + tag: _, + } = self; ciphertext.on_cpu().is_conformant(¶ms.params) } @@ -85,26 +90,49 @@ impl Named for FheInt { const NAME: &'static str = "high_level_api::FheInt"; } +impl Tagged for FheInt +where + Id: FheIntId, +{ + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl FheInt where Id: FheIntId, { - pub(in crate::high_level_api) fn new(ciphertext: impl Into) -> Self { + pub(in crate::high_level_api) fn new(ciphertext: impl Into, tag: Tag) -> Self { Self { ciphertext: ciphertext.into(), id: Id::default(), + tag, } } - pub fn into_raw_parts(self) -> (crate::integer::SignedRadixCiphertext, Id) { - let Self { ciphertext, id } = self; - (ciphertext.into_cpu(), id) + pub fn into_raw_parts(self) -> (crate::integer::SignedRadixCiphertext, Id, Tag) { + let Self { + ciphertext, + id, + tag, + } = self; + (ciphertext.into_cpu(), id, tag) } - pub fn from_raw_parts(ciphertext: crate::integer::SignedRadixCiphertext, id: Id) -> Self { + pub fn from_raw_parts( + ciphertext: crate::integer::SignedRadixCiphertext, + id: Id, + tag: Tag, + ) -> Self { Self { ciphertext: ciphertext.into(), id, + tag, } } @@ -155,7 +183,7 @@ where let ciphertext = cpu_key .pbs_key() .abs_parallelized(&*self.ciphertext.on_cpu()); - Self::new(ciphertext) + Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -187,12 +215,12 @@ where let result = cpu_key .pbs_key() .is_even_parallelized(&*self.ciphertext.on_cpu()); - FheBool::new(result) + FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.is_even(&*self.ciphertext.on_gpu(), streams); - FheBool::new(result) + FheBool::new(result, cuda_key.tag.clone()) }), }) } @@ -220,12 +248,12 @@ where let result = cpu_key .pbs_key() .is_odd_parallelized(&*self.ciphertext.on_cpu()); - FheBool::new(result) + FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.is_odd(&*self.ciphertext.on_gpu(), streams); - FheBool::new(result) + FheBool::new(result, cuda_key.tag.clone()) }), }) } @@ -257,7 +285,7 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - crate::FheUint32::new(result) + crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -293,7 +321,7 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - crate::FheUint32::new(result) + crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -329,7 +357,7 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - crate::FheUint32::new(result) + crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -365,7 +393,7 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - crate::FheUint32::new(result) + crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -403,7 +431,7 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - crate::FheUint32::new(result) + crate::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -445,7 +473,10 @@ where result, crate::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - (crate::FheUint32::new(result), FheBool::new(is_ok)) + ( + crate::FheUint32::new(result, cpu_key.tag.clone()), + FheBool::new(is_ok, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -519,7 +550,7 @@ where let ct = self.ciphertext.on_cpu(); - Self::new(sk.reverse_bits_parallelized(&*ct)) + Self::new(sk.reverse_bits_parallelized(&*ct), cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -558,7 +589,7 @@ where let new_ciphertext = cpu_key .pbs_key() .cast_to_signed(input.ciphertext.into_cpu(), target_num_blocks); - Self::new(new_ciphertext) + Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -568,7 +599,7 @@ where target_num_blocks, streams, ); - Self::new(new_ciphertext) + Self::new(new_ciphertext, cuda_key.tag.clone()) }), }) } @@ -599,11 +630,11 @@ where fn cast_from(input: FheUint) -> Self { global_state::with_internal_keys(|keys| match keys { InternalServerKey::Cpu(cpu_key) => { - let new_ciphertext = cpu_key.key.cast_to_signed( + let new_ciphertext = cpu_key.pbs_key().cast_to_signed( input.ciphertext.on_cpu().to_owned(), IntoId::num_blocks(cpu_key.message_modulus()), ); - Self::new(new_ciphertext) + Self::new(new_ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -612,7 +643,7 @@ where IntoId::num_blocks(cuda_key.message_modulus()), streams, ); - Self::new(new_ciphertext) + Self::new(new_ciphertext, cuda_key.tag.clone()) }), }) } @@ -650,7 +681,7 @@ where Id::num_blocks(cpu_key.message_modulus()), cpu_key.pbs_key(), ); - Self::new(ciphertext) + Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -659,7 +690,7 @@ where Id::num_blocks(cuda_key.message_modulus()), streams, ); - Self::new(inner) + Self::new(inner, cuda_key.tag.clone()) }), }) } diff --git a/tfhe/src/high_level_api/integers/signed/compressed.rs b/tfhe/src/high_level_api/integers/signed/compressed.rs index 2529e75f63..9fbe046324 100644 --- a/tfhe/src/high_level_api/integers/signed/compressed.rs +++ b/tfhe/src/high_level_api/integers/signed/compressed.rs @@ -8,6 +8,7 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state::with_cpu_internal_keys; use crate::high_level_api::integers::signed::base::FheIntConformanceParams; use crate::high_level_api::integers::{FheInt, FheIntId}; +use crate::high_level_api::traits::Tagged; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::{ CompressedModulusSwitchedSignedRadixCiphertext, @@ -16,7 +17,7 @@ use crate::integer::ciphertext::{ use crate::integer::parameters::RadixCiphertextConformanceParams; use crate::named::Named; use crate::prelude::FheTryEncrypt; -use crate::ClientKey; +use crate::{ClientKey, Tag}; /// Compressed [FheInt] /// @@ -47,28 +48,54 @@ pub struct CompressedFheInt where Id: FheIntId, { - pub(in crate::high_level_api::integers) ciphertext: CompressedSignedRadixCiphertext, - pub(in crate::high_level_api::integers) id: Id, + pub(in crate::high_level_api) ciphertext: CompressedSignedRadixCiphertext, + pub(in crate::high_level_api) id: Id, + pub(crate) tag: Tag, +} + +impl Tagged for CompressedFheInt +where + Id: FheIntId, +{ + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } } impl CompressedFheInt where Id: FheIntId, { - pub(in crate::high_level_api::integers) fn new(inner: CompressedSignedRadixCiphertext) -> Self { + pub(in crate::high_level_api::integers) fn new( + inner: CompressedSignedRadixCiphertext, + tag: Tag, + ) -> Self { Self { ciphertext: inner, id: Id::default(), + tag, } } - pub fn into_raw_parts(self) -> (CompressedSignedRadixCiphertext, Id) { - let Self { ciphertext, id } = self; - (ciphertext, id) + pub fn into_raw_parts(self) -> (CompressedSignedRadixCiphertext, Id, Tag) { + let Self { + ciphertext, + id, + tag, + } = self; + (ciphertext, id, tag) } - pub fn from_raw_parts(ciphertext: CompressedSignedRadixCiphertext, id: Id) -> Self { - Self { ciphertext, id } + pub fn from_raw_parts(ciphertext: CompressedSignedRadixCiphertext, id: Id, tag: Tag) -> Self { + Self { + ciphertext, + id, + tag, + } } } @@ -80,12 +107,13 @@ where /// /// See [CompressedFheInt] example. pub fn decompress(&self) -> FheInt { - FheInt::new(match &self.ciphertext { + let ciphertext = match &self.ciphertext { CompressedSignedRadixCiphertext::Seeded(ct) => ct.decompress(), CompressedSignedRadixCiphertext::ModulusSwitched(ct) => { - with_cpu_internal_keys(|sk| sk.key.decompress_signed_parallelized(ct)) + with_cpu_internal_keys(|sk| sk.pbs_key().decompress_signed_parallelized(ct)) } - }) + }; + FheInt::new(ciphertext, self.tag.clone()) } } @@ -100,7 +128,10 @@ where let integer_client_key = &key.key.key; let inner = integer_client_key .encrypt_signed_radix_compressed(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(CompressedSignedRadixCiphertext::Seeded(inner))) + Ok(Self::new( + CompressedSignedRadixCiphertext::Seeded(inner), + key.tag.clone(), + )) } } @@ -108,7 +139,11 @@ impl ParameterSetConformant for CompressedFheInt { type ParameterSet = FheIntConformanceParams; fn is_conformant(&self, params: &FheIntConformanceParams) -> bool { - let Self { ciphertext, id: _ } = self; + let Self { + ciphertext, + id: _, + tag: _, + } = self; ciphertext.is_conformant(¶ms.params) } @@ -141,10 +176,13 @@ where { pub fn compress(&self) -> CompressedFheInt { let a = with_cpu_internal_keys(|sk| { - sk.key + sk.pbs_key() .switch_modulus_and_compress_signed_parallelized(&self.ciphertext.on_cpu()) }); - CompressedFheInt::new(CompressedSignedRadixCiphertext::ModulusSwitched(a)) + CompressedFheInt::new( + CompressedSignedRadixCiphertext::ModulusSwitched(a), + self.tag.clone(), + ) } } diff --git a/tfhe/src/high_level_api/integers/signed/encrypt.rs b/tfhe/src/high_level_api/integers/signed/encrypt.rs index 695e25075f..9a75e5e1aa 100644 --- a/tfhe/src/high_level_api/integers/signed/encrypt.rs +++ b/tfhe/src/high_level_api/integers/signed/encrypt.rs @@ -50,7 +50,7 @@ where .key .key .encrypt_signed_radix(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, key.tag.clone())) } } @@ -65,7 +65,7 @@ where let ciphertext = key .key .encrypt_signed_radix(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, key.tag.clone())) } } @@ -80,7 +80,7 @@ where let ciphertext = key .key .encrypt_signed_radix(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, key.tag.clone())) } } @@ -101,14 +101,15 @@ where /// Trivial encryptions become real encrypted data once used in an operation /// that involves a real ciphertext fn try_encrypt_trivial(value: T) -> Result { - let ciphertext = global_state::with_cpu_internal_keys(|sks| { - sks.pbs_key() + global_state::with_cpu_internal_keys(|sks| { + let ciphertext = sks + .pbs_key() .create_trivial_radix::( value, Id::num_blocks(sks.message_modulus()), - ) - }); - Ok(Self::new(ciphertext)) + ); + Ok(Self::new(ciphertext, sks.tag.clone())) + }) } } diff --git a/tfhe/src/high_level_api/integers/signed/mod.rs b/tfhe/src/high_level_api/integers/signed/mod.rs index b1447be8d7..9b3b2a534b 100644 --- a/tfhe/src/high_level_api/integers/signed/mod.rs +++ b/tfhe/src/high_level_api/integers/signed/mod.rs @@ -13,7 +13,7 @@ mod tests; pub use base::{FheInt, FheIntId}; pub use compressed::CompressedFheInt; pub(in crate::high_level_api) use compressed::CompressedSignedRadixCiphertext; -pub(in crate::high_level_api) use inner::RadixCiphertextVersionOwned; +pub(in crate::high_level_api) use inner::{RadixCiphertext, RadixCiphertextVersionOwned}; expand_pub_use_fhe_type!( pub use static_{ diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index d6c17be7a5..d8509e8437 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -58,12 +58,12 @@ where .map_or_else( || { let radix: crate::integer::SignedRadixCiphertext = - cpu_key.key.create_trivial_zero_radix(Id::num_blocks( + cpu_key.pbs_key().create_trivial_zero_radix(Id::num_blocks( cpu_key.message_modulus(), )); - Self::new(radix) + Self::new(radix, cpu_key.tag.clone()) }, - Self::new, + |ct| Self::new(ct, cpu_key.tag.clone()), ) } #[cfg(feature = "gpu")] @@ -105,7 +105,7 @@ where let inner_result = cpu_key .pbs_key() .max_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -114,7 +114,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -151,7 +151,7 @@ where let inner_result = cpu_key .pbs_key() .min_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -160,7 +160,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -208,7 +208,7 @@ where let inner_result = cpu_key .pbs_key() .eq_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -217,7 +217,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -247,7 +247,7 @@ where let inner_result = cpu_key .pbs_key() .ne_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -256,7 +256,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -312,7 +312,7 @@ where let inner_result = cpu_key .pbs_key() .lt_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -321,7 +321,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -351,7 +351,7 @@ where let inner_result = cpu_key .pbs_key() .le_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -360,7 +360,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -390,7 +390,7 @@ where let inner_result = cpu_key .pbs_key() .gt_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -399,7 +399,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -429,7 +429,7 @@ where let inner_result = cpu_key .pbs_key() .ge_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -438,7 +438,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -508,7 +508,10 @@ where let (q, r) = cpu_key .pbs_key() .div_rem_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - (FheInt::::new(q), FheInt::::new(r)) + ( + FheInt::::new(q, cpu_key.tag.clone()), + FheInt::::new(r, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -583,14 +586,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .add_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .add(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -626,14 +629,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .sub_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .sub(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -669,14 +672,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .mul_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .mul(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -710,14 +713,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitand_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitand(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -751,14 +754,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitor_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -792,14 +795,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitxor_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitxor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -841,7 +844,7 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .div_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(_cuda_key) => { @@ -887,7 +890,7 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .rem_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheInt::new(inner_result) + FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(_cuda_key) => { @@ -996,14 +999,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .left_shift_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .left_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1040,14 +1043,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .right_shift_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .right_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1084,14 +1087,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .rotate_left_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .rotate_left(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1128,14 +1131,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .rotate_right_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .rotate_right(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1789,12 +1792,12 @@ where let ciphertext = cpu_key .pbs_key() .neg_parallelized(&*self.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.neg(&*self.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -1855,12 +1858,12 @@ where global_state::with_internal_keys(|keys| match keys { InternalServerKey::Cpu(cpu_key) => { let ciphertext = cpu_key.pbs_key().bitnot(&*self.ciphertext.on_cpu()); - FheInt::new(ciphertext) + FheInt::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.bitnot(&*self.ciphertext.on_gpu(), streams); - FheInt::new(inner_result) + FheInt::new(inner_result, cuda_key.tag.clone()) }), }) } diff --git a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs index 52181a985c..7d91a14a20 100644 --- a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs @@ -43,11 +43,14 @@ where fn overflowing_add(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.signed_overflowing_add_parallelized( + let (result, overflow) = cpu_key.pbs_key().signed_overflowing_add_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -56,7 +59,10 @@ where &other.ciphertext.on_gpu(), streams, ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cuda_key.tag.clone()), + FheBool::new(overflow, cuda_key.tag.clone()), + ) }), }) } @@ -135,9 +141,12 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let (result, overflow) = cpu_key - .key + .pbs_key() .signed_overflowing_scalar_add_parallelized(&self.ciphertext.on_cpu(), other); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -146,7 +155,10 @@ where other, streams, ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cuda_key.tag.clone()), + FheBool::new(overflow, cuda_key.tag.clone()), + ) }), }) } @@ -261,11 +273,14 @@ where fn overflowing_sub(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.signed_overflowing_sub_parallelized( + let (result, overflow) = cpu_key.pbs_key().signed_overflowing_sub_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -274,7 +289,10 @@ where &other.ciphertext.on_gpu(), streams, ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cuda_key.tag.clone()), + FheBool::new(overflow, cuda_key.tag.clone()), + ) }), }) } @@ -352,9 +370,12 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let (result, overflow) = cpu_key - .key + .pbs_key() .signed_overflowing_scalar_sub_parallelized(&self.ciphertext.on_cpu(), other); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -363,7 +384,10 @@ where other, streams, ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cuda_key.tag.clone()), + FheBool::new(overflow, cuda_key.tag.clone()), + ) }), }) } @@ -440,11 +464,14 @@ where fn overflowing_mul(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.signed_overflowing_mul_parallelized( + let (result, overflow) = cpu_key.pbs_key().signed_overflowing_mul_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheInt::new(result), FheBool::new(overflow)) + ( + FheInt::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index d8960fc8c4..43779dda08 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -1,5 +1,6 @@ #[cfg(feature = "gpu")] use crate::core_crypto::commons::numeric::CastFrom; +use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; @@ -50,7 +51,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_max_parallelized(&*self.ciphertext.on_cpu(), rhs); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -59,7 +60,7 @@ where cuda_key .key .scalar_max(&*self.ciphertext.on_gpu(), rhs, streams); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -97,7 +98,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_min_parallelized(&*self.ciphertext.on_cpu(), rhs); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -106,7 +107,7 @@ where cuda_key .key .scalar_min(&*self.ciphertext.on_gpu(), rhs, streams); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -143,7 +144,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_eq_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -152,7 +153,7 @@ where cuda_key .key .scalar_eq(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -183,7 +184,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_ne_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -192,7 +193,7 @@ where cuda_key .key .scalar_ne(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -228,7 +229,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_lt_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -237,7 +238,7 @@ where cuda_key .key .scalar_lt(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -267,7 +268,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_le_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -276,7 +277,7 @@ where cuda_key .key .scalar_le(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -306,7 +307,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_gt_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -315,7 +316,7 @@ where cuda_key .key .scalar_gt(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -345,7 +346,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_ge_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { @@ -354,7 +355,7 @@ where cuda_key .key .scalar_ge(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -394,8 +395,8 @@ macro_rules! generic_integer_impl_scalar_div_rem { .pbs_key() .$key_method(&*self.ciphertext.on_cpu(), rhs); ( - <$concrete_type>::new(q), - <$concrete_type>::new(r) + <$concrete_type>::new(q, cpu_key.tag.clone()), + <$concrete_type>::new(r, cpu_key.tag.clone()) ) } #[cfg(feature = "gpu")] diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 69bc5a1b6f..8a723e3161 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -9,6 +9,7 @@ use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::signed::{FheInt, FheIntId}; use crate::high_level_api::integers::IntegerId; use crate::high_level_api::keys::InternalServerKey; +use crate::high_level_api::traits::Tagged; use crate::high_level_api::{global_state, Device}; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; use crate::integer::parameters::RadixCiphertextConformanceParams; @@ -17,7 +18,7 @@ use crate::named::Named; use crate::prelude::CastInto; use crate::shortint::ciphertext::NotTrivialCiphertextError; use crate::shortint::PBSParameters; -use crate::{FheBool, ServerKey}; +use crate::{FheBool, ServerKey, Tag}; use std::marker::PhantomData; #[derive(Debug)] @@ -77,7 +78,8 @@ pub trait FheUintId: IntegerId {} #[versionize(FheUintVersions)] pub struct FheUint { pub(in crate::high_level_api) ciphertext: RadixCiphertext, - pub(in crate::high_level_api::integers) id: Id, + pub(in crate::high_level_api) id: Id, + pub(crate) tag: Tag, } pub struct FheUintConformanceParams { @@ -114,7 +116,11 @@ impl ParameterSetConformant for FheUint { type ParameterSet = FheUintConformanceParams; fn is_conformant(&self, params: &FheUintConformanceParams) -> bool { - let Self { ciphertext, id: _ } = self; + let Self { + ciphertext, + id: _, + tag: _, + } = self; ciphertext.on_cpu().is_conformant(¶ms.params) } @@ -124,32 +130,51 @@ impl Named for FheUint { const NAME: &'static str = "high_level_api::FheUint"; } +impl Tagged for FheUint +where + Id: FheUintId, +{ + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl FheUint where Id: FheUintId, { - pub(in crate::high_level_api) fn new(ciphertext: T) -> Self + pub(in crate::high_level_api) fn new(ciphertext: T, tag: Tag) -> Self where T: Into, { Self { ciphertext: ciphertext.into(), id: Id::default(), + tag, } } - pub fn into_raw_parts(self) -> (crate::integer::RadixCiphertext, Id) { - let Self { ciphertext, id } = self; + pub fn into_raw_parts(self) -> (crate::integer::RadixCiphertext, Id, Tag) { + let Self { + ciphertext, + id, + tag, + } = self; let ciphertext = ciphertext.into_cpu(); - (ciphertext, id) + (ciphertext, id, tag) } - pub fn from_raw_parts(ciphertext: crate::integer::RadixCiphertext, id: Id) -> Self { + pub fn from_raw_parts(ciphertext: crate::integer::RadixCiphertext, id: Id, tag: Tag) -> Self { Self { ciphertext: RadixCiphertext::Cpu(ciphertext), id, + tag, } } @@ -196,12 +221,12 @@ where let result = cpu_key .pbs_key() .is_even_parallelized(&*self.ciphertext.on_cpu()); - FheBool::new(result) + FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.is_even(&*self.ciphertext.on_gpu(), streams); - FheBool::new(result) + FheBool::new(result, cuda_key.tag.clone()) }), }) } @@ -229,12 +254,12 @@ where let result = cpu_key .pbs_key() .is_odd_parallelized(&*self.ciphertext.on_cpu()); - FheBool::new(result) + FheBool::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.is_odd(&*self.ciphertext.on_gpu(), streams); - FheBool::new(result) + FheBool::new(result, cuda_key.tag.clone()) }), }) } @@ -359,7 +384,7 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - super::FheUint32::new(result) + super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -395,7 +420,7 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - super::FheUint32::new(result) + super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -431,7 +456,7 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - super::FheUint32::new(result) + super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -467,7 +492,7 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - super::FheUint32::new(result) + super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -505,7 +530,7 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - super::FheUint32::new(result) + super::FheUint32::new(result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -547,7 +572,10 @@ where result, super::FheUint32Id::num_blocks(cpu_key.pbs_key().message_modulus()), ); - (super::FheUint32::new(result), FheBool::new(is_ok)) + ( + super::FheUint32::new(result, cpu_key.tag.clone()), + FheBool::new(is_ok, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -611,7 +639,10 @@ where let result = cpu_key .pbs_key() .cast_to_unsigned(result, target_num_blocks); - Ok((FheUint::new(result), FheBool::new(matched))) + Ok(( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(matched, cpu_key.tag.clone()), + )) } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } @@ -676,7 +707,7 @@ where let result = cpu_key .pbs_key() .cast_to_unsigned(result, target_num_blocks); - Ok(FheUint::new(result)) + Ok(FheUint::new(result, cpu_key.tag.clone())) } else { Err(crate::Error::new("Output type does not have enough bits to represent all possible output values".to_string())) } @@ -715,7 +746,7 @@ where let ct = self.ciphertext.on_cpu(); - Self::new(sk.reverse_bits_parallelized(&*ct)) + Self::new(sk.reverse_bits_parallelized(&*ct), cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -771,7 +802,7 @@ where } } - let mut ciphertext = Self::new(other); + let mut ciphertext = Self::new(other, Tag::default()); ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) } @@ -818,7 +849,7 @@ where input.ciphertext.into_cpu(), IntoId::num_blocks(cpu_key.message_modulus()), ); - Self::new(casted) + Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -827,7 +858,7 @@ where IntoId::num_blocks(cuda_key.message_modulus()), streams, ); - Self::new(casted) + Self::new(casted, cuda_key.tag.clone()) }), }) } @@ -862,7 +893,7 @@ where input.ciphertext.on_cpu().to_owned(), IntoId::num_blocks(cpu_key.message_modulus()), ); - Self::new(casted) + Self::new(casted, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -871,7 +902,7 @@ where IntoId::num_blocks(cuda_key.message_modulus()), streams, ); - Self::new(casted) + Self::new(casted, cuda_key.tag.clone()) }), }) } @@ -906,7 +937,7 @@ where .on_cpu() .into_owned() .into_radix(Id::num_blocks(cpu_key.message_modulus()), cpu_key.pbs_key()); - Self::new(ciphertext) + Self::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -915,7 +946,7 @@ where Id::num_blocks(cuda_key.message_modulus()), streams, ); - Self::new(inner) + Self::new(inner, cuda_key.tag.clone()) }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/compressed.rs b/tfhe/src/high_level_api/integers/unsigned/compressed.rs index 078c6be333..6a667f2095 100644 --- a/tfhe/src/high_level_api/integers/unsigned/compressed.rs +++ b/tfhe/src/high_level_api/integers/unsigned/compressed.rs @@ -9,7 +9,7 @@ use crate::high_level_api::global_state::with_cpu_internal_keys; use crate::high_level_api::integers::unsigned::base::{ FheUint, FheUintConformanceParams, FheUintId, }; -use crate::high_level_api::traits::FheTryEncrypt; +use crate::high_level_api::traits::{FheTryEncrypt, Tagged}; use crate::high_level_api::ClientKey; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::{ @@ -18,6 +18,7 @@ use crate::integer::ciphertext::{ }; use crate::integer::parameters::RadixCiphertextConformanceParams; use crate::named::Named; +use crate::Tag; /// Compressed [FheUint] /// @@ -49,26 +50,49 @@ where { pub(in crate::high_level_api::integers) ciphertext: CompressedRadixCiphertext, pub(in crate::high_level_api::integers) id: Id, + pub(crate) tag: Tag, +} + +impl Tagged for CompressedFheUint +where + Id: FheUintId, +{ + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } } impl CompressedFheUint where Id: FheUintId, { - pub(in crate::high_level_api::integers) fn new(inner: CompressedRadixCiphertext) -> Self { + pub(in crate::high_level_api) fn new(inner: CompressedRadixCiphertext, tag: Tag) -> Self { Self { ciphertext: inner, id: Id::default(), + tag, } } - pub fn into_raw_parts(self) -> (CompressedRadixCiphertext, Id) { - let Self { ciphertext, id } = self; - (ciphertext, id) + pub fn into_raw_parts(self) -> (CompressedRadixCiphertext, Id, Tag) { + let Self { + ciphertext, + id, + tag, + } = self; + (ciphertext, id, tag) } - pub fn from_raw_parts(ciphertext: CompressedRadixCiphertext, id: Id) -> Self { - Self { ciphertext, id } + pub fn from_raw_parts(ciphertext: CompressedRadixCiphertext, id: Id, tag: Tag) -> Self { + Self { + ciphertext, + id, + tag, + } } } @@ -80,12 +104,14 @@ where /// /// See [CompressedFheUint] example. pub fn decompress(&self) -> FheUint { - let mut ciphertext = FheUint::new(match &self.ciphertext { + let inner = match &self.ciphertext { CompressedRadixCiphertext::Seeded(ct) => ct.decompress(), CompressedRadixCiphertext::ModulusSwitched(ct) => { - with_cpu_internal_keys(|sk| sk.key.decompress_parallelized(ct)) + with_cpu_internal_keys(|sk| sk.pbs_key().decompress_parallelized(ct)) } - }); + }; + + let mut ciphertext = FheUint::new(inner, self.tag.clone()); ciphertext.move_to_device_of_server_key_if_set(); ciphertext @@ -104,7 +130,10 @@ where .key .key .encrypt_radix_compressed(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(CompressedRadixCiphertext::Seeded(inner))) + Ok(Self::new( + CompressedRadixCiphertext::Seeded(inner), + key.tag.clone(), + )) } } @@ -112,7 +141,11 @@ impl ParameterSetConformant for CompressedFheUint { type ParameterSet = FheUintConformanceParams; fn is_conformant(&self, params: &FheUintConformanceParams) -> bool { - let Self { ciphertext, id: _ } = self; + let Self { + ciphertext, + id: _, + tag: _, + } = self; ciphertext.is_conformant(¶ms.params) } @@ -144,12 +177,11 @@ where Id: FheUintId, { pub fn compress(&self) -> CompressedFheUint { - CompressedFheUint::new(CompressedRadixCiphertext::ModulusSwitched( - with_cpu_internal_keys(|sk| { - sk.key - .switch_modulus_and_compress_parallelized(&self.ciphertext.on_cpu()) - }), - )) + let ciphertext = CompressedRadixCiphertext::ModulusSwitched(with_cpu_internal_keys(|sk| { + sk.pbs_key() + .switch_modulus_and_compress_parallelized(&self.ciphertext.on_cpu()) + })); + CompressedFheUint::new(ciphertext, self.tag.clone()) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs index d702f7d2ee..b34bca761a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/encrypt.rs +++ b/tfhe/src/high_level_api/integers/unsigned/encrypt.rs @@ -54,7 +54,7 @@ where .key .key .encrypt_radix(value, Id::num_blocks(key.message_modulus())); - let mut ciphertext = Self::new(cpu_ciphertext); + let mut ciphertext = Self::new(cpu_ciphertext, key.tag.clone()); ciphertext.move_to_device_of_server_key_if_set(); @@ -73,7 +73,7 @@ where let cpu_ciphertext = key .key .encrypt_radix(value, Id::num_blocks(key.message_modulus())); - let mut ciphertext = Self::new(cpu_ciphertext); + let mut ciphertext = Self::new(cpu_ciphertext, key.tag.clone()); ciphertext.move_to_device_of_server_key_if_set(); @@ -92,7 +92,7 @@ where let cpu_ciphertext = key .key .encrypt_radix(value, Id::num_blocks(key.message_modulus())); - let mut ciphertext = Self::new(cpu_ciphertext); + let mut ciphertext = Self::new(cpu_ciphertext, key.tag.clone()); ciphertext.move_to_device_of_server_key_if_set(); Ok(ciphertext) @@ -112,7 +112,7 @@ where let ciphertext: crate::integer::RadixCiphertext = key .pbs_key() .create_trivial_radix(value, Id::num_blocks(key.message_modulus())); - Ok(Self::new(ciphertext)) + Ok(Self::new(ciphertext, key.tag.clone())) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -121,7 +121,7 @@ where Id::num_blocks(cuda_key.key.message_modulus), streams, ); - Ok(Self::new(inner)) + Ok(Self::new(inner, cuda_key.tag.clone())) }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/mod.rs b/tfhe/src/high_level_api/integers/unsigned/mod.rs index 59ec71f492..bde4180751 100644 --- a/tfhe/src/high_level_api/integers/unsigned/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/mod.rs @@ -10,7 +10,7 @@ expand_pub_use_fhe_type!( pub use compressed::CompressedFheUint; pub(in crate::high_level_api) use compressed::CompressedRadixCiphertext; -pub(in crate::high_level_api) use inner::RadixCiphertextVersionOwned; +pub(in crate::high_level_api) use inner::{RadixCiphertext, RadixCiphertextVersionOwned}; mod base; mod compressed; diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 6fd26b4308..cd0778c6f9 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -58,15 +58,18 @@ where InternalServerKey::Cpu(cpu_key) => { let ciphertexts = iter.map(|elem| elem.ciphertext.into_cpu()).collect(); cpu_key - .key + .pbs_key() .unchecked_sum_ciphertexts_vec_parallelized(ciphertexts) .map_or_else( || { - Self::new(RadixCiphertext::Cpu(cpu_key.key.create_trivial_zero_radix( - Id::num_blocks(cpu_key.message_modulus()), - ))) + Self::new( + RadixCiphertext::Cpu(cpu_key.pbs_key().create_trivial_zero_radix( + Id::num_blocks(cpu_key.message_modulus()), + )), + cpu_key.tag.clone(), + ) }, - Self::new, + |ct| Self::new(ct, cpu_key.tag.clone()), ) } #[cfg(feature = "gpu")] @@ -85,7 +88,7 @@ where streams, ) }); - Self::new(inner) + Self::new(inner, cuda_key.tag.clone()) }), }) } @@ -130,17 +133,20 @@ where .collect(); let msg_mod = cpu_key.pbs_key().message_modulus(); cpu_key - .key + .pbs_key() .unchecked_sum_ciphertexts_vec_parallelized(ciphertexts) .map_or_else( || { - Self::new(RadixCiphertext::Cpu( - cpu_key - .key - .create_trivial_zero_radix(Id::num_blocks(msg_mod)), - )) + Self::new( + RadixCiphertext::Cpu( + cpu_key + .pbs_key() + .create_trivial_zero_radix(Id::num_blocks(msg_mod)), + ), + cpu_key.tag.clone(), + ) }, - Self::new, + |ct| Self::new(ct, cpu_key.tag.clone()), ) } #[cfg(feature = "gpu")] @@ -173,7 +179,7 @@ where streams, ) }); - Self::new(inner) + Self::new(inner, cuda_key.tag.clone()) }) } }) @@ -211,7 +217,7 @@ where let inner_result = cpu_key .pbs_key() .max_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -220,7 +226,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -257,7 +263,7 @@ where let inner_result = cpu_key .pbs_key() .min_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -266,7 +272,7 @@ where &*rhs.ciphertext.on_gpu(), streams, ); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -314,7 +320,7 @@ where let inner_result = cpu_key .pbs_key() .eq_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -323,7 +329,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -353,7 +359,7 @@ where let inner_result = cpu_key .pbs_key() .ne_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -362,7 +368,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -418,7 +424,7 @@ where let inner_result = cpu_key .pbs_key() .lt_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -427,7 +433,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -457,7 +463,7 @@ where let inner_result = cpu_key .pbs_key() .le_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -466,7 +472,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -496,7 +502,7 @@ where let inner_result = cpu_key .pbs_key() .gt_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -505,7 +511,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -535,7 +541,7 @@ where let inner_result = cpu_key .pbs_key() .ge_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -544,7 +550,7 @@ where &rhs.ciphertext.on_gpu(), streams, ); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -615,7 +621,10 @@ where let (q, r) = cpu_key .pbs_key() .div_rem_parallelized(&*self.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - (FheUint::::new(q), FheUint::::new(r)) + ( + FheUint::::new(q, cpu_key.tag.clone()), + FheUint::::new(r, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -625,8 +634,8 @@ where streams, ); ( - FheUint::::new(inner_result.0), - FheUint::::new(inner_result.1), + FheUint::::new(inner_result.0, cuda_key.tag.clone()), + FheUint::::new(inner_result.1, cuda_key.tag.clone()), ) }), }) @@ -700,14 +709,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .add_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .add(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -743,14 +752,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .sub_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .sub(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -786,14 +795,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .mul_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .mul(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -827,14 +836,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitand_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitand(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -868,14 +877,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitor_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -909,14 +918,14 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .bitxor_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .bitxor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } }) @@ -958,7 +967,7 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .div_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -966,7 +975,7 @@ generic_integer_impl_operation!( cuda_key .key .div(&lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -1008,7 +1017,7 @@ generic_integer_impl_operation!( let inner_result = cpu_key .pbs_key() .rem_parallelized(&*lhs.ciphertext.on_cpu(), &*rhs.ciphertext.on_cpu()); - FheUint::new(inner_result) + FheUint::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -1016,7 +1025,7 @@ generic_integer_impl_operation!( cuda_key .key .rem(&lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -1121,14 +1130,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .left_shift_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .left_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1165,14 +1174,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .right_shift_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .right_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1209,14 +1218,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .rotate_left_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .rotate_left(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1253,14 +1262,14 @@ generic_integer_impl_shift_rotate!( let ciphertext = cpu_key .pbs_key() .rotate_right_parallelized(&*lhs.ciphertext.on_cpu(), &rhs.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key .rotate_right(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }) } } @@ -1916,12 +1925,12 @@ where let ciphertext = cpu_key .pbs_key() .neg_parallelized(&*self.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.neg(&*self.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -1982,12 +1991,12 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let ciphertext = cpu_key.pbs_key().bitnot(&*self.ciphertext.on_cpu()); - FheUint::new(ciphertext) + FheUint::new(ciphertext, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.bitnot(&*self.ciphertext.on_gpu(), streams); - FheUint::new(inner_result) + FheUint::new(inner_result, cuda_key.tag.clone()) }), }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs index 6531295c1b..709c91cafa 100644 --- a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs @@ -43,11 +43,14 @@ where fn overflowing_add(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.unsigned_overflowing_add_parallelized( + let (result, overflow) = cpu_key.pbs_key().unsigned_overflowing_add_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheUint::new(result), FheBool::new(overflow)) + ( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -57,8 +60,8 @@ where streams, ); ( - FheUint::::new(inner_result.0), - FheBool::new(inner_result.1), + FheUint::::new(inner_result.0, cuda_key.tag.clone()), + FheBool::new(inner_result.1, cuda_key.tag.clone()), ) }), }) @@ -138,9 +141,12 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let (result, overflow) = cpu_key - .key + .pbs_key() .unsigned_overflowing_scalar_add_parallelized(&self.ciphertext.on_cpu(), other); - (FheUint::new(result), FheBool::new(overflow)) + ( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -150,8 +156,8 @@ where streams, ); ( - FheUint::::new(inner_result.0), - FheBool::new(inner_result.1), + FheUint::::new(inner_result.0, cuda_key.tag.clone()), + FheBool::new(inner_result.1, cuda_key.tag.clone()), ) }), }) @@ -269,11 +275,14 @@ where fn overflowing_sub(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.unsigned_overflowing_sub_parallelized( + let (result, overflow) = cpu_key.pbs_key().unsigned_overflowing_sub_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheUint::new(result), FheBool::new(overflow)) + ( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -356,9 +365,12 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let (result, overflow) = cpu_key - .key + .pbs_key() .unsigned_overflowing_scalar_sub_parallelized(&self.ciphertext.on_cpu(), other); - (FheUint::new(result), FheBool::new(overflow)) + ( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -438,11 +450,14 @@ where fn overflowing_mul(self, other: Self) -> (Self::Output, FheBool) { global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let (result, overflow) = cpu_key.key.unsigned_overflowing_mul_parallelized( + let (result, overflow) = cpu_key.pbs_key().unsigned_overflowing_mul_parallelized( &self.ciphertext.on_cpu(), &other.ciphertext.on_cpu(), ); - (FheUint::new(result), FheBool::new(overflow)) + ( + FheUint::new(result, cpu_key.tag.clone()), + FheBool::new(overflow, cpu_key.tag.clone()), + ) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 9da79fdc03..451093555a 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -5,6 +5,7 @@ use super::base::FheUint; use super::inner::RadixCiphertext; use crate::error::InvalidRangeError; +use crate::high_level_api::errors::UnwrapResultExt; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; @@ -57,14 +58,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_eq_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_eq(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -94,14 +95,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_ne_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_ne(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -137,14 +138,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_lt_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_lt(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -174,14 +175,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_le_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_le(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -211,14 +212,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_gt_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_gt(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -248,14 +249,14 @@ where let inner_result = cpu_key .pbs_key() .scalar_ge_parallelized(&*self.ciphertext.on_cpu(), rhs); - FheBool::new(inner_result) + FheBool::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key .key .scalar_ge(&*self.ciphertext.on_gpu(), rhs, streams); - FheBool::new(inner_result) + FheBool::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -293,7 +294,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_max_parallelized(&*self.ciphertext.on_cpu(), rhs); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -301,7 +302,7 @@ where cuda_key .key .scalar_max(&*self.ciphertext.on_gpu(), rhs, streams); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -339,7 +340,7 @@ where let inner_result = cpu_key .pbs_key() .scalar_min_parallelized(&*self.ciphertext.on_cpu(), rhs); - Self::new(inner_result) + Self::new(inner_result, cpu_key.tag.clone()) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { @@ -347,7 +348,7 @@ where cuda_key .key .scalar_min(&*self.ciphertext.on_gpu(), rhs, streams); - Self::new(inner_result) + Self::new(inner_result, cuda_key.tag.clone()) }), }) } @@ -391,9 +392,9 @@ where global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { let result = cpu_key - .key + .pbs_key() .scalar_bitslice_parallelized(&self.ciphertext.on_cpu(), range)?; - Ok(FheUint::new(result)) + Ok(FheUint::new(result, cpu_key.tag.clone())) } #[cfg(feature = "gpu")] InternalServerKey::Cuda(_) => { @@ -469,23 +470,21 @@ macro_rules! generic_integer_impl_scalar_div_rem { type Output = ($concrete_type, $concrete_type); fn div_rem(self, rhs: $scalar_type) -> Self::Output { - let (q, r) = - global_state::with_internal_keys(|key| { - match key { - InternalServerKey::Cpu(cpu_key) => { - cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs) - } - #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support div_rem yet"); - } + global_state::with_internal_keys(|key| { + match key { + InternalServerKey::Cpu(cpu_key) => { + let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs); + ( + <$concrete_type>::new(q, cpu_key.tag.clone()), + <$concrete_type>::new(r, cpu_key.tag.clone()) + ) } - }); - - ( - <$concrete_type>::new(q), - <$concrete_type>::new(r) - ) + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Cuda devices do not support div_rem yet"); + } + } + }) } } )* // Closing second repeating pattern @@ -545,7 +544,8 @@ macro_rules! generic_integer_impl_scalar_operation { fn $rust_trait_method(self, rhs: $scalar_type) -> Self::Output { let inner_result = $closure(self, rhs); - <$concrete_type>::new(inner_result) + let tag = global_state::tag_of_internal_server_key().unwrap_display(); + <$concrete_type>::new(inner_result, tag) } } )* // Closing second repeating pattern @@ -1095,7 +1095,8 @@ macro_rules! generic_integer_impl_scalar_left_operation { $(#[$doc])* fn $rust_trait_method(self, rhs: &$concrete_type) -> Self::Output { let inner_result = $closure(*self, rhs); - <$concrete_type>::new(inner_result) + let tag = global_state::tag_of_internal_server_key().unwrap_display(); + <$concrete_type>::new(inner_result, tag) } } )* // Closing second repeating pattern diff --git a/tfhe/src/high_level_api/keys/client.rs b/tfhe/src/high_level_api/keys/client.rs index ce139d4320..8ed6d68695 100644 --- a/tfhe/src/high_level_api/keys/client.rs +++ b/tfhe/src/high_level_api/keys/client.rs @@ -6,8 +6,10 @@ use super::{CompressedServerKey, ServerKey}; use crate::high_level_api::backward_compatibility::keys::ClientKeyVersions; use crate::high_level_api::config::Config; use crate::high_level_api::keys::{CompactPrivateKey, IntegerClientKey}; +use crate::prelude::Tagged; use crate::shortint::list_compression::CompressionPrivateKeys; use crate::shortint::MessageModulus; +use crate::Tag; use concrete_csprng::seeders::Seed; use tfhe_versionable::Versionize; @@ -21,6 +23,7 @@ use tfhe_versionable::Versionize; #[versionize(ClientKeyVersions)] pub struct ClientKey { pub(crate) key: IntegerClientKey, + pub(crate) tag: Tag, } impl ClientKey { @@ -29,6 +32,7 @@ impl ClientKey { let config: Config = config.into(); Self { key: IntegerClientKey::from(config.inner), + tag: Tag::default(), } } @@ -61,6 +65,7 @@ impl ClientKey { let config: Config = config.into(); Self { key: IntegerClientKey::with_seed(config.inner, seed), + tag: Tag::default(), } } @@ -70,8 +75,10 @@ impl ClientKey { crate::integer::ClientKey, Option, Option, + Tag, ) { - self.key.into_raw_parts() + let (cks, cpk, cppk) = self.key.into_raw_parts(); + (cks, cpk, cppk, self.tag) } pub fn from_raw_parts( @@ -81,6 +88,7 @@ impl ClientKey { crate::shortint::parameters::key_switching::ShortintKeySwitchingParameters, )>, compression_key: Option, + tag: Tag, ) -> Self { Self { key: IntegerClientKey::from_raw_parts( @@ -88,6 +96,7 @@ impl ClientKey { dedicated_compact_private_key, compression_key, ), + tag, } } @@ -109,6 +118,16 @@ impl ClientKey { } } +impl Tagged for ClientKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + impl AsRef for ClientKey { fn as_ref(&self) -> &crate::integer::ClientKey { &self.key.key diff --git a/tfhe/src/high_level_api/keys/key_switching_key.rs b/tfhe/src/high_level_api/keys/key_switching_key.rs index abb5623183..c8f927a640 100644 --- a/tfhe/src/high_level_api/keys/key_switching_key.rs +++ b/tfhe/src/high_level_api/keys/key_switching_key.rs @@ -5,7 +5,7 @@ use crate::high_level_api::integers::{FheIntId, FheUintId}; use crate::integer::BooleanBlock; use crate::prelude::FheKeyswitch; pub use crate::shortint::parameters::key_switching::ShortintKeySwitchingParameters; -use crate::{ClientKey, FheBool, FheInt, FheUint, ServerKey}; +use crate::{ClientKey, FheBool, FheInt, FheUint, ServerKey, Tag}; use std::fmt::{Display, Formatter}; #[derive(Copy, Clone, Debug)] @@ -23,6 +23,8 @@ impl std::error::Error for IncompatibleParameters {} #[versionize(KeySwitchingKeyVersions)] pub struct KeySwitchingKey { key: crate::integer::key_switching_key::KeySwitchingKey, + tag_in: Tag, + tag_out: Tag, } impl KeySwitchingKey { @@ -58,6 +60,8 @@ impl KeySwitchingKey { (&key_pair_to.0.key.key, &key_pair_to.1.key.key), params, ), + tag_in: key_pair_from.0.tag.clone(), + tag_out: key_pair_to.0.tag.clone(), } } } @@ -69,7 +73,7 @@ where fn keyswitch(&self, input: &FheUint) -> FheUint { let radix = input.ciphertext.on_cpu(); let casted = self.key.cast(&*radix); - FheUint::new(casted) + FheUint::new(casted, self.tag_out.clone()) } } @@ -80,7 +84,7 @@ where fn keyswitch(&self, input: &FheInt) -> FheInt { let radix = input.ciphertext.on_cpu(); let casted = self.key.cast(&*radix); - FheInt::new(casted) + FheInt::new(casted, self.tag_out.clone()) } } @@ -88,6 +92,6 @@ impl FheKeyswitch for KeySwitchingKey { fn keyswitch(&self, input: &FheBool) -> FheBool { let boolean_block = input.ciphertext.on_cpu(); let casted = self.key.key.cast(boolean_block.as_ref()); - FheBool::new(BooleanBlock::new_unchecked(casted)) + FheBool::new(BooleanBlock::new_unchecked(casted), self.tag_out.clone()) } } diff --git a/tfhe/src/high_level_api/keys/public.rs b/tfhe/src/high_level_api/keys/public.rs index 97e04d704c..722e7a4357 100644 --- a/tfhe/src/high_level_api/keys/public.rs +++ b/tfhe/src/high_level_api/keys/public.rs @@ -21,8 +21,9 @@ use crate::backward_compatibility::keys::{ PublicKeyVersions, }; use crate::high_level_api::keys::{IntegerCompactPublicKey, IntegerCompressedCompactPublicKey}; +use crate::prelude::Tagged; use crate::shortint::MessageModulus; -use crate::Error; +use crate::{Error, Tag}; /// Classical public key. /// @@ -31,6 +32,7 @@ use crate::Error; #[versionize(PublicKeyVersions)] pub struct PublicKey { pub(in crate::high_level_api) key: crate::integer::PublicKey, + pub(crate) tag: Tag, } impl PublicKey { @@ -39,15 +41,16 @@ impl PublicKey { let base_integer_key = crate::integer::PublicKey::new(&client_key.key.key); Self { key: base_integer_key, + tag: client_key.tag.clone(), } } - pub fn into_raw_parts(self) -> crate::integer::PublicKey { - self.key + pub fn into_raw_parts(self) -> (crate::integer::PublicKey, Tag) { + (self.key, self.tag) } - pub fn from_raw_parts(key: crate::integer::PublicKey) -> Self { - Self { key } + pub fn from_raw_parts(key: crate::integer::PublicKey, tag: Tag) -> Self { + Self { key, tag } } pub(crate) fn message_modulus(&self) -> MessageModulus { @@ -55,11 +58,22 @@ impl PublicKey { } } +impl Tagged for PublicKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + /// Compressed classical public key. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedPublicKeyVersions)] pub struct CompressedPublicKey { pub(in crate::high_level_api) key: crate::integer::CompressedPublicKey, + pub(crate) tag: Tag, } impl CompressedPublicKey { @@ -67,20 +81,22 @@ impl CompressedPublicKey { let base_integer_key = crate::integer::CompressedPublicKey::new(&client_key.key.key); Self { key: base_integer_key, + tag: client_key.tag.clone(), } } - pub fn into_raw_parts(self) -> crate::integer::CompressedPublicKey { - self.key + pub fn into_raw_parts(self) -> (crate::integer::CompressedPublicKey, Tag) { + (self.key, self.tag) } - pub fn from_raw_parts(key: crate::integer::CompressedPublicKey) -> Self { - Self { key } + pub fn from_raw_parts(key: crate::integer::CompressedPublicKey, tag: Tag) -> Self { + Self { key, tag } } pub fn decompress(&self) -> PublicKey { PublicKey { key: self.key.decompress(), + tag: self.tag.clone(), } } @@ -89,6 +105,16 @@ impl CompressedPublicKey { } } +impl Tagged for CompressedPublicKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + /// A more compact public key /// /// Compared to the [PublicKey], this one is much smaller @@ -97,6 +123,7 @@ impl CompressedPublicKey { #[versionize(CompactPublicKeyVersions)] pub struct CompactPublicKey { pub(in crate::high_level_api) key: IntegerCompactPublicKey, + pub(crate) tag: Tag, } impl CompactPublicKey { @@ -108,24 +135,39 @@ impl CompactPublicKey { pub fn new(client_key: &ClientKey) -> Self { Self { key: IntegerCompactPublicKey::new(&client_key.key), + tag: client_key.tag.clone(), } } pub fn try_new(client_key: &ClientKey) -> Result { - IntegerCompactPublicKey::try_new(&client_key.key).map(|key| Self { key }) + IntegerCompactPublicKey::try_new(&client_key.key).map(|key| Self { + key, + tag: client_key.tag.clone(), + }) } - pub fn into_raw_parts(self) -> crate::integer::public_key::CompactPublicKey { - self.key.into_raw_parts() + pub fn into_raw_parts(self) -> (crate::integer::public_key::CompactPublicKey, Tag) { + (self.key.into_raw_parts(), self.tag) } - pub fn from_raw_parts(key: crate::integer::public_key::CompactPublicKey) -> Self { + pub fn from_raw_parts(key: crate::integer::public_key::CompactPublicKey, tag: Tag) -> Self { Self { key: IntegerCompactPublicKey::from_raw_parts(key), + tag, } } } +impl Tagged for CompactPublicKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + /// Compressed variant of [CompactPublicKey] /// /// The compression of [CompactPublicKey] allows to save disk space @@ -134,6 +176,7 @@ impl CompactPublicKey { #[versionize(CompressedCompactPublicKeyVersions)] pub struct CompressedCompactPublicKey { pub(in crate::high_level_api) key: IntegerCompressedCompactPublicKey, + pub(crate) tag: Tag, } impl CompressedCompactPublicKey { @@ -145,18 +188,20 @@ impl CompressedCompactPublicKey { pub fn new(client_key: &ClientKey) -> Self { Self { key: IntegerCompressedCompactPublicKey::new(&client_key.key), + tag: client_key.tag.clone(), } } /// Deconstruct a [`CompressedCompactPublicKey`] into its constituents. - pub fn into_raw_parts(self) -> crate::integer::CompressedCompactPublicKey { - self.key.into_raw_parts() + pub fn into_raw_parts(self) -> (crate::integer::CompressedCompactPublicKey, Tag) { + (self.key.into_raw_parts(), self.tag) } /// Construct a [`CompressedCompactPublicKey`] from its constituents. - pub fn from_raw_parts(key: crate::integer::CompressedCompactPublicKey) -> Self { + pub fn from_raw_parts(key: crate::integer::CompressedCompactPublicKey, tag: Tag) -> Self { Self { key: IntegerCompressedCompactPublicKey::from_raw_parts(key), + tag, } } @@ -164,6 +209,17 @@ impl CompressedCompactPublicKey { pub fn decompress(&self) -> CompactPublicKey { CompactPublicKey { key: self.key.decompress(), + tag: self.tag.clone(), } } } + +impl Tagged for CompressedCompactPublicKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 3875791af3..e911ca9677 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -1,16 +1,18 @@ use tfhe_versionable::Versionize; +use super::ClientKey; use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions}; #[cfg(feature = "gpu")] use crate::core_crypto::gpu::{synchronize_devices, CudaStreams}; use crate::high_level_api::keys::{IntegerCompressedServerKey, IntegerServerKey}; +use crate::prelude::Tagged; use crate::shortint::list_compression::{ CompressedCompressionKey, CompressedDecompressionKey, CompressionKey, DecompressionKey, }; +use crate::shortint::MessageModulus; +use crate::Tag; use std::sync::Arc; -use super::ClientKey; - /// Key of the server /// /// This key contains the different keys needed to be able to do computations for @@ -25,12 +27,14 @@ use super::ClientKey; #[versionize(ServerKeyVersions)] pub struct ServerKey { pub(crate) key: Arc, + pub(crate) tag: Tag, } impl ServerKey { pub fn new(keys: &ClientKey) -> Self { Self { key: Arc::new(IntegerServerKey::new(&keys.key)), + tag: keys.tag.clone(), } } @@ -41,6 +45,7 @@ impl ServerKey { Option, Option, Option, + Tag, ) { let IntegerServerKey { key, @@ -54,6 +59,7 @@ impl ServerKey { cpk_key_switching_key_material, compression_key, decompression_key, + self.tag, ) } @@ -64,6 +70,7 @@ impl ServerKey { >, compression_key: Option, decompression_key: Option, + tag: Tag, ) -> Self { Self { key: Arc::new(IntegerServerKey { @@ -72,8 +79,33 @@ impl ServerKey { compression_key, decompression_key, }), + tag, } } + + pub(in crate::high_level_api) fn pbs_key(&self) -> &crate::integer::ServerKey { + self.key.pbs_key() + } + + pub(in crate::high_level_api) fn cpk_casting_key( + &self, + ) -> Option { + self.key.cpk_casting_key() + } + + pub(in crate::high_level_api) fn message_modulus(&self) -> MessageModulus { + self.key.message_modulus() + } +} + +impl Tagged for ServerKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } } impl AsRef for ServerKey { @@ -100,6 +132,7 @@ impl AsRef for ServerKey { #[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] struct SerializableServerKey<'a> { pub(crate) integer_key: &'a IntegerServerKey, + pub(crate) tag: &'a Tag, } impl serde::Serialize for ServerKey { @@ -109,6 +142,7 @@ impl serde::Serialize for ServerKey { { SerializableServerKey { integer_key: &self.key, + tag: &self.tag, } .serialize(serializer) } @@ -117,6 +151,7 @@ impl serde::Serialize for ServerKey { #[derive(serde::Deserialize)] struct DeserializableServerKey { pub(crate) integer_key: IntegerServerKey, + pub(crate) tag: Tag, } impl<'de> serde::Deserialize<'de> for ServerKey { @@ -126,6 +161,7 @@ impl<'de> serde::Deserialize<'de> for ServerKey { { DeserializableServerKey::deserialize(deserializer).map(|deserialized| Self { key: Arc::new(deserialized.integer_key), + tag: deserialized.tag, }) } } @@ -142,12 +178,14 @@ impl<'de> serde::Deserialize<'de> for ServerKey { #[versionize(CompressedServerKeyVersions)] pub struct CompressedServerKey { pub(crate) integer_key: IntegerCompressedServerKey, + pub(crate) tag: Tag, } impl CompressedServerKey { pub fn new(keys: &ClientKey) -> Self { Self { integer_key: IntegerCompressedServerKey::new(&keys.key), + tag: keys.tag.clone(), } } @@ -158,8 +196,10 @@ impl CompressedServerKey { Option, Option, Option, + Tag, ) { - self.integer_key.into_raw_parts() + let (a, b, c, d) = self.integer_key.into_raw_parts(); + (a, b, c, d, self.tag) } pub fn from_raw_parts( @@ -169,6 +209,7 @@ impl CompressedServerKey { >, compression_key: Option, decompression_key: Option, + tag: Tag, ) -> Self { Self { integer_key: IntegerCompressedServerKey::from_raw_parts( @@ -177,12 +218,14 @@ impl CompressedServerKey { compression_key, decompression_key, ), + tag, } } pub fn decompress(&self) -> ServerKey { ServerKey { key: Arc::new(self.integer_key.decompress()), + tag: self.tag.clone(), } } @@ -197,14 +240,26 @@ impl CompressedServerKey { synchronize_devices(streams.len() as u32); CudaServerKey { key: Arc::new(cuda_key), + tag: self.tag.clone(), } } } +impl Tagged for CompressedServerKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + #[cfg(feature = "gpu")] #[derive(Clone)] pub struct CudaServerKey { pub(crate) key: Arc, + pub(crate) tag: Tag, } #[cfg(feature = "gpu")] @@ -214,15 +269,26 @@ impl CudaServerKey { } } +#[cfg(feature = "gpu")] +impl Tagged for CudaServerKey { + fn tag(&self) -> &Tag { + &self.tag + } + + fn tag_mut(&mut self) -> &mut Tag { + &mut self.tag + } +} + pub enum InternalServerKey { - Cpu(Arc), + Cpu(ServerKey), #[cfg(feature = "gpu")] Cuda(CudaServerKey), } impl From for InternalServerKey { fn from(value: ServerKey) -> Self { - Self::Cpu(value.key) + Self::Cpu(value) } } #[cfg(feature = "gpu")] diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index 4588e85dde..3b61ba1fb2 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -60,6 +60,8 @@ pub use compact_list::{ pub use compressed_ciphertext_list::{CompressedCiphertextList, CompressedCiphertextListBuilder}; pub use safe_serialize::{safe_serialize, safe_serialize_versioned}; +pub use tag::Tag; + mod booleans; mod compressed_ciphertext_list; mod config; @@ -73,6 +75,7 @@ mod utils; pub mod array; pub mod backward_compatibility; mod compact_list; +mod tag; pub(in crate::high_level_api) mod details; /// The tfhe prelude. diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 3b4365d774..06c1fd84de 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -9,7 +9,7 @@ pub use crate::high_level_api::traits::{ BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight, - RotateRightAssign, + RotateRightAssign, Tagged, }; pub use crate::conformance::ParameterSetConformant; diff --git a/tfhe/src/high_level_api/tag.rs b/tfhe/src/high_level_api/tag.rs new file mode 100644 index 0000000000..b41c6c73c1 --- /dev/null +++ b/tfhe/src/high_level_api/tag.rs @@ -0,0 +1,449 @@ +use crate::high_level_api::backward_compatibility::tag::TagVersions; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; + +const STACK_ARRAY_SIZE: usize = std::mem::size_of::>() - 1; + +/// Simple short optimized vec, where if the data is small enough +/// (<= std::mem::size_of::>() - 1) the data will be stored on the stack +/// +/// Once a true heap allocated Vec was needed, it won't be deallocated in favor +/// of stack data. +#[derive(Clone, Debug)] +pub(in crate::high_level_api) enum SmallVec { + Stack { + bytes: [u8; STACK_ARRAY_SIZE], + // The array has a fixed size, but the user may not use all of it + // so we keep track of the actual len + len: u8, + }, + Heap(Vec), +} + +impl Default for SmallVec { + fn default() -> Self { + Self::Stack { + bytes: Default::default(), + len: 0, + } + } +} +impl PartialEq for SmallVec { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::Stack { + bytes: l_bytes, + len: l_len, + }, + Self::Stack { + bytes: r_bytes, + len: r_len, + }, + ) => l_len == r_len && l_bytes[..usize::from(*l_len)] == r_bytes[..usize::from(*l_len)], + (Self::Heap(l_vec), Self::Heap(r_vec)) => l_vec == r_vec, + ( + Self::Heap(l_vec), + Self::Stack { + bytes: r_bytes, + len: r_len, + }, + ) => l_vec.len() == usize::from(*r_len) && l_vec == &r_bytes[..usize::from(*r_len)], + ( + Self::Stack { + bytes: l_bytes, + len: l_len, + }, + Self::Heap(r_vec), + ) => usize::from(*l_len) == r_vec.len() && &l_bytes[..usize::from(*l_len)] == r_vec, + } + } +} + +impl Eq for SmallVec {} + +impl SmallVec { + /// Returns a slice to the bytes stored + pub fn data(&self) -> &[u8] { + match self { + Self::Stack { bytes, len } => &bytes[..usize::from(*len)], + Self::Heap(vec) => vec.as_slice(), + } + } + + /// Returns a slice to the bytes stored (same a [Self::data]) + pub fn as_slice(&self) -> &[u8] { + self.data() + } + + /// Returns a mutable slice to the bytes stored + pub fn as_mut_slice(&mut self) -> &mut [u8] { + match self { + Self::Stack { bytes, len } => &mut bytes[..usize::from(*len)], + Self::Heap(vec) => vec.as_mut_slice(), + } + } + + /// Returns the len, i.e. the number of bytes stored + pub fn len(&self) -> usize { + match self { + Self::Stack { len, .. } => usize::from(*len), + Self::Heap(vec) => vec.len(), + } + } + + /// Returns whether self is empty + pub fn is_empty(&self) -> bool { + match self { + Self::Stack { len, .. } => *len == 0, + Self::Heap(vec) => vec.is_empty(), + } + } + + /// Return the u64 value when interpreting the bytes as a `u64` + /// + /// * Bytes are interpreted in little endian + /// * Bytes above the 8th are ignored + pub fn as_u64(&self) -> u64 { + let mut le_bytes = [0u8; u64::BITS as usize / 8]; + let data = self.data(); + let smallest = le_bytes.len().min(data.len()); + le_bytes[..smallest].copy_from_slice(&data[..smallest]); + + u64::from_le_bytes(le_bytes) + } + + /// Return the u128 value when interpreting the bytes as a `u128` + /// + /// * Bytes are interpreted in little endian + /// * Bytes above the 16th are ignored + pub fn as_u128(&self) -> u128 { + let mut le_bytes = [0u8; u128::BITS as usize / 8]; + let data = self.data(); + let smallest = le_bytes.len().min(data.len()); + le_bytes[..smallest].copy_from_slice(&data[..smallest]); + + u128::from_le_bytes(le_bytes) + } + + /// Sets the data stored in the tag + /// + /// This overwrites existing data stored + pub fn set_data(&mut self, data: &[u8]) { + match self { + Self::Stack { bytes, len } => { + if data.len() > bytes.len() { + // There is not enough space, so we have to allocate + // a Vec + *self = Self::Heap(data.to_vec()); + } else { + bytes[..data.len()].copy_from_slice(data); + *len = data.len() as u8; + } + } + Self::Heap(vec) => { + // Even if the data could fit in the Stack array, + // Since, we already have a vec allocated we use it instead. + // + // And in that case, there won't be any allocations since, + // to have a vec in the first place, the allocated size is > + // size_of::> + // + // But of course, if the new data is larger than the vec, a new + // allocation will be made + vec.clear(); + vec.extend_from_slice(data); + } + } + } + + /// Sets the tag with the given u64 value + /// + /// * Bytes are stored in little endian + /// * This overwrites existing data stored + pub fn set_u64(&mut self, value: u64) { + let le_bytes = value.to_le_bytes(); + self.set_data(le_bytes.as_slice()); + } + + /// Sets the tag with the given u128 value + /// + /// * Bytes are stored in little endian + /// * This overwrites existing data stored + pub fn set_u128(&mut self, value: u128) { + let le_bytes = value.to_le_bytes(); + self.set_data(le_bytes.as_slice()); + } + + // Creates a SmallVec from the vec, but, only re-uses the vec + // if its len would not fit on the stack part. + // + // Meant for versioning and deserializing + fn from_vec_conservative(vec: Vec) -> Self { + // We only re-use the versioned vec, if the SmallVec would actually + // have had its data on the heap, otherwise we prefer to keep data on stack + // as its cheaper in memory and copies + if vec.len() > STACK_ARRAY_SIZE { + Self::Heap(vec) + } else { + let mut data = Self::default(); + data.set_data(vec.as_slice()); + data + } + } +} + +impl serde::Serialize for SmallVec { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(self.data()) + } +} + +struct SmallVecVisitor; + +impl<'de> serde::de::Visitor<'de> for SmallVecVisitor { + type Value = SmallVec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a slice of bytes (&[u8]) or Vec") + } + + fn visit_bytes(self, bytes: &[u8]) -> Result + where + E: serde::de::Error, + { + let mut vec = SmallVec::default(); + vec.set_data(bytes); + Ok(vec) + } + + fn visit_byte_buf(self, bytes: Vec) -> Result + where + E: serde::de::Error, + { + Ok(SmallVec::from_vec_conservative(bytes)) + } +} + +impl Versionize for SmallVec { + type Versioned<'vers> = &'vers [u8] + where + Self: 'vers; + + fn versionize(&self) -> Self::Versioned<'_> { + self.data() + } +} + +impl VersionizeOwned for SmallVec { + type VersionedOwned = Vec; + + fn versionize_owned(self) -> Self::VersionedOwned { + match self { + Self::Stack { bytes, len } => bytes[..usize::from(len)].to_vec(), + Self::Heap(vec) => vec, + } + } +} + +impl Unversionize for SmallVec { + fn unversionize(versioned: Self::VersionedOwned) -> Result { + Ok(Self::from_vec_conservative(versioned)) + } +} + +impl<'de> serde::Deserialize<'de> for SmallVec { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_bytes(SmallVecVisitor) + } +} + +/// Tag +/// +/// The `Tag` allows to store bytes alongside entities (keys, and ciphertexts) +/// the main purpose of this system is to `tag` / identify ciphertext with their keys. +/// +/// TFHE-RS does not interpret or check this data, it only stores it and passes it around +/// like so: +/// +/// * When encrypted, a ciphertext gets the tag of the key used to encrypt it. +/// * Ciphertexts resulting from operations (add, sub, etc.) get the tag from the ServerKey used +/// * PublicKey gets its tag from the ClientKey that was used to create it +/// * ServerKey gets its tag from the ClientKey that was used to create it +/// +/// User can change the tag of any entities at any point. +/// +/// # Example +/// +/// ``` +/// use rand::random; +/// use tfhe::prelude::*; +/// use tfhe::{ClientKey, ConfigBuilder, FheUint32, ServerKey}; +/// +/// // Generate the client key then set its tag +/// let mut cks = ClientKey::generate(ConfigBuilder::default()); +/// let tag_value = random(); +/// cks.tag_mut().set_u64(tag_value); +/// assert_eq!(cks.tag().as_u64(), tag_value); +/// +/// // The server key inherits the client key tag +/// let sks = ServerKey::new(&cks); +/// assert_eq!(sks.tag(), cks.tag()); +/// +/// // Encrypted data inherits the tag of the encryption key +/// let a = FheUint32::encrypt(32832u32, &cks); +/// assert_eq!(a.tag(), cks.tag()); +/// ``` +#[derive( + Default, Clone, Debug, serde::Serialize, serde::Deserialize, Versionize, PartialEq, Eq, +)] +#[versionize(TagVersions)] +pub struct Tag { + // We don't want the enum to be public + inner: SmallVec, +} + +impl Tag { + /// Returns a slice to the bytes stored + pub fn data(&self) -> &[u8] { + self.inner.data() + } + + /// Returns a slice to the bytes stored (same a [Self::data]) + pub fn as_slice(&self) -> &[u8] { + self.inner.as_slice() + } + + /// Returns a mutable slice to the bytes stored + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.inner.as_mut_slice() + } + + /// Returns the len, i.e. the number of bytes stored in the tag + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns whether the tag is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Return the u64 value when interpreting the bytes as a `u64` + /// + /// * Bytes are interpreted in little endian + /// * Bytes above the 8th are ignored + pub fn as_u64(&self) -> u64 { + self.inner.as_u64() + } + + /// Return the u128 value when interpreting the bytes as a `u128` + /// + /// * Bytes are interpreted in little endian + /// * Bytes above the 16th are ignored + pub fn as_u128(&self) -> u128 { + self.inner.as_u128() + } + + /// Sets the data stored in the tag + /// + /// This overwrites existing data stored + pub fn set_data(&mut self, data: &[u8]) { + self.inner.set_data(data); + } + + /// Sets the tag with the given u64 value + /// + /// * Bytes are stored in little endian + /// * This overwrites existing data stored + pub fn set_u64(&mut self, value: u64) { + self.inner.set_u64(value); + } + + /// Sets the tag with the given u128 value + /// + /// * Bytes are stored in little endian + /// * This overwrites existing data stored + pub fn set_u128(&mut self, value: u128) { + self.inner.set_u128(value); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[test] + fn test_small_vec() { + let mut vec_1 = SmallVec::default(); + vec_1.set_data(&[1, 2, 3, 4, 5]); + + let mut vec_2 = SmallVec::default(); + vec_2.set_data(vec_1.data()); + + assert!(matches!(vec_1, SmallVec::Stack { .. })); + assert!(matches!(vec_2, SmallVec::Stack { .. })); + assert_eq!(vec_2.len(), vec_1.len()); + assert_eq!(vec_1.len(), 5); + assert_eq!(vec_1, vec_2); // Test both ways + assert_eq!(vec_2, vec_1); + + // Put something big in vec_1, we expect the data to be on the heap now + let big_data = (0..500u64).map(|x| (x % 256) as u8).collect::>(); + vec_1.set_data(&big_data); + assert!(matches!(vec_1, SmallVec::Heap(_))); + assert!(matches!(vec_2, SmallVec::Stack { .. })); + assert_ne!(vec_2.len(), vec_1.len()); + assert_eq!(vec_1.len(), big_data.len()); + assert_ne!(vec_1, vec_2); + assert_ne!(vec_2, vec_1); + + // Put something the same big data in vec_2, + // we also expect the data to be on the heap now + vec_2.set_data(&big_data); + assert!(matches!(vec_1, SmallVec::Heap(_))); + assert!(matches!(vec_2, SmallVec::Heap(_))); + assert_eq!(vec_2.len(), vec_1.len()); + assert_eq!(vec_1.len(), big_data.len()); + assert_eq!(vec_1, vec_2); // Test both ways + assert_eq!(vec_2, vec_1); + + // Now put back something small in vec 1 + // We expect the data to still be on the heap, since + // the heap was allocated to store the previous big data + vec_1.set_data(&[1, 2, 3, 4, 5]); + assert!(matches!(vec_1, SmallVec::Heap(_))); + assert_eq!(vec_1.len(), 5); + assert_eq!(vec_1.data(), &[1, 2, 3, 4, 5]); + assert_ne!(vec_1, vec_2); + assert_ne!(vec_2, vec_1); + } + + #[test] + fn test_small_vec_u64_u128() { + let mut rng = rand::thread_rng(); + + let mut vec = SmallVec::default(); + { + let value = rng.gen(); + vec.set_u64(value); + assert_eq!(vec.as_u64(), value); + + assert_eq!(vec.as_u128(), u128::from(value)); + } + + { + let value = rng.gen(); + vec.set_u128(value); + assert_eq!(vec.as_u128(), value); + + assert_eq!(vec.as_u64(), value as u64); + } + } +} diff --git a/tfhe/src/high_level_api/tests/mod.rs b/tfhe/src/high_level_api/tests/mod.rs index 33d17238d3..5624685a95 100644 --- a/tfhe/src/high_level_api/tests/mod.rs +++ b/tfhe/src/high_level_api/tests/mod.rs @@ -1,3 +1,5 @@ +mod tags_on_entities; + use crate::high_level_api::prelude::*; use crate::high_level_api::{ generate_keys, ClientKey, ConfigBuilder, FheBool, FheUint256, FheUint8, PublicKey, diff --git a/tfhe/src/high_level_api/tests/tags_on_entities.rs b/tfhe/src/high_level_api/tests/tags_on_entities.rs new file mode 100644 index 0000000000..0945164312 --- /dev/null +++ b/tfhe/src/high_level_api/tests/tags_on_entities.rs @@ -0,0 +1,324 @@ +use crate::prelude::*; +use crate::shortint::parameters::compact_public_key_only::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; +use crate::shortint::parameters::key_switching::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; +use crate::shortint::parameters::*; +use crate::shortint::ClassicPBSParameters; +use crate::{ + set_server_key, ClientKey, CompactCiphertextList, CompactCiphertextListExpander, + CompactPublicKey, CompressedCiphertextList, CompressedCiphertextListBuilder, CompressedFheBool, + CompressedFheInt32, CompressedFheUint32, CompressedServerKey, ConfigBuilder, Device, FheBool, + FheInt32, FheInt64, FheUint32, ServerKey, +}; +use rand::random; + +#[test] +fn test_tag_propagation_cpu() { + test_tag_propagation( + Device::Cpu, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + Some(( + PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + )), + Some(COMP_PARAM_MESSAGE_2_CARRY_2), + ) +} + +#[test] +#[cfg(feature = "zk-pok")] +fn test_tag_propagation_zk_pok() { + use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + use crate::ProvenCompactCiphertextList; + + let config = + ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64).build(); + let crs = crate::zk::CompactPkeCrs::from_config(config, (2 * 32) + (2 * 64) + 2).unwrap(); + + let mut cks = ClientKey::generate(config); + let tag_value = random(); + cks.tag_mut().set_u64(tag_value); + let cks = serialize_then_deserialize(cks); + assert_eq!(cks.tag().as_u64(), tag_value); + + let sks = ServerKey::new(&cks); + set_server_key(sks); + + let cpk = CompactPublicKey::new(&cks); + assert_eq!(cpk.tag(), cks.tag()); + + let mut builder = CompactCiphertextList::builder(&cpk); + + let list_packed = builder + .push(32u32) + .push(1u32) + .push(-1i64) + .push(i64::MIN) + .push(false) + .push(true) + .build_with_proof_packed(crs.public_params(), crate::zk::ZkComputeLoad::Proof) + .unwrap(); + + let list_packed: ProvenCompactCiphertextList = serialize_then_deserialize(list_packed); + assert_eq!(list_packed.tag(), cks.tag()); + + let expander = list_packed + .verify_and_expand(crs.public_params(), &cpk) + .unwrap(); + + { + let au32: FheUint32 = expander.get(0).unwrap().unwrap(); + let bu32: FheUint32 = expander.get(1).unwrap().unwrap(); + assert_eq!(au32.tag(), cks.tag()); + assert_eq!(bu32.tag(), cks.tag()); + + let cu32 = au32 + bu32; + assert_eq!(cu32.tag(), cks.tag()); + } + + { + let ai64: FheInt64 = expander.get(2).unwrap().unwrap(); + let bi64: FheInt64 = expander.get(3).unwrap().unwrap(); + assert_eq!(ai64.tag(), cks.tag()); + assert_eq!(bi64.tag(), cks.tag()); + + let ci64 = ai64 + bi64; + assert_eq!(ci64.tag(), cks.tag()); + } + + { + let abool: FheBool = expander.get(4).unwrap().unwrap(); + let bbool: FheBool = expander.get(5).unwrap().unwrap(); + assert_eq!(abool.tag(), cks.tag()); + assert_eq!(bbool.tag(), cks.tag()); + + let cbool = abool & bbool; + assert_eq!(cbool.tag(), cks.tag()); + } +} + +#[test] +#[cfg(feature = "gpu")] +fn test_tag_propagation_gpu() { + test_tag_propagation( + Device::CudaGpu, + PARAM_MESSAGE_2_CARRY_2, + None, + Some(COMP_PARAM_MESSAGE_2_CARRY_2), + ) +} + +fn serialize_then_deserialize(value: T) -> T +where + T: serde::Serialize + for<'a> serde::de::Deserialize<'a>, +{ + let serialized = bincode::serialize(&value).unwrap(); + bincode::deserialize(&serialized).unwrap() +} + +fn test_tag_propagation( + device: Device, + pbs_parameters: ClassicPBSParameters, + dedicated_compact_public_key_parameters: Option<( + CompactPublicKeyEncryptionParameters, + ShortintKeySwitchingParameters, + )>, + comp_parameters: Option, +) { + let mut builder = ConfigBuilder::with_custom_parameters(pbs_parameters); + if let Some(parameters) = dedicated_compact_public_key_parameters { + builder = builder.use_dedicated_compact_public_key_parameters(parameters); + } + if let Some(parameters) = comp_parameters { + builder = builder.enable_compression(parameters); + } + let config = builder.build(); + + let mut cks = ClientKey::generate(config); + let tag_value = random(); + cks.tag_mut().set_u64(tag_value); + let cks = serialize_then_deserialize(cks); + assert_eq!(cks.tag().as_u64(), tag_value); + + let compressed_sks = CompressedServerKey::new(&cks); + let compressed_sks = serialize_then_deserialize(compressed_sks); + assert_eq!(compressed_sks.tag(), cks.tag()); + + match device { + Device::Cpu => { + let sks = ServerKey::new(&cks); + let sks = serialize_then_deserialize(sks); + assert_eq!(sks.tag(), cks.tag()); + + // Now test when the sks comes from a compressed one + let sks = compressed_sks.decompress(); + let sks = serialize_then_deserialize(sks); + assert_eq!(sks.tag(), cks.tag()); + + set_server_key(sks); + } + #[cfg(feature = "gpu")] + Device::CudaGpu => { + let sks = compressed_sks.decompress_to_gpu(); + assert_eq!(sks.tag(), cks.tag()); + + set_server_key(sks); + } + } + + // Check encrypting regular ct with client key + { + let mut compression_builder = CompressedCiphertextListBuilder::new(); + + // Check FheUint have a tag + { + let ct_a = FheUint32::encrypt(8182u32, &cks); + let ct_a = serialize_then_deserialize(ct_a); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_b = FheUint32::encrypt(8182u32, &cks); + assert_eq!(ct_b.tag(), cks.tag()); + + let ct_c = ct_a + ct_b; + assert_eq!(ct_c.tag(), cks.tag()); + + compression_builder.push(ct_c); + } + + // Check FheInt have a tag + { + let ct_a = FheInt32::encrypt(-1i32, &cks); + let ct_a = serialize_then_deserialize(ct_a); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_b = FheInt32::encrypt(i32::MIN, &cks); + assert_eq!(ct_b.tag(), cks.tag()); + + let ct_c = ct_a + ct_b; + assert_eq!(ct_c.tag(), cks.tag()); + + compression_builder.push(ct_c); + } + + // Check FheBool have a tag + { + let ct_a = FheBool::encrypt(false, &cks); + let ct_a = serialize_then_deserialize(ct_a); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_b = FheBool::encrypt(true, &cks); + assert_eq!(ct_b.tag(), cks.tag()); + + let ct_c = ct_a | ct_b; + assert_eq!(ct_c.tag(), cks.tag()); + + compression_builder.push(ct_c); + } + + if device == Device::Cpu { + // Cuda do no yet support compressing + let compressed_list = compression_builder.build().unwrap(); + assert_eq!(compressed_list.tag(), cks.tag()); + + let serialized = bincode::serialize(&compressed_list).unwrap(); + let compressed_list: CompressedCiphertextList = + bincode::deserialize(&serialized).unwrap(); + assert_eq!(compressed_list.tag(), cks.tag()); + + let a: FheUint32 = compressed_list.get(0).unwrap().unwrap(); + assert_eq!(a.tag(), cks.tag()); + let b: FheInt32 = compressed_list.get(1).unwrap().unwrap(); + assert_eq!(b.tag(), cks.tag()); + let c: FheBool = compressed_list.get(2).unwrap().unwrap(); + assert_eq!(c.tag(), cks.tag()); + } + } + + // Check compressed encryption + { + { + let ct_a = CompressedFheUint32::encrypt(8182u32, &cks); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_a = ct_a.decompress(); + assert_eq!(ct_a.tag(), cks.tag()); + } + + { + let ct_a = CompressedFheInt32::encrypt(-1i32, &cks); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_a = ct_a.decompress(); + assert_eq!(ct_a.tag(), cks.tag()); + } + + { + let ct_a = CompressedFheBool::encrypt(false, &cks); + assert_eq!(ct_a.tag(), cks.tag()); + + let ct_a = ct_a.decompress(); + assert_eq!(ct_a.tag(), cks.tag()); + } + } + + // Test compact public key stuff + if device == Device::Cpu { + let cpk = CompactPublicKey::new(&cks); + let cpk = serialize_then_deserialize(cpk); + assert_eq!(cpk.tag(), cks.tag()); + + let mut builder = CompactCiphertextList::builder(&cpk); + builder + .push(32u32) + .push(1u32) + .push(-1i64) + .push(i64::MIN) + .push(false) + .push(true); + + let expand_and_check_tags = |expander: CompactCiphertextListExpander, cks: &ClientKey| { + { + let au32: FheUint32 = expander.get(0).unwrap().unwrap(); + let bu32: FheUint32 = expander.get(1).unwrap().unwrap(); + assert_eq!(au32.tag(), cks.tag()); + assert_eq!(bu32.tag(), cks.tag()); + + let cu32 = au32 + bu32; + assert_eq!(cu32.tag(), cks.tag()); + } + + { + let ai64: FheInt64 = expander.get(2).unwrap().unwrap(); + let bi64: FheInt64 = expander.get(3).unwrap().unwrap(); + assert_eq!(ai64.tag(), cks.tag()); + assert_eq!(bi64.tag(), cks.tag()); + + let ci64 = ai64 + bi64; + assert_eq!(ci64.tag(), cks.tag()); + } + + { + let abool: FheBool = expander.get(4).unwrap().unwrap(); + let bbool: FheBool = expander.get(5).unwrap().unwrap(); + assert_eq!(abool.tag(), cks.tag()); + assert_eq!(bbool.tag(), cks.tag()); + + let cbool = abool & bbool; + assert_eq!(cbool.tag(), cks.tag()); + } + }; + + { + let list = builder.build(); + let list: CompactCiphertextList = serialize_then_deserialize(list); + assert_eq!(list.tag(), cks.tag()); + expand_and_check_tags(list.expand().unwrap(), &cks); + } + + { + let list_packed = builder.build_packed(); + let list_packed: CompactCiphertextList = serialize_then_deserialize(list_packed); + assert_eq!(list_packed.tag(), cks.tag()); + expand_and_check_tags(list_packed.expand().unwrap(), &cks); + } + } +} diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 4240dda9de..850ec3c952 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -2,7 +2,7 @@ use std::ops::RangeBounds; use crate::error::InvalidRangeError; use crate::high_level_api::ClientKey; -use crate::FheBool; +use crate::{FheBool, Tag}; /// Trait used to have a generic way of creating a value of a FHE type /// from a native value. @@ -193,3 +193,9 @@ pub trait BitSlice { where R: RangeBounds; } + +pub trait Tagged { + fn tag(&self) -> &Tag; + + fn tag_mut(&mut self) -> &mut Tag; +} diff --git a/tfhe/src/high_level_api/utils.rs b/tfhe/src/high_level_api/utils.rs index 2aeb5ffbd2..c4da0f932e 100644 --- a/tfhe/src/high_level_api/utils.rs +++ b/tfhe/src/high_level_api/utils.rs @@ -3,7 +3,7 @@ use crate::high_level_api::integers::unsigned::FheUintId; use crate::integer::ciphertext::{DataKind, Expandable}; use crate::integer::BooleanBlock; use crate::shortint::Ciphertext; -use crate::{FheBool, FheInt, FheUint}; +use crate::{FheBool, FheInt, FheUint, Tag}; fn num_bits_of_blocks(blocks: &[Ciphertext]) -> u32 { blocks @@ -18,7 +18,11 @@ impl Expandable for FheUint { DataKind::Unsigned(_) => { let stored_num_bits = num_bits_of_blocks(&blocks) as usize; if stored_num_bits == Id::num_bits() { - Ok(Self::new(crate::integer::RadixCiphertext::from(blocks))) + // The expander will be responsible for setting the correct tag + Ok(Self::new( + crate::integer::RadixCiphertext::from(blocks), + Tag::default(), + )) } else { Err(crate::Error::new(format!( "Tried to expand a FheUint{} while a FheUint{} is stored in this slot", @@ -57,9 +61,11 @@ impl Expandable for FheInt { DataKind::Signed(_) => { let stored_num_bits = num_bits_of_blocks(&blocks) as usize; if stored_num_bits == Id::num_bits() { - Ok(Self::new(crate::integer::SignedRadixCiphertext::from( - blocks, - ))) + // The expander will be responsible for setting the correct tag + Ok(Self::new( + crate::integer::SignedRadixCiphertext::from(blocks), + Tag::default(), + )) } else { Err(crate::Error::new(format!( "Tried to expand a FheInt{} while a FheInt{} is stored in this slot", @@ -91,7 +97,14 @@ impl Expandable for FheBool { "Tried to expand a FheBool while a FheInt{stored_num_bits} is stored in this slot", ))) } - DataKind::Boolean => Ok(Self::new(BooleanBlock::new_unchecked(blocks[0].clone()))), + DataKind::Boolean => { + let mut boolean_block = BooleanBlock::new_unchecked(blocks[0].clone()); + // We know the value is a boolean one (via the data kind) + boolean_block.0.degree = crate::shortint::ciphertext::Degree::new(1); + + // The expander will be responsible for setting the correct tag + Ok(Self::new(boolean_block, Tag::default())) + } } } } diff --git a/tfhe/tests/backward_compatibility/high_level_api.rs b/tfhe/tests/backward_compatibility/high_level_api.rs index 685cee95a1..86385d32ee 100644 --- a/tfhe/tests/backward_compatibility/high_level_api.rs +++ b/tfhe/tests/backward_compatibility/high_level_api.rs @@ -392,7 +392,7 @@ pub fn test_hl_clientkey( let test_params = load_hl_params(&test.parameters); let key: ClientKey = load_and_unversionize(dir, test, format)?; - let (integer_key, _, _) = key.into_raw_parts(); + let (integer_key, _, _, _) = key.into_raw_parts(); let key_params = integer_key.parameters(); if test_params != key_params {