Skip to content

Commit

Permalink
Merge pull request #1318 from akoshelev/active-work-per-context
Browse files Browse the repository at this point in the history
Allow active work to be overridden by contexts
  • Loading branch information
akoshelev authored Oct 1, 2024
2 parents 404fa5b + e8ad98f commit 11339c5
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 16 deletions.
63 changes: 58 additions & 5 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,15 @@ impl Gateway {
&self,
channel_id: &HelperChannelId,
total_records: TotalRecords,
active_work: NonZeroUsize,
) -> send::SendingEnd<Role, M> {
let transport = &self.transports.mpc;
let channel = self.inner.mpc_senders.get::<M, _>(
channel_id,
transport,
self.config,
// we override the active work provided in config if caller
// wants to use a different value.
self.config.set_active_work(active_work),
self.query_id,
total_records,
);
Expand Down Expand Up @@ -280,11 +283,23 @@ impl GatewayConfig {
// we set active to be at least 2, so unwrap is fine.
self.active = NonZeroUsize::new(active).unwrap();
}

/// Creates a new configuration by overriding the value of active work.
#[must_use]
pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self {
Self {
active: active_work,
..*self
}
}
}

#[cfg(all(test, unit_test))]
mod tests {
use std::iter::{repeat, zip};
use std::{
iter::{repeat, zip},
num::NonZeroUsize,
};

use futures::{
future::{join, try_join, try_join_all},
Expand All @@ -293,12 +308,14 @@ mod tests {

use crate::{
ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions},
helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd},
helpers::{
ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords,
},
protocol::{
context::{Context, ShardedContext},
RecordId,
Gate, RecordId,
},
secret_sharing::replicated::semi_honest::AdditiveShare,
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
sharding::ShardConfiguration,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards},
Expand Down Expand Up @@ -516,6 +533,42 @@ mod tests {
});
}

#[test]
fn custom_active_work() {
run(|| async move {
let world = TestWorld::new_with(TestWorldConfig {
gateway_config: GatewayConfig {
active: 5.try_into().unwrap(),
..Default::default()
},
..Default::default()
});
let new_active_work = NonZeroUsize::new(3).unwrap();
assert!(new_active_work < world.gateway(Role::H1).config().active_work());
let sender = world.gateway(Role::H1).get_mpc_sender::<BA3>(
&ChannelId::new(Role::H2, Gate::default()),
TotalRecords::specified(15).unwrap(),
new_active_work,
);
try_join_all(
(0..new_active_work.get())
.map(|record_id| sender.send(record_id.into(), BA3::ZERO)),
)
.await
.unwrap();
let recv = world.gateway(Role::H2).get_mpc_receiver::<BA3>(&ChannelId {
peer: Role::H1,
gate: Gate::default(),
});
// this will hang if the original active work is used
try_join_all(
(0..new_active_work.get()).map(|record_id| recv.receive(record_id.into())),
)
.await
.unwrap();
});
}

async fn shard_comms_test(test_world: &TestWorld<WithShards<2>>) {
let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)];

Expand Down
4 changes: 3 additions & 1 deletion ipa-core/src/helpers/gateway/stall_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl<T: ObserveState> Observed<T> {
}

