From 946da8b65f93c6906d914867b98358b98263e77e Mon Sep 17 00:00:00 2001 From: Andrew Kirillov <20803092+akirillo@users.noreply.github.com> Date: Tue, 5 Sep 2023 18:27:34 -0700 Subject: [PATCH] tests: added profiling feature flag & breakpoints to e2e tests --- src/darkpool.cairo | 70 +++++++++++++++++--------- src/darkpool/types.cairo | 2 + src/testing/tests/darkpool_tests.cairo | 17 +++++-- src/testing/tests/merkle_tests.cairo | 2 +- src/verifier.cairo | 14 ++++-- tests/src/darkpool/utils.rs | 34 +++++++++---- tests/src/utils.rs | 65 ++++++++++++++++++++++++ tests/src/verifier/utils.rs | 3 +- 8 files changed, 163 insertions(+), 44 deletions(-) diff --git a/src/darkpool.cairo b/src/darkpool.cairo index 75cf7045..79f28b23 100644 --- a/src/darkpool.cairo +++ b/src/darkpool.cairo @@ -330,14 +330,16 @@ mod Darkpool { let feature_flags = self.feature_flags.read(); - if breakpoint == Breakpoint::PreMerkleInitialize { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::PreMerkleInitialize { return; } // Initialize the Merkle tree _get_merkle_tree(@self).initialize(height, feature_flags); - if breakpoint == Breakpoint::MerkleInitialize { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::MerkleInitialize { return; } @@ -517,7 +519,8 @@ mod Darkpool { // Inject witness append_statement_commitments(@statement, ref witness_commitments); - if breakpoint == Breakpoint::AppendStatementCommitments { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::AppendStatementCommitments { return; } @@ -531,7 +534,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::QueueVerification { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::QueueVerification { return; } @@ -566,7 +570,8 @@ mod Darkpool { let verified = verifier .step_verification(Circuit::ValidWalletCreate(()).into(), verification_job_id); - if breakpoint == Breakpoint::StepVerification { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::StepVerification { return Option::Some(Result::Err('breakpoint reached')); } @@ -588,14 +593,16 @@ mod Darkpool { hash_input.span(), self.feature_flags.read().use_base_field_poseidon ); - if breakpoint == Breakpoint::SharesCommitment { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::SharesCommitment { return Option::Some(Result::Err('breakpoint reached')); } let merkle_tree = _get_merkle_tree(@self); let new_root = merkle_tree.insert(total_shares_commitment); - if breakpoint == Breakpoint::MerkleInsert { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::MerkleInsert { return Option::Some(Result::Err('breakpoint reached')); } @@ -644,7 +651,8 @@ mod Darkpool { // now signing a new wallet with a new root key. let statement_hash = hash_statement(@statement); - if breakpoint == Breakpoint::HashStatement { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::HashStatement { return; } @@ -658,7 +666,7 @@ mod Darkpool { 'invalid statement signature' ); - if breakpoint == Breakpoint::CheckECDSA { + if self.feature_flags.read().enable_profiling && breakpoint == Breakpoint::CheckECDSA { return; } @@ -668,7 +676,8 @@ mod Darkpool { // Inject witness append_statement_commitments(@statement, ref witness_commitments); - if breakpoint == Breakpoint::AppendStatementCommitments { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::AppendStatementCommitments { return; } @@ -682,7 +691,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::QueueVerification { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::QueueVerification { return; } @@ -725,7 +735,8 @@ mod Darkpool { let verified = verifier .step_verification(Circuit::ValidWalletUpdate(()).into(), verification_job_id); - if breakpoint == Breakpoint::StepVerification { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::StepVerification { return Option::Some(Result::Err('breakpoint reached')); } @@ -749,14 +760,16 @@ mod Darkpool { hash_input.span(), self.feature_flags.read().use_base_field_poseidon ); - if breakpoint == Breakpoint::SharesCommitment { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::SharesCommitment { return Option::Some(Result::Err('breakpoint reached')); } let merkle_tree = _get_merkle_tree(@self); let new_root = merkle_tree.insert(total_shares_commitment); - if breakpoint == Breakpoint::MerkleInsert { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::MerkleInsert { return Option::Some(Result::Err('breakpoint reached')); } @@ -837,7 +850,8 @@ mod Darkpool { let verifier = _get_verifier(@self); - if breakpoint == Breakpoint::PreInjectAndQueue { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::PreInjectAndQueue { return; } @@ -855,7 +869,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::Party0ValidCommitments { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::Party0ValidCommitments { return; } @@ -873,7 +888,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::Party0ValidReblind { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::Party0ValidReblind { return; } @@ -891,7 +907,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::Party1ValidCommitments { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::Party1ValidCommitments { return; } @@ -909,7 +926,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::Party1ValidReblind { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::Party1ValidReblind { return; } @@ -924,7 +942,8 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::ValidMatchMpc { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::ValidMatchMpc { return; } @@ -941,7 +960,7 @@ mod Darkpool { breakpoint, ); - if breakpoint == Breakpoint::ValidSettle { + if self.feature_flags.read().enable_profiling && breakpoint == Breakpoint::ValidSettle { return; } @@ -981,7 +1000,8 @@ mod Darkpool { ) -> Option> { let poll_result = _check_and_poll_process_match(@self, verification_job_id); - if breakpoint == Breakpoint::CheckAndPoll { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::CheckAndPoll { return Option::Some(Result::Err('breakpoint reached')); } @@ -1206,7 +1226,8 @@ mod Darkpool { party_1_hash_input.span(), use_base_field_poseidon ); - if breakpoint == Breakpoint::SharesCommitment { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::SharesCommitment { return Option::Some(Result::Err('breakpoint reached')); } @@ -1214,7 +1235,8 @@ mod Darkpool { merkle_tree.insert(party_0_total_shares_commitment); let new_root = merkle_tree.insert(party_1_total_shares_commitment); - if breakpoint == Breakpoint::MerkleInsert { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::MerkleInsert { return Option::Some(Result::Err('breakpoint reached')); } diff --git a/src/darkpool/types.cairo b/src/darkpool/types.cairo index 01ee04e0..4bfb83b7 100644 --- a/src/darkpool/types.cairo +++ b/src/darkpool/types.cairo @@ -100,6 +100,8 @@ struct FeatureFlags { use_base_field_poseidon: bool, /// Whether or not to verify proofs disable_verification: bool, + /// Whether or not to enable profiling + enable_profiling: bool, } // -------------------------- diff --git a/src/testing/tests/darkpool_tests.cairo b/src/testing/tests/darkpool_tests.cairo index 8e28d5a9..8afd7b64 100644 --- a/src/testing/tests/darkpool_tests.cairo +++ b/src/testing/tests/darkpool_tests.cairo @@ -97,7 +97,9 @@ fn test_upgrade_verifier() { let test_caller = contract_address_try_from_felt252(TEST_CALLER).unwrap(); set_contract_address(test_caller); let mut darkpool = setup_darkpool_with_flags( - FeatureFlags { use_base_field_poseidon: true, disable_verification: false } + FeatureFlags { + use_base_field_poseidon: true, disable_verification: false, enable_profiling: false + } ); darkpool.upgrade_verifier(DummyUpgradeTarget::TEST_CLASS_HASH.try_into().unwrap()); @@ -191,7 +193,9 @@ fn test_upgrade_verifier_access() { let test_caller = contract_address_try_from_felt252(TEST_CALLER).unwrap(); set_contract_address(test_caller); let mut darkpool = setup_darkpool_with_flags( - FeatureFlags { use_base_field_poseidon: true, disable_verification: false } + FeatureFlags { + use_base_field_poseidon: true, disable_verification: false, enable_profiling: false + } ); let dummy_caller = contract_address_try_from_felt252(DUMMY_CALLER).unwrap(); @@ -227,7 +231,10 @@ fn test_initialize_twice() { let mut calldata = ArrayTrait::new(); calldata.append(TEST_CALLER); Serde::::serialize( - @FeatureFlags { use_base_field_poseidon: false, disable_verification: false }, ref calldata + @FeatureFlags { + use_base_field_poseidon: false, disable_verification: false, enable_profiling: false + }, + ref calldata ); let (darkpool_address, _) = deploy_syscall( @@ -249,7 +256,9 @@ fn setup_darkpool() -> IDarkpoolDispatcher { // Default feature flags used disable the scalar field poseidon hash and the verifier, as these // are generally not what is being tested here and disabling them speeds up tests. setup_darkpool_with_flags( - FeatureFlags { use_base_field_poseidon: true, disable_verification: true } + FeatureFlags { + use_base_field_poseidon: true, disable_verification: true, enable_profiling: false + } ) } diff --git a/src/testing/tests/merkle_tests.cairo b/src/testing/tests/merkle_tests.cairo index 54b215bc..ed6bb796 100644 --- a/src/testing/tests/merkle_tests.cairo +++ b/src/testing/tests/merkle_tests.cairo @@ -55,7 +55,7 @@ fn setup_merkle() -> ContractState { merkle .initialize( TEST_MERKLE_HEIGHT, - FeatureFlags { use_base_field_poseidon: true, disable_verification: true } + FeatureFlags { use_base_field_poseidon: true, disable_verification: true, enable_profiling: false } ); merkle } diff --git a/src/verifier.cairo b/src/verifier.cairo index f2bab067..deea1f6a 100644 --- a/src/verifier.cairo +++ b/src/verifier.cairo @@ -257,14 +257,15 @@ mod MultiVerifier { let W_V = self.W_V.read(circuit_id).inner; let c = self.c.read(circuit_id).inner; - if breakpoint == Breakpoint::ReadCircuitParams { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::ReadCircuitParams { return; } // Prep `RemainingGenerators` structs for G and H generators let (G_rem, H_rem) = prep_rem_gens(n_plus); - if breakpoint == Breakpoint::PrepRemGens { + if self.feature_flags.read().enable_profiling && breakpoint == Breakpoint::PrepRemGens { return; } @@ -273,7 +274,8 @@ mod MultiVerifier { @proof, witness_commitments.span(), m, n_plus ); - if breakpoint == Breakpoint::SqueezeChallengeScalars { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::SqueezeChallengeScalars { return; } @@ -294,7 +296,8 @@ mod MultiVerifier { y_inv, z, u, x, w, r, @proof, n, n_plus, @W_L, @W_R, @c, ); - if breakpoint == Breakpoint::PrepRemScalarPolys { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::PrepRemScalarPolys { return; } @@ -304,7 +307,8 @@ mod MultiVerifier { ref proof, ref witness_commitments, pedersen_generator, pedersen_generator, ); - if breakpoint == Breakpoint::PrepRemCommitments { + if self.feature_flags.read().enable_profiling + && breakpoint == Breakpoint::PrepRemCommitments { return; } diff --git a/tests/src/darkpool/utils.rs b/tests/src/darkpool/utils.rs index 78104dd1..2bdaa0a4 100644 --- a/tests/src/darkpool/utils.rs +++ b/tests/src/darkpool/utils.rs @@ -44,11 +44,11 @@ use crate::{ call_contract, check_verification_job_status, felt_to_u128, get_circuit_params, get_contract_address_from_artifact, get_dummy_statement_scalars, get_sierra_class_hash_from_artifact, global_setup, invoke_contract, parameterize_circuit, - random_felt, scalar_to_felt, setup_sequencer, singleprover_prove, CalldataSerializable, - Circuit, DummyValidCommitments, DummyValidMatchMpc, DummyValidReblind, DummyValidSettle, - DummyValidWalletCreate, DummyValidWalletUpdate, FeatureFlags, MatchPayload, NewWalletArgs, - ProcessMatchArgs, TestConfig, UpdateWalletArgs, ARTIFACTS_PATH_ENV_VAR, DUMMY_VALUE, - SK_ROOT, + random_felt, scalar_to_felt, setup_sequencer, singleprover_prove, Breakpoint, + CalldataSerializable, Circuit, DummyValidCommitments, DummyValidMatchMpc, + DummyValidReblind, DummyValidSettle, DummyValidWalletCreate, DummyValidWalletUpdate, + FeatureFlags, MatchPayload, NewWalletArgs, ProcessMatchArgs, TestConfig, UpdateWalletArgs, + ARTIFACTS_PATH_ENV_VAR, DUMMY_VALUE, SK_ROOT, }, }; @@ -292,12 +292,13 @@ pub async fn initialize_darkpool( verifier_class_hash: FieldElement, merkle_height: FieldElement, ) -> Result<()> { - let calldata = vec![ + let mut calldata = vec![ merkle_class_hash, nullifier_set_class_hash, verifier_class_hash, merkle_height, ]; + calldata.extend(Breakpoint::None.to_calldata().into_iter()); initialize(account, darkpool_address, calldata) .await @@ -346,11 +347,14 @@ pub async fn poll_new_wallet( account: &ScriptAccount, verification_job_id: FieldElement, ) -> Result<()> { + let calldata = iter::once(verification_job_id) + .chain(Breakpoint::None.to_calldata()) + .collect(); invoke_contract( account, *DARKPOOL_ADDRESS.get().unwrap(), POLL_NEW_WALLET_FN_NAME, - vec![verification_job_id], + calldata, ) .await .map(|_| ()) @@ -406,11 +410,14 @@ pub async fn poll_update_wallet( account: &ScriptAccount, verification_job_id: FieldElement, ) -> Result<()> { + let calldata = iter::once(verification_job_id) + .chain(Breakpoint::None.to_calldata()) + .collect(); invoke_contract( account, *DARKPOOL_ADDRESS.get().unwrap(), POLL_UPDATE_WALLET_FN_NAME, - vec![verification_job_id], + calldata, ) .await .map(|_| ()) @@ -465,11 +472,14 @@ pub async fn poll_process_match( account: &ScriptAccount, verification_job_id: FieldElement, ) -> Result<()> { + let calldata = iter::once(verification_job_id) + .chain(Breakpoint::None.to_calldata()) + .collect(); invoke_contract( account, *DARKPOOL_ADDRESS.get().unwrap(), POLL_PROCESS_MATCH_FN_NAME, - vec![verification_job_id], + calldata, ) .await .map(|_| ()) @@ -560,6 +570,7 @@ pub fn get_dummy_new_wallet_args() -> Result { let (_, statement) = create_default_witness_statement(); let (_, proof) = singleprover_prove::((), statement.clone())?; let verification_job_id = random_felt(); + let breakpoint = Breakpoint::None; Ok(NewWalletArgs { wallet_blinder_share, @@ -567,6 +578,7 @@ pub fn get_dummy_new_wallet_args() -> Result { proof, witness_commitments: vec![], verification_job_id, + breakpoint, }) } @@ -590,6 +602,7 @@ pub fn get_dummy_update_wallet_args( let statement_signature = sign_scalar_message(&statement.to_scalars(), &SK_ROOT); let (_, proof) = singleprover_prove::((), statement.clone())?; let verification_job_id = random_felt(); + let breakpoint = Breakpoint::None; Ok(UpdateWalletArgs { wallet_blinder_share, @@ -598,6 +611,7 @@ pub fn get_dummy_update_wallet_args( proof, witness_commitments: vec![], verification_job_id, + breakpoint, }) } @@ -617,6 +631,7 @@ pub fn get_dummy_process_match_args( let (_, valid_settle_proof) = singleprover_prove::((), valid_settle_statement.clone())?; let verification_job_id = random_felt(); + let breakpoint = Breakpoint::None; Ok(ProcessMatchArgs { party_0_match_payload, @@ -627,5 +642,6 @@ pub fn get_dummy_process_match_args( valid_settle_witness_commitments: vec![], valid_settle_proof, verification_job_id, + breakpoint, }) } diff --git a/tests/src/utils.rs b/tests/src/utils.rs index eb018cf2..6dddf8b2 100644 --- a/tests/src/utils.rs +++ b/tests/src/utils.rs @@ -560,6 +560,7 @@ pub struct NewWalletArgs { pub proof: R1CSProof, pub witness_commitments: Vec, pub verification_job_id: FieldElement, + pub breakpoint: Breakpoint, } pub struct UpdateWalletArgs { @@ -569,6 +570,7 @@ pub struct UpdateWalletArgs { pub proof: R1CSProof, pub witness_commitments: Vec, pub verification_job_id: FieldElement, + pub breakpoint: Breakpoint, } pub struct ProcessMatchArgs { @@ -580,6 +582,7 @@ pub struct ProcessMatchArgs { pub valid_settle_witness_commitments: Vec, pub valid_settle_proof: R1CSProof, pub verification_job_id: FieldElement, + pub breakpoint: Breakpoint, } #[derive(Default)] @@ -588,6 +591,34 @@ pub struct FeatureFlags { pub use_base_field_poseidon: bool, /// Whether or not to verify proofs pub disable_verification: bool, + /// Whether or not to enable profiling + pub enable_profiling: bool, +} + +pub enum Breakpoint { + None, + ReadCircuitParams, + PrepRemGens, + SqueezeChallengeScalars, + PrepRemScalarPolys, + PrepRemCommitments, + PreMerkleInitialize, + MerkleInitialize, + AppendStatementCommitments, + QueueVerification, + StepVerification, + SharesCommitment, + MerkleInsert, + HashStatement, + CheckECDSA, + PreInjectAndQueue, + Party0ValidCommitments, + Party0ValidReblind, + Party1ValidCommitments, + Party1ValidReblind, + ValidMatchMpc, + ValidSettle, + CheckAndPoll, } pub trait CalldataSerializable { @@ -865,6 +896,7 @@ impl CalldataSerializable for NewWalletArgs { .chain(self.witness_commitments.to_calldata()) .chain(self.proof.to_calldata()) .chain(self.verification_job_id.to_calldata()) + .chain(self.breakpoint.to_calldata()) .collect() } } @@ -880,6 +912,7 @@ impl CalldataSerializable for UpdateWalletArgs { .chain(self.witness_commitments.to_calldata()) .chain(self.proof.to_calldata()) .chain(self.verification_job_id.to_calldata()) + .chain(self.breakpoint.to_calldata()) .collect() } } @@ -896,6 +929,7 @@ impl CalldataSerializable for ProcessMatchArgs { .chain(self.valid_settle_witness_commitments.to_calldata()) .chain(self.valid_settle_proof.to_calldata()) .chain(self.verification_job_id.to_calldata()) + .chain(self.breakpoint.to_calldata()) .collect() } } @@ -905,10 +939,41 @@ impl CalldataSerializable for FeatureFlags { vec![ FieldElement::from(self.use_base_field_poseidon as u8), FieldElement::from(self.disable_verification as u8), + FieldElement::from(self.enable_profiling as u8), ] } } +impl CalldataSerializable for Breakpoint { + fn to_calldata(&self) -> Vec { + vec![match self { + Breakpoint::None => FieldElement::from(0_u8), + Breakpoint::ReadCircuitParams => FieldElement::from(1_u8), + Breakpoint::PrepRemGens => FieldElement::from(2_u8), + Breakpoint::SqueezeChallengeScalars => FieldElement::from(3_u8), + Breakpoint::PrepRemScalarPolys => FieldElement::from(4_u8), + Breakpoint::PrepRemCommitments => FieldElement::from(5_u8), + Breakpoint::PreMerkleInitialize => FieldElement::from(6_u8), + Breakpoint::MerkleInitialize => FieldElement::from(7_u8), + Breakpoint::AppendStatementCommitments => FieldElement::from(8_u8), + Breakpoint::QueueVerification => FieldElement::from(9_u8), + Breakpoint::StepVerification => FieldElement::from(10_u8), + Breakpoint::SharesCommitment => FieldElement::from(11_u8), + Breakpoint::MerkleInsert => FieldElement::from(12_u8), + Breakpoint::HashStatement => FieldElement::from(13_u8), + Breakpoint::CheckECDSA => FieldElement::from(14_u8), + Breakpoint::PreInjectAndQueue => FieldElement::from(15_u8), + Breakpoint::Party0ValidCommitments => FieldElement::from(16_u8), + Breakpoint::Party0ValidReblind => FieldElement::from(17_u8), + Breakpoint::Party1ValidCommitments => FieldElement::from(18_u8), + Breakpoint::Party1ValidReblind => FieldElement::from(19_u8), + Breakpoint::ValidMatchMpc => FieldElement::from(20_u8), + Breakpoint::ValidSettle => FieldElement::from(21_u8), + Breakpoint::CheckAndPoll => FieldElement::from(22_u8), + }] + } +} + // ------------------ // | DUMMY CIRCUITS | // ------------------ diff --git a/tests/src/verifier/utils.rs b/tests/src/verifier/utils.rs index 25140c7c..eb711930 100644 --- a/tests/src/verifier/utils.rs +++ b/tests/src/verifier/utils.rs @@ -19,7 +19,7 @@ use tracing::debug; use crate::utils::{ get_contract_address_from_artifact, global_setup, invoke_contract, parameterize_circuit, - CalldataSerializable, CircuitParams, ARTIFACTS_PATH_ENV_VAR, TRANSCRIPT_SEED, + Breakpoint, CalldataSerializable, CircuitParams, ARTIFACTS_PATH_ENV_VAR, TRANSCRIPT_SEED, }; pub const FUZZ_ROUNDS: usize = 1; @@ -114,6 +114,7 @@ pub async fn queue_verification_job( .chain(proof.to_calldata()) .chain(witness_commitments.to_calldata()) .chain(verification_job_id.to_calldata()) + .chain(Breakpoint::None.to_calldata()) .collect(); invoke_contract(