diff --git a/ntp-proto/src/algorithm/kalman/source.rs b/ntp-proto/src/algorithm/kalman/source.rs index b89537bbb..d5f137e62 100644 --- a/ntp-proto/src/algorithm/kalman/source.rs +++ b/ntp-proto/src/algorithm/kalman/source.rs @@ -285,7 +285,7 @@ pub trait MeasurementNoiseEstimator { fn reset(&mut self) -> Self; // for SourceSnapshot - fn get_max_roundtrip(&self, samples: &i32) -> Option; + fn get_max_roundtrip(&self, samples: i32) -> Option; fn get_delay_mean(&self) -> f64; } @@ -293,7 +293,7 @@ impl MeasurementNoiseEstimator for AveragingBuffer { type MeasurementDelay = NtpDuration; fn update(&mut self, delay: Self::MeasurementDelay) { - self.update(delay.to_seconds()) + self.update(delay.to_seconds()); } fn get_noise_estimate(&self) -> f64 { @@ -312,8 +312,9 @@ impl MeasurementNoiseEstimator for AveragingBuffer { AveragingBuffer::default() } - fn get_max_roundtrip(&self, samples: &i32) -> Option { - self.data[..*samples as usize] + #[allow(clippy::cast_sign_loss)] + fn get_max_roundtrip(&self, samples: i32) -> Option { + self.data[..samples as usize] .iter() .copied() .fold(None, |v1, v2| { @@ -351,7 +352,7 @@ impl MeasurementNoiseEstimator for f64 { *self } - fn get_max_roundtrip(&self, _samples: &i32) -> Option { + fn get_max_roundtrip(&self, _samples: i32) -> Option { Some(1.) } @@ -712,12 +713,12 @@ impl } } - #[allow(clippy::cast_sign_loss)] fn snapshot( &self, index: Index, config: &AlgorithmConfig, ) -> Option> { + #[allow(clippy::cast_sign_loss)] match &self.0 { SourceStateInner::Initial(InitialSourceFilter { noise_estimator, @@ -725,7 +726,7 @@ impl last_measurement: Some(last_measurement), samples, }) if *samples > 0 => { - let max_roundtrip = noise_estimator.get_max_roundtrip(samples)?; + let max_roundtrip = noise_estimator.get_max_roundtrip(*samples)?; Some(SourceSnapshot { index, source_uncertainty: last_measurement.root_dispersion, @@ -1043,7 +1044,7 @@ mod tests { D: Debug + Clone + Copy, N: MeasurementNoiseEstimator + Clone, >( - noise_estimator: N, + noise_estimator: &N, delay: D, ) { let base = NtpTimestamp::from_fixed_int(0); @@ -1242,7 +1243,7 @@ mod tests { #[test] fn test_offset_steering_and_measurements_normal() { test_offset_steering_and_measurements( - AveragingBuffer { + &AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, @@ -1252,7 +1253,7 @@ mod tests { #[test] fn test_offset_steering_and_measurements_constant_noise_estimate() { - test_offset_steering_and_measurements(1e-9, ()); + test_offset_steering_and_measurements(&1e-9, ()); } #[test] diff --git a/ntp-proto/src/ipfilter.rs b/ntp-proto/src/ipfilter.rs index 673b435d1..673703cdd 100644 --- a/ntp-proto/src/ipfilter.rs +++ b/ntp-proto/src/ipfilter.rs @@ -185,7 +185,7 @@ impl IpFilter { for subnet in subnets { match subnet.addr { IpAddr::V4(addr) => ipv4list.push(( - (u32::from_be_bytes(addr.octets()) as u128) << 96, + u128::from(u32::from_be_bytes(addr.octets())) << 96, subnet.mask, )), IpAddr::V6(addr) => { @@ -214,7 +214,7 @@ impl IpFilter { /// Panics if `addr` has invalid octets. fn is_in4(&self, addr: Ipv4Addr) -> bool { self.ipv4_filter - .lookup((u32::from_be_bytes(addr.octets()) as u128) << 96) + .lookup(u128::from(u32::from_be_bytes(addr.octets())) << 96) } fn is_in6(&self, addr: Ipv6Addr) -> bool { diff --git a/ntp-proto/src/keyset.rs b/ntp-proto/src/keyset.rs index afdaf30e9..a608e0511 100644 --- a/ntp-proto/src/keyset.rs +++ b/ntp-proto/src/keyset.rs @@ -118,10 +118,11 @@ impl KeySetProvider { let mut buf = [0; 64]; reader.read_exact(&mut buf[0..20])?; - let time = Self::convert_to_system_time(&buf[0..8])?; - let id_offset = Self::convert_to_u32(&buf[8..12])?; - let primary = Self::convert_to_u32(&buf[12..16])?; - let len = Self::convert_to_u32(&buf[16..20])?; + let time = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_secs(u64::from_be_bytes(buf[0..8].try_into().unwrap())); + let id_offset = u32::from_be_bytes(buf[8..12].try_into().unwrap()); + let primary = u32::from_be_bytes(buf[12..16].try_into().unwrap()); + let len = u32::from_be_bytes(buf[16..20].try_into().unwrap()); if primary >= len { return Err(std::io::Error::new( @@ -175,23 +176,6 @@ impl KeySetProvider { pub fn get(&self) -> Arc { self.current.clone() } - - fn convert_to_system_time(bytes: &[u8]) -> std::io::Result { - let time = u64::from_be_bytes(bytes.try_into().map_err(|_| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Invalid buffer for SystemTime", - ) - })?); - Ok(std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(time)) - } - - fn convert_to_u32(bytes: &[u8]) -> std::io::Result { - let value = u32::from_be_bytes(bytes.try_into().map_err(|_| { - std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid buffer for u32") - })?); - Ok(value) - } } pub struct KeySet { diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index b3aaf883d..000806cfe 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -1229,9 +1229,6 @@ impl KeyExchangeClient { self.tls_connection.read_tls(rd) } - /// # Errors - /// - /// Returns error if getting the TLS bool `wants_write` fails. #[must_use] pub fn wants_write(&self) -> bool { self.tls_connection.wants_write() @@ -1843,7 +1840,7 @@ impl KeyExchangeServer { fn decoder_done( mut self, - data: ServerKeyExchangeData, + data: &ServerKeyExchangeData, ) -> ControlFlow, Self> { let algorithm = data.algorithm; let protocol = data.protocol; @@ -3029,7 +3026,7 @@ mod test { Certified, } - fn client_server_pair(client_type: ClientType) -> (KeyExchangeClient, KeyExchangeServer) { + fn client_server_pair(client_type: &ClientType) -> (KeyExchangeClient, KeyExchangeServer) { #[allow(unused)] use tls_utils::CloneKeyShim; diff --git a/ntp-proto/src/packet/extension_fields.rs b/ntp-proto/src/packet/extension_fields.rs index 2636caee0..9e0c738c8 100644 --- a/ntp-proto/src/packet/extension_fields.rs +++ b/ntp-proto/src/packet/extension_fields.rs @@ -130,12 +130,13 @@ impl<'a> ExtensionField<'a> { #[must_use] pub fn into_owned(self) -> ExtensionField<'static> { - #[cfg(feature = "ntpv5")] use ExtensionField::{ - DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse, + InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown, }; + + #[cfg(feature = "ntpv5")] use ExtensionField::{ - InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown, + DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse, }; match self { @@ -171,12 +172,13 @@ impl<'a> ExtensionField<'a> { minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { - #[cfg(feature = "ntpv5")] use ExtensionField::{ - DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse, + InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown, }; + + #[cfg(feature = "ntpv5")] use ExtensionField::{ - InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown, + DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse, }; match self { @@ -204,7 +206,9 @@ impl<'a> ExtensionField<'a> { } } - #[allow(clippy::missing_errors_doc)] + /// # Errors + /// + /// Returns `io::Error` if serialization fails. #[cfg(feature = "__internal-fuzz")] pub fn serialize_pub( &self, @@ -469,7 +473,7 @@ impl<'a> ExtensionField<'a> { /// # Errors /// - /// Returns error if writing to the sink fails. + /// Returns `io::Error` if encoding fails. #[cfg(feature = "ntpv5")] pub fn encode_padding_field( mut w: impl NonBlockingWrite, @@ -552,7 +556,7 @@ impl<'a> ExtensionField<'a> { type EF<'a> = ExtensionField<'a>; type TypeId = ExtensionFieldTypeId; - let message = raw.message_bytes; + let message = &raw.message_bytes; match raw.type_id { TypeId::UniqueIdentifier => EF::decode_unique_identifier(message), @@ -563,7 +567,7 @@ impl<'a> ExtensionField<'a> { EF::decode_draft_identification(message, extension_header_version) } #[cfg(feature = "ntpv5")] - TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message).into()), + TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message)?.into()), #[cfg(feature = "ntpv5")] TypeId::ReferenceIdResponse => Ok(ReferenceIdResponse::decode(message).into()), type_id => Ok(EF::decode_unknown(type_id.to_type_id(), message)), @@ -669,11 +673,11 @@ impl<'a> ExtensionFieldData<'a> { RawExtensionField::V4_UNENCRYPTED_MINIMUM_SIZE, version, ) { - let (offset, field) = field.map_err(super::error::ParsingError::generalize)?; + let (offset, field) = field.map_err(super::ParsingError::generalize)?; size = offset + field.wire_length(version); if field.type_id == ExtensionFieldTypeId::NtsEncryptedField { let encrypted = RawEncryptedField::from_message_bytes(field.message_bytes) - .map_err(super::error::ParsingError::generalize)?; + .map_err(super::ParsingError::generalize)?; let Some(cipher) = cipher.get(&efdata.untrusted) else { efdata.untrusted.push(InvalidNtsEncryptedField); @@ -711,7 +715,7 @@ impl<'a> ExtensionFieldData<'a> { efdata.authenticated.append(&mut efdata.untrusted); } else { let field = ExtensionField::decode(&field, version) - .map_err(super::error::ParsingError::generalize)?; + .map_err(super::ParsingError::generalize)?; efdata.untrusted.push(field); } } @@ -746,23 +750,23 @@ impl<'a> RawEncryptedField<'a> { fn from_message_bytes( message_bytes: &'a [u8], ) -> Result> { - use ParsingError::IncorrectLength; - let [b0, b1, b2, b3, ref rest @ ..] = message_bytes[..] else { - return Err(IncorrectLength); + return Err(ParsingError::IncorrectLength); }; let nonce_length = u16::from_be_bytes([b0, b1]) as usize; let ciphertext_length = u16::from_be_bytes([b2, b3]) as usize; - let nonce = rest.get(..nonce_length).ok_or(IncorrectLength)?; + let nonce = rest + .get(..nonce_length) + .ok_or(ParsingError::IncorrectLength)?; // skip the lengths and the nonce. pad to a multiple of 4 let ciphertext_start = 4 + next_multiple_of_u16(nonce_length as u16, 4) as usize; let ciphertext = message_bytes .get(ciphertext_start..ciphertext_start + ciphertext_length) - .ok_or(IncorrectLength)?; + .ok_or(ParsingError::IncorrectLength)?; Ok(Self { nonce, ciphertext }) } diff --git a/ntp-proto/src/packet/mod.rs b/ntp-proto/src/packet/mod.rs index 2e9cf1be1..92b8ce350 100644 --- a/ntp-proto/src/packet/mod.rs +++ b/ntp-proto/src/packet/mod.rs @@ -1352,6 +1352,7 @@ fn check_uid_extensionfield<'a, I: IntoIterator>>( #[cfg(any(test, feature = "__internal-fuzz", feature = "__internal-test"))] impl NtpPacket<'_> { + #[must_use] pub fn test() -> Self { Self::default() } diff --git a/ntp-proto/src/packet/v5/extension_fields.rs b/ntp-proto/src/packet/v5/extension_fields.rs index 34ae794d0..fb653b395 100644 --- a/ntp-proto/src/packet/v5/extension_fields.rs +++ b/ntp-proto/src/packet/v5/extension_fields.rs @@ -1,7 +1,9 @@ use crate::io::NonBlockingWrite; +use crate::packet::error::ParsingError; use crate::packet::v5::server_reference_id::BloomFilter; use crate::packet::ExtensionField; use std::borrow::Cow; +use std::convert::Infallible; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Type { @@ -119,7 +121,7 @@ impl ReferenceIdRequest { Ok(()) } - pub fn decode(msg: &[u8]) -> Self { + pub fn decode(msg: &[u8]) -> Result> { let payload_len = u16::try_from(msg.len()).expect("NTP fields can not be longer than u16::MAX"); let offset_bytes: [u8; 2] = msg @@ -128,10 +130,10 @@ impl ReferenceIdRequest { .try_into() .unwrap(); - Self { + Ok(Self { payload_len, offset: u16::from_be_bytes(offset_bytes), - } + }) } pub const fn offset(self) -> u16 { diff --git a/ntp-proto/src/source.rs b/ntp-proto/src/source.rs index cff4c121d..b38c1e0c8 100644 --- a/ntp-proto/src/source.rs +++ b/ntp-proto/src/source.rs @@ -110,7 +110,7 @@ impl> OneWaySource Self { AllowAnyAnonymousOrCertificateBearingClient { supported_algs: provider.signature_verification_algorithms, @@ -93,6 +94,7 @@ mod rustls23_shim { pub use rustls_pemfile2::pkcs8_private_keys; pub use rustls_pemfile2::private_key; + #[must_use] pub fn rootstore_ref_shim(cert: &super::Certificate) -> super::Certificate { cert.clone() } @@ -100,22 +102,26 @@ mod rustls23_shim { pub trait CloneKeyShim {} + #[must_use] pub fn client_config_builder( ) -> rustls23::ConfigBuilder { ClientConfig::builder() } + #[must_use] pub fn client_config_builder_with_protocol_versions( versions: &[&'static rustls23::SupportedProtocolVersion], ) -> rustls23::ConfigBuilder { ClientConfig::builder_with_protocol_versions(versions) } + #[must_use] pub fn server_config_builder( ) -> rustls23::ConfigBuilder { ServerConfig::builder() } + #[must_use] pub fn server_config_builder_with_protocol_versions( versions: &[&'static rustls23::SupportedProtocolVersion], ) -> rustls23::ConfigBuilder { @@ -276,11 +282,11 @@ mod rustls21_shim { ) -> Result, std::io::Error> { for item in std::iter::from_fn(|| rustls_pemfile1::read_one(rd).transpose()) { match item { - Ok(rustls_pemfile1::Item::RSAKey(key)) - | Ok(rustls_pemfile1::Item::PKCS8Key(key)) - | Ok(rustls_pemfile1::Item::ECKey(key)) => { - return Ok(Some(super::PrivateKey(key))) - } + Ok( + rustls_pemfile1::Item::RSAKey(key) + | rustls_pemfile1::Item::PKCS8Key(key) + | rustls_pemfile1::Item::ECKey(key), + ) => return Ok(Some(super::PrivateKey(key))), Err(e) => return Err(e), _ => {} } diff --git a/ntpd/src/ctl.rs b/ntpd/src/ctl.rs index 6840fcd55..8d2f361f1 100644 --- a/ntpd/src/ctl.rs +++ b/ntpd/src/ctl.rs @@ -124,27 +124,30 @@ impl NtpCtlOptions { } } -fn validate(config: Option) -> std::io::Result { +fn validate(config: Option<&PathBuf>) -> ExitCode { // Late completion not needed, so ignore result. crate::daemon::tracing::tracing_init(LogLevel::Info, true).init(); - match Config::from_args(config, vec![], vec![]) { + match Config::from_args(config.as_ref(), vec![], vec![]) { Ok(config) => { if config.check() { eprintln!("Config looks good"); - Ok(ExitCode::SUCCESS) + ExitCode::SUCCESS } else { - Ok(ExitCode::FAILURE) + ExitCode::FAILURE } } Err(e) => { eprintln!("Error: Could not load configuration: {e}"); - Ok(ExitCode::FAILURE) + ExitCode::FAILURE } } } const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// # Errors +/// +/// Returns 'Error' if arguments to program are invalid. pub fn main() -> std::io::Result { let options = match NtpCtlOptions::try_parse_from(std::env::args()) { Ok(options) => options, @@ -160,10 +163,10 @@ pub fn main() -> std::io::Result { eprintln!("ntp-ctl {VERSION}"); Ok(ExitCode::SUCCESS) } - NtpCtlAction::Validate => validate(options.config), - NtpCtlAction::ForceSync => force_sync::force_sync(options.config), + NtpCtlAction::Validate => Ok(validate(options.config.as_ref())), + NtpCtlAction::ForceSync => force_sync::force_sync(options.config.as_ref()), NtpCtlAction::Status => { - let config = Config::from_args(options.config, vec![], vec![]); + let config = Config::from_args(options.config.as_ref(), vec![], vec![]); if let Err(ref e) = config { println!("Warning: Unable to load configuration file: {e}"); @@ -292,9 +295,12 @@ mod tests { use std::os::unix::prelude::PermissionsExt; use std::path::Path; + use ntp_proto::SystemSnapshot; + use crate::{ daemon::{ config::ObservabilityConfig, + observer::ProgramData, sockets::{create_unix_socket_with_permissions, write_json}, }, test::alloc_port, @@ -333,8 +339,8 @@ mod tests { #[tokio::test] async fn test_control_socket_source() -> std::io::Result<()> { let value = ObservableState { - program: Default::default(), - system: Default::default(), + program: ProgramData::default(), + system: SystemSnapshot::default(), sources: vec![], servers: vec![], }; @@ -351,8 +357,8 @@ mod tests { #[tokio::test] async fn test_control_socket_prometheus() -> std::io::Result<()> { let value = ObservableState { - program: Default::default(), - system: Default::default(), + program: ProgramData::default(), + system: SystemSnapshot::default(), sources: vec![], servers: vec![], }; diff --git a/ntpd/src/daemon/config/mod.rs b/ntpd/src/daemon/config/mod.rs index d9916a3f3..9ad21e264 100644 --- a/ntpd/src/daemon/config/mod.rs +++ b/ntpd/src/daemon/config/mod.rs @@ -323,9 +323,9 @@ pub struct ObservabilityConfig { impl Default for ObservabilityConfig { fn default() -> Self { Self { - log_level: Default::default(), + log_level: Option::default(), ansi_colors: default_ansi_colors(), - observation_path: Default::default(), + observation_path: Option::default(), observation_permissions: default_observation_permissions(), metrics_exporter_listen: default_metrics_exporter_listen(), } @@ -416,7 +416,7 @@ impl Config { } pub fn from_args( - file: Option>, + file: Option<&impl AsRef>, sources: Vec, servers: Vec, ) -> Result { @@ -532,6 +532,7 @@ mod tests { use super::*; #[test] + #[allow(clippy::too_many_lines)] fn test_config() { let config: Config = toml::from_str("[[source]]\nmode = \"server\"\naddress = \"example.com\"").unwrap(); diff --git a/ntpd/src/daemon/config/ntp_source.rs b/ntpd/src/daemon/config/ntp_source.rs index 4492ed3d4..135c08a69 100644 --- a/ntpd/src/daemon/config/ntp_source.rs +++ b/ntpd/src/daemon/config/ntp_source.rs @@ -375,7 +375,7 @@ mod tests { NtpSourceConfig::Pool(c) => c.addr.to_string(), #[cfg(feature = "unstable_nts-pool")] NtpSourceConfig::NtsPool(c) => c.addr.to_string(), - NtpSourceConfig::Sock(_c) => "".to_string(), + NtpSourceConfig::Sock(_c) => String::new(), } } diff --git a/ntpd/src/daemon/config/server.rs b/ntpd/src/daemon/config/server.rs index 2804d4b34..3c4d43afa 100644 --- a/ntpd/src/daemon/config/server.rs +++ b/ntpd/src/daemon/config/server.rs @@ -144,7 +144,7 @@ impl From for ServerConfig { denylist: default_denylist(), allowlist: default_allowlist(), rate_limiting_cache_size: Default::default(), - rate_limiting_cutoff: Default::default(), + rate_limiting_cutoff: Duration::default(), require_nts: None, } } diff --git a/ntpd/src/daemon/mod.rs b/ntpd/src/daemon/mod.rs index 510783378..666e6e639 100644 --- a/ntpd/src/daemon/mod.rs +++ b/ntpd/src/daemon/mod.rs @@ -29,6 +29,9 @@ use self::tracing::LogLevel; const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// # Errors +/// +/// Returns 'Error' if arguments to program are invalid. pub fn main() -> Result<(), Box> { let options = NtpDaemonOptions::try_parse_from(std::env::args())?; @@ -39,7 +42,7 @@ pub fn main() -> Result<(), Box> { config::NtpDaemonAction::Version => { eprintln!("ntp-daemon {VERSION}"); } - config::NtpDaemonAction::Run => run(options)?, + config::NtpDaemonAction::Run => run(&options)?, } Ok(()) @@ -49,13 +52,13 @@ pub fn main() -> Result<(), Box> { // log level based on the config if required. pub(crate) fn initialize_logging_parse_config( initial_log_level: Option, - config_path: Option, + config_path: Option<&PathBuf>, ) -> Config { let mut log_level = initial_log_level.unwrap_or_default(); let config_tracing = crate::daemon::tracing::tracing_init(log_level, true); let config = ::tracing::subscriber::with_default(config_tracing, || { - match Config::from_args(config_path, vec![], vec![]) { + match Config::from_args(config_path.as_ref(), vec![], vec![]) { Ok(c) => c, Err(e) => { // print to stderr because tracing is not yet setup @@ -78,8 +81,8 @@ pub(crate) fn initialize_logging_parse_config( config } -fn run(options: NtpDaemonOptions) -> Result<(), Box> { - let config = initialize_logging_parse_config(options.log_level, options.config); +fn run(options: &NtpDaemonOptions) -> Result<(), Box> { + let config = initialize_logging_parse_config(options.log_level, options.config.as_ref()); let runtime = if config.servers.is_empty() && config.nts_ke.is_empty() { Builder::new_current_thread().enable_all().build()? diff --git a/ntpd/src/daemon/ntp_source.rs b/ntpd/src/daemon/ntp_source.rs index 902d89209..0690e71cc 100644 --- a/ntpd/src/daemon/ntp_source.rs +++ b/ntpd/src/daemon/ntp_source.rs @@ -84,6 +84,18 @@ enum SocketResult { Abort, } +#[allow(clippy::large_enum_variant)] +enum SelectResult { + Timer, + Recv(Result, std::io::Error>), + SystemUpdate( + Result< + SystemSourceUpdate, + tokio::sync::broadcast::error::RecvError, + >, + ), +} + impl, T> SourceTask where @@ -132,22 +144,60 @@ where }, }; - let actions = self.handle_selected_actions(selected, &buf).await; - for action in actions { - if self.process_action(action, &mut poll_wait).await { - return; - } + if let Some(actions) = self.handle_selected(&buf, selected).await { + self.handle_actions(&mut poll_wait, actions).await; } } } - async fn handle_selected_actions( + async fn handle_selected( &mut self, - selected: SelectResult, buf: &[u8], - ) -> NtpSourceActionIterator { + selected: SelectResult, + ) -> Option::SourceMessage>> { match selected { - SelectResult::Recv(result) => self.handle_recv(result, buf).await, + SelectResult::Recv(result) => { + tracing::debug!("accept packet"); + match accept_packet(result, &buf, &self.clock) { + AcceptResult::Accept(packet, recv_timestamp) => { + let Some(send_timestamp) = self.last_send_timestamp else { + debug!("we received a message without having sent one; discarding"); + return None; + }; + let actions = self.source.handle_incoming( + packet, + NtpInstant::now(), + send_timestamp, + recv_timestamp, + ); + self.channels + .source_snapshots + .write() + .expect("Unexpected poisoned mutex") + .insert( + self.index, + self.source.observe(self.name.clone(), self.index), + ); + return Some(actions); + } + AcceptResult::NetworkGone => { + self.channels + .msg_for_system_sender + .send(MsgForSystem::NetworkIssue(self.index)) + .await + .ok(); + self.channels + .source_snapshots + .write() + .expect("Unexpected poisoned mutex") + .remove(&self.index); + return None; + } + AcceptResult::Ignore => { + return Some(NtpSourceActionIterator::default()); + } + } + } SelectResult::Timer => { tracing::debug!("wait completed"); let actions = self.source.handle_timer(); @@ -159,7 +209,7 @@ where self.index, self.source.observe(self.name.clone(), self.index), ); - actions + return Some(actions); } SelectResult::SystemUpdate(result) => match result { Ok(update) => { @@ -172,69 +222,96 @@ where self.index, self.source.observe(self.name.clone(), self.index), ); - actions + return Some(actions); } - Err(_) => NtpSourceActionIterator::default(), + Err(_) => return Some(NtpSourceActionIterator::default()), }, - } - } - - async fn handle_recv( - &mut self, - result: Result, std::io::Error>, - buf: &[u8], - ) -> NtpSourceActionIterator<::SourceMessage> { - tracing::debug!("accept packet"); - match accept_packet(result, buf, &self.clock) { - AcceptResult::Accept(packet, recv_timestamp) => { - let Some(send_timestamp) = self.last_send_timestamp else { - debug!("we received a message without having sent one; discarding"); - return NtpSourceActionIterator::default(); - }; - let actions = self.source.handle_incoming( - packet, - NtpInstant::now(), - send_timestamp, - recv_timestamp, - ); - self.channels - .source_snapshots - .write() - .expect("Unexpected poisoned mutex") - .insert( - self.index, - self.source.observe(self.name.clone(), self.index), - ); - actions - } - AcceptResult::NetworkGone => { - self.channels - .msg_for_system_sender - .send(MsgForSystem::NetworkIssue(self.index)) - .await - .ok(); - self.channels - .source_snapshots - .write() - .expect("Unexpected poisoned mutex") - .remove(&self.index); - NtpSourceActionIterator::default() - } - AcceptResult::Ignore => NtpSourceActionIterator::default(), - } + }; } - async fn process_action( + async fn handle_actions( &mut self, - action: ntp_proto::NtpSourceAction<::SourceMessage>, poll_wait: &mut Pin<&mut T>, - ) -> bool { - match action { - ntp_proto::NtpSourceAction::Send(packet) => { - if matches!(self.setup_socket(), SocketResult::Abort) { + actions: NtpSourceActionIterator<::SourceMessage>, + ) { + for action in actions { + match action { + ntp_proto::NtpSourceAction::Send(packet) => { + if matches!(self.setup_socket(), SocketResult::Abort) { + self.channels + .msg_for_system_sender + .send(MsgForSystem::NetworkIssue(self.index)) + .await + .ok(); + self.channels + .source_snapshots + .write() + .expect("Unexpected poisoned mutex") + .remove(&self.index); + return; + } + + match self.clock.now() { + Err(e) => { + // we cannot determine the origin_timestamp + error!(error = ?e, "There was an error retrieving the current time"); + + // report as no permissions, since this seems the most likely + std::process::exit(exitcode::NOPERM); + } + Ok(ts) => { + self.last_send_timestamp = Some(ts); + } + } + + match self.socket.as_mut().unwrap().send(&packet).await { + Err(error) => { + warn!(?error, "poll message could not be sent"); + + match error.raw_os_error() { + Some( + libc::EHOSTDOWN + | libc::EHOSTUNREACH + | libc::ENETDOWN + | libc::ENETUNREACH, + ) => { + self.channels + .msg_for_system_sender + .send(MsgForSystem::NetworkIssue(self.index)) + .await + .ok(); + self.channels + .source_snapshots + .write() + .expect("Unexpected poisoned mutex") + .remove(&self.index); + return; + } + _ => {} + } + } + Ok(opt_send_timestamp) => { + // update the last_send_timestamp with the one given by the kernel, if available + self.last_send_timestamp = opt_send_timestamp + .map(convert_net_timestamp) + .or(self.last_send_timestamp); + } + } + } + ntp_proto::NtpSourceAction::UpdateSystem(update) => { + self.channels + .msg_for_system_sender + .send(MsgForSystem::SourceUpdate(self.index, update)) + .await + .ok(); + } + ntp_proto::NtpSourceAction::SetTimer(timeout) => { + poll_wait.as_mut().reset(Instant::now() + timeout); + } + ntp_proto::NtpSourceAction::Reset => { self.channels .msg_for_system_sender - .send(MsgForSystem::NetworkIssue(self.index)) + .send(MsgForSystem::Unreachable(self.index)) .await .ok(); self.channels @@ -242,93 +319,23 @@ where .write() .expect("Unexpected poisoned mutex") .remove(&self.index); - return true; - } - - match self.clock.now() { - Err(e) => { - // we cannot determine the origin_timestamp - error!(error = ?e, "There was an error retrieving the current time"); - - // report as no permissions, since this seems the most likely - std::process::exit(exitcode::NOPERM); - } - Ok(ts) => { - self.last_send_timestamp = Some(ts); - } + return; } - - match self.socket.as_mut().unwrap().send(&packet).await { - Err(error) => { - warn!(?error, "poll message could not be sent"); - - if let Some( - libc::EHOSTDOWN - | libc::EHOSTUNREACH - | libc::ENETDOWN - | libc::ENETUNREACH, - ) = error.raw_os_error() - { - self.channels - .msg_for_system_sender - .send(MsgForSystem::NetworkIssue(self.index)) - .await - .ok(); - self.channels - .source_snapshots - .write() - .expect("Unexpected poisoned mutex") - .remove(&self.index); - return true; - } - } - Ok(opt_send_timestamp) => { - // update the last_send_timestamp with the one given by the kernel, if available - self.last_send_timestamp = opt_send_timestamp - .map(convert_net_timestamp) - .or(self.last_send_timestamp); - } + ntp_proto::NtpSourceAction::Demobilize => { + self.channels + .msg_for_system_sender + .send(MsgForSystem::MustDemobilize(self.index)) + .await + .ok(); + self.channels + .source_snapshots + .write() + .expect("Unexpected poisoned mutex") + .remove(&self.index); + return; } } - ntp_proto::NtpSourceAction::UpdateSystem(update) => { - self.channels - .msg_for_system_sender - .send(MsgForSystem::SourceUpdate(self.index, update)) - .await - .ok(); - } - ntp_proto::NtpSourceAction::SetTimer(timeout) => { - poll_wait.as_mut().reset(Instant::now() + timeout); - } - ntp_proto::NtpSourceAction::Reset => { - self.channels - .msg_for_system_sender - .send(MsgForSystem::Unreachable(self.index)) - .await - .ok(); - self.channels - .source_snapshots - .write() - .expect("Unexpected poisoned mutex") - .remove(&self.index); - return true; - } - ntp_proto::NtpSourceAction::Demobilize => { - self.channels - .msg_for_system_sender - .send(MsgForSystem::MustDemobilize(self.index)) - .await - .ok(); - self.channels - .source_snapshots - .write() - .expect("Unexpected poisoned mutex") - .remove(&self.index); - return true; - } } - - false } } @@ -551,9 +558,10 @@ mod tests { let cur = std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH)?; - #[allow(clippy::cast_possible_truncation)] Ok(NtpTimestamp::from_seconds_nanos_since_ntp_era( - EPOCH_OFFSET.wrapping_add(cur.as_secs() as u32), + EPOCH_OFFSET.wrapping_add( + u32::try_from(cur.as_secs()).expect("Couldn't fit unix epoch inside u32"), + ), cur.subsec_nanos(), )) } @@ -590,7 +598,7 @@ mod tests { } } - async fn test_startup() -> ( + fn test_startup() -> ( SourceTask, T>, Socket, mpsc::Receiver>>, @@ -654,7 +662,7 @@ mod tests { #[tokio::test] async fn test_poll_sends_state_update_and_packet() { // Note: Ports must be unique among tests to deal with parallelism - let (mut process, socket, _, _system_update_sender) = test_startup().await; + let (mut process, socket, _, _system_update_sender) = test_startup(); let (poll_wait, poll_send) = TestWait::new(); @@ -685,7 +693,7 @@ mod tests { #[tokio::test] async fn test_timeroundtrip() { // Note: Ports must be unique among tests to deal with parallelism - let (mut process, mut socket, mut msg_recv, _system_update_sender) = test_startup().await; + let (mut process, mut socket, mut msg_recv, _system_update_sender) = test_startup(); let system = SystemSnapshot { time_snapshot: TimeSnapshot { @@ -734,7 +742,7 @@ mod tests { #[tokio::test] async fn test_deny_stops_poll() { // Note: Ports must be unique among tests to deal with parallelism - let (mut process, mut socket, mut msg_recv, _system_update_sender) = test_startup().await; + let (mut process, mut socket, mut msg_recv, _system_update_sender) = test_startup(); let (poll_wait, poll_send) = TestWait::new(); diff --git a/ntpd/src/daemon/server.rs b/ntpd/src/daemon/server.rs index 72e62c438..bf42b383b 100644 --- a/ntpd/src/daemon/server.rs +++ b/ntpd/src/daemon/server.rs @@ -147,32 +147,31 @@ impl ServerTask { let mut cur_socket = None; loop { // open socket if it is not already open - let socket = match &mut cur_socket { - Some(socket) => socket, - None => { - let new_socket = loop { - let socket_res = open_ip( - self.config.listen, - timestamped_socket::socket::GeneralTimestampMode::SoftwareRecv, - ); - - match socket_res { - Ok(socket) => break socket, - Err(error) => { - warn!(?error, ?self.config.listen, "Could not open server socket"); - tokio::time::sleep(self.network_wait_period).await; - } + let socket = if let Some(socket) = &mut cur_socket { + socket + } else { + let new_socket = loop { + let socket_res = open_ip( + self.config.listen, + timestamped_socket::socket::GeneralTimestampMode::SoftwareRecv, + ); + + match socket_res { + Ok(socket) => break socket, + Err(error) => { + warn!(?error, ?self.config.listen, "Could not open server socket"); + tokio::time::sleep(self.network_wait_period).await; } - }; + } + }; - // system and keyset may now be wildly out of date, ensure they are always updated. - self.server - .update_system(*self.system_receiver.borrow_and_update()); - self.server - .update_keyset(self.keyset.borrow_and_update().clone()); + // system and keyset may now be wildly out of date, ensure they are always updated. + self.server + .update_system(*self.system_receiver.borrow_and_update()); + self.server + .update_keyset(self.keyset.borrow_and_update().clone()); - cur_socket.insert(new_socket) - } + cur_socket.insert(new_socket) }; let mut buf = [0_u8; MAX_PACKET_SIZE]; diff --git a/ntpd/src/daemon/sock_source.rs b/ntpd/src/daemon/sock_source.rs index 1eeb225e3..074096aa5 100644 --- a/ntpd/src/daemon/sock_source.rs +++ b/ntpd/src/daemon/sock_source.rs @@ -25,7 +25,7 @@ struct SockSample { magic: i32, } -const SOCK_MAGIC: i32 = 0x534f434b; +const SOCK_MAGIC: i32 = 0x534f_434b; const SOCK_SAMPLE_SIZE: usize = 40; #[derive(Debug)] @@ -101,6 +101,16 @@ fn create_socket>(path: T) -> std::io::Result { Ok(socket) } +enum SelectResult { + SockRecv(Result), + SystemUpdate( + Result< + SystemSourceUpdate, + tokio::sync::broadcast::error::RecvError, + >, + ), +} + impl> SockSourceTask where C: 'static + NtpClock + Send + Sync, @@ -109,16 +119,6 @@ where loop { let mut buf = [0; SOCK_SAMPLE_SIZE]; - enum SelectResult { - SockRecv(Result), - SystemUpdate( - Result< - SystemSourceUpdate, - tokio::sync::broadcast::error::RecvError, - >, - ), - } - let selected: SelectResult = tokio::select! { result = self.socket.recv(&mut buf) => { SelectResult::SockRecv(result) @@ -181,12 +181,8 @@ where } }, SelectResult::SystemUpdate(result) => match result { - Ok(update) => { - self.source.handle_message(update.message); - } - Err(e) => { - error!("Error receiving system update: {:?}", e) - } + Ok(update) => self.source.handle_message(update.message), + Err(e) => error!("Error receiving system update: {:?}", e), }, }; } @@ -256,7 +252,9 @@ mod tests { std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH)?; Ok(NtpTimestamp::from_seconds_nanos_since_ntp_era( - EPOCH_OFFSET.wrapping_add(cur.as_secs() as u32), + EPOCH_OFFSET.wrapping_add( + u32::try_from(cur.as_secs()).expect("Couldn't fit unix epoch inside u32"), + ), cur.subsec_nanos(), )) } @@ -350,6 +348,7 @@ mod tests { } #[test] + #[allow(clippy::float_cmp)] fn test_deserialize_sample() { // Example sock sample let buf = [ @@ -357,7 +356,8 @@ mod tests { 119, 19, 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 75, 67, 79, 83, ]; let sample = deserialize_sample(Ok(buf.len()), buf).unwrap(); - assert_eq!(sample.offset, 318975.704798661); + + assert_eq!(sample.offset, 318_975.704_798_661); assert_eq!(sample.pulse, 0); assert_eq!(sample.leap, 0); assert_eq!(sample.magic, SOCK_MAGIC); diff --git a/ntpd/src/daemon/spawn/sock.rs b/ntpd/src/daemon/spawn/sock.rs index fa68a2691..59d1f70bb 100644 --- a/ntpd/src/daemon/spawn/sock.rs +++ b/ntpd/src/daemon/spawn/sock.rs @@ -17,7 +17,7 @@ impl SockSpawner { pub fn new(config: SockSourceConfig) -> SockSpawner { SockSpawner { config, - id: Default::default(), + id: SpawnerId::default(), has_spawned: false, } } diff --git a/ntpd/src/daemon/system.rs b/ntpd/src/daemon/system.rs index 5406ffa66..1154150b2 100644 --- a/ntpd/src/daemon/system.rs +++ b/ntpd/src/daemon/system.rs @@ -127,17 +127,12 @@ pub async fn spawn { - system - .add_spawner(SockSpawner::new(cfg.clone())) - .map_err(|e| { - tracing::error!("Could not spawn source: {}", e); - std::io::Error::new(std::io::ErrorKind::Other, e) - })?; + system.add_spawner(SockSpawner::new(cfg.clone())); } } } - for server_config in server_configs { + for server_config in server_configs.iter() { system.add_server(server_config.to_owned()); } diff --git a/ntpd/src/force_sync/algorithm.rs b/ntpd/src/force_sync/algorithm.rs index cc461f396..a1ee5bc4a 100644 --- a/ntpd/src/force_sync/algorithm.rs +++ b/ntpd/src/force_sync/algorithm.rs @@ -110,7 +110,7 @@ impl SingleShotController { for source in self.sources.values() { if source.get_offset().abs_diff(peak_offset) < Self::ASSUMED_UNCERTAINTY { count += 1; - sum += source.get_offset().to_seconds() + sum += source.get_offset().to_seconds(); } } diff --git a/ntpd/src/force_sync/mod.rs b/ntpd/src/force_sync/mod.rs index 2b205c6c0..c780d6828 100644 --- a/ntpd/src/force_sync/mod.rs +++ b/ntpd/src/force_sync/mod.rs @@ -112,7 +112,7 @@ impl SingleShotController { } } -pub(crate) fn force_sync(config: Option) -> std::io::Result { +pub(crate) fn force_sync(config: Option<&PathBuf>) -> std::io::Result { let config = initialize_logging_parse_config(Some(LogLevel::Warn), config); // Warn/error if the config is unreasonable. We do this after finishing @@ -138,11 +138,11 @@ pub(crate) fn force_sync(config: Option) -> std::io::Result { | config::NtpSourceConfig::Nts(_) | config::NtpSourceConfig::Sock(_) => total_sources += 1, config::NtpSourceConfig::Pool(PoolSourceConfig { count, .. }) => { - total_sources += count + total_sources += count; } #[cfg(feature = "unstable_nts-pool")] config::NtpSourceConfig::NtsPool(NtsPoolSourceConfig { count, .. }) => { - total_sources += count + total_sources += count; } } } diff --git a/ntpd/src/metrics/exporter.rs b/ntpd/src/metrics/exporter.rs index fe3684bfb..09deeec96 100644 --- a/ntpd/src/metrics/exporter.rs +++ b/ntpd/src/metrics/exporter.rs @@ -110,6 +110,9 @@ impl NtpMetricsExporterOptions { } } +/// # Errors +/// +/// Returns 'Error' if arguments to program are invalid. pub fn main() -> Result<(), Box> { let options = NtpMetricsExporterOptions::try_parse_from(std::env::args())?; match options.action { @@ -121,22 +124,21 @@ pub fn main() -> Result<(), Box> { eprintln!("ntp-metrics-exporter {VERSION}"); Ok(()) } - MetricsAction::Run => run(options), + MetricsAction::Run => run(&options), } } -fn run(options: NtpMetricsExporterOptions) -> Result<(), Box> { - let config = initialize_logging_parse_config(None, options.config); +fn run(options: &NtpMetricsExporterOptions) -> Result<(), Box> { + let config = initialize_logging_parse_config(None, options.config.as_ref()); Builder::new_current_thread().enable_all().build()?.block_on(async { let timeout = std::time::Duration::from_millis(1000); - let observation_socket_path = match config.observability.observation_path { - Some(path) => Arc::new(path), - None => { - eprintln!("An observation socket path must be configured using the observation-path option in the [observability] section of the configuration"); - std::process::exit(1); - } + let observation_socket_path = if let Some(path) = config.observability.observation_path { + Arc::new(path) + } else { + eprintln!("An observation socket path must be configured using the observation-path option in the [observability] section of the configuration"); + std::process::exit(1); }; println!( @@ -185,7 +187,7 @@ fn run(options: NtpMetricsExporterOptions) -> Result<(), Box { error!("Not enough resources available to accept incoming connection: {e}"); diff --git a/ntpd/src/metrics/mod.rs b/ntpd/src/metrics/mod.rs index 559cbe10c..35e557a52 100644 --- a/ntpd/src/metrics/mod.rs +++ b/ntpd/src/metrics/mod.rs @@ -254,7 +254,7 @@ pub fn format_state(w: &mut impl std::fmt::Write, state: &ObservableState) -> st w, "ntp_source_nts_cookies_available", "Number of unused cookies available for nts-enabled ntp exchanges", - MetricType::Gauge, + &MetricType::Gauge, None, collect_some_sources!(state, |p| p.nts_cookies), )?;