Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(gpu): improve compression tests #1543

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 193 additions & 35 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,11 @@ mod tests {
use crate::integer::ClientKey;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;

#[test]
fn test_gpu_ciphertext_compression() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
Expand All @@ -427,42 +430,197 @@ mod tests {
let (cuda_compression_key, cuda_decompression_key) =
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

for _ in 0..NB_TESTS {
let ct1 = radix_cks.encrypt(3_u32);
let ct2 = radix_cks.encrypt_signed(-2);
let ct3 = radix_cks.encrypt_bool(true);

// Copy to GPU
let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams);

let cuda_compressed = CudaCompressedCiphertextListBuilder::new()
.push(d_ct1, &streams)
.push(d_ct2, &streams)
.push(d_ct3, &streams)
.build(&cuda_compression_key, &streams);

let d_decompressed1 = CudaUnsignedRadixCiphertext {
ciphertext: cuda_compressed.get(0, &cuda_decompression_key, &streams),
};
let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams);
let decrypted: u32 = radix_cks.decrypt(&decompressed1);
assert_eq!(decrypted, 3_u32);

let d_decompressed2 = CudaSignedRadixCiphertext {
ciphertext: cuda_compressed.get(1, &cuda_decompression_key, &streams),
};
let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams);
let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2);
assert_eq!(decrypted, -2);

let d_decompressed3 = CudaBooleanBlock::from_cuda_radix_ciphertext(
cuda_compressed.get(2, &cuda_decompression_key, &streams),
);
let decompressed3 = d_decompressed3.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed3);
assert!(decrypted);
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt(*message);
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaUnsignedRadixCiphertext {
ciphertext: cuda_compressed.get(i, &cuda_decompression_key, &streams),
};
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt_signed(*message);
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaSignedRadixCiphertext {
ciphertext: cuda_compressed.get(i, &cuda_decompression_key, &streams),
};
let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams);
let decrypted: i128 = radix_cks.decrypt_signed(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt_bool(*message);
CudaBooleanBlock::from_boolean_block(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaBooleanBlock::from_cuda_radix_ciphertext(
cuda_compressed.get(i, &cuda_decompression_key, &streams),
);
let decompressed = d_decompressed.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Hybrid
enum MessageType {
UNSIGNED(u128),
SIGNED(i128),
BOOLEAN(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CudaCompressedCiphertextListBuilder::new();

let nb_messages = 1 + (rng.gen::<u64>() % 6);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
let message = rng.gen::<u128>() % modulus;
let ct = radix_cks.encrypt(message);
let d_ct =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
messages.push(MessageType::UNSIGNED(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = radix_cks.encrypt_signed(message);
let d_ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(
&ct, &streams,
);
builder.push(d_ct, &streams);
messages.push(MessageType::SIGNED(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = radix_cks.encrypt_bool(message);
let d_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
builder.push(d_ct, &streams);
messages.push(MessageType::BOOLEAN(message));
}
}
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::UNSIGNED(message) => {
let d_decompressed = CudaUnsignedRadixCiphertext {
ciphertext: cuda_compressed.get(
i,
&cuda_decompression_key,
&streams,
),
};
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::SIGNED(message) => {
let d_decompressed = CudaSignedRadixCiphertext {
ciphertext: cuda_compressed.get(
i,
&cuda_decompression_key,
&streams,
),
};
let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams);
let decrypted: i128 = radix_cks.decrypt_signed(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::BOOLEAN(message) => {
let d_decompressed = CudaBooleanBlock::from_cuda_radix_ciphertext(
cuda_compressed.get(i, &cuda_decompression_key, &streams),
);
let decompressed = d_decompressed.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}
}
Loading