diff --git a/ntp-proto/src/nts_pool_ke.rs b/ntp-proto/src/nts_pool_ke.rs index f4526ac93..67eed9798 100644 --- a/ntp-proto/src/nts_pool_ke.rs +++ b/ntp-proto/src/nts_pool_ke.rs @@ -199,8 +199,8 @@ impl ClientToPoolDecoder { } #[cfg(feature = "nts-pool")] - NtpServerDeny { _denied } => { - state.denied_servers.push(_denied); + NtpServerDeny { denied } => { + state.denied_servers.push(denied); Continue(state) } diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index 930cf3e2d..0b0742f27 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -168,7 +168,7 @@ pub enum NtsRecord { }, #[cfg(feature = "nts-pool")] NtpServerDeny { - _denied: String, + denied: String, }, } @@ -202,9 +202,11 @@ impl NtsRecord { pub const BAD_REQUEST: u16 = 1; pub const INTERNAL_SERVER_ERROR: u16 = 2; - pub fn client_key_exchange_records() -> [NtsRecord; if cfg!(feature = "ntpv5") { 4 } else { 3 }] - { - [ + #[cfg_attr(not(feature = "nts-pool"), allow(unused_variables))] + pub fn client_key_exchange_records( + denied_servers: impl IntoIterator, + ) -> Box<[NtsRecord]> { + let mut base = vec![ #[cfg(feature = "ntpv5")] NtsRecord::DraftId { data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(), @@ -223,8 +225,18 @@ impl NtsRecord { .map(|algorithm| *algorithm as u16) .collect(), }, - NtsRecord::EndOfMessage, - ] + ]; + + #[cfg(feature = "nts-pool")] + base.extend( + denied_servers + .into_iter() + .map(|server| NtsRecord::NtpServerDeny { denied: server }), + ); + + base.push(NtsRecord::EndOfMessage); + + base.into_boxed_slice() } #[cfg(feature = "nts-pool")] @@ -385,7 +397,7 @@ impl NtsRecord { // NOTE: the string data should be ascii (not utf8) but we don't enforce that here let str_data = read_bytes_exact(reader, record_len)?; match String::from_utf8(str_data) { - Ok(_denied) => NtsRecord::NtpServerDeny { _denied }, + Ok(denied) => NtsRecord::NtpServerDeny { denied }, Err(e) => NtsRecord::Unknown { record_type, critical, @@ -496,7 +508,7 @@ impl NtsRecord { writer.write_all(s2c)?; } #[cfg(feature = "nts-pool")] - NtsRecord::NtpServerDeny { _denied: name } => { + NtsRecord::NtpServerDeny { denied: name } => { // NOTE: the server name should be ascii #[cfg(not(feature = "__internal-fuzz"))] debug_assert!(name.is_ascii()); @@ -1149,6 +1161,7 @@ impl KeyExchangeClient { pub fn new( server_name: String, tls_config: rustls::ClientConfig, + denied_servers: impl IntoIterator, ) -> Result { let mut client = Self::new_without_tls_write(server_name, tls_config)?; @@ -1156,7 +1169,7 @@ impl KeyExchangeClient { // We use an intermediary buffer to ensure that all records are sent at once. // This should not be needed, but works around issues in some NTS-ke server implementations let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records(denied_servers).iter() { record.write(&mut buffer)?; } client.tls_connection.writer().write_all(&buffer)?; @@ -1403,7 +1416,7 @@ impl KeyExchangeServerDecoder { Continue(state) } #[cfg(feature = "nts-pool")] - NtpServerDeny { _denied } => { + NtpServerDeny { denied: _ } => { // we are not a NTS pool server, so we ignore this record Continue(state) } @@ -1753,7 +1766,7 @@ mod test { #[test] fn test_client_key_exchange_records() { let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records([]).iter() { record.write(&mut buffer).unwrap(); } @@ -1777,7 +1790,7 @@ mod test { decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), ], - NtsRecord::client_key_exchange_records() + NtsRecord::client_key_exchange_records(vec![]).as_ref() ); assert!(decoder.step().unwrap().is_none()); @@ -1901,7 +1914,7 @@ mod test { let mut buffer = Vec::new(); let record = NtsRecord::NtpServerDeny { - _denied: String::from("a string"), + denied: String::from("a string"), }; record.write(&mut buffer).unwrap(); @@ -2506,14 +2519,15 @@ mod test { #[test] fn server_decoder_finds_algorithm() { - let result = server_decode_records(&NtsRecord::client_key_exchange_records()).unwrap(); + let result = + server_decode_records(&NtsRecord::client_key_exchange_records(vec![])).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_new_cookie() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert( 0, NtsRecord::NewCookie { @@ -2527,7 +2541,7 @@ mod test { #[test] fn server_decoder_ignores_server_and_port_preference() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert( 0, NtsRecord::Server { @@ -2550,7 +2564,7 @@ mod test { #[test] fn server_decoder_ignores_warn() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert(0, NtsRecord::Warning { warningcode: 42 }); let result = server_decode_records(&records).unwrap(); @@ -2559,7 +2573,7 @@ mod test { #[test] fn server_decoder_ignores_unknown_not_critical() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert( 0, NtsRecord::Unknown { @@ -2575,7 +2589,7 @@ mod test { #[test] fn server_decoder_reports_unknown_critical() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert( 0, NtsRecord::Unknown { @@ -2594,7 +2608,7 @@ mod test { #[test] fn server_decoder_reports_error() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); records.insert(0, NtsRecord::Error { errorcode: 2 }); let error = server_decode_records(&records).unwrap_err(); @@ -2760,7 +2774,7 @@ mod test { .with_no_client_auth(); let mut server = rustls::ServerConnection::new(Arc::new(serverconfig)).unwrap(); - let mut client = KeyExchangeClient::new("localhost".into(), clientconfig).unwrap(); + let mut client = KeyExchangeClient::new("localhost".into(), clientconfig, vec![]).unwrap(); server.writer().write_all(NTS_TIME_NL_RESPONSE).unwrap(); @@ -2942,7 +2956,7 @@ mod test { let (mut client, server) = client_server_pair(ClientType::Uncertified); let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records([]).iter() { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); @@ -3040,7 +3054,7 @@ mod test { #[test] fn test_keyexchange_invalid_input() { let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records([]).iter() { record.write(&mut buffer).unwrap(); } diff --git a/ntpd/src/daemon/config/mod.rs b/ntpd/src/daemon/config/mod.rs index b174a2f97..0cfd866be 100644 --- a/ntpd/src/daemon/config/mod.rs +++ b/ntpd/src/daemon/config/mod.rs @@ -372,6 +372,8 @@ impl Config { PeerConfig::Standard(_) => count += 1, PeerConfig::Nts(_) => count += 1, PeerConfig::Pool(config) => count += config.max_peers, + #[cfg(feature = "unstable_nts-pool")] + PeerConfig::NtsPool(config) => count += config.max_peers, } } count diff --git a/ntpd/src/daemon/config/peer.rs b/ntpd/src/daemon/config/peer.rs index f2cfc2a88..1a6564ad8 100644 --- a/ntpd/src/daemon/config/peer.rs +++ b/ntpd/src/daemon/config/peer.rs @@ -63,6 +63,22 @@ fn max_peers_default() -> usize { 4 } +#[cfg(feature = "unstable_nts-pool")] +#[derive(Deserialize, Debug, PartialEq, Eq, Clone)] +#[serde(deny_unknown_fields)] +pub struct NtsPoolPeerConfig { + #[serde(rename = "address")] + pub addr: NtsKeAddress, + #[serde( + deserialize_with = "deserialize_certificate_authorities", + default = "default_certificate_authorities", + rename = "certificate-authority" + )] + pub certificate_authorities: Arc<[Certificate]>, + #[serde(rename = "count", default = "max_peers_default")] + pub max_peers: usize, +} + #[derive(Debug, Deserialize, PartialEq, Eq, Clone)] #[serde(tag = "mode")] pub enum PeerConfig { @@ -73,6 +89,9 @@ pub enum PeerConfig { #[serde(rename = "pool")] Pool(PoolPeerConfig), // Consul(ConsulPeerConfig), + #[cfg(feature = "unstable_nts-pool")] + #[serde(rename = "nts-pool")] + NtsPool(NtsPoolPeerConfig), } /// A normalized address has a host and a port part. However, the host may be @@ -312,6 +331,8 @@ mod tests { PeerConfig::Standard(c) => c.address.to_string(), PeerConfig::Nts(c) => c.address.to_string(), PeerConfig::Pool(c) => c.addr.to_string(), + #[cfg(feature = "unstable_nts-pool")] + PeerConfig::NtsPool(c) => c.addr.to_string(), } } @@ -396,6 +417,22 @@ mod tests { if let PeerConfig::Nts(config) = test.peer { assert_eq!(config.address.to_string(), "example.com:4460"); } + + #[cfg(feature = "unstable_nts-pool")] + { + let test: TestConfig = toml::from_str( + r#" + [peer] + address = "example.com" + mode = "nts-pool" + "#, + ) + .unwrap(); + assert!(matches!(test.peer, PeerConfig::NtsPool(_))); + if let PeerConfig::Nts(config) = test.peer { + assert_eq!(config.address.to_string(), "example.com:4460"); + } + } } #[test] diff --git a/ntpd/src/daemon/keyexchange.rs b/ntpd/src/daemon/keyexchange.rs index 3029defaa..c63c7a599 100644 --- a/ntpd/src/daemon/keyexchange.rs +++ b/ntpd/src/daemon/keyexchange.rs @@ -48,7 +48,20 @@ pub(crate) async fn key_exchange_client( let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; let config = build_client_config(extra_certificates)?; - BoundKeyExchangeClient::new(socket, server_name, config)?.await + BoundKeyExchangeClient::new(socket, server_name, config, Vec::new())?.await +} + +#[cfg(feature = "unstable_nts-pool")] +pub(crate) async fn key_exchange_client_with_denied_servers( + server_name: String, + port: u16, + extra_certificates: &[Certificate], + denied_servers: impl IntoIterator, +) -> Result { + let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; + let config = build_client_config(extra_certificates)?; + + BoundKeyExchangeClient::new(socket, server_name, config, denied_servers)?.await } pub fn spawn( @@ -205,11 +218,12 @@ where io: IO, server_name: String, config: rustls::ClientConfig, + denied_servers: impl IntoIterator, ) -> Result { Ok(Self { inner: Some(BoundKeyExchangeClientData { io, - client: KeyExchangeClient::new(server_name, config)?, + client: KeyExchangeClient::new(server_name, config, denied_servers)?, need_flush: false, }), }) @@ -668,7 +682,7 @@ mod tests { fn client_key_exchange_message_length() -> usize { let mut buffer = Vec::with_capacity(1024); - for record in ntp_proto::NtsRecord::client_key_exchange_records() { + for record in ntp_proto::NtsRecord::client_key_exchange_records(vec![]).iter() { record.write(&mut buffer).unwrap(); } @@ -868,7 +882,7 @@ mod tests { #[tokio::test] async fn server_expected_client_records() { - let records = NtsRecord::client_key_exchange_records().to_vec(); + let records = NtsRecord::client_key_exchange_records(vec![]).to_vec(); let result = send_records_to_server(records).await; assert!(result.is_ok()); diff --git a/ntpd/src/daemon/spawn/mod.rs b/ntpd/src/daemon/spawn/mod.rs index afb398c5e..46238e4d1 100644 --- a/ntpd/src/daemon/spawn/mod.rs +++ b/ntpd/src/daemon/spawn/mod.rs @@ -9,6 +9,8 @@ use super::config::NormalizedAddress; #[cfg(test)] pub mod dummy; pub mod nts; +#[cfg(feature = "unstable_nts-pool")] +pub mod nts_pool; pub mod pool; pub mod standard; diff --git a/ntpd/src/daemon/spawn/nts_pool.rs b/ntpd/src/daemon/spawn/nts_pool.rs new file mode 100644 index 000000000..89c6b1284 --- /dev/null +++ b/ntpd/src/daemon/spawn/nts_pool.rs @@ -0,0 +1,188 @@ +use std::{net::SocketAddr, ops::Deref}; + +use thiserror::Error; +use tokio::sync::mpsc; +use tracing::warn; + +use super::super::{ + config::NtsPoolPeerConfig, keyexchange::key_exchange_client_with_denied_servers, +}; + +use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId}; + +struct PoolPeer { + id: PeerId, + remote: String, +} + +pub struct NtsPoolSpawner { + config: NtsPoolPeerConfig, + network_wait_period: std::time::Duration, + id: SpawnerId, + current_peers: Vec, +} + +#[derive(Error, Debug)] +pub enum NtsPoolSpawnError { + #[error("Channel send error: {0}")] + SendError(#[from] mpsc::error::SendError), +} + +impl NtsPoolSpawner { + pub fn new( + config: NtsPoolPeerConfig, + network_wait_period: std::time::Duration, + ) -> NtsPoolSpawner { + NtsPoolSpawner { + config, + network_wait_period, + id: Default::default(), + current_peers: Default::default(), + //known_ips: Default::default(), + } + } + + //NOTE: this is the same code as in nts.rs, so we should introduce some code sharing + async fn resolve_addr(&mut self, address: (&str, u16)) -> Option { + const MAX_RETRIES: usize = 5; + const BACKOFF_FACTOR: u32 = 2; + + let mut network_wait = self.network_wait_period; + + for i in 0..MAX_RETRIES { + if i != 0 { + // Ensure we dont spam dns + tokio::time::sleep(network_wait).await; + network_wait *= BACKOFF_FACTOR; + } + match tokio::net::lookup_host(address).await { + Ok(mut addresses) => match addresses.next() { + Some(address) => return Some(address), + None => { + warn!("received unknown domain name from NTS-ke"); + return None; + } + }, + Err(e) => { + warn!(error = ?e, "error while resolving peer address, retrying"); + } + } + } + + warn!("Could not resolve peer address, restarting NTS initialization"); + + None + } + + fn contains_peer(&self, domain: &str) -> bool { + self.current_peers.iter().any(|peer| peer.remote == domain) + } + + pub async fn fill_pool( + &mut self, + action_tx: &mpsc::Sender, + ) -> Result<(), NtsPoolSpawnError> { + let mut wait_period = self.network_wait_period; + + // early return if there is nothing to do + if self.current_peers.len() >= self.config.max_peers { + return Ok(()); + } + + loop { + // Try and add peers to our pool + while self.current_peers.len() < self.config.max_peers { + match key_exchange_client_with_denied_servers( + self.config.addr.server_name.clone(), + self.config.addr.port, + &self.config.certificate_authorities, + self.current_peers.iter().map(|peer| peer.remote.clone()), + ) + .await + { + Ok(ke) if !self.contains_peer(&ke.remote) => { + if let Some(address) = + self.resolve_addr((ke.remote.as_str(), ke.port)).await + { + let id = PeerId::new(); + self.current_peers.push(PoolPeer { + id, + remote: ke.remote, + }); + action_tx + .send(SpawnEvent::new( + self.id, + SpawnAction::create( + PeerId::new(), + address, + self.config.addr.deref().clone(), + ke.protocol_version, + Some(ke.nts), + ), + )) + .await?; + } + } + Ok(_) => { + warn!("received an address from pool-ke that we already had, ignoring"); + break; + } + Err(e) => { + warn!(error = ?e, "error while attempting key exchange"); + break; + } + }; + } + + let wait_period_max = if cfg!(test) { + std::time::Duration::default() + } else { + std::time::Duration::from_secs(60) + }; + + wait_period = Ord::min(2 * wait_period, wait_period_max); + let peers_needed = self.config.max_peers - self.current_peers.len(); + if peers_needed > 0 { + warn!(peers_needed, "could not fully fill pool"); + tokio::time::sleep(wait_period).await; + } else { + return Ok(()); + } + } + } +} + +#[async_trait::async_trait] +impl BasicSpawner for NtsPoolSpawner { + type Error = NtsPoolSpawnError; + + async fn handle_init( + &mut self, + action_tx: &mpsc::Sender, + ) -> Result<(), NtsPoolSpawnError> { + self.fill_pool(action_tx).await?; + Ok(()) + } + + async fn handle_peer_removed( + &mut self, + removed_peer: PeerRemovedEvent, + action_tx: &mpsc::Sender, + ) -> Result<(), NtsPoolSpawnError> { + self.current_peers.retain(|p| p.id != removed_peer.id); + self.fill_pool(action_tx).await?; + Ok(()) + } + + fn get_id(&self) -> SpawnerId { + self.id + } + + fn get_addr_description(&self) -> String { + format!("{} ({})", self.config.addr.deref(), self.config.max_peers) + } + + fn get_description(&self) -> &str { + "nts-pool" + } +} diff --git a/ntpd/src/daemon/system.rs b/ntpd/src/daemon/system.rs index dc437cd36..255b3bece 100644 --- a/ntpd/src/daemon/system.rs +++ b/ntpd/src/daemon/system.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "unstable_nts-pool")] +use super::spawn::nts_pool::NtsPoolSpawner; use super::{ config::{ClockConfig, NormalizedAddress, PeerConfig, ServerConfig}, peer::{MsgForSystem, PeerChannels}, @@ -127,6 +129,15 @@ pub async fn spawn( std::io::Error::new(std::io::ErrorKind::Other, e) })?; } + #[cfg(feature = "unstable_nts-pool")] + PeerConfig::NtsPool(cfg) => { + system + .add_spawner(NtsPoolSpawner::new(cfg.clone(), NETWORK_WAIT_PERIOD)) + .map_err(|e| { + tracing::error!("Could not spawn peer: {}", e); + std::io::Error::new(std::io::ErrorKind::Other, e) + })?; + } } }