Skip to content

Commit

Permalink
chore: add test for fast pks with gemm with appropriate parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 11, 2024
1 parent 7a36836 commit a000de7
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ test_integer_compression: install_rs_build_toolchain
.PHONY: test_integer_compression_gpu
test_integer_compression_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,gpu -p $(TFHE_SPEC) -- integer::gpu::ciphertext::compressed_ciphertext_list::tests:: --nocapture
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,gpu -p $(TFHE_SPEC) -- integer::gpu::ciphertext::compressed_ciphertext_list::tests::test_gpu_ciphertext_compression_fast_path --nocapture
# RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,gpu -p $(TFHE_SPEC) -- integer::gpu::ciphertext::compress --nocapture

.PHONY: test_integer_gpu_ci # Run the tests for integer ci on gpu backend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(

// Optimization of packing keyswitch when packing many LWEs

printf("USING FAST PKS\n");

cudaSetDevice(gpu_index);
check_cuda_error(cudaGetLastError());

Expand Down
7 changes: 5 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
if (can_use_pks_fast_path(input_lwe_dimension, num_lwes,
output_polynomial_size, level_count,
output_glwe_dimension)) {

::fprintf(stderr, "USING FAST PKS");
abort();
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
Expand All @@ -85,7 +86,9 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
base_log, level_count, num_lwes);
} else
printf("USING CLASSICAL PKS\n");
::fprintf(stderr, "USING CLASSICAL PKS");
abort();

host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
static_cast<cudaStream_t>(stream), gpu_index,
static_cast<uint64_t *>(glwe_array_out),
Expand Down
152 changes: 152 additions & 0 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,17 @@ mod tests {
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use rand::Rng;

use crate::core_crypto::prelude::*;
use crate::shortint::ciphertext::MaxNoiseLevel;
use crate::shortint::parameters::{CarryModulus, ClassicPBSParameters, MessageModulus};

use crate::core_crypto::prelude::{CiphertextModulusLog, LweCiphertextCount};
use crate::shortint::parameters::list_compression::CompressionParameters;
use crate::shortint::parameters::{
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension,
PolynomialSize,
};

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;

Expand Down Expand Up @@ -717,4 +728,145 @@ mod tests {
}
}
}

#[test]
fn test_gpu_ciphertext_compression_fast_path() {
// these parameters are insecure
const PARAM_CUSTOM_FAST_PATH: ClassicPBSParameters = ClassicPBSParameters {
lwe_dimension: LweDimension(2048),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0,
)),
pbs_base_log: DecompositionBaseLog(22),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(3),
ks_level: DecompositionLevelCount(1),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
max_noise_level: MaxNoiseLevel::new(5),
log2_p_fail: -64.138,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
};

const COMP_PARAM_CUSTOM_FAST_PATH: CompressionParameters = CompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(23),
packing_ks_level: DecompositionLevelCount(1),
packing_ks_base_log: DecompositionBaseLog(21),
packing_ks_polynomial_size: PolynomialSize(2048),
packing_ks_glwe_dimension: GlweDimension(1),
lwe_per_glwe: LweCiphertextCount(2048),
storage_log_modulus: CiphertextModulusLog(19),
packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(
StandardDev(0.0),
),
};

const NUM_BLOCKS: usize = 32;

let streams = CudaStreams::new_multi_gpu();

let (radix_cks, sks) = gen_keys_radix_gpu(PARAM_CUSTOM_FAST_PATH, NUM_BLOCKS, &streams);
let cks = radix_cks.as_ref();

let private_compression_key = cks.new_compression_private_key(COMP_PARAM_CUSTOM_FAST_PATH);

let (cuda_compression_key, cuda_decompression_key) =
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);

const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_CUSTOM_FAST_PATH.lwe_per_glwe.0 / NUM_BLOCKS;

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CudaCompressedCiphertextListBuilder::new();

let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = radix_cks.encrypt(message);
let d_ct =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = radix_cks.encrypt_signed(message);
let d_ct =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = radix_cks.encrypt_bool(message);
let d_boolean_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
let d_ct = d_boolean_ct.0;
let d_and_boolean_ct =
CudaBooleanBlock::from_cuda_radix_ciphertext(d_ct.ciphertext);
builder.push(d_and_boolean_ct, &streams);
messages.push(MessageType::Boolean(message));
}
}
}

let cuda_compressed = builder.build(&cuda_compression_key, &streams);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let d_decompressed: CudaUnsignedRadixCiphertext = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let d_decompressed: CudaSignedRadixCiphertext = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_signed_radix_ciphertext(&streams);
let decrypted: i128 = radix_cks.decrypt_signed(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let d_decompressed: CudaBooleanBlock = cuda_compressed
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_boolean_block(&streams);
let decrypted = radix_cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}

0 comments on commit a000de7

Please sign in to comment.