From 896cf3eb57c53fd099fa796f9cb4167d863c3a30 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 2 Dec 2024 17:32:11 -0800 Subject: [PATCH] Fixes for timeouts in CI --- .../src/protocol/context/dzkp_validator.rs | 55 ++++++++++--------- ipa-core/src/protocol/dp/mod.rs | 10 +++- ipa-core/src/protocol/hybrid/oprf.rs | 6 +- .../ipa_prf/aggregation/breakdown_reveal.rs | 5 +- .../src/protocol/ipa_prf/aggregation/mod.rs | 2 +- ipa-core/src/test_fixture/world.rs | 14 +++-- 6 files changed, 56 insertions(+), 36 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 4f0b47096..31687808b 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -973,7 +973,7 @@ mod tests { seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_executor::run_random, - test_fixture::{join3v, Reconstruct, Runner, TestWorld}, + test_fixture::{join3v, Reconstruct, Runner, TestWorld, TestWorldConfig}, }; async fn test_select_semi_honest() @@ -1162,30 +1162,35 @@ mod tests { let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); - let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() - .malicious( - zip(bit.clone(), zip(a.clone(), b.clone())), - |ctx, inputs| async move { - let v = ctx - .set_total_records(count) - .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); - let m_ctx = v.context(); - - v.validated_seq_join(stream::iter(inputs).enumerate().map( - |(i, (bit_share, (a_share, b_share)))| { - let m_ctx = m_ctx.clone(); - async move { - select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) - .await - } - }, - )) - .try_collect() - .await - }, - ) - .await - .map(Result::unwrap); + // Timeout is 10 seconds plus count * (3 ms). + let config = TestWorldConfig::default() + .with_timeout_secs(10 + 3 * u64::try_from(count).unwrap() / 1000); + + let [ab0, ab1, ab2]: [Vec>; 3] = + TestWorld::::with_config(&config) + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); let ab: Vec = [ab0, ab1, ab2].reconstruct(); diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index 6ec703c0f..fbd8263f4 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -619,6 +619,7 @@ mod test { replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, BitDecomposed, SharedValue, TransposeFrom, }, + sharding::NotSharded, telemetry::metrics::BYTES_SENT, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -863,7 +864,8 @@ mod test { if std::env::var("EXEC_SLOW_TESTS").is_err() { return; } - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( @@ -898,7 +900,8 @@ mod test { type OutputValue = BA16; const NUM_BREAKDOWNS: u32 = 32; let num_bernoulli: u32 = 2000; - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( @@ -933,7 +936,8 @@ mod test { type OutputValue = BA16; const NUM_BREAKDOWNS: u32 = 256; let num_bernoulli: u32 = 1000; - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index da2bf903f..652c6ac1d 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -200,7 +200,10 @@ where #[cfg(all(test, unit_test, feature = "in-memory-infra"))] mod test { - use std::collections::{HashMap, HashSet}; + use std::{ + collections::{HashMap, HashSet}, + time::Duration, + }; use ipa_step::StepNarrow; @@ -218,6 +221,7 @@ mod test { const SHARDS: usize = 2; let world: TestWorld> = TestWorld::with_shards(TestWorldConfig { initial_gate: Some(Gate::default().narrow(&ProtocolStep::Hybrid)), + timeout: Duration::from_secs(60), ..Default::default() }); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 198de6be9..8b799d0ac 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -392,9 +392,12 @@ pub mod tests { #[test] #[cfg(not(feature = "shuttle"))] // too slow fn malicious_happy_path() { + use crate::{sharding::NotSharded, test_fixture::TestWorldConfig}; + type HV = BA16; run(|| async { - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 6a9adb345..e8d7631b7 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -536,7 +536,7 @@ pub mod tests { proptest! { #[test] - fn aggregate_proptest( + fn aggregate_values_proptest( input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN), seed in any::(), ) { diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 6042b20eb..d16a1a804 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -387,11 +387,15 @@ impl TestWorld { } async fn with_timeout(&self, fut: F) -> F::Output { - let Ok(output) = tokio::time::timeout(self.timeout, fut).await else { - tracing::error!("timed out after {:?}", self.timeout); - panic!("timed out after {:?}", self.timeout); - }; - output + if cfg!(feature = "shuttle") { + fut.await + } else { + let Ok(output) = tokio::time::timeout(self.timeout, fut).await else { + tracing::error!("timed out after {:?}", self.timeout); + panic!("timed out after {:?}", self.timeout); + }; + output + } } }