Skip to content

Commit

Permalink
add nts pool spawner scaffolding
Browse files Browse the repository at this point in the history
  • Loading branch information
squell committed Dec 7, 2023
1 parent 356f8cc commit ef0981f
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ntpd/src/daemon/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ impl Config {
PeerConfig::Standard(_) => count += 1,
PeerConfig::Nts(_) => count += 1,
PeerConfig::Pool(config) => count += config.max_peers,
#[cfg(feature = "unstable_nts-pool")]
PeerConfig::NtsPool(config) => count += config.max_peers,
}
}
count
Expand Down
37 changes: 37 additions & 0 deletions ntpd/src/daemon/config/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ fn max_peers_default() -> usize {
4
}

#[cfg(feature = "unstable_nts-pool")]
#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
pub struct NtsPoolPeerConfig {
#[serde(rename = "address")]
pub addr: NtsKeAddress,
#[serde(
deserialize_with = "deserialize_certificate_authorities",
default = "default_certificate_authorities",
rename = "certificate-authority"
)]
pub certificate_authorities: Arc<[Certificate]>,
#[serde(rename = "count", default = "max_peers_default")]
pub max_peers: usize,
}

#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
#[serde(tag = "mode")]
pub enum PeerConfig {
Expand All @@ -73,6 +89,9 @@ pub enum PeerConfig {
#[serde(rename = "pool")]
Pool(PoolPeerConfig),
// Consul(ConsulPeerConfig),
#[cfg(feature = "unstable_nts-pool")]
#[serde(rename = "nts-pool")]
NtsPool(NtsPoolPeerConfig),
}

/// A normalized address has a host and a port part. However, the host may be
Expand Down Expand Up @@ -312,6 +331,8 @@ mod tests {
PeerConfig::Standard(c) => c.address.to_string(),
PeerConfig::Nts(c) => c.address.to_string(),
PeerConfig::Pool(c) => c.addr.to_string(),
#[cfg(feature = "unstable_nts-pool")]
PeerConfig::NtsPool(c) => c.addr.to_string(),
}
}

Expand Down Expand Up @@ -396,6 +417,22 @@ mod tests {
if let PeerConfig::Nts(config) = test.peer {
assert_eq!(config.address.to_string(), "example.com:4460");
}

#[cfg(feature = "unstable_nts-pool")]
{
let test: TestConfig = toml::from_str(
r#"
[peer]
address = "example.com"
mode = "nts-pool"
"#,
)
.unwrap();
assert!(matches!(test.peer, PeerConfig::NtsPool(_)));
if let PeerConfig::Nts(config) = test.peer {
assert_eq!(config.address.to_string(), "example.com:4460");
}
}
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions ntpd/src/daemon/spawn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use super::config::NormalizedAddress;
#[cfg(test)]
pub mod dummy;
pub mod nts;
#[cfg(feature = "unstable_nts-pool")]
pub mod nts_pool;
pub mod pool;
pub mod standard;

Expand Down
175 changes: 175 additions & 0 deletions ntpd/src/daemon/spawn/nts_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use std::{net::SocketAddr, ops::Deref};

use thiserror::Error;
use tokio::sync::mpsc;
use tracing::warn;

use super::super::{config::NtsPoolPeerConfig, keyexchange::key_exchange_client};

use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId};

struct PoolPeer {
id: PeerId,
address: SocketAddr,
}

pub struct NtsPoolSpawner {
config: NtsPoolPeerConfig,
network_wait_period: std::time::Duration,
id: SpawnerId,
current_peers: Vec<PoolPeer>,
}

#[derive(Error, Debug)]
pub enum NtsPoolSpawnError {
#[error("Channel send error: {0}")]
SendError(#[from] mpsc::error::SendError<SpawnEvent>),
}

impl NtsPoolSpawner {
pub fn new(
config: NtsPoolPeerConfig,
network_wait_period: std::time::Duration,
) -> NtsPoolSpawner {
NtsPoolSpawner {
config,
network_wait_period,
id: Default::default(),
current_peers: Default::default(),
//known_ips: Default::default(),
}
}

//NOTE: this is the same code as in nts.rs, so we should introduce some code sharing
async fn resolve_addr(&mut self, address: (&str, u16)) -> Option<SocketAddr> {
const MAX_RETRIES: usize = 5;
const BACKOFF_FACTOR: u32 = 2;

let mut network_wait = self.network_wait_period;

for i in 0..MAX_RETRIES {
if i != 0 {
// Ensure we dont spam dns
tokio::time::sleep(network_wait).await;
network_wait *= BACKOFF_FACTOR;
}
match tokio::net::lookup_host(address).await {
Ok(mut addresses) => match addresses.next() {
Some(address) => return Some(address),
None => {
warn!("received unknown domain name from NTS-ke");
return None;
}
},
Err(e) => {
warn!(error = ?e, "error while resolving peer address, retrying");
}
}
}

warn!("Could not resolve peer address, restarting NTS initialization");

None
}

pub async fn fill_pool(
&mut self,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), NtsPoolSpawnError> {
let mut wait_period = self.network_wait_period;

// early return if there is nothing to do
if self.current_peers.len() >= self.config.max_peers {
return Ok(());
}

loop {
// Try and add peers to our pool
while self.current_peers.len() < self.config.max_peers {
match key_exchange_client(
self.config.addr.server_name.clone(),
self.config.addr.port,
&self.config.certificate_authorities,
)
.await
{
Ok(ke) => {
if let Some(address) =
self.resolve_addr((ke.remote.as_str(), ke.port)).await
{
let id = PeerId::new();
self.current_peers.push(PoolPeer { id, address });
action_tx
.send(SpawnEvent::new(
self.id,
SpawnAction::create(
PeerId::new(),
address,
self.config.addr.deref().clone(),
ke.protocol_version,
Some(ke.nts),
),
))
.await?;
return Ok(());
}
}
Err(e) => {
warn!(error = ?e, "error while attempting key exchange");
break;
}
};
}

let wait_period_max = if cfg!(test) {
std::time::Duration::default()
} else {
std::time::Duration::from_secs(60)
};

wait_period = Ord::min(2 * wait_period, wait_period_max);
let peers_needed = self.config.max_peers - self.current_peers.len();
if peers_needed > 0 {
warn!(peers_needed, "could not fully fill pool");
tokio::time::sleep(wait_period).await;
} else {
return Ok(());
}
}
}
}

#[async_trait::async_trait]
impl BasicSpawner for NtsPoolSpawner {
type Error = NtsPoolSpawnError;

async fn handle_init(
&mut self,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), NtsPoolSpawnError> {
self.fill_pool(action_tx).await?;
Ok(())
}

async fn handle_peer_removed(
&mut self,
removed_peer: PeerRemovedEvent,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), NtsPoolSpawnError> {
self.current_peers.retain(|p| p.id != removed_peer.id);
self.fill_pool(action_tx).await?;
Ok(())
}

fn get_id(&self) -> SpawnerId {
self.id
}

fn get_addr_description(&self) -> String {
format!("{} ({})", self.config.addr.deref(), self.config.max_peers)
}

fn get_description(&self) -> &str {
"nts-pool"
}
}
11 changes: 11 additions & 0 deletions ntpd/src/daemon/system.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "unstable_nts-pool")]
use super::spawn::nts_pool::NtsPoolSpawner;
use super::{
config::{ClockConfig, NormalizedAddress, PeerConfig, ServerConfig},
peer::{MsgForSystem, PeerChannels},
Expand Down Expand Up @@ -127,6 +129,15 @@ pub async fn spawn(
std::io::Error::new(std::io::ErrorKind::Other, e)
})?;
}
#[cfg(feature = "unstable_nts-pool")]
PeerConfig::NtsPool(cfg) => {
system
.add_spawner(NtsPoolSpawner::new(cfg.clone(), NETWORK_WAIT_PERIOD))
.map_err(|e| {
tracing::error!("Could not spawn peer: {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})?;
}
}
}

Expand Down

0 comments on commit ef0981f

Please sign in to comment.