diff --git a/tfhe/src/high_level_api/tests/tags_on_entities.rs b/tfhe/src/high_level_api/tests/tags_on_entities.rs index a7e45ad7d2..c963749694 100644 --- a/tfhe/src/high_level_api/tests/tags_on_entities.rs +++ b/tfhe/src/high_level_api/tests/tags_on_entities.rs @@ -8,7 +8,7 @@ use crate::{ set_server_key, ClientKey, CompactCiphertextList, CompactCiphertextListExpander, CompactPublicKey, CompressedCiphertextList, CompressedCiphertextListBuilder, CompressedFheBool, CompressedFheInt32, CompressedFheUint32, CompressedServerKey, ConfigBuilder, Device, FheBool, - FheInt32, FheInt64, FheUint32, ServerKey, + FheInt32, FheInt64, FheUint32, KeySwitchingKey, ServerKey, }; use rand::random; @@ -22,6 +22,10 @@ fn test_tag_propagation_cpu() { PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, )), Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64), + Some(( + PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + )), ) } @@ -40,7 +44,7 @@ fn test_tag_propagation_zk_pok() { let mut cks = ClientKey::generate(config); let tag_value = random(); cks.tag_mut().set_u64(tag_value); - let cks = serialize_then_deserialize(cks); + let cks = serialize_then_deserialize(&cks); assert_eq!(cks.tag().as_u64(), tag_value); let sks = ServerKey::new(&cks); @@ -61,7 +65,7 @@ fn test_tag_propagation_zk_pok() { .build_with_proof_packed(&crs, &metadata, crate::zk::ZkComputeLoad::Proof) .unwrap(); - let list_packed: ProvenCompactCiphertextList = serialize_then_deserialize(list_packed); + let list_packed: ProvenCompactCiphertextList = serialize_then_deserialize(&list_packed); assert_eq!(list_packed.tag(), cks.tag()); let expander = list_packed @@ -139,14 +143,18 @@ fn test_tag_propagation_gpu() { PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, None, Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64), + Some(( + PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + )), ) } -fn serialize_then_deserialize(value: T) -> T +fn serialize_then_deserialize(value: &T) -> T where T: serde::Serialize + for<'a> serde::de::Deserialize<'a>, { - let serialized = bincode::serialize(&value).unwrap(); + let serialized = bincode::serialize(value).unwrap(); bincode::deserialize(&serialized).unwrap() } @@ -158,6 +166,7 @@ fn test_tag_propagation( ShortintKeySwitchingParameters, )>, comp_parameters: Option, + ks_to_params: Option<(ClassicPBSParameters, ShortintKeySwitchingParameters)>, ) { let mut builder = ConfigBuilder::with_custom_parameters(pbs_parameters); if let Some(parameters) = dedicated_compact_public_key_parameters { @@ -171,22 +180,22 @@ fn test_tag_propagation( let mut cks = ClientKey::generate(config); let tag_value = random(); cks.tag_mut().set_u64(tag_value); - let cks = serialize_then_deserialize(cks); + let cks = serialize_then_deserialize(&cks); assert_eq!(cks.tag().as_u64(), tag_value); let compressed_sks = CompressedServerKey::new(&cks); - let compressed_sks = serialize_then_deserialize(compressed_sks); + let compressed_sks = serialize_then_deserialize(&compressed_sks); assert_eq!(compressed_sks.tag(), cks.tag()); + let sks = ServerKey::new(&cks); match device { Device::Cpu => { - let sks = ServerKey::new(&cks); - let sks = serialize_then_deserialize(sks); + let sks = serialize_then_deserialize(&sks); assert_eq!(sks.tag(), cks.tag()); // Now test when the sks comes from a compressed one let sks = compressed_sks.decompress(); - let sks = serialize_then_deserialize(sks); + let sks = serialize_then_deserialize(&sks); assert_eq!(sks.tag(), cks.tag()); set_server_key(sks); @@ -207,7 +216,7 @@ fn test_tag_propagation( // Check FheUint have a tag { let ct_a = FheUint32::encrypt(8182u32, &cks); - let ct_a = serialize_then_deserialize(ct_a); + let ct_a = serialize_then_deserialize(&ct_a); assert_eq!(ct_a.tag(), cks.tag()); let ct_b = FheUint32::encrypt(8182u32, &cks); @@ -222,7 +231,7 @@ fn test_tag_propagation( // Check FheInt have a tag { let ct_a = FheInt32::encrypt(-1i32, &cks); - let ct_a = serialize_then_deserialize(ct_a); + let ct_a = serialize_then_deserialize(&ct_a); assert_eq!(ct_a.tag(), cks.tag()); let ct_b = FheInt32::encrypt(i32::MIN, &cks); @@ -237,7 +246,7 @@ fn test_tag_propagation( // Check FheBool have a tag { let ct_a = FheBool::encrypt(false, &cks); - let ct_a = serialize_then_deserialize(ct_a); + let ct_a = serialize_then_deserialize(&ct_a); assert_eq!(ct_a.tag(), cks.tag()); let ct_b = FheBool::encrypt(true, &cks); @@ -249,8 +258,7 @@ fn test_tag_propagation( compression_builder.push(ct_c); } - if device == Device::Cpu { - // Cuda do no yet support compressing + { let compressed_list = compression_builder.build().unwrap(); assert_eq!(compressed_list.tag(), cks.tag()); @@ -265,6 +273,27 @@ fn test_tag_propagation( assert_eq!(b.tag(), cks.tag()); let c: FheBool = compressed_list.get(2).unwrap().unwrap(); assert_eq!(c.tag(), cks.tag()); + + if let Some((dest_params, ks_params)) = ks_to_params { + let dest_config = ConfigBuilder::with_custom_parameters(dest_params); + let mut dest_cks = ClientKey::generate(dest_config); + dest_cks.tag_mut().set_u64(random()); + let compressed_dest_sks = CompressedServerKey::new(&dest_cks); + let dest_sks = compressed_dest_sks.decompress(); + + let ksk = KeySwitchingKey::with_parameters( + (&cks, &sks), + (&dest_cks, &dest_sks), + ks_params, + ); + + let ks_a = ksk.keyswitch(&a); + assert_eq!(ks_a.tag(), dest_cks.tag()); + let ks_b = ksk.keyswitch(&b); + assert_eq!(ks_b.tag(), dest_cks.tag()); + let ks_c = ksk.keyswitch(&c); + assert_eq!(ks_c.tag(), dest_cks.tag()); + } } } @@ -298,7 +327,7 @@ fn test_tag_propagation( // Test compact public key stuff if device == Device::Cpu { let cpk = CompactPublicKey::new(&cks); - let cpk = serialize_then_deserialize(cpk); + let cpk = serialize_then_deserialize(&cpk); assert_eq!(cpk.tag(), cks.tag()); let mut builder = CompactCiphertextList::builder(&cpk); @@ -344,14 +373,14 @@ fn test_tag_propagation( { let list = builder.build(); - let list: CompactCiphertextList = serialize_then_deserialize(list); + let list: CompactCiphertextList = serialize_then_deserialize(&list); assert_eq!(list.tag(), cks.tag()); expand_and_check_tags(list.expand().unwrap(), &cks); } { let list_packed = builder.build_packed(); - let list_packed: CompactCiphertextList = serialize_then_deserialize(list_packed); + let list_packed: CompactCiphertextList = serialize_then_deserialize(&list_packed); assert_eq!(list_packed.tag(), cks.tag()); expand_and_check_tags(list_packed.expand().unwrap(), &cks); }