mod gateway {
use std::num::NonZeroUsize;

use delegate::delegate;

Expand Down Expand Up @@ -153,12 +154,13 @@ mod gateway {
&self,
channel_id: &HelperChannelId,
total_records: TotalRecords,
active_work: NonZeroUsize,
) -> SendingEnd<Role, M> {
Observed::wrap(
Weak::clone(self.get_sn()),
self.inner()
.gateway
.get_mpc_sender(channel_id, total_records),
.get_mpc_sender(channel_id, total_records, active_work),
)
}

Expand Down
12 changes: 10 additions & 2 deletions ipa-core/src/helpers/prss_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ pub async fn negotiate<R: RngCore + CryptoRng>(
let left_channel = ChannelId::new(gateway.role().peer(Direction::Left), gate.clone());
let right_channel = ChannelId::new(gateway.role().peer(Direction::Right), gate.clone());

let left_sender = gateway.get_mpc_sender::<PublicKey>(&left_channel, TotalRecords::ONE);
let right_sender = gateway.get_mpc_sender::<PublicKey>(&right_channel, TotalRecords::ONE);
let left_sender = gateway.get_mpc_sender::<PublicKey>(
&left_channel,
TotalRecords::ONE,
gateway.config().active_work(),
);
let right_sender = gateway.get_mpc_sender::<PublicKey>(
&right_channel,
TotalRecords::ONE,
gateway.config().active_work(),
);
let left_receiver = gateway.get_mpc_receiver::<PublicKey>(&left_channel);
let right_receiver = gateway.get_mpc_receiver::<PublicKey>(&right_channel);

Expand Down
10 changes: 6 additions & 4 deletions ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::{
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
active_work: NonZeroUsize,
}

impl<'a> DZKPUpgraded<'a> {
Expand Down Expand Up @@ -59,8 +58,11 @@ impl<'a> DZKPUpgraded<'a> {
};
Self {
validator_inner: Arc::downgrade(validator_inner),
base_ctx,
active_work,
// This overrides the active work for this context and all children
// created from it by using narrow, clone, etc.
// This allows all steps participating in malicious validation
// to use the same active work window and prevent deadlocks
base_ctx: base_ctx.set_active_work(active_work),
}
}

Expand Down Expand Up @@ -152,7 +154,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> {

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.active_work
self.base_ctx.active_work()
}
}

Expand Down
7 changes: 7 additions & 0 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ impl<'a> Context<'a> {
..self.inner
}
}

#[must_use]
pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self {
Self {
inner: self.inner.set_active_work(new_active_work),
}
}
}

impl<'a> super::Context for Context<'a> {
Expand Down
22 changes: 18 additions & 4 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> {
inner: Inner<'a>,
gate: Gate,
total_records: TotalRecords,
active_work: NonZeroUsize,
/// This indicates whether the system uses sharding or no. It's not ideal that we keep it here
/// because it gets cloned often, a potential solution to that, if this shows up on flame graph,
/// would be to move it to [`Inner`] struct.
Expand All @@ -180,9 +181,18 @@ impl<'a, B: ShardBinding> Base<'a, B> {
inner: Inner::new(participant, gateway),
gate,
total_records,
active_work: gateway.config().active_work(),
sharding,
}
}

#[must_use]
pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self {
Self {
active_work: new_active_work,
..self.clone()
}
}
}

impl ShardedContext for Base<'_, Sharded> {
Expand Down Expand Up @@ -217,6 +227,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> {
inner: self.inner.clone(),
gate: self.gate.narrow(step),
total_records: self.total_records,
active_work: self.active_work,
sharding: self.sharding.clone(),
}
}
Expand All @@ -226,6 +237,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> {
inner: self.inner.clone(),
gate: self.gate.clone(),
total_records: self.total_records.overwrite(total_records),
active_work: self.active_work,
sharding: self.sharding.clone(),
}
}
Expand Down Expand Up @@ -254,9 +266,11 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> {
}

fn send_channel<M: MpcMessage>(&self, role: Role) -> SendingEnd<Role, M> {
self.inner
.gateway
.get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records)
self.inner.gateway.get_mpc_sender(
&ChannelId::new(role, self.gate.clone()),
self.total_records,
self.active_work,
)
}

fn recv_channel<M: MpcMessage>(&self, role: Role) -> MpcReceivingEnd<M> {
Expand Down Expand Up @@ -322,7 +336,7 @@ impl ShardConfiguration for Base<'_, Sharded> {

impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> {
fn active_work(&self) -> NonZeroUsize {
self.inner.gateway.config().active_work()
self.active_work
}
}

Expand Down

0 comments on commit 11339c5

Please sign in to comment.