diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index 78a4e198a3..7a10cf92aa 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -133,7 +133,7 @@ impl GroupElements { } /// Allows to compute proof with bad inputs for tests -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] enum ProofSanityCheckMode { Panic, #[cfg(test)] diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index ec766f5ddb..4382142cb0 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -1262,6 +1262,8 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; + type Curve = curve_api::Bls12_446; + /// Compact key params used with pkev1 pub(super) const PKEV1_TEST_PARAMS: PkeTestParameters = PkeTestParameters { d: 1024, @@ -1310,8 +1312,6 @@ mod tests { let mut fake_metadata = [255u8; METADATA_LEN]; fake_metadata.fill_with(|| rng.gen::()); - type Curve = curve_api::Bls12_446; - // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1397,9 +1397,9 @@ mod tests { } } - fn prove_and_verify( + fn prove_and_verify( testcase: &PkeTestcase, - crs: &PublicParams, + crs: &PublicParams, load: ComputeLoad, rng: &mut StdRng, ) -> VerificationResult { @@ -1434,10 +1434,10 @@ mod tests { } } - fn assert_prove_and_verify( + fn assert_prove_and_verify( testcase: &PkeTestcase, testcase_name: &str, - crs: &PublicParams, + crs: &PublicParams, rng: &mut StdRng, expected_result: VerificationResult, ) { @@ -1466,8 +1466,6 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS); - type Curve = curve_api::Bls12_446; - // A CRS where the number of slots = the number of messages to encrypt let crs = crs_gen::(d, k, B, q, t, msbs_zero_padding_bit_count, rng); @@ -1630,7 +1628,6 @@ mod tests { }; let ct = testcase.encrypt(PKEV1_TEST_PARAMS); - type Curve = curve_api::Bls12_446; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index 9db0ee29fa..72150adec8 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -2412,6 +2412,8 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; + type Curve = curve_api::Bls12_446; + /// Compact key params used with pkev2 pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters { d: 2048, @@ -2459,8 +2461,6 @@ mod tests { let mut fake_metadata = [255u8; METADATA_LEN]; fake_metadata.fill_with(|| rng.gen::()); - type Curve = curve_api::Bls12_446; - // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2546,14 +2546,14 @@ mod tests { } } - fn prove_and_verify( + fn prove_and_verify( testcase: &PkeTestcase, - crs: &PublicParams, + ct: &PkeTestCiphertext, + crs: &PublicParams, load: ComputeLoad, + sanity_check_mode: ProofSanityCheckMode, rng: &mut StdRng, ) -> VerificationResult { - let ct = testcase.encrypt_unchecked(PKEV2_TEST_PARAMS); - let (public_commit, private_commit) = commit( testcase.a.clone(), testcase.b.clone(), @@ -2573,7 +2573,7 @@ mod tests { &testcase.metadata, load, rng, - ProofSanityCheckMode::Ignore, + sanity_check_mode, ); if verify(&proof, (crs, &public_commit), &testcase.metadata).is_ok() { @@ -2583,16 +2583,18 @@ mod tests { } } - fn assert_prove_and_verify( + fn assert_prove_and_verify( testcase: &PkeTestcase, + ct: &PkeTestCiphertext, testcase_name: &str, - crs: &PublicParams, - rng: &mut StdRng, + crs: &PublicParams, + sanity_check_mode: ProofSanityCheckMode, expected_result: VerificationResult, + rng: &mut StdRng, ) { for load in [ComputeLoad::Proof, ComputeLoad::Verify] { assert_eq!( - prove_and_verify(testcase, crs, load, rng), + prove_and_verify(testcase, ct, crs, load, sanity_check_mode, rng), expected_result, "Testcase {testcase_name} failed" ) @@ -2785,8 +2787,6 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); - type Curve = curve_api::Bls12_446; - let crs = crs_gen::(d, k, B, q, t, msbs_zero_padding_bit_count, rng); let crs_max_k = crs_gen::(d, d, B, q, t, msbs_zero_padding_bit_count, rng); @@ -2848,19 +2848,24 @@ mod tests { expected_result, } in testcases { + let ct = testcase.encrypt_unchecked(PKEV2_TEST_PARAMS); assert_prove_and_verify( &testcase, + &ct, &format!("{name}_crs"), &crs, - rng, + ProofSanityCheckMode::Ignore, expected_result, + rng, ); assert_prove_and_verify( &testcase, + &ct, &format!("{name}_crs_max_k"), &crs_max_k, - rng, + ProofSanityCheckMode::Ignore, expected_result, + rng, ); } } @@ -2926,8 +2931,6 @@ mod tests { let ct = testcase.encrypt(PKEV2_TEST_PARAMS); - type Curve = curve_api::Bls12_446; - // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2938,37 +2941,23 @@ mod tests { let public_param_that_was_not_compressed = serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); - for public_param in [ - original_public_param, - public_param_that_was_compressed, - public_param_that_was_not_compressed, + for (public_param, test_name) in [ + (original_public_param, "original_params"), + ( + public_param_that_was_compressed, + "serialized_compressed_params", + ), + (public_param_that_was_not_compressed, "serialize_params"), ] { - let (public_commit, private_commit) = commit( - testcase.a.clone(), - testcase.b.clone(), - ct.c1.clone(), - ct.c2.clone(), - testcase.r.clone(), - testcase.e1.clone(), - testcase.m.clone(), - testcase.e2.clone(), + assert_prove_and_verify( + &testcase, + &ct, + test_name, &public_param, + ProofSanityCheckMode::Panic, + VerificationResult::Reject, rng, ); - - for load in [ComputeLoad::Proof, ComputeLoad::Verify] { - let proof = prove( - (&public_param, &public_commit), - &private_commit, - &testcase.metadata, - load, - rng, - ); - - assert!( - verify(&proof, (&public_param, &public_commit), &testcase.metadata).is_err() - ); - } } } @@ -2989,8 +2978,6 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS); - type Curve = curve_api::Bls12_446; - let crs_k = k + 1 + (rng.gen::() % (d - k)); let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); @@ -3042,8 +3029,6 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS); - type Curve = curve_api::Bls12_446; - let crs_k = k + 1 + (rng.gen::() % (d - k)); let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);