Skip to content

Commit

Permalink
fix(compression): update compression parameters, fix compression on G…
Browse files Browse the repository at this point in the history
…PU and improve test

- the new compression parameters went through a noise check to verify constraints
- CPU and GPU compression tests are improved and the same
- implement Debug, Eq, PartialEq to CompressedCiphertextList
- fix gpu compression when a radix ciphertext is split through more than one compact GLWE
  • Loading branch information
pdroalves authored and agnesLeroy committed Oct 10, 2024
1 parent c2aae98 commit e376049
Show file tree
Hide file tree
Showing 13 changed files with 372 additions and 221 deletions.
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ test_integer_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,gpu -p $(TFHE_SPEC) -- integer::gpu::server_key::

.PHONY: test_integer_compression
test_integer_compression: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer -p $(TFHE_SPEC) -- integer::ciphertext::compressed_ciphertext_list::tests::
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer -p $(TFHE_SPEC) -- integer::ciphertext::compress

.PHONY: test_integer_compression_gpu
test_integer_compression_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
auto out_len = (number_bits_to_pack + nbits - 1) / nbits;

// Last GLWE
auto last_body_count = num_lwes % compression_params.polynomial_size;
in_len =
compression_params.glwe_dimension * compression_params.polynomial_size +
last_body_count;
number_bits_to_pack = in_len * log_modulus;
auto last_out_len = (number_bits_to_pack + nbits - 1) / nbits;

Expand All @@ -75,10 +71,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,

dim3 grid(num_blocks);
dim3 threads(num_threads);
cuda_memset_async(array_out, 0,
num_glwes * (compression_params.glwe_dimension + 1) *
compression_params.polynomial_size * sizeof(Torus),
stream, gpu_index);
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
num_coeffs, in_len, out_len);
check_cuda_error(cudaGetLastError());
Expand Down Expand Up @@ -294,7 +286,7 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes,
compression_params.glwe_dimension,
compression_params.polynomial_size);
d_indexes_array_chunk += num_lwes;
extracted_lwe += lwe_accumulator_size;
extracted_lwe += num_lwes * lwe_accumulator_size;
current_idx = last_idx;
}

Expand Down
11 changes: 7 additions & 4 deletions tfhe/docs/fundamentals/compress.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@ The following example shows how to compress and decompress a list containing 4 m

```rust
use tfhe::prelude::*;
use tfhe::shortint::parameters::{COMP_PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_2_CARRY_2};
use tfhe::shortint::parameters::{
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use tfhe::{
set_server_key, CompressedCiphertextList, CompressedCiphertextListBuilder, FheBool,
FheInt64, FheUint16, FheUint2, FheUint32,
};

fn main() {
let config = tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2)
.build();
let config =
tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.build();

let ck = tfhe::ClientKey::generate(config);
let sk = tfhe::ServerKey::new(&ck);
Expand Down
11 changes: 7 additions & 4 deletions tfhe/docs/guides/run_on_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,19 @@ The following example shows how to compress and decompress a list containing 4 m

```rust
use tfhe::prelude::*;
use tfhe::shortint::parameters::{COMP_PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_2_CARRY_2};
use tfhe::shortint::parameters::{
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use tfhe::{
set_server_key, CompressedCiphertextList, CompressedCiphertextListBuilder, FheBool,
FheInt64, FheUint16, FheUint2, FheUint32,
};

fn main() {
let config = tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2)
.build();
let config =
tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.build();

let ck = tfhe::ClientKey::generate(config);
let compressed_server_key = tfhe::CompressedServerKey::new(&ck);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*;
/// );
/// }
/// ```
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedModulusSwitchedGlweCiphertextVersions)]
pub struct CompressedModulusSwitchedGlweCiphertext<Scalar: UnsignedInteger> {
pub(crate) packed_integers: PackedIntegers<Scalar>,
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/entities/packed_integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant;
use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
use crate::core_crypto::prelude::*;

#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(PackedIntegersVersions)]
pub struct PackedIntegers<Scalar: UnsignedInteger> {
pub(crate) packed_coeffs: Vec<Scalar>,
Expand Down
7 changes: 4 additions & 3 deletions tfhe/src/high_level_api/tests/tags_on_entities.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::prelude::*;
use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::*;
use crate::shortint::ClassicPBSParameters;
use crate::{
Expand All @@ -20,7 +21,7 @@ fn test_tag_propagation_cpu() {
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)),
Some(COMP_PARAM_MESSAGE_2_CARRY_2),
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
)
}

Expand Down Expand Up @@ -139,9 +140,9 @@ fn test_tag_propagation_zk_pok() {
fn test_tag_propagation_gpu() {
test_tag_propagation(
Device::CudaGpu,
PARAM_MESSAGE_2_CARRY_2,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
None,
Some(COMP_PARAM_MESSAGE_2_CARRY_2),
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
)
}

Expand Down
209 changes: 178 additions & 31 deletions tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder {
}
}

#[derive(Clone, Serialize, Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub(crate) packed_list: ShortintCompressedCiphertextList,
Expand Down Expand Up @@ -153,46 +153,193 @@ impl CompressedCiphertextList {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ClientKey;
use crate::integer::{gen_keys, IntegerKeyKind};
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use itertools::Itertools;
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;
#[test]
fn test_heterogeneous_ciphertext_compression_ci_run_filter() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
fn test_ciphertext_compression() {
const NUM_BLOCKS: usize = 32;

let (cks, sks) = gen_keys(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
IntegerKeyKind::Radix,
);

let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);

let ct1 = cks.encrypt_radix(3_u32, 16);

let ct2 = cks.encrypt_signed_radix(-2, 16);

let ct3 = cks.encrypt_bool(true);

let compressed = CompressedCiphertextListBuilder::new()
.push(ct1)
.push(ct2)
.push(ct3)
.build(&compression_key);

let decompressed1 = compressed.get(0, &decompression_key).unwrap().unwrap();

let decrypted: u32 = cks.decrypt_radix(&decompressed1);

assert_eq!(decrypted, 3_u32);

let decompressed2 = compressed.get(1, &decompression_key).unwrap().unwrap();

let decrypted2: i32 = cks.decrypt_signed_radix(&decompressed2);

assert_eq!(decrypted2, -2);

let decompressed3 = compressed.get(2, &decompression_key).unwrap().unwrap();

assert!(cks.decrypt_bool(&decompressed3));
const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.lwe_per_glwe
.0
/ NUM_BLOCKS;

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

for _ in 0..NB_TESTS {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_radix(*message, NUM_BLOCKS))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_signed_radix(*message, NUM_BLOCKS))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_bool(*message))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
let and_ct = sks.boolean_bitand(&ct, &ct);
builder.push(and_ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CompressedCiphertextListBuilder::new();

let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = cks.encrypt_radix(message, NUM_BLOCKS);
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = cks.encrypt_signed_radix(message, NUM_BLOCKS);
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = cks.encrypt_bool(message);
let and_ct = sks.boolean_bitand(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Boolean(message));
}
}
}

let compressed = builder.build(&compression_key);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}
}
Loading

0 comments on commit e376049

Please sign in to comment.