Skip to content

Commit

Permalink
chore(gpu): improve compression tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Sep 16, 2024
1 parent 1f01c2d commit 361ff87
Showing 1 changed file with 107 additions and 35 deletions.
142 changes: 107 additions & 35 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,11 @@ mod tests {
use crate::integer::ClientKey;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;

#[test]
fn test_gpu_ciphertext_compression() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
Expand All @@ -427,42 +430,111 @@ mod tests {
let (cuda_compression_key, cuda_decompression_key) =
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

for _ in 0..NB_TESTS {
let ct1 = radix_cks.encrypt(3_u32);
let ct2 = radix_cks.encrypt_signed(-2);
let ct3 = radix_cks.encrypt_bool(true);

// Copy to GPU
let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams);

let cuda_compressed = CudaCompressedCiphertextListBuilder::new()
.push(d_ct1, &streams)
.push(d_ct2, &streams)
.push(d_ct3, &streams)
.build(&cuda_compression_key, &streams);

let d_decompressed1 = CudaUnsignedRadixCiphertext {
ciphertext: cuda_compressed.get(0, &cuda_decompression_key, &streams),
};
let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams);
let decrypted: u32 = radix_cks.decrypt(&decompressed1);
assert_eq!(decrypted, 3_u32);

let d_decompressed2 = CudaSignedRadixCiphertext {
ciphertext: cuda_compressed.get(1, &cuda_decompression_key, &streams),
};
let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams);
let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2);
assert_eq!(decrypted, -2);

let d_decompressed3 = CudaBooleanBlock::from_cuda_radix_ciphertext(
cuda_compressed.get(2, &cuda_decompression_key, &streams),
);
let decompressed3 = d_decompressed3.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed3);
assert!(decrypted);
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32) as u128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt(*message);
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaUnsignedRadixCiphertext {
ciphertext: cuda_compressed.get(i, &cuda_decompression_key, &streams),
};
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Signed
let modulus = (message_modulus.pow(num_blocks as u32) / 2) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt_signed(*message);
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaSignedRadixCiphertext {
ciphertext: cuda_compressed.get(i, &cuda_decompression_key, &streams),
};
let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams);
let decrypted: i128 = radix_cks.decrypt_signed(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();

let d_cts = messages
.iter()
.map(|message| {
let ct = radix_cks.encrypt_bool(*message);
CudaBooleanBlock::from_boolean_block(&ct, &streams)
})
.collect_vec();

let mut builder = CudaCompressedCiphertextListBuilder::new();

for d_ct in d_cts {
builder.push(d_ct, &streams);
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed = CudaBooleanBlock::from_cuda_radix_ciphertext(
cuda_compressed.get(i, &cuda_decompression_key, &streams),
);
let decompressed = d_decompressed.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}

0 comments on commit 361ff87

Please sign in to comment.