Skip to content

Commit

Permalink
refactor(all): decompress takes shared reference
Browse files Browse the repository at this point in the history
remove from/into decompression
  • Loading branch information
mayeul-zama committed Mar 28, 2024
1 parent c20eccf commit cc1a8e4
Show file tree
Hide file tree
Showing 31 changed files with 80 additions and 174 deletions.
2 changes: 1 addition & 1 deletion tfhe/benches/shortint/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ fn server_key_from_compressed_key(c: &mut Criterion) {
b.iter_batched(
clone_compressed_key,
|sks_cloned| {
let _ = ServerKey::from(sks_cloned);
let _ = sks_cloned.decompress();
},
criterion::BatchSize::PerIteration,
)
Expand Down
8 changes: 3 additions & 5 deletions tfhe/src/boolean/ciphertext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ pub struct CompressedCiphertext {
pub(crate) ciphertext: SeededLweCiphertext<u32>,
}

impl From<CompressedCiphertext> for Ciphertext {
fn from(value: CompressedCiphertext) -> Self {
Self::Encrypted(value.ciphertext.decompress_into_lwe_ciphertext())
impl CompressedCiphertext {
pub fn decompress(&self) -> Ciphertext {
Ciphertext::Encrypted(self.ciphertext.decompress_into_lwe_ciphertext())
}
}

impl CompressedCiphertext {
/// Deconstruct a [`CompressedCiphertext`] into its constituents.
pub fn into_raw_parts(self) -> SeededLweCiphertext<u32> {
self.ciphertext
Expand Down
6 changes: 0 additions & 6 deletions tfhe/src/boolean/engine/bootstrapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,3 @@ impl ServerKey {
output
}
}

impl From<CompressedServerKey> for ServerKey {
fn from(compressed_server_key: CompressedServerKey) -> Self {
compressed_server_key.decompress()
}
}
13 changes: 7 additions & 6 deletions tfhe/src/boolean/public_key/standard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,16 @@ impl PublicKey {
}
}

impl From<CompressedPublicKey> for PublicKey {
fn from(compressed_public_key: CompressedPublicKey) -> Self {
let parameters = compressed_public_key.parameters;
impl CompressedPublicKey {
pub fn decompress(&self) -> PublicKey {
let parameters = self.parameters;

let decompressed_public_key = compressed_public_key
let decompressed_public_key = self
.compressed_lwe_public_key
.as_view()
.par_decompress_into_lwe_public_key();

Self {
PublicKey {
lwe_public_key: decompressed_public_key,
parameters,
}
Expand Down Expand Up @@ -194,7 +195,7 @@ mod tests {
let keys = KEY_CACHE.get_from_param(parameters);
let (cks, sks) = (keys.client_key(), keys.server_key());
let cpks = CompressedPublicKey::new(cks);
let pks = PublicKey::from(cpks);
let pks = cpks.decompress();

for _ in 0..NB_TESTS {
let b1 = random_boolean();
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/c_api/boolean/ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub unsafe extern "C" fn boolean_decompress_ciphertext(

let compressed_ciphertext = get_mut_checked(compressed_ciphertext).unwrap();

let ciphertext = compressed_ciphertext.0.clone().into();
let ciphertext = compressed_ciphertext.0.decompress();

let heap_allocated_ciphertext = Box::new(BooleanCiphertext(ciphertext));

Expand Down
5 changes: 2 additions & 3 deletions tfhe/src/c_api/boolean/server_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,8 @@ pub unsafe extern "C" fn boolean_decompress_server_key(

let compressed_server_key = get_ref_checked(compressed_server_key).unwrap();

let heap_allocated_public_key = Box::new(BooleanServerKey(
boolean::server_key::ServerKey::from(compressed_server_key.0.clone()),
));
let heap_allocated_public_key =
Box::new(BooleanServerKey(compressed_server_key.0.decompress()));

*result = Box::into_raw(heap_allocated_public_key);
})
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/c_api/high_level_api/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ macro_rules! create_integer_wrapper_type {
$crate::c_api::utils::catch_panic(|| {
let compressed = $crate::c_api::utils::get_ref_checked(sself).unwrap();

let decompressed_inner = compressed.0.clone().into();
let decompressed_inner = compressed.0.decompress();
*result = Box::into_raw(Box::new($name(decompressed_inner)));
})
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/c_api/shortint/ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub unsafe extern "C" fn shortint_decompress_ciphertext(

let compressed_ciphertext = get_ref_checked(compressed_ciphertext).unwrap();

let ciphertext = compressed_ciphertext.0.clone().into();
let ciphertext = compressed_ciphertext.0.decompress();

let heap_allocated_ciphertext = Box::new(ShortintCiphertext(ciphertext));

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/c_api/shortint/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ pub unsafe extern "C" fn shortint_decompress_public_key(
let compressed_public_key = get_ref_checked(compressed_public_key).unwrap();

let heap_allocated_public_key =
Box::new(ShortintPublicKey(compressed_public_key.0.clone().into()));
Box::new(ShortintPublicKey(compressed_public_key.0.decompress()));

*result = Box::into_raw(heap_allocated_public_key);
})
Expand Down
5 changes: 2 additions & 3 deletions tfhe/src/c_api/shortint/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,8 @@ pub unsafe extern "C" fn shortint_decompress_server_key(

let compressed_server_key = get_ref_checked(compressed_server_key).unwrap();

let heap_allocated_public_key = Box::new(ShortintServerKey(
shortint::server_key::ServerKey::from(compressed_server_key.0.clone()),
));
let heap_allocated_public_key =
Box::new(ShortintServerKey(compressed_server_key.0.decompress()));

*result = Box::into_raw(heap_allocated_public_key);
})
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/entities/seeded_lwe_ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::misc::check_encrypted_content_respects_mod;

/// A [`seeded GLWE ciphertext`](`SeededLweCiphertext`).
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SeededLweCiphertext<Scalar: UnsignedInteger> {
data: Scalar,
lwe_size: LweSize,
Expand Down
15 changes: 1 addition & 14 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::inner::InnerBoolean;
use crate::conformance::ParameterSetConformant;
use crate::high_level_api::booleans::compressed::CompressedFheBool;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_stream;
Expand All @@ -11,7 +10,7 @@ use crate::integer::parameters::RadixCiphertextConformanceParams;
use crate::integer::BooleanBlock;
use crate::named::Named;
use crate::shortint::ciphertext::NotTrivialCiphertextError;
use crate::{CompactFheBool, Device};
use crate::Device;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
Expand Down Expand Up @@ -59,18 +58,6 @@ impl ParameterSetConformant for FheBool {
}
}

impl From<CompressedFheBool> for FheBool {
fn from(value: CompressedFheBool) -> Self {
value.decompress()
}
}

impl From<CompactFheBool> for FheBool {
fn from(value: CompactFheBool) -> Self {
value.expand()
}
}

impl FheBool {
pub(in crate::high_level_api) fn new<T: Into<InnerBoolean>>(ciphertext: T) -> Self {
Self {
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/booleans/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ fn compressed_bool_test_case(setup_fn: impl FnOnce() -> (ClientKey, Device)) {
let cttrue = CompressedFheBool::encrypt(true, &cks);
let cffalse = CompressedFheBool::encrypt(false, &cks);

let a = FheBool::from(cttrue);
let b = FheBool::from(cffalse);
let a = cttrue.decompress();
let b = cffalse.decompress();

assert_degree_is_ok(&a);
assert_degree_is_ok(&b);
Expand Down
20 changes: 1 addition & 19 deletions tfhe/src/high_level_api/integers/signed/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::integer::SignedRadixCiphertext;
use crate::named::Named;
use crate::prelude::CastFrom;
use crate::shortint::ciphertext::NotTrivialCiphertextError;
use crate::{CompactFheInt, CompressedFheInt, FheBool};
use crate::FheBool;

pub trait FheIntId: IntegerId {}

Expand All @@ -32,24 +32,6 @@ pub struct FheInt<Id: FheIntId> {
pub(in crate::high_level_api::integers) id: Id,
}

impl<Id> From<CompressedFheInt<Id>> for FheInt<Id>
where
Id: FheIntId,
{
fn from(value: CompressedFheInt<Id>) -> Self {
value.decompress()
}
}

impl<Id> From<CompactFheInt<Id>> for FheInt<Id>
where
Id: FheIntId,
{
fn from(value: CompactFheInt<Id>) -> Self {
value.expand()
}
}

impl<Id: FheIntId> ParameterSetConformant for FheInt<Id> {
type ParameterSet = RadixCiphertextConformanceParams;
fn is_conformant(&self, params: &RadixCiphertextConformanceParams) -> bool {
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/integers/signed/compressed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ where
/// Decompress to a [FheInt]
///
/// See [CompressedFheInt] example.
pub fn decompress(self) -> FheInt<Id> {
let inner = self.ciphertext.into();
pub fn decompress(&self) -> FheInt<Id> {
let inner = self.ciphertext.decompress();
FheInt::new(inner)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/integers/signed/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn test_signed_integer_compressed() {

let clear = -1234i16;
let compressed = CompressedFheInt16::try_encrypt(clear, &client_key).unwrap();
let decompressed = FheInt16::from(compressed);
let decompressed = compressed.decompress();
let clear_decompressed: i16 = decompressed.decrypt(&client_key);
assert_eq!(clear_decompressed, clear);
}
Expand All @@ -28,7 +28,7 @@ fn test_integer_compressed_small() {

let clear = rng.gen::<i16>();
let compressed = CompressedFheInt16::try_encrypt(clear, &client_key).unwrap();
let decompressed = FheInt16::from(compressed);
let decompressed = compressed.decompress();
let clear_decompressed: i16 = decompressed.decrypt(&client_key);
assert_eq!(clear_decompressed, clear);
}
Expand Down
20 changes: 1 addition & 19 deletions tfhe/src/high_level_api/integers/unsigned/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::integer::block_decomposition::RecomposableFrom;
use crate::integer::parameters::RadixCiphertextConformanceParams;
use crate::named::Named;
use crate::shortint::ciphertext::NotTrivialCiphertextError;
use crate::{CompactFheUint, CompressedFheUint, FheBool};
use crate::FheBool;

#[derive(Debug)]
pub enum GenericIntegerBlockError {
Expand Down Expand Up @@ -70,24 +70,6 @@ pub struct FheUint<Id: FheUintId> {
pub(in crate::high_level_api::integers) id: Id,
}

impl<Id> From<CompressedFheUint<Id>> for FheUint<Id>
where
Id: FheUintId,
{
fn from(value: CompressedFheUint<Id>) -> Self {
value.decompress()
}
}

impl<Id> From<CompactFheUint<Id>> for FheUint<Id>
where
Id: FheUintId,
{
fn from(value: CompactFheUint<Id>) -> Self {
value.expand()
}
}

impl<Id: FheUintId> ParameterSetConformant for FheUint<Id> {
type ParameterSet = RadixCiphertextConformanceParams;
fn is_conformant(&self, params: &RadixCiphertextConformanceParams) -> bool {
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/integers/unsigned/compressed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ where
/// Decompress to a [FheUint]
///
/// See [CompressedFheUint] example.
pub fn decompress(self) -> FheUint<Id> {
let inner: crate::integer::RadixCiphertext = self.ciphertext.into();
pub fn decompress(&self) -> FheUint<Id> {
let inner: crate::integer::RadixCiphertext = self.ciphertext.decompress();
let mut ciphertext = FheUint::new(inner);
ciphertext.move_to_device_of_server_key_if_set();
ciphertext
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn test_integer_compressed_can_be_serialized() {
let bytes = bincode::serialize(&compressed).unwrap();
let deserialized: CompressedFheUint256 = bincode::deserialize_from(bytes.as_slice()).unwrap();

let decompressed = FheUint256::from(deserialized);
let decompressed = FheUint256::from(deserialized.decompress());
let clear_decompressed: U256 = decompressed.decrypt(&client_key);
assert_eq!(clear_decompressed, clear);
}
Expand All @@ -52,7 +52,7 @@ fn test_integer_compressed() {

let clear = 12_837u16;
let compressed = CompressedFheUint16::try_encrypt(clear, &client_key).unwrap();
let decompressed = FheUint16::from(compressed);
let decompressed = FheUint16::from(compressed.decompress());
let clear_decompressed: u16 = decompressed.decrypt(&client_key);
assert_eq!(clear_decompressed, clear);
}
Expand All @@ -64,7 +64,7 @@ fn test_integer_compressed_small() {

let clear = 12_837u16;
let compressed = CompressedFheUint16::try_encrypt(clear, &client_key).unwrap();
let decompressed = FheUint16::from(compressed);
let decompressed = FheUint16::from(compressed.decompress());
let clear_decompressed: u16 = decompressed.decrypt(&client_key);
assert_eq!(clear_decompressed, clear);
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/high_level_api/keys/public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ impl CompressedPublicKey {
Self { key }
}

pub fn decompress(self) -> PublicKey {
pub fn decompress(&self) -> PublicKey {
PublicKey {
key: crate::integer::PublicKey::from(self.key),
key: self.key.decompress(),
}
}

Expand Down
6 changes: 0 additions & 6 deletions tfhe/src/high_level_api/keys/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,6 @@ impl CompressedServerKey {
}
}

impl From<CompressedServerKey> for ServerKey {
fn from(value: CompressedServerKey) -> Self {
value.decompress()
}
}

#[cfg(feature = "gpu")]
#[derive(Clone)]
pub struct CudaServerKey {
Expand Down
Loading

0 comments on commit cc1a8e4

Please sign in to comment.