Skip to content

Commit

Permalink
test: add tag check for parameter keyswitch in HL API
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Nov 8, 2024
1 parent 6ef22e8 commit 9ee18dd
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions tfhe/src/high_level_api/tests/tags_on_entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
set_server_key, ClientKey, CompactCiphertextList, CompactCiphertextListExpander,
CompactPublicKey, CompressedCiphertextList, CompressedCiphertextListBuilder, CompressedFheBool,
CompressedFheInt32, CompressedFheUint32, CompressedServerKey, ConfigBuilder, Device, FheBool,
FheInt32, FheInt64, FheUint32, ServerKey,
FheInt32, FheInt64, FheUint32, KeySwitchingKey, ServerKey,
};
use rand::random;

Expand All @@ -22,6 +22,10 @@ fn test_tag_propagation_cpu() {
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)),
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
Some((
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)),
)
}

Expand All @@ -40,7 +44,7 @@ fn test_tag_propagation_zk_pok() {
let mut cks = ClientKey::generate(config);
let tag_value = random();
cks.tag_mut().set_u64(tag_value);
let cks = serialize_then_deserialize(cks);
let cks = serialize_then_deserialize(&cks);
assert_eq!(cks.tag().as_u64(), tag_value);

let sks = ServerKey::new(&cks);
Expand All @@ -61,7 +65,7 @@ fn test_tag_propagation_zk_pok() {
.build_with_proof_packed(&crs, &metadata, crate::zk::ZkComputeLoad::Proof)
.unwrap();

let list_packed: ProvenCompactCiphertextList = serialize_then_deserialize(list_packed);
let list_packed: ProvenCompactCiphertextList = serialize_then_deserialize(&list_packed);
assert_eq!(list_packed.tag(), cks.tag());

let expander = list_packed
Expand Down Expand Up @@ -139,14 +143,18 @@ fn test_tag_propagation_gpu() {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
None,
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
Some((
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)),
)
}

fn serialize_then_deserialize<T>(value: T) -> T
fn serialize_then_deserialize<T>(value: &T) -> T
where
T: serde::Serialize + for<'a> serde::de::Deserialize<'a>,
{
let serialized = bincode::serialize(&value).unwrap();
let serialized = bincode::serialize(value).unwrap();
bincode::deserialize(&serialized).unwrap()
}

Expand All @@ -158,6 +166,7 @@ fn test_tag_propagation(
ShortintKeySwitchingParameters,
)>,
comp_parameters: Option<CompressionParameters>,
ks_to_params: Option<(ClassicPBSParameters, ShortintKeySwitchingParameters)>,
) {
let mut builder = ConfigBuilder::with_custom_parameters(pbs_parameters);
if let Some(parameters) = dedicated_compact_public_key_parameters {
Expand All @@ -171,22 +180,22 @@ fn test_tag_propagation(
let mut cks = ClientKey::generate(config);
let tag_value = random();
cks.tag_mut().set_u64(tag_value);
let cks = serialize_then_deserialize(cks);
let cks = serialize_then_deserialize(&cks);
assert_eq!(cks.tag().as_u64(), tag_value);

let compressed_sks = CompressedServerKey::new(&cks);
let compressed_sks = serialize_then_deserialize(compressed_sks);
let compressed_sks = serialize_then_deserialize(&compressed_sks);
assert_eq!(compressed_sks.tag(), cks.tag());
let sks = ServerKey::new(&cks);

match device {
Device::Cpu => {
let sks = ServerKey::new(&cks);
let sks = serialize_then_deserialize(sks);
let sks = serialize_then_deserialize(&sks);
assert_eq!(sks.tag(), cks.tag());

// Now test when the sks comes from a compressed one
let sks = compressed_sks.decompress();
let sks = serialize_then_deserialize(sks);
let sks = serialize_then_deserialize(&sks);
assert_eq!(sks.tag(), cks.tag());

set_server_key(sks);
Expand All @@ -207,7 +216,7 @@ fn test_tag_propagation(
// Check FheUint have a tag
{
let ct_a = FheUint32::encrypt(8182u32, &cks);
let ct_a = serialize_then_deserialize(ct_a);
let ct_a = serialize_then_deserialize(&ct_a);
assert_eq!(ct_a.tag(), cks.tag());

let ct_b = FheUint32::encrypt(8182u32, &cks);
Expand All @@ -222,7 +231,7 @@ fn test_tag_propagation(
// Check FheInt have a tag
{
let ct_a = FheInt32::encrypt(-1i32, &cks);
let ct_a = serialize_then_deserialize(ct_a);
let ct_a = serialize_then_deserialize(&ct_a);
assert_eq!(ct_a.tag(), cks.tag());

let ct_b = FheInt32::encrypt(i32::MIN, &cks);
Expand All @@ -237,7 +246,7 @@ fn test_tag_propagation(
// Check FheBool have a tag
{
let ct_a = FheBool::encrypt(false, &cks);
let ct_a = serialize_then_deserialize(ct_a);
let ct_a = serialize_then_deserialize(&ct_a);
assert_eq!(ct_a.tag(), cks.tag());

let ct_b = FheBool::encrypt(true, &cks);
Expand All @@ -249,8 +258,7 @@ fn test_tag_propagation(
compression_builder.push(ct_c);
}

if device == Device::Cpu {
// Cuda do no yet support compressing
{
let compressed_list = compression_builder.build().unwrap();
assert_eq!(compressed_list.tag(), cks.tag());

Expand All @@ -265,6 +273,27 @@ fn test_tag_propagation(
assert_eq!(b.tag(), cks.tag());
let c: FheBool = compressed_list.get(2).unwrap().unwrap();
assert_eq!(c.tag(), cks.tag());

if let Some((dest_params, ks_params)) = ks_to_params {
let dest_config = ConfigBuilder::with_custom_parameters(dest_params);
let mut dest_cks = ClientKey::generate(dest_config);
dest_cks.tag_mut().set_u64(random());
let compressed_dest_sks = CompressedServerKey::new(&dest_cks);
let dest_sks = compressed_dest_sks.decompress();

let ksk = KeySwitchingKey::with_parameters(
(&cks, &sks),
(&dest_cks, &dest_sks),
ks_params,
);

let ks_a = ksk.keyswitch(&a);
assert_eq!(ks_a.tag(), dest_cks.tag());
let ks_b = ksk.keyswitch(&b);
assert_eq!(ks_b.tag(), dest_cks.tag());
let ks_c = ksk.keyswitch(&c);
assert_eq!(ks_c.tag(), dest_cks.tag());
}
}
}

Expand Down Expand Up @@ -298,7 +327,7 @@ fn test_tag_propagation(
// Test compact public key stuff
if device == Device::Cpu {
let cpk = CompactPublicKey::new(&cks);
let cpk = serialize_then_deserialize(cpk);
let cpk = serialize_then_deserialize(&cpk);
assert_eq!(cpk.tag(), cks.tag());

let mut builder = CompactCiphertextList::builder(&cpk);
Expand Down Expand Up @@ -344,14 +373,14 @@ fn test_tag_propagation(

{
let list = builder.build();
let list: CompactCiphertextList = serialize_then_deserialize(list);
let list: CompactCiphertextList = serialize_then_deserialize(&list);
assert_eq!(list.tag(), cks.tag());
expand_and_check_tags(list.expand().unwrap(), &cks);
}

{
let list_packed = builder.build_packed();
let list_packed: CompactCiphertextList = serialize_then_deserialize(list_packed);
let list_packed: CompactCiphertextList = serialize_then_deserialize(&list_packed);
assert_eq!(list_packed.tag(), cks.tag());
expand_and_check_tags(list_packed.expand().unwrap(), &cks);
}
Expand Down

0 comments on commit 9ee18dd

Please sign in to comment.