diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index b1c57b3cc..e654f85f7 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -150,12 +150,15 @@ impl Gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroUsize, ) -> send::SendingEnd { let transport = &self.transports.mpc; let channel = self.inner.mpc_senders.get::( channel_id, transport, - self.config, + // we override the active work provided in config if caller + // wants to use a different value. + self.config.set_active_work(active_work), self.query_id, total_records, ); @@ -280,11 +283,23 @@ impl GatewayConfig { // we set active to be at least 2, so unwrap is fine. self.active = NonZeroUsize::new(active).unwrap(); } + + /// Creates a new configuration by overriding the value of active work. + #[must_use] + pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self { + Self { + active: active_work, + ..*self + } + } } #[cfg(all(test, unit_test))] mod tests { - use std::iter::{repeat, zip}; + use std::{ + iter::{repeat, zip}, + num::NonZeroUsize, + }; use futures::{ future::{join, try_join, try_join_all}, @@ -293,12 +308,14 @@ mod tests { use crate::{ ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, - helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd}, + helpers::{ + ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords, + }, protocol::{ context::{Context, ShardedContext}, - RecordId, + Gate, RecordId, }, - secret_sharing::replicated::semi_honest::AdditiveShare, + secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, @@ -516,6 +533,42 @@ mod tests { }); } + #[test] + fn custom_active_work() { + run(|| async move { + let world = TestWorld::new_with(TestWorldConfig { + gateway_config: GatewayConfig { + active: 5.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }); + let new_active_work = NonZeroUsize::new(3).unwrap(); + assert!(new_active_work < world.gateway(Role::H1).config().active_work()); + let sender = world.gateway(Role::H1).get_mpc_sender::( + &ChannelId::new(Role::H2, Gate::default()), + TotalRecords::specified(15).unwrap(), + new_active_work, + ); + try_join_all( + (0..new_active_work.get()) + .map(|record_id| sender.send(record_id.into(), BA3::ZERO)), + ) + .await + .unwrap(); + let recv = world.gateway(Role::H2).get_mpc_receiver::(&ChannelId { + peer: Role::H1, + gate: Gate::default(), + }); + // this will hang if the original active work is used + try_join_all( + (0..new_active_work.get()).map(|record_id| recv.receive(record_id.into())), + ) + .await + .unwrap(); + }); + } + async fn shard_comms_test(test_world: &TestWorld>) { let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 49a879be4..43706f450 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -67,6 +67,7 @@ impl Observed { } mod gateway { + use std::num::NonZeroUsize; use delegate::delegate; @@ -153,12 +154,13 @@ mod gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroUsize, ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), self.inner() .gateway - .get_mpc_sender(channel_id, total_records), + .get_mpc_sender(channel_id, total_records, active_work), ) } diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index 8171ca019..f9284f9eb 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -21,8 +21,16 @@ pub async fn negotiate( let left_channel = ChannelId::new(gateway.role().peer(Direction::Left), gate.clone()); let right_channel = ChannelId::new(gateway.role().peer(Direction::Right), gate.clone()); - let left_sender = gateway.get_mpc_sender::(&left_channel, TotalRecords::ONE); - let right_sender = gateway.get_mpc_sender::(&right_channel, TotalRecords::ONE); + let left_sender = gateway.get_mpc_sender::( + &left_channel, + TotalRecords::ONE, + gateway.config().active_work(), + ); + let right_sender = gateway.get_mpc_sender::( + &right_channel, + TotalRecords::ONE, + gateway.config().active_work(), + ); let left_receiver = gateway.get_mpc_receiver::(&left_channel); let right_receiver = gateway.get_mpc_receiver::(&right_channel); diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 70dd3d2af..9f28239ba 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -29,7 +29,6 @@ use crate::{ pub struct DZKPUpgraded<'a> { validator_inner: Weak>, base_ctx: MaliciousContext<'a>, - active_work: NonZeroUsize, } impl<'a> DZKPUpgraded<'a> { @@ -59,8 +58,11 @@ impl<'a> DZKPUpgraded<'a> { }; Self { validator_inner: Arc::downgrade(validator_inner), - base_ctx, - active_work, + // This overrides the active work for this context and all children + // created from it by using narrow, clone, etc. + // This allows all steps participating in malicious validation + // to use the same active work window and prevent deadlocks + base_ctx: base_ctx.set_active_work(active_work), } } @@ -152,7 +154,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> { impl<'a> SeqJoin for DZKPUpgraded<'a> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.base_ctx.active_work() } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 18b9b8e29..8c287b1f2 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -78,6 +78,13 @@ impl<'a> Context<'a> { ..self.inner } } + + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + Self { + inner: self.inner.set_active_work(new_active_work), + } + } } impl<'a> super::Context for Context<'a> { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 4a090bae3..eead81a16 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -162,6 +162,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> { inner: Inner<'a>, gate: Gate, total_records: TotalRecords, + active_work: NonZeroUsize, /// This indicates whether the system uses sharding or no. It's not ideal that we keep it here /// because it gets cloned often, a potential solution to that, if this shows up on flame graph, /// would be to move it to [`Inner`] struct. @@ -180,9 +181,18 @@ impl<'a, B: ShardBinding> Base<'a, B> { inner: Inner::new(participant, gateway), gate, total_records, + active_work: gateway.config().active_work(), sharding, } } + + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + Self { + active_work: new_active_work, + ..self.clone() + } + } } impl ShardedContext for Base<'_, Sharded> { @@ -217,6 +227,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.narrow(step), total_records: self.total_records, + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -226,6 +237,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.clone(), total_records: self.total_records.overwrite(total_records), + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -254,9 +266,11 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { } fn send_channel(&self, role: Role) -> SendingEnd { - self.inner - .gateway - .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + self.inner.gateway.get_mpc_sender( + &ChannelId::new(role, self.gate.clone()), + self.total_records, + self.active_work, + ) } fn recv_channel(&self, role: Role) -> MpcReceivingEnd { @@ -322,7 +336,7 @@ impl ShardConfiguration for Base<'_, Sharded> { impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> { fn active_work(&self) -> NonZeroUsize { - self.inner.gateway.config().active_work() + self.active_work } }