diff --git a/ntp-proto/src/nts_pool_ke.rs b/ntp-proto/src/nts_pool_ke.rs index f4526ac93..67eed9798 100644 --- a/ntp-proto/src/nts_pool_ke.rs +++ b/ntp-proto/src/nts_pool_ke.rs @@ -199,8 +199,8 @@ impl ClientToPoolDecoder { } #[cfg(feature = "nts-pool")] - NtpServerDeny { _denied } => { - state.denied_servers.push(_denied); + NtpServerDeny { denied } => { + state.denied_servers.push(denied); Continue(state) } diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index 930cf3e2d..f918d1c7a 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -168,7 +168,7 @@ pub enum NtsRecord { }, #[cfg(feature = "nts-pool")] NtpServerDeny { - _denied: String, + denied: String, }, } @@ -202,9 +202,9 @@ impl NtsRecord { pub const BAD_REQUEST: u16 = 1; pub const INTERNAL_SERVER_ERROR: u16 = 2; - pub fn client_key_exchange_records() -> [NtsRecord; if cfg!(feature = "ntpv5") { 4 } else { 3 }] - { - [ + #[cfg_attr(not(feature = "nts-pool"), allow(unused_variables))] + pub fn client_key_exchange_records(denied_servers: &[&str]) -> Box<[NtsRecord]> { + let mut base = vec![ #[cfg(feature = "ntpv5")] NtsRecord::DraftId { data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(), @@ -223,8 +223,20 @@ impl NtsRecord { .map(|algorithm| *algorithm as u16) .collect(), }, - NtsRecord::EndOfMessage, - ] + ]; + + #[cfg(feature = "nts-pool")] + base.extend( + denied_servers + .iter() + .map(|server| NtsRecord::NtpServerDeny { + denied: server.to_string(), + }), + ); + + base.push(NtsRecord::EndOfMessage); + + base.into_boxed_slice() } #[cfg(feature = "nts-pool")] @@ -385,7 +397,7 @@ impl NtsRecord { // NOTE: the string data should be ascii (not utf8) but we don't enforce that here let str_data = read_bytes_exact(reader, record_len)?; match String::from_utf8(str_data) { - Ok(_denied) => NtsRecord::NtpServerDeny { _denied }, + Ok(denied) => NtsRecord::NtpServerDeny { denied }, Err(e) => NtsRecord::Unknown { record_type, critical, @@ -496,7 +508,7 @@ impl NtsRecord { writer.write_all(s2c)?; } #[cfg(feature = "nts-pool")] - NtsRecord::NtpServerDeny { _denied: name } => { + NtsRecord::NtpServerDeny { denied: name } => { // NOTE: the server name should be ascii #[cfg(not(feature = "__internal-fuzz"))] debug_assert!(name.is_ascii()); @@ -1149,6 +1161,7 @@ impl KeyExchangeClient { pub fn new( server_name: String, tls_config: rustls::ClientConfig, + denied_servers: &[&str], ) -> Result { let mut client = Self::new_without_tls_write(server_name, tls_config)?; @@ -1156,7 +1169,7 @@ impl KeyExchangeClient { // We use an intermediary buffer to ensure that all records are sent at once. // This should not be needed, but works around issues in some NTS-ke server implementations let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records(denied_servers).iter() { record.write(&mut buffer)?; } client.tls_connection.writer().write_all(&buffer)?; @@ -1403,7 +1416,7 @@ impl KeyExchangeServerDecoder { Continue(state) } #[cfg(feature = "nts-pool")] - NtpServerDeny { _denied } => { + NtpServerDeny { denied: _ } => { // we are not a NTS pool server, so we ignore this record Continue(state) } @@ -1753,7 +1766,7 @@ mod test { #[test] fn test_client_key_exchange_records() { let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records(&[]).iter() { record.write(&mut buffer).unwrap(); } @@ -1777,7 +1790,7 @@ mod test { decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), ], - NtsRecord::client_key_exchange_records() + NtsRecord::client_key_exchange_records(&[]).as_ref() ); assert!(decoder.step().unwrap().is_none()); @@ -1901,7 +1914,7 @@ mod test { let mut buffer = Vec::new(); let record = NtsRecord::NtpServerDeny { - _denied: String::from("a string"), + denied: String::from("a string"), }; record.write(&mut buffer).unwrap(); @@ -2506,14 +2519,14 @@ mod test { #[test] fn server_decoder_finds_algorithm() { - let result = server_decode_records(&NtsRecord::client_key_exchange_records()).unwrap(); + let result = server_decode_records(&NtsRecord::client_key_exchange_records(&[])).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_new_cookie() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert( 0, NtsRecord::NewCookie { @@ -2527,7 +2540,7 @@ mod test { #[test] fn server_decoder_ignores_server_and_port_preference() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert( 0, NtsRecord::Server { @@ -2550,7 +2563,7 @@ mod test { #[test] fn server_decoder_ignores_warn() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert(0, NtsRecord::Warning { warningcode: 42 }); let result = server_decode_records(&records).unwrap(); @@ -2559,7 +2572,7 @@ mod test { #[test] fn server_decoder_ignores_unknown_not_critical() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert( 0, NtsRecord::Unknown { @@ -2575,7 +2588,7 @@ mod test { #[test] fn server_decoder_reports_unknown_critical() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert( 0, NtsRecord::Unknown { @@ -2594,7 +2607,7 @@ mod test { #[test] fn server_decoder_reports_error() { - let mut records = NtsRecord::client_key_exchange_records().to_vec(); + let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec(); records.insert(0, NtsRecord::Error { errorcode: 2 }); let error = server_decode_records(&records).unwrap_err(); @@ -2760,7 +2773,7 @@ mod test { .with_no_client_auth(); let mut server = rustls::ServerConnection::new(Arc::new(serverconfig)).unwrap(); - let mut client = KeyExchangeClient::new("localhost".into(), clientconfig).unwrap(); + let mut client = KeyExchangeClient::new("localhost".into(), clientconfig, &[]).unwrap(); server.writer().write_all(NTS_TIME_NL_RESPONSE).unwrap(); @@ -2942,7 +2955,7 @@ mod test { let (mut client, server) = client_server_pair(ClientType::Uncertified); let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records(&[]).iter() { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); @@ -3040,7 +3053,7 @@ mod test { #[test] fn test_keyexchange_invalid_input() { let mut buffer = Vec::with_capacity(1024); - for record in NtsRecord::client_key_exchange_records() { + for record in NtsRecord::client_key_exchange_records(&[]).iter() { record.write(&mut buffer).unwrap(); } diff --git a/ntpd/src/daemon/keyexchange.rs b/ntpd/src/daemon/keyexchange.rs index 3029defaa..d5b2dc04d 100644 --- a/ntpd/src/daemon/keyexchange.rs +++ b/ntpd/src/daemon/keyexchange.rs @@ -48,7 +48,20 @@ pub(crate) async fn key_exchange_client( let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; let config = build_client_config(extra_certificates)?; - BoundKeyExchangeClient::new(socket, server_name, config)?.await + BoundKeyExchangeClient::new(socket, server_name, config, &[])?.await +} + +#[cfg(feature = "unstable_nts-pool")] +pub(crate) async fn key_exchange_client_with_denied_servers( + server_name: String, + port: u16, + extra_certificates: &[Certificate], + denied_servers: &[&str], +) -> Result { + let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; + let config = build_client_config(extra_certificates)?; + + BoundKeyExchangeClient::new(socket, server_name, config, denied_servers)?.await } pub fn spawn( @@ -205,11 +218,12 @@ where io: IO, server_name: String, config: rustls::ClientConfig, + denied_servers: &[&str], ) -> Result { Ok(Self { inner: Some(BoundKeyExchangeClientData { io, - client: KeyExchangeClient::new(server_name, config)?, + client: KeyExchangeClient::new(server_name, config, denied_servers)?, need_flush: false, }), }) @@ -668,7 +682,7 @@ mod tests { fn client_key_exchange_message_length() -> usize { let mut buffer = Vec::with_capacity(1024); - for record in ntp_proto::NtsRecord::client_key_exchange_records() { + for record in ntp_proto::NtsRecord::client_key_exchange_records(&[]).iter() { record.write(&mut buffer).unwrap(); } @@ -868,7 +882,7 @@ mod tests { #[tokio::test] async fn server_expected_client_records() { - let records = NtsRecord::client_key_exchange_records().to_vec(); + let records = NtsRecord::client_key_exchange_records(&[]).to_vec(); let result = send_records_to_server(records).await; assert!(result.is_ok()); diff --git a/ntpd/src/daemon/spawn/nts_pool.rs b/ntpd/src/daemon/spawn/nts_pool.rs index aed21723d..d58a9eb8c 100644 --- a/ntpd/src/daemon/spawn/nts_pool.rs +++ b/ntpd/src/daemon/spawn/nts_pool.rs @@ -4,7 +4,9 @@ use thiserror::Error; use tokio::sync::mpsc; use tracing::warn; -use super::super::{config::NtsPoolPeerConfig, keyexchange::key_exchange_client}; +use super::super::{ + config::NtsPoolPeerConfig, keyexchange::key_exchange_client_with_denied_servers, +}; use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId}; @@ -90,10 +92,15 @@ impl NtsPoolSpawner { loop { // Try and add peers to our pool while self.current_peers.len() < self.config.max_peers { - match key_exchange_client( + match key_exchange_client_with_denied_servers( self.config.addr.server_name.clone(), self.config.addr.port, &self.config.certificate_authorities, + &self + .current_peers + .iter() + .map(|peer| peer.remote.as_str()) + .collect::>(), ) .await {