diff --git a/synedrion/src/cggmp21/protocols/generic.rs b/synedrion/src/cggmp21/protocols/generic.rs index 691030c0..8710a882 100644 --- a/synedrion/src/cggmp21/protocols/generic.rs +++ b/synedrion/src/cggmp21/protocols/generic.rs @@ -1,3 +1,4 @@ +use alloc::collections::{BTreeMap, BTreeSet}; use alloc::string::String; use alloc::vec::Vec; use core::fmt::Debug; @@ -6,7 +7,7 @@ use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use super::common::PartyIdx; -use crate::tools::collections::{HoleRange, HoleVec}; +use crate::tools::collections::{HoleRange, HoleVec, HoleVecAccum}; /// A round that sends out a broadcast. pub(crate) trait BroadcastRound: BaseRound { @@ -111,19 +112,93 @@ pub(crate) trait BaseRound { const ROUND_NUM: u8; // TODO (#78): find a way to derive it from `ROUND_NUM` const NEXT_ROUND_NUM: Option; + + fn num_parties(&self) -> usize; + fn party_idx(&self) -> PartyIdx; +} + +pub(crate) trait Round: BroadcastRound + DirectRound + BaseRound + Finalizable {} + +impl Round for R {} + +#[allow(clippy::enum_variant_names)] +pub(crate) enum FinalizationRequirement { + AllBroadcasts, + AllDms, + AllBroadcastsAndDms, } -pub(crate) trait Round: BroadcastRound + DirectRound + BaseRound {} +pub(crate) trait Finalizable: BroadcastRound + DirectRound { + fn requirement() -> FinalizationRequirement; -impl Round for R {} + fn can_finalize<'a>( + &self, + bc_payloads: impl Iterator, + dm_payloads: impl Iterator, + dm_artifacts: impl Iterator, + ) -> bool { + match Self::requirement() { + FinalizationRequirement::AllBroadcasts => { + contains_all_except(bc_payloads, self.num_parties(), self.party_idx()) + } + FinalizationRequirement::AllDms => { + contains_all_except(dm_payloads, self.num_parties(), self.party_idx()) + && contains_all_except(dm_artifacts, self.num_parties(), self.party_idx()) + } + FinalizationRequirement::AllBroadcastsAndDms => { + contains_all_except(bc_payloads, self.num_parties(), self.party_idx()) + && contains_all_except(dm_payloads, self.num_parties(), self.party_idx()) + && contains_all_except(dm_artifacts, self.num_parties(), self.party_idx()) + } + } + } + + fn missing_payloads<'a>( + &self, + bc_payloads: impl Iterator, + dm_payloads: impl Iterator, + dm_artifacts: impl Iterator, + ) -> BTreeSet { + match Self::requirement() { + FinalizationRequirement::AllBroadcasts => { + missing_payloads(bc_payloads, self.num_parties(), self.party_idx()) + } + FinalizationRequirement::AllDms => { + let mut missing = + missing_payloads(dm_payloads, self.num_parties(), self.party_idx()); + missing.append(&mut missing_payloads( + dm_artifacts, + self.num_parties(), + self.party_idx(), + )); + missing + } + FinalizationRequirement::AllBroadcastsAndDms => { + let mut missing = + missing_payloads(bc_payloads, self.num_parties(), self.party_idx()); + missing.append(&mut missing_payloads( + dm_payloads, + self.num_parties(), + self.party_idx(), + )); + missing.append(&mut missing_payloads( + dm_artifacts, + self.num_parties(), + self.party_idx(), + )); + missing + } + } + } +} pub(crate) trait FinalizableToResult: Round + BaseRound { fn finalize_to_result( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError>; } @@ -132,9 +207,9 @@ pub(crate) trait FinalizableToNextRound: Round + BaseRound { fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result>; } @@ -177,3 +252,45 @@ pub(crate) fn all_parties_except(num_parties: usize, party_idx: PartyIdx) -> Vec .map(PartyIdx::from_usize) .collect() } + +fn contains_all_except<'a>( + party_idxs: impl Iterator, + num_parties: usize, + party_idx: PartyIdx, +) -> bool { + let set = party_idxs.cloned().collect::>(); + for idx in HoleRange::new(num_parties, party_idx.as_usize()) { + if !set.contains(&PartyIdx::from_usize(idx)) { + return false; + } + } + true +} + +fn missing_payloads<'a>( + party_idxs: impl Iterator, + num_parties: usize, + party_idx: PartyIdx, +) -> BTreeSet { + let set = party_idxs.cloned().collect::>(); + let mut missing = BTreeSet::new(); + for idx in HoleRange::new(num_parties, party_idx.as_usize()) { + let party_idx = PartyIdx::from_usize(idx); + if !set.contains(&party_idx) { + missing.insert(party_idx); + } + } + missing +} + +pub(crate) fn try_to_holevec( + payloads: BTreeMap, + num_parties: usize, + party_idx: PartyIdx, +) -> Option> { + let mut accum = HoleVecAccum::new(num_parties, party_idx.as_usize()); + for (idx, elem) in payloads.into_iter() { + accum.insert(idx.as_usize(), elem)?; + } + accum.finalize() +} diff --git a/synedrion/src/cggmp21/protocols/interactive_signing.rs b/synedrion/src/cggmp21/protocols/interactive_signing.rs index d3be1b67..d75401b9 100644 --- a/synedrion/src/cggmp21/protocols/interactive_signing.rs +++ b/synedrion/src/cggmp21/protocols/interactive_signing.rs @@ -1,4 +1,5 @@ use alloc::boxed::Box; +use alloc::collections::BTreeMap; use core::marker::PhantomData; use rand_core::CryptoRngCore; @@ -13,7 +14,6 @@ use super::signing::{self, SigningResult}; use super::wrappers::{wrap_finalize_error, ResultWrapper, RoundWrapper}; use crate::cggmp21::params::SchemeParams; use crate::curve::{RecoverableSignature, Scalar}; -use crate::tools::collections::HoleVec; /// Possible results of the merged Presigning and Signing protocols. #[derive(Debug, Clone, Copy)] @@ -127,9 +127,9 @@ impl FinalizableToNextRound for Round1

{ fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result> { let round = self .round @@ -163,9 +163,9 @@ impl FinalizableToNextRound for Round2

{ fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result> { let round = self .round @@ -199,9 +199,9 @@ impl FinalizableToNextRound for Round3

{ fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result> { let presigning_data = self .round @@ -250,9 +250,9 @@ impl FinalizableToResult for Round4

{ fn finalize_to_result( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { self.round .finalize_to_result(rng, bc_payloads, dm_payloads, dm_artifacts) diff --git a/synedrion/src/cggmp21/protocols/key_gen.rs b/synedrion/src/cggmp21/protocols/key_gen.rs index 021ef53c..ff9da25e 100644 --- a/synedrion/src/cggmp21/protocols/key_gen.rs +++ b/synedrion/src/cggmp21/protocols/key_gen.rs @@ -1,6 +1,7 @@ //! Merged KeyInit and KeyRefresh protocols, to generate a full key share in one go. //! Since both take three rounds and are independent, we can execute them in parallel. +use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use core::marker::PhantomData; @@ -10,14 +11,14 @@ use serde::{Deserialize, Serialize}; use super::common::{KeyShare, PartyIdx}; use super::generic::{ - BaseRound, BroadcastRound, DirectRound, FinalizableToNextRound, FinalizableToResult, - FinalizeError, FirstRound, InitError, ProtocolResult, ReceiveError, ToNextRound, ToResult, + BaseRound, BroadcastRound, DirectRound, Finalizable, FinalizableToNextRound, + FinalizableToResult, FinalizationRequirement, FinalizeError, FirstRound, InitError, + ProtocolResult, ReceiveError, ToNextRound, ToResult, }; use super::key_init::{self, KeyInitResult}; use super::key_refresh::{self, KeyRefreshResult}; use super::wrappers::{wrap_finalize_error, wrap_receive_error, ResultWrapper}; use crate::cggmp21::SchemeParams; -use crate::tools::collections::HoleVec; /// Possible results of the merged KeyGen and KeyRefresh protocols. #[derive(Debug, Clone, Copy)] @@ -115,6 +116,14 @@ impl BaseRound for Round1

{ type Result = KeyGenResult

; const ROUND_NUM: u8 = 1; const NEXT_ROUND_NUM: Option = Some(2); + + fn num_parties(&self) -> usize { + self.key_init_round.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.key_init_round.party_idx() + } } impl BroadcastRound for Round1

{ @@ -162,28 +171,40 @@ impl DirectRound for Round1

{ type Artifact = (); } +impl Finalizable for Round1

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round1

{ type NextRound = Round2

; fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); let (key_init_bc_payloads, key_refresh_bc_payloads) = bc_payloads - .map(|payloads| payloads.unzip()) - .map_or((None, None), |(x, y)| (Some(x), Some(y))); + .into_iter() + .map(|(idx, (init_payload, refresh_payload))| { + ((idx, init_payload), (idx, refresh_payload)) + }) + .unzip(); let key_init_round = self .key_init_round - .finalize_to_next_round(rng, key_init_bc_payloads, None, None) + .finalize_to_next_round(rng, key_init_bc_payloads, BTreeMap::new(), BTreeMap::new()) .map_err(wrap_finalize_error)?; let key_refresh_round = self .key_refresh_round - .finalize_to_next_round(rng, key_refresh_bc_payloads, None, None) + .finalize_to_next_round( + rng, + key_refresh_bc_payloads, + BTreeMap::new(), + BTreeMap::new(), + ) .map_err(wrap_finalize_error)?; Ok(Round2 { key_init_round, @@ -216,6 +237,14 @@ impl BaseRound for Round2

{ type Result = KeyGenResult

; const ROUND_NUM: u8 = 2; const NEXT_ROUND_NUM: Option = Some(3); + + fn num_parties(&self) -> usize { + self.key_init_round.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.key_init_round.party_idx() + } } impl BroadcastRound for Round2

{ @@ -265,28 +294,40 @@ impl DirectRound for Round2

{ type Artifact = (); } +impl Finalizable for Round2

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round2

{ type NextRound = Round3

; fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); let (key_init_bc_payloads, key_refresh_bc_payloads) = bc_payloads - .map(|payloads| payloads.unzip()) - .map_or((None, None), |(x, y)| (Some(x), Some(y))); + .into_iter() + .map(|(idx, (init_payload, refresh_payload))| { + ((idx, init_payload), (idx, refresh_payload)) + }) + .unzip(); let key_init_round = self .key_init_round - .finalize_to_next_round(rng, key_init_bc_payloads, None, None) + .finalize_to_next_round(rng, key_init_bc_payloads, BTreeMap::new(), BTreeMap::new()) .map_err(wrap_finalize_error)?; let key_refresh_round = self .key_refresh_round - .finalize_to_next_round(rng, key_refresh_bc_payloads, None, None) + .finalize_to_next_round( + rng, + key_refresh_bc_payloads, + BTreeMap::new(), + BTreeMap::new(), + ) .map_err(wrap_finalize_error)?; Ok(Round3 { key_init_round, @@ -305,6 +346,14 @@ impl BaseRound for Round3

{ type Result = KeyGenResult

; const ROUND_NUM: u8 = 3; const NEXT_ROUND_NUM: Option = None; + + fn num_parties(&self) -> usize { + self.key_init_round.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.key_init_round.party_idx() + } } impl BroadcastRound for Round3

{ @@ -357,21 +406,27 @@ impl DirectRound for Round3

{ } } +impl Finalizable for Round3

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcastsAndDms + } +} + impl FinalizableToResult for Round3

{ fn finalize_to_result( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { let keyshare_seed = self .key_init_round - .finalize_to_result(rng, bc_payloads, None, None) + .finalize_to_result(rng, bc_payloads, BTreeMap::new(), BTreeMap::new()) .map_err(wrap_finalize_error)?; let keyshare_change = self .key_refresh_round - .finalize_to_result(rng, None, dm_payloads, dm_artifacts) + .finalize_to_result(rng, BTreeMap::new(), dm_payloads, dm_artifacts) .map_err(wrap_finalize_error)?; Ok(KeyShare::new(keyshare_seed, keyshare_change)) } diff --git a/synedrion/src/cggmp21/protocols/key_init.rs b/synedrion/src/cggmp21/protocols/key_init.rs index 9d076b78..b57aaa15 100644 --- a/synedrion/src/cggmp21/protocols/key_init.rs +++ b/synedrion/src/cggmp21/protocols/key_init.rs @@ -3,6 +3,7 @@ //! auxiliary parameters need to be generated as well (during the KeyRefresh protocol). use alloc::boxed::Box; +use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use core::marker::PhantomData; @@ -12,9 +13,9 @@ use serde::{Deserialize, Serialize}; use super::common::{KeyShareSeed, PartyIdx}; use super::generic::{ - all_parties_except, BaseRound, BroadcastRound, DirectRound, FinalizableToNextRound, - FinalizableToResult, FinalizeError, FirstRound, InitError, ProtocolResult, ReceiveError, - ToNextRound, ToResult, + all_parties_except, try_to_holevec, BaseRound, BroadcastRound, DirectRound, Finalizable, + FinalizableToNextRound, FinalizableToResult, FinalizationRequirement, FinalizeError, + FirstRound, InitError, ProtocolResult, ReceiveError, ToNextRound, ToResult, }; use crate::cggmp21::{ sigma::{SchCommitment, SchProof, SchSecret}, @@ -142,6 +143,14 @@ impl BaseRound for Round1

{ type Result = KeyInitResult; const ROUND_NUM: u8 = 1; const NEXT_ROUND_NUM: Option = Some(2); + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl BroadcastRound for Round1

{ @@ -150,16 +159,13 @@ impl BroadcastRound for Round1

{ type Payload = HashOutput; fn broadcast_destinations(&self) -> Option> { - Some(all_parties_except( - self.context.num_parties, - self.context.party_idx, - )) + Some(all_parties_except(self.num_parties(), self.party_idx())) } fn make_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { let hash = self .context .data - .hash(&self.context.shared_randomness, self.context.party_idx); + .hash(&self.context.shared_randomness, self.party_idx()); Ok(Round1Bcast { hash }) } fn verify_broadcast( @@ -177,19 +183,28 @@ impl DirectRound for Round1

{ type Artifact = (); } +impl Finalizable for Round1

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round1

{ type NextRound = Round2

; fn finalize_to_next_round( self, _rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); Ok(Round2 { - hashes: bc_payloads.unwrap(), + hashes: try_to_holevec( + bc_payloads, + self.context.num_parties, + self.context.party_idx, + ) + .unwrap(), context: self.context, phantom: PhantomData, }) @@ -212,6 +227,14 @@ impl BaseRound for Round2

{ type Result = KeyInitResult; const ROUND_NUM: u8 = 2; const NEXT_ROUND_NUM: Option = Some(3); + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl BroadcastRound for Round2

{ @@ -220,10 +243,7 @@ impl BroadcastRound for Round2

{ type Payload = FullData; fn broadcast_destinations(&self) -> Option> { - Some(all_parties_except( - self.context.num_parties, - self.context.party_idx, - )) + Some(all_parties_except(self.num_parties(), self.party_idx())) } fn make_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { Ok(Round2Bcast { @@ -251,18 +271,27 @@ impl DirectRound for Round2

{ type Artifact = (); } +impl Finalizable for Round2

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round2

{ type NextRound = Round3

; fn finalize_to_next_round( self, _rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); - let bc_payloads = bc_payloads.unwrap(); + let bc_payloads = try_to_holevec( + bc_payloads, + self.context.num_parties, + self.context.party_idx, + ) + .unwrap(); // XOR the vectors together // TODO (#61): is there a better way? let mut rid = self.context.data.rid.clone(); @@ -298,6 +327,14 @@ impl BaseRound for Round3

{ type Result = KeyInitResult; const ROUND_NUM: u8 = 3; const NEXT_ROUND_NUM: Option = None; + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl BroadcastRound for Round3

{ @@ -306,16 +343,13 @@ impl BroadcastRound for Round3

{ type Payload = (); fn broadcast_destinations(&self) -> Option> { - Some(all_parties_except( - self.context.num_parties, - self.context.party_idx, - )) + Some(all_parties_except(self.num_parties(), self.party_idx())) } fn make_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { let aux = ( &self.context.shared_randomness, - &self.context.party_idx, + &self.party_idx(), &self.rid, ); let proof = SchProof::new( @@ -352,16 +386,20 @@ impl DirectRound for Round3

{ type Artifact = (); } +impl Finalizable for Round3

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToResult for Round3

{ fn finalize_to_result( self, _rng: &mut impl CryptoRngCore, - _bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + _bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); let datas = self.datas.into_vec(self.context.data); let public_keys = datas.into_iter().map(|data| data.public).collect(); Ok(KeyShareSeed { diff --git a/synedrion/src/cggmp21/protocols/key_refresh.rs b/synedrion/src/cggmp21/protocols/key_refresh.rs index 4981ea77..65021465 100644 --- a/synedrion/src/cggmp21/protocols/key_refresh.rs +++ b/synedrion/src/cggmp21/protocols/key_refresh.rs @@ -3,6 +3,7 @@ //! for ZK proofs (e.g. Paillier keys). use alloc::boxed::Box; +use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use core::marker::PhantomData; @@ -12,9 +13,9 @@ use serde::{Deserialize, Serialize}; use super::common::{KeyShareChange, PartyIdx, PublicAuxInfo, SecretAuxInfo}; use super::generic::{ - all_parties_except, BaseRound, BroadcastRound, DirectRound, FinalizableToNextRound, - FinalizableToResult, FinalizeError, FirstRound, InitError, ProtocolResult, ReceiveError, - ToNextRound, ToResult, + all_parties_except, try_to_holevec, BaseRound, BroadcastRound, DirectRound, Finalizable, + FinalizableToNextRound, FinalizableToResult, FinalizationRequirement, FinalizeError, + FirstRound, InitError, ProtocolResult, ReceiveError, ToNextRound, ToResult, }; use crate::cggmp21::{ sigma::{FacProof, ModProof, PrmProof, SchCommitment, SchProof, SchSecret}, @@ -218,6 +219,14 @@ impl BaseRound for Round1

{ type Result = KeyRefreshResult

; const ROUND_NUM: u8 = 1; const NEXT_ROUND_NUM: Option = Some(2); + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl DirectRound for Round1

{ @@ -257,20 +266,26 @@ impl BroadcastRound for Round1

{ } } +impl Finalizable for Round1

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round1

{ type NextRound = Round2

; fn finalize_to_next_round( self, _rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); + let num_parties = self.num_parties(); + let party_idx = self.party_idx(); Ok(Round2 { context: self.context, - hashes: bc_payloads.unwrap(), + hashes: try_to_holevec(bc_payloads, num_parties, party_idx).unwrap(), }) } } @@ -292,6 +307,14 @@ impl BaseRound for Round2

{ type Result = KeyRefreshResult

; const ROUND_NUM: u8 = 2; const NEXT_ROUND_NUM: Option = Some(3); + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl DirectRound for Round2

{ @@ -363,18 +386,22 @@ impl BroadcastRound for Round2

{ } } +impl Finalizable for Round2

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToNextRound for Round2

{ type NextRound = Round3

; fn finalize_to_next_round( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - assert!(dm_payloads.is_none()); - assert!(dm_artifacts.is_none()); - let messages = bc_payloads.unwrap(); + let messages = try_to_holevec(bc_payloads, self.num_parties(), self.party_idx()).unwrap(); // XOR the vectors together // TODO (#61): is there a better way? let mut rho = self.context.data_precomp.data.rho_bits.clone(); @@ -451,6 +478,14 @@ impl BaseRound for Round3

{ type Result = KeyRefreshResult

; const ROUND_NUM: u8 = 3; const NEXT_ROUND_NUM: Option = None; + + fn num_parties(&self) -> usize { + self.context.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.context.party_idx + } } impl BroadcastRound for Round3

{ @@ -585,16 +620,21 @@ impl DirectRound for Round3

{ } } +impl Finalizable for Round3

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllDms + } +} + impl FinalizableToResult for Round3

{ fn finalize_to_result( self, _rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - _dm_artifacts: Option::Artifact>>, + _bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { - assert!(bc_payloads.is_none()); - let secrets = dm_payloads + let secrets = try_to_holevec(dm_payloads, self.num_parties(), self.party_idx()) .unwrap() .into_vec(self.context.xs_secret[self.context.party_idx.as_usize()]); let secret_share_change = secrets.iter().sum(); diff --git a/synedrion/src/cggmp21/protocols/presigning.rs b/synedrion/src/cggmp21/protocols/presigning.rs index d222cf3a..85739431 100644 --- a/synedrion/src/cggmp21/protocols/presigning.rs +++ b/synedrion/src/cggmp21/protocols/presigning.rs @@ -1,6 +1,7 @@ //! Presigning protocol, in the paper ECDSA Pre-Signing (Fig. 7). use alloc::boxed::Box; +use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use core::marker::PhantomData; @@ -10,9 +11,9 @@ use serde::{Deserialize, Serialize}; use super::common::{KeyShare, KeySharePrecomputed, PartyIdx, PresigningData}; use super::generic::{ - all_parties_except, BaseRound, BroadcastRound, DirectRound, FinalizableToNextRound, - FinalizableToResult, FinalizeError, FirstRound, InitError, ProtocolResult, ReceiveError, - ToNextRound, ToResult, + all_parties_except, try_to_holevec, BaseRound, BroadcastRound, DirectRound, Finalizable, + FinalizableToNextRound, FinalizableToResult, FinalizationRequirement, FinalizeError, + FirstRound, InitError, ProtocolResult, ReceiveError, ToNextRound, ToResult, }; use crate::cggmp21::{ sigma::{AffGProof, DecProof, EncProof, LogStarProof, MulProof}, @@ -115,6 +116,14 @@ impl BaseRound for Round1

{ type Result = PresigningResult

; const ROUND_NUM: u8 = 1; const NEXT_ROUND_NUM: Option = Some(2); + + fn num_parties(&self) -> usize { + self.context.key_share.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.context.key_share.party_index() + } } #[derive(Clone, Serialize, Deserialize)] @@ -202,17 +211,33 @@ impl DirectRound for Round1

{ } } +impl Finalizable for Round1

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcastsAndDms + } +} + impl FinalizableToNextRound for Round1

{ type NextRound = Round2

; fn finalize_to_next_round( self, _rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - _dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - let ciphertexts = bc_payloads.unwrap(); - let proofs = dm_payloads.unwrap(); + let ciphertexts = try_to_holevec( + bc_payloads, + self.context.key_share.num_parties(), + self.context.key_share.party_index(), + ) + .unwrap(); + let proofs = try_to_holevec( + dm_payloads, + self.context.key_share.num_parties(), + self.context.key_share.party_index(), + ) + .unwrap(); let aux = ( &self.context.shared_randomness, @@ -298,6 +323,14 @@ impl BaseRound for Round2

{ type Result = PresigningResult

; const ROUND_NUM: u8 = 2; const NEXT_ROUND_NUM: Option = Some(3); + + fn num_parties(&self) -> usize { + self.context.key_share.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.context.key_share.party_index() + } } impl BroadcastRound for Round2

{ @@ -508,17 +541,33 @@ impl DirectRound for Round2

{ } } +impl Finalizable for Round2

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllDms + } +} + impl FinalizableToNextRound for Round2

{ type NextRound = Round3

; fn finalize_to_next_round( self, _rng: &mut impl CryptoRngCore, - _bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + _bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, ) -> Result> { - let dm_payloads = dm_payloads.unwrap(); - let dm_artifacts = dm_artifacts.unwrap(); + let dm_payloads = try_to_holevec( + dm_payloads, + self.context.key_share.num_parties(), + self.context.key_share.party_index(), + ) + .unwrap(); + let dm_artifacts = try_to_holevec( + dm_artifacts, + self.context.key_share.num_parties(), + self.context.key_share.party_index(), + ) + .unwrap(); let gamma: Point = dm_payloads.iter().map(|payload| payload.gamma).sum(); let gamma = gamma + self.context.gamma.mul_by_generator(); @@ -583,6 +632,14 @@ impl BaseRound for Round3

{ type Result = PresigningResult

; const ROUND_NUM: u8 = 3; const NEXT_ROUND_NUM: Option = None; + + fn num_parties(&self) -> usize { + self.context.key_share.num_parties() + } + + fn party_idx(&self) -> PartyIdx { + self.context.key_share.party_index() + } } pub struct Round3Payload { @@ -680,15 +737,26 @@ pub struct PresigningProof { dec_proofs: Vec<(PartyIdx, DecProof

)>, } +impl Finalizable for Round3

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllDms + } +} + impl FinalizableToResult for Round3

{ fn finalize_to_result( self, rng: &mut impl CryptoRngCore, - _bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - _dm_artifacts: Option::Artifact>>, + _bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { - let dm_payloads = dm_payloads.unwrap(); + let dm_payloads = try_to_holevec( + dm_payloads, + self.context.key_share.num_parties(), + self.context.key_share.party_index(), + ) + .unwrap(); let (deltas, big_deltas) = dm_payloads .map(|payload| (payload.delta, payload.big_delta)) .unzip(); diff --git a/synedrion/src/cggmp21/protocols/signing.rs b/synedrion/src/cggmp21/protocols/signing.rs index 045f9526..d66fe2de 100644 --- a/synedrion/src/cggmp21/protocols/signing.rs +++ b/synedrion/src/cggmp21/protocols/signing.rs @@ -1,6 +1,7 @@ //! Signing using previously calculated presigning data, in the paper ECDSA Signing (Fig. 8). use alloc::boxed::Box; +use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use core::marker::PhantomData; @@ -10,8 +11,9 @@ use serde::{Deserialize, Serialize}; use super::common::{KeySharePrecomputed, PartyIdx, PresigningData}; use super::generic::{ - all_parties_except, BaseRound, BroadcastRound, DirectRound, FinalizableToResult, FinalizeError, - FirstRound, InitError, ProtocolResult, ReceiveError, ToResult, + all_parties_except, try_to_holevec, BaseRound, BroadcastRound, DirectRound, Finalizable, + FinalizableToResult, FinalizationRequirement, FinalizeError, FirstRound, InitError, + ProtocolResult, ReceiveError, ToResult, }; use crate::cggmp21::{ sigma::{AffGProof, DecProof, MulStarProof}, @@ -19,7 +21,7 @@ use crate::cggmp21::{ }; use crate::curve::{RecoverableSignature, Scalar}; use crate::paillier::RandomizerMod; -use crate::tools::collections::{HoleRange, HoleVec}; +use crate::tools::collections::HoleRange; use crate::uint::{Bounded, FromScalar, Signed}; /// Possible results of the Signing protocol. @@ -85,6 +87,14 @@ impl BaseRound for Round1

{ type Result = SigningResult

; const ROUND_NUM: u8 = 1; const NEXT_ROUND_NUM: Option = None; + + fn num_parties(&self) -> usize { + self.num_parties + } + + fn party_idx(&self) -> PartyIdx { + self.party_idx + } } #[derive(Clone, Serialize, Deserialize)] @@ -97,7 +107,7 @@ impl BroadcastRound for Round1

{ type Message = Round1Bcast; type Payload = Scalar; fn broadcast_destinations(&self) -> Option> { - Some(all_parties_except(self.num_parties, self.party_idx)) + Some(all_parties_except(self.num_parties(), self.party_idx())) } fn make_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { Ok(Round1Bcast { @@ -120,15 +130,21 @@ impl DirectRound for Round1

{ type Artifact = (); } +impl Finalizable for Round1

{ + fn requirement() -> FinalizationRequirement { + FinalizationRequirement::AllBroadcasts + } +} + impl FinalizableToResult for Round1

{ fn finalize_to_result( self, rng: &mut impl CryptoRngCore, - bc_payloads: Option::Payload>>, - _dm_payloads: Option::Payload>>, - _dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + _dm_payloads: BTreeMap::Payload>, + _dm_artifacts: BTreeMap::Artifact>, ) -> Result<::Success, FinalizeError> { - let shares = bc_payloads.unwrap(); + let shares = try_to_holevec(bc_payloads, self.num_parties, self.party_idx).unwrap(); let s: Scalar = shares.iter().sum(); let s = s + self.s_part; diff --git a/synedrion/src/cggmp21/protocols/test_utils.rs b/synedrion/src/cggmp21/protocols/test_utils.rs index 5d6b2136..eb29d3ca 100644 --- a/synedrion/src/cggmp21/protocols/test_utils.rs +++ b/synedrion/src/cggmp21/protocols/test_utils.rs @@ -1,3 +1,4 @@ +use alloc::collections::BTreeMap; use alloc::format; use alloc::string::String; use alloc::vec::Vec; @@ -9,7 +10,6 @@ use super::generic::{ BroadcastRound, DirectRound, FinalizableToNextRound, FinalizableToResult, ProtocolResult, Round, }; use super::{FinalizeError, PartyIdx}; -use crate::tools::collections::{HoleVec, HoleVecAccum}; #[derive(Debug)] pub(crate) enum StepError { @@ -19,9 +19,9 @@ pub(crate) enum StepError { pub(crate) struct AssembledRound { round: R, - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, } pub(crate) fn step_round( @@ -35,7 +35,7 @@ where // Collect outgoing messages let mut dm_artifact_accums = (0..rounds.len()) - .map(|idx| HoleVecAccum::<::Artifact>::new(rounds.len(), idx)) + .map(|_| BTreeMap::new()) .collect::>(); // `to, from, message` @@ -49,9 +49,9 @@ where for idx_to in destinations { let (message, artifact) = round.make_direct_message(rng, idx_to).unwrap(); direct_messages.push((idx_to, idx_from, message)); - dm_artifact_accums[idx_from.as_usize()] - .insert(idx_to.as_usize(), artifact) - .unwrap(); + assert!(dm_artifact_accums[idx_from.as_usize()] + .insert(idx_to, artifact) + .is_none()); } } @@ -66,67 +66,42 @@ where // Deliver direct messages let mut dm_payload_accums = (0..rounds.len()) - .map(|idx| HoleVecAccum::<::Payload>::new(rounds.len(), idx)) + .map(|_| BTreeMap::new()) .collect::>(); for (idx_to, idx_from, message) in direct_messages.into_iter() { let round = &rounds[idx_to.as_usize()]; let payload = round .verify_direct_message(idx_from, message) .map_err(|err| StepError::Receive(format!("{:?}", err)))?; - dm_payload_accums[idx_to.as_usize()].insert(idx_from.as_usize(), payload); + dm_payload_accums[idx_to.as_usize()].insert(idx_from, payload); } // Deliver broadcasts let mut bc_payload_accums = (0..rounds.len()) - .map(|idx| HoleVecAccum::<::Payload>::new(rounds.len(), idx)) + .map(|_| BTreeMap::new()) .collect::>(); for (idx_to, idx_from, message) in broadcasts.into_iter() { let round = &rounds[idx_to.as_usize()]; let payload = round .verify_broadcast(idx_from, message) .map_err(|err| StepError::Receive(format!("{:?}", err)))?; - bc_payload_accums[idx_to.as_usize()].insert(idx_from.as_usize(), payload); - } - - // Finalize accumulators - - let mut dm_payloads = Vec::new(); - for accum in dm_payload_accums.into_iter() { - let payloads = if accum.is_empty() { - None - } else { - Some(accum.finalize().ok_or(StepError::AccumFinalize)?) - }; - dm_payloads.push(payloads); - } - - let mut dm_artifacts = Vec::new(); - for accum in dm_artifact_accums.into_iter() { - let artifacts = if accum.is_empty() { - None - } else { - Some(accum.finalize().ok_or(StepError::AccumFinalize)?) - }; - dm_artifacts.push(artifacts); - } - - let mut bc_payloads = Vec::new(); - for accum in bc_payload_accums.into_iter() { - let payloads = if accum.is_empty() { - None - } else { - Some(accum.finalize().ok_or(StepError::AccumFinalize)?) - }; - bc_payloads.push(payloads); + bc_payload_accums[idx_to.as_usize()].insert(idx_from, payload); } // Assemble let mut assembled = Vec::new(); - for (round, bc_payloads, dm_payloads, dm_artifacts) in - izip!(rounds, bc_payloads, dm_payloads, dm_artifacts) - { + for (round, bc_payloads, dm_payloads, dm_artifacts) in izip!( + rounds, + bc_payload_accums, + dm_payload_accums, + dm_artifact_accums + ) { + if !round.can_finalize(bc_payloads.keys(), dm_payloads.keys(), dm_artifacts.keys()) { + return Err(StepError::AccumFinalize); + }; + assembled.push(AssembledRound { round, bc_payloads, diff --git a/synedrion/src/cggmp21/protocols/wrappers.rs b/synedrion/src/cggmp21/protocols/wrappers.rs index 90af9835..accb2d1c 100644 --- a/synedrion/src/cggmp21/protocols/wrappers.rs +++ b/synedrion/src/cggmp21/protocols/wrappers.rs @@ -5,8 +5,8 @@ use rand_core::CryptoRngCore; use super::common::PartyIdx; use super::generic::{ - BaseRound, BroadcastRound, DirectRound, FinalizableType, FinalizeError, ProtocolResult, - ReceiveError, Round, + BaseRound, BroadcastRound, DirectRound, Finalizable, FinalizableType, FinalizationRequirement, + FinalizeError, ProtocolResult, ReceiveError, Round, }; pub(crate) trait ResultWrapper: ProtocolResult { @@ -50,6 +50,13 @@ impl BaseRound for T { type Result = T::Result; const ROUND_NUM: u8 = T::ROUND_NUM; const NEXT_ROUND_NUM: Option = T::NEXT_ROUND_NUM; + + fn num_parties(&self) -> usize { + self.inner_round().num_parties() + } + fn party_idx(&self) -> PartyIdx { + self.inner_round().party_idx() + } } impl BroadcastRound for T { @@ -97,3 +104,9 @@ impl DirectRound for T { .map_err(wrap_receive_error) } } + +impl Finalizable for T { + fn requirement() -> FinalizationRequirement { + T::InnerRound::requirement() + } +} diff --git a/synedrion/src/sessions/states.rs b/synedrion/src/sessions/states.rs index e11aab41..44d9fc0a 100644 --- a/synedrion/src/sessions/states.rs +++ b/synedrion/src/sessions/states.rs @@ -194,8 +194,6 @@ where RoundAccumulator::new( self.context.verifiers.len(), self.context.party_idx, - self.broadcast_destinations().is_some(), - self.direct_message_destinations().is_some(), self.is_broadcast_consensus_round(), ) } @@ -203,7 +201,7 @@ where /// Returns `true` if the round can be finalized. pub fn can_finalize(&self, accum: &RoundAccumulator) -> Result { match &self.tp { - SessionType::Normal(_) => Ok(accum.processed.can_finalize()), + SessionType::Normal(round) => Ok(round.can_finalize(&accum.processed)), SessionType::Bc { .. } => Ok(accum .bc_accum .as_ref() @@ -216,12 +214,29 @@ where } /// Returns a list of parties whose messages for this round have not been received yet. - pub fn missing_messages(&self, accum: &RoundAccumulator) -> Vec { - accum - .missing_messages() - .into_iter() - .map(|idx| self.context.verifiers[idx.as_usize()].clone()) - .collect() + pub fn missing_messages( + &self, + accum: &RoundAccumulator, + ) -> Result, LocalError> { + let missing = match &self.tp { + SessionType::Normal(round) => Ok(round + .missing_payloads(&accum.processed) + .into_iter() + .collect()), + SessionType::Bc { .. } => { + let bc_accum = accum.bc_accum.as_ref().ok_or(LocalError( + "This is a BC consensus round, but the accumulator is in an invalid state" + .into(), + ))?; + Ok(bc_accum.missing_messages()) + } + }; + + missing.map(|set| { + set.into_iter() + .map(|idx| self.context.verifiers[idx.as_usize()].clone()) + .collect() + }) } fn is_broadcast_consensus_round(&self) -> bool { @@ -583,18 +598,12 @@ pub struct RoundAccumulator { } impl RoundAccumulator { - fn new( - num_parties: usize, - party_idx: PartyIdx, - is_bc_round: bool, - is_dm_round: bool, - is_bc_consensus_round: bool, - ) -> Self { + fn new(num_parties: usize, party_idx: PartyIdx, is_bc_consensus_round: bool) -> Self { // TODO (#68): can return an error if party_idx is out of bounds Self { received_direct_messages: Vec::new(), received_broadcasts: Vec::new(), - processed: DynRoundAccum::new(num_parties, party_idx, is_bc_round, is_dm_round), + processed: DynRoundAccum::new(), cached_messages: Vec::new(), cached_message_count: vec![0; num_parties], bc_accum: if is_bc_consensus_round { @@ -605,14 +614,6 @@ impl RoundAccumulator { } } - fn missing_messages(&self) -> Vec { - if let Some(accum) = &self.bc_accum { - accum.missing_messages() - } else { - self.processed.missing_messages() - } - } - /// Save an artifact produced by [`Session::make_direct_message`]. pub fn add_artifact( &mut self, @@ -625,9 +626,6 @@ impl RoundAccumulator { "Artifact for the destination {:?} was already added", artifact.destination )), - AccumAddError::NoAccumulator => { - LocalError("This round does not send out direct messages".into()) - } }) } @@ -638,37 +636,25 @@ impl RoundAccumulator { ) -> Result>, LocalError> { match pm.message { ProcessedMessageEnum::BcPayload { payload, message } => { - match self.processed.add_bc_payload(pm.from_idx, payload) { - Err(AccumAddError::SlotTaken) => { - return Ok(Err(RemoteError { - party: pm.from, - error: RemoteErrorEnum::DuplicateMessage, - })) - } - Err(AccumAddError::NoAccumulator) => { - return Err(LocalError( - "This round does not send out broadcast messages".into(), - )) - } - Ok(()) => {} - }; + if let Err(AccumAddError::SlotTaken) = + self.processed.add_bc_payload(pm.from_idx, payload) + { + return Ok(Err(RemoteError { + party: pm.from, + error: RemoteErrorEnum::DuplicateMessage, + })); + } self.received_broadcasts.push((pm.from_idx, message)); } ProcessedMessageEnum::DmPayload { payload, message } => { - match self.processed.add_dm_payload(pm.from_idx, payload) { - Err(AccumAddError::SlotTaken) => { - return Ok(Err(RemoteError { - party: pm.from, - error: RemoteErrorEnum::DuplicateMessage, - })) - } - Err(AccumAddError::NoAccumulator) => { - return Err(LocalError( - "This round does not send out direct messages".into(), - )) - } - Ok(()) => {} - }; + if let Err(AccumAddError::SlotTaken) = + self.processed.add_dm_payload(pm.from_idx, payload) + { + return Ok(Err(RemoteError { + party: pm.from, + error: RemoteErrorEnum::DuplicateMessage, + })); + } self.received_direct_messages.push((pm.from_idx, message)); } ProcessedMessageEnum::Bc => match &mut self.bc_accum { diff --git a/synedrion/src/sessions/type_erased.rs b/synedrion/src/sessions/type_erased.rs index d044da30..93bd043a 100644 --- a/synedrion/src/sessions/type_erased.rs +++ b/synedrion/src/sessions/type_erased.rs @@ -4,7 +4,7 @@ This way they can be used in a state machine loop without code repetition. */ use alloc::boxed::Box; -use alloc::collections::BTreeSet; +use alloc::collections::{BTreeMap, BTreeSet}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec::Vec; @@ -18,7 +18,6 @@ use crate::cggmp21::{ self, BroadcastRound, DirectRound, FinalizableToNextRound, FinalizableToResult, PartyIdx, ProtocolResult, Round, ToNextRound, ToResult, }; -use crate::tools::collections::{HoleVec, HoleVecAccum}; pub(crate) fn serialize_message(message: &impl Serialize) -> Result, LocalError> { bincode::serialize(message) @@ -41,13 +40,10 @@ pub(crate) enum FinalizeOutcome { pub enum AccumAddError { /// An item with the given origin has already been added to the accumulator. SlotTaken, - /// Trying to add an item to an accumulator that was not initialized on construction. - NoAccumulator, } #[derive(Debug, Clone)] pub enum AccumFinalizeError { - NotEnoughMessages, Downcast(String), } @@ -120,6 +116,8 @@ pub(crate) trait DynRound: Send { from: PartyIdx, message: &[u8], ) -> Result>; + fn can_finalize(&self, accum: &DynRoundAccum) -> bool; + fn missing_payloads(&self, accum: &DynRoundAccum) -> BTreeSet; } impl DynRound for R @@ -203,85 +201,63 @@ where let message = serialize_message(&typed_message)?; Ok((message, DynDmArtifact(Box::new(typed_artifact)))) } + + fn can_finalize(&self, accum: &DynRoundAccum) -> bool { + self.can_finalize( + accum.bc_payloads.keys(), + accum.dm_payloads.keys(), + accum.dm_artifacts.keys(), + ) + } + + fn missing_payloads(&self, accum: &DynRoundAccum) -> BTreeSet { + self.missing_payloads( + accum.bc_payloads.keys(), + accum.dm_payloads.keys(), + accum.dm_artifacts.keys(), + ) + } } pub(crate) struct DynRoundAccum { - bc_payloads: Option>, - dm_payloads: Option>, - dm_artifacts: Option>, + bc_payloads: BTreeMap, + dm_payloads: BTreeMap, + dm_artifacts: BTreeMap, } struct RoundAccum { - bc_payloads: Option::Payload>>, - dm_payloads: Option::Payload>>, - dm_artifacts: Option::Artifact>>, + bc_payloads: BTreeMap::Payload>, + dm_payloads: BTreeMap::Payload>, + dm_artifacts: BTreeMap::Artifact>, } impl DynRoundAccum { - pub fn new(num_parties: usize, idx: PartyIdx, is_bc_round: bool, is_dm_round: bool) -> Self { + pub fn new() -> Self { Self { - bc_payloads: if is_bc_round { - Some(HoleVecAccum::new(num_parties, idx.as_usize())) - } else { - None - }, - dm_payloads: if is_dm_round { - Some(HoleVecAccum::new(num_parties, idx.as_usize())) - } else { - None - }, - dm_artifacts: if is_dm_round { - Some(HoleVecAccum::new(num_parties, idx.as_usize())) - } else { - None - }, + bc_payloads: BTreeMap::new(), + dm_payloads: BTreeMap::new(), + dm_artifacts: BTreeMap::new(), } } pub fn contains(&self, from: PartyIdx, broadcast: bool) -> bool { if broadcast { - return self - .bc_payloads - .as_ref() - .unwrap() - .contains(from.as_usize()) - .unwrap(); + self.bc_payloads.contains_key(&from) } else { - return self - .dm_payloads - .as_ref() - .unwrap() - .contains(from.as_usize()) - .unwrap(); + self.dm_payloads.contains_key(&from) } } - pub fn missing_messages(&self) -> Vec { - let mut idxs = BTreeSet::new(); - if let Some(payloads) = &self.bc_payloads { - for idx in payloads.missing() { - idxs.insert(idx); - } - } - if let Some(payloads) = &self.dm_payloads { - for idx in payloads.missing() { - idxs.insert(idx); - } - } - idxs.into_iter().map(PartyIdx::from_usize).collect() - } - pub fn add_bc_payload( &mut self, from: PartyIdx, payload: DynBcPayload, ) -> Result<(), AccumAddError> { - match &mut self.bc_payloads { - Some(payloads) => payloads - .insert(from.as_usize(), payload) - .ok_or(AccumAddError::SlotTaken), - None => Err(AccumAddError::NoAccumulator), + if self.bc_payloads.contains_key(&from) { + return Err(AccumAddError::SlotTaken); } + self.bc_payloads.insert(from, payload); + Ok(()) } pub fn add_dm_payload( @@ -289,12 +265,11 @@ impl DynRoundAccum { from: PartyIdx, payload: DynDmPayload, ) -> Result<(), AccumAddError> { - match &mut self.dm_payloads { - Some(payloads) => payloads - .insert(from.as_usize(), payload) - .ok_or(AccumAddError::SlotTaken), - None => Err(AccumAddError::NoAccumulator), + if self.dm_payloads.contains_key(&from) { + return Err(AccumAddError::SlotTaken); } + self.dm_payloads.insert(from, payload); + Ok(()) } pub fn add_dm_artifact( @@ -302,27 +277,11 @@ impl DynRoundAccum { destination: PartyIdx, artifact: DynDmArtifact, ) -> Result<(), AccumAddError> { - match &mut self.dm_artifacts { - Some(artifacts) => artifacts - .insert(destination.as_usize(), artifact) - .ok_or(AccumAddError::SlotTaken), - None => Err(AccumAddError::NoAccumulator), + if self.dm_artifacts.contains_key(&destination) { + return Err(AccumAddError::SlotTaken); } - } - - pub fn can_finalize(&self) -> bool { - // TODO (#85): should this be the job of the round itself? - self.bc_payloads - .as_ref() - .map_or(true, |accum| accum.can_finalize()) - && self - .dm_payloads - .as_ref() - .map_or(true, |accum| accum.can_finalize()) - && self - .dm_artifacts - .as_ref() - .map_or(true, |accum| accum.can_finalize()) + self.dm_artifacts.insert(destination, artifact); + Ok(()) } fn finalize(self) -> Result, AccumFinalizeError> @@ -331,33 +290,27 @@ impl DynRoundAccum { ::Payload: 'static, ::Artifact: 'static, { - let bc_payloads = match self.bc_payloads { - Some(accum) => { - let hvec = accum - .finalize() - .ok_or(AccumFinalizeError::NotEnoughMessages)?; - Some(hvec.map_fallible(|elem| downcast::<::Payload>(elem.0))?) - } - None => None, - }; - let dm_payloads = match self.dm_payloads { - Some(accum) => { - let hvec = accum - .finalize() - .ok_or(AccumFinalizeError::NotEnoughMessages)?; - Some(hvec.map_fallible(|elem| downcast::<::Payload>(elem.0))?) - } - None => None, - }; - let dm_artifacts = match self.dm_artifacts { - Some(accum) => { - let hvec = accum - .finalize() - .ok_or(AccumFinalizeError::NotEnoughMessages)?; - Some(hvec.map_fallible(|elem| downcast::<::Artifact>(elem.0))?) - } - None => None, - }; + let bc_payloads = self + .bc_payloads + .into_iter() + .map(|(idx, elem)| { + downcast::<::Payload>(elem.0).map(|elem| (idx, elem)) + }) + .collect::, _>>()?; + let dm_payloads = self + .dm_payloads + .into_iter() + .map(|(idx, elem)| { + downcast::<::Payload>(elem.0).map(|elem| (idx, elem)) + }) + .collect::, _>>()?; + let dm_artifacts = self + .dm_artifacts + .into_iter() + .map(|(idx, elem)| { + downcast::<::Artifact>(elem.0).map(|elem| (idx, elem)) + }) + .collect::, _>>()?; Ok(RoundAccum { bc_payloads, dm_payloads, diff --git a/synedrion/src/tools/collections.rs b/synedrion/src/tools/collections.rs index 415e8545..1c93ac61 100644 --- a/synedrion/src/tools/collections.rs +++ b/synedrion/src/tools/collections.rs @@ -57,11 +57,6 @@ impl HoleVecAccum { self.elems.iter().all(|elem| elem.is_some()) } - #[cfg(any(test, feature = "bench-internals"))] - pub fn is_empty(&self) -> bool { - self.elems.iter().all(|elem| elem.is_none()) - } - fn len(&self) -> usize { self.elems.len() + 1 } @@ -194,20 +189,6 @@ impl HoleVec { hole_at: self.hole_at, } } - - pub fn map_fallible(self, f: F) -> Result, E> - where - F: FnMut(T) -> Result, - { - Ok(HoleVec { - elems: self - .elems - .into_iter() - .map(f) - .collect::, E>>()?, - hole_at: self.hole_at, - }) - } } impl HoleVec<(T, V)> { diff --git a/synedrion/tests/sessions.rs b/synedrion/tests/sessions.rs index 832bc65b..460df6d3 100644 --- a/synedrion/tests/sessions.rs +++ b/synedrion/tests/sessions.rs @@ -93,7 +93,7 @@ async fn run_session( while !session.can_finalize(&accum).unwrap() { // This can be checked if a timeout expired, to see which nodes have not responded yet. - let unresponsive_parties = session.missing_messages(&accum); + let unresponsive_parties = session.missing_messages(&accum).unwrap(); assert!(!unresponsive_parties.is_empty()); println!("{key_str}: waiting for a message");