Skip to content

Commit

Permalink
fix: resolve issues around leader election (#18)
Browse files Browse the repository at this point in the history
* fix: resolve DefaultLeaderElector::elect_leader panicking for party_weights summing to 1

* fix: functions operating on party_weights are safe for sum of weights > u64::MAX

* fix: optimize leader election removing binary search

* fix: malicious party test - setting non-leader id

* fix: using rand_chacha crate directly

* fix: saturating subtraction for weights
  • Loading branch information
NikitaMasych authored Oct 17, 2024
1 parent 4adb7dd commit e2203ba
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ bincode = "^1.3.3"
rkyv = { version = "^0.7.44", features = ["validation"] }
tokio = { version = "^1.39.2", features = ["full"] }
rand = "^0.9.0-alpha.2"
seeded-random = "^0.6.0"
thiserror = "^1.0.63"
rand_chacha = "0.3.1"

[features]
default = ["tokio/full", "rkyv/validation"]
Expand Down
73 changes: 38 additions & 35 deletions src/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
use crate::party::Party;
use crate::value::{Value, ValueSelector};
use seeded_random::{Random, Seed};
use std::cmp::Ordering;
use rand_chacha::rand_core::{RngCore, SeedableRng};
use rand_chacha::ChaCha20Rng;
use std::hash::{DefaultHasher, Hash, Hasher};
use thiserror::Error;

Expand Down Expand Up @@ -70,7 +70,7 @@ impl DefaultLeaderElector {

/// Hashes the seed to a value within a specified range.
///
/// This method uses the computed seed to generate a value within the range [0, range).
/// This method uses the computed seed to generate a value within the range [0, range].
/// The algorithm ensures uniform distribution of the resulting value, which is crucial
/// for fair leader election.
///
Expand All @@ -79,19 +79,19 @@ impl DefaultLeaderElector {
/// - `range`: The upper limit for the random value generation, typically the sum of party weights.
///
/// # Returns
/// A `u64` value within the specified range.
fn hash_to_range(seed: u64, range: u64) -> u64 {
/// A `u128` value within the specified range.
fn hash_to_range(seed: u64, range: u128) -> u128 {
// Determine the number of bits required to represent the range
let mut k = 64;
while 1u64 << (k - 1) >= range {
let mut k = 128;
while 1u128 << (k - 1) > range {
k -= 1;
}

// Use a seeded random generator to produce a value within the desired range
let rng = Random::from_seed(Seed::unsafe_new(seed));
let mut rng = ChaCha20Rng::seed_from_u64(seed);
loop {
let mut raw_res: u64 = rng.gen();
raw_res >>= 64 - k;
let mut raw_res = ((rng.next_u64() as u128) << 64) | (rng.next_u64() as u128);
raw_res >>= 128 - k;

if raw_res < range {
return raw_res;
Expand Down Expand Up @@ -124,31 +124,23 @@ impl<V: Value, VS: ValueSelector<V>> LeaderElector<V, VS> for DefaultLeaderElect
fn elect_leader(&self, party: &Party<V, VS>) -> Result<u64, Box<dyn std::error::Error>> {
let seed = DefaultLeaderElector::compute_seed(party);

let total_weight: u64 = party.cfg.party_weights.iter().sum();
let total_weight: u128 = party.cfg.party_weights.iter().map(|&x| x as u128).sum();
if total_weight == 0 {
return Err(DefaultLeaderElectorError::ZeroWeightSum.into());
}

// Generate a random number in the range [0, total_weight)
// Generate a random number in the range [0, total_weight]
let random_value = DefaultLeaderElector::hash_to_range(seed, total_weight);

// Use binary search to find the corresponding participant based on the cumulative weight
let mut cumulative_weights = vec![0; party.cfg.party_weights.len()];
cumulative_weights[0] = party.cfg.party_weights[0];

for i in 1..party.cfg.party_weights.len() {
cumulative_weights[i] = cumulative_weights[i - 1] + party.cfg.party_weights[i];
}

match cumulative_weights.binary_search_by(|&weight| {
if random_value < weight {
Ordering::Greater
} else {
Ordering::Less
let mut cumulative_sum = 0u128;
for (index, &weight) in party.cfg.party_weights.iter().enumerate() {
cumulative_sum += weight as u128;
if random_value <= cumulative_sum {
return Ok(index as u64);
}
}) {
Ok(index) | Err(index) => Ok(index as u64),
}

unreachable!("Index is guaranteed to be returned in a loop.")
}
}

Expand All @@ -161,6 +153,17 @@ mod tests {
use std::thread;
use std::time::Duration;

#[test]
fn test_default_leader_elector_weight_one() {
let mut party = MockParty::default();
party.cfg.party_weights = vec![0, 1, 0, 0];

let elector = DefaultLeaderElector::new();

let leader = elector.elect_leader(&party).unwrap();
println!("leader: {}", leader);
}

#[test]
fn test_default_leader_elector_determinism() {
let party = MockParty::default();
Expand Down Expand Up @@ -209,11 +212,11 @@ mod tests {
k -= 1;
}

let rng = Random::from_seed(Seed::unsafe_new(seed));
let mut rng = ChaCha20Rng::seed_from_u64(seed);

let mut iteration = 1u64;
loop {
let mut raw_res: u64 = rng.gen();
let mut raw_res: u64 = rng.next_u64();
raw_res >>= 64 - k;

if raw_res < range {
Expand Down Expand Up @@ -260,15 +263,15 @@ mod tests {

#[test]
fn test_rng() {
let rng1 = Random::from_seed(Seed::unsafe_new(123456));
let rng2 = Random::from_seed(Seed::unsafe_new(123456));
let mut rng1 = ChaCha20Rng::seed_from_u64(123456);
let mut rng2 = ChaCha20Rng::seed_from_u64(123456);

println!("{}", rng1.gen::<u64>());
println!("{}", rng2.gen::<u64>());
println!("{}", rng1.next_u64());
println!("{}", rng2.next_u64());

thread::sleep(Duration::from_secs(2));

println!("{}", rng1.gen::<u64>());
println!("{}", rng2.gen::<u64>());
println!("{}", rng1.next_u64());
println!("{}", rng2.next_u64());
}
}
10 changes: 7 additions & 3 deletions src/party.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ impl<V: Value, VS: ValueSelector<V>> Party<V, VS> {
self.cfg.party_weights[routing.sender as usize] as u128;

let self_weight = self.cfg.party_weights[self.id as usize] as u128;
if self.messages_1b_weight >= self.cfg.threshold - self_weight {
if self.messages_1b_weight >= self.cfg.threshold.saturating_sub(self_weight) {
self.status = PartyStatus::Passed1b;
}
}
Expand Down Expand Up @@ -576,7 +576,9 @@ impl<V: Value, VS: ValueSelector<V>> Party<V, VS> {
);

let self_weight = self.cfg.party_weights[self.id as usize] as u128;
if self.messages_2av_state.get_weight() >= self.cfg.threshold - self_weight {
if self.messages_2av_state.get_weight()
>= self.cfg.threshold.saturating_sub(self_weight)
{
self.status = PartyStatus::Passed2av;
}
}
Expand Down Expand Up @@ -609,7 +611,9 @@ impl<V: Value, VS: ValueSelector<V>> Party<V, VS> {
);

let self_weight = self.cfg.party_weights[self.id as usize] as u128;
if self.messages_2b_state.get_weight() >= self.cfg.threshold - self_weight {
if self.messages_2b_state.get_weight()
>= self.cfg.threshold.saturating_sub(self_weight)
{
self.status = PartyStatus::Passed2b;
}
}
Expand Down
32 changes: 31 additions & 1 deletion tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ async fn test_ballot_malicious_party() {

let elector = DefaultLeaderElector::new();
let leader = elector.elect_leader(&parties[0]).unwrap();
const MALICIOUS_PARTY_ID: u64 = 1;
const MALICIOUS_PARTY_ID: u64 = 2;

assert_ne!(
MALICIOUS_PARTY_ID, leader,
Expand Down Expand Up @@ -285,3 +285,33 @@ async fn test_ballot_many_parties() {

analyze_ballot(results);
}

#[tokio::test]
async fn test_ballot_max_weight() {
let weights = vec![u64::MAX, 1];
let threshold = BPConConfig::compute_bft_threshold(weights.clone());
let cfg = BPConConfig::with_default_timeouts(weights, threshold);

let (parties, receivers, senders) = create_parties(cfg);
let ballot_tasks = launch_parties(parties);
let p2p_task = propagate_p2p(receivers, senders);
let results = await_results(ballot_tasks).await;
p2p_task.abort();

analyze_ballot(results);
}

#[tokio::test]
async fn test_ballot_weights_underflow() {
let weights = vec![100, 1, 2, 3, 4];
let threshold = BPConConfig::compute_bft_threshold(weights.clone());
let cfg = BPConConfig::with_default_timeouts(weights, threshold);

let (parties, receivers, senders) = create_parties(cfg);
let ballot_tasks = launch_parties(parties);
let p2p_task = propagate_p2p(receivers, senders);
let results = await_results(ballot_tasks).await;
p2p_task.abort();

analyze_ballot(results);
}

0 comments on commit e2203ba

Please sign in to comment.