-
-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
use std::{net::SocketAddr, ops::Deref}; | ||
|
||
use thiserror::Error; | ||
use tokio::sync::mpsc; | ||
use tracing::warn; | ||
|
||
use super::super::{config::NtsPoolPeerConfig, keyexchange::key_exchange_client}; | ||
|
||
use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId}; | ||
|
||
struct PoolPeer { | ||
id: PeerId, | ||
address: SocketAddr, | ||
} | ||
|
||
pub struct NtsPoolSpawner { | ||
config: NtsPoolPeerConfig, | ||
network_wait_period: std::time::Duration, | ||
id: SpawnerId, | ||
current_peers: Vec<PoolPeer>, | ||
} | ||
|
||
#[derive(Error, Debug)] | ||
pub enum NtsPoolSpawnError { | ||
#[error("Channel send error: {0}")] | ||
SendError(#[from] mpsc::error::SendError<SpawnEvent>), | ||
} | ||
|
||
impl NtsPoolSpawner { | ||
pub fn new( | ||
config: NtsPoolPeerConfig, | ||
network_wait_period: std::time::Duration, | ||
) -> NtsPoolSpawner { | ||
NtsPoolSpawner { | ||
config, | ||
network_wait_period, | ||
id: Default::default(), | ||
current_peers: Default::default(), | ||
//known_ips: Default::default(), | ||
} | ||
} | ||
|
||
//NOTE: this is the same code as in nts.rs, so we should introduce some code sharing | ||
async fn resolve_addr(&mut self, address: (&str, u16)) -> Option<SocketAddr> { | ||
const MAX_RETRIES: usize = 5; | ||
const BACKOFF_FACTOR: u32 = 2; | ||
|
||
let mut network_wait = self.network_wait_period; | ||
|
||
for i in 0..MAX_RETRIES { | ||
if i != 0 { | ||
// Ensure we dont spam dns | ||
tokio::time::sleep(network_wait).await; | ||
network_wait *= BACKOFF_FACTOR; | ||
} | ||
match tokio::net::lookup_host(address).await { | ||
Ok(mut addresses) => match addresses.next() { | ||
Some(address) => return Some(address), | ||
None => { | ||
warn!("received unknown domain name from NTS-ke"); | ||
return None; | ||
} | ||
}, | ||
Err(e) => { | ||
warn!(error = ?e, "error while resolving peer address, retrying"); | ||
} | ||
} | ||
} | ||
|
||
warn!("Could not resolve peer address, restarting NTS initialization"); | ||
|
||
None | ||
} | ||
|
||
pub async fn fill_pool( | ||
&mut self, | ||
action_tx: &mpsc::Sender<SpawnEvent>, | ||
) -> Result<(), NtsPoolSpawnError> { | ||
let mut wait_period = self.network_wait_period; | ||
|
||
// early return if there is nothing to do | ||
if self.current_peers.len() >= self.config.max_peers { | ||
return Ok(()); | ||
} | ||
|
||
loop { | ||
// Try and add peers to our pool | ||
while self.current_peers.len() < self.config.max_peers { | ||
match key_exchange_client( | ||
self.config.addr.server_name.clone(), | ||
self.config.addr.port, | ||
&self.config.certificate_authorities, | ||
) | ||
.await | ||
{ | ||
Ok(ke) => { | ||
if let Some(address) = | ||
self.resolve_addr((ke.remote.as_str(), ke.port)).await | ||
{ | ||
let id = PeerId::new(); | ||
self.current_peers.push(PoolPeer { id, address }); | ||
action_tx | ||
.send(SpawnEvent::new( | ||
self.id, | ||
SpawnAction::create( | ||
PeerId::new(), | ||
address, | ||
self.config.addr.deref().clone(), | ||
ke.protocol_version, | ||
Some(ke.nts), | ||
), | ||
)) | ||
.await?; | ||
return Ok(()); | ||
} | ||
} | ||
Err(e) => { | ||
warn!(error = ?e, "error while attempting key exchange"); | ||
break; | ||
} | ||
}; | ||
} | ||
|
||
let wait_period_max = if cfg!(test) { | ||
std::time::Duration::default() | ||
} else { | ||
std::time::Duration::from_secs(60) | ||
}; | ||
|
||
wait_period = Ord::min(2 * wait_period, wait_period_max); | ||
let peers_needed = self.config.max_peers - self.current_peers.len(); | ||
if peers_needed > 0 { | ||
warn!(peers_needed, "could not fully fill pool"); | ||
tokio::time::sleep(wait_period).await; | ||
} else { | ||
return Ok(()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl BasicSpawner for NtsPoolSpawner { | ||
type Error = NtsPoolSpawnError; | ||
|
||
async fn handle_init( | ||
&mut self, | ||
action_tx: &mpsc::Sender<SpawnEvent>, | ||
) -> Result<(), NtsPoolSpawnError> { | ||
self.fill_pool(action_tx).await?; | ||
Ok(()) | ||
} | ||
|
||
async fn handle_peer_removed( | ||
&mut self, | ||
removed_peer: PeerRemovedEvent, | ||
action_tx: &mpsc::Sender<SpawnEvent>, | ||
) -> Result<(), NtsPoolSpawnError> { | ||
self.current_peers.retain(|p| p.id != removed_peer.id); | ||
self.fill_pool(action_tx).await?; | ||
Ok(()) | ||
} | ||
|
||
fn get_id(&self) -> SpawnerId { | ||
self.id | ||
} | ||
|
||
fn get_addr_description(&self) -> String { | ||
format!("{} ({})", self.config.addr.deref(), self.config.max_peers) | ||
} | ||
|
||
fn get_description(&self) -> &str { | ||
"nts-pool" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters