Skip to content

Commit

Permalink
Rebase with main
Browse files Browse the repository at this point in the history
  • Loading branch information
van-sprundel committed Jan 9, 2025
1 parent 978b528 commit 8c44e3f
Show file tree
Hide file tree
Showing 24 changed files with 336 additions and 311 deletions.
21 changes: 11 additions & 10 deletions ntp-proto/src/algorithm/kalman/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,15 @@ pub trait MeasurementNoiseEstimator {
fn reset(&mut self) -> Self;

// for SourceSnapshot
fn get_max_roundtrip(&self, samples: &i32) -> Option<f64>;
fn get_max_roundtrip(&self, samples: i32) -> Option<f64>;
fn get_delay_mean(&self) -> f64;
}

impl MeasurementNoiseEstimator for AveragingBuffer {
type MeasurementDelay = NtpDuration;

fn update(&mut self, delay: Self::MeasurementDelay) {
self.update(delay.to_seconds())
self.update(delay.to_seconds());
}

fn get_noise_estimate(&self) -> f64 {
Expand All @@ -312,8 +312,9 @@ impl MeasurementNoiseEstimator for AveragingBuffer {
AveragingBuffer::default()
}

fn get_max_roundtrip(&self, samples: &i32) -> Option<f64> {
self.data[..*samples as usize]
#[allow(clippy::cast_sign_loss)]
fn get_max_roundtrip(&self, samples: i32) -> Option<f64> {
self.data[..samples as usize]
.iter()
.copied()
.fold(None, |v1, v2| {
Expand Down Expand Up @@ -351,7 +352,7 @@ impl MeasurementNoiseEstimator for f64 {
*self
}

fn get_max_roundtrip(&self, _samples: &i32) -> Option<f64> {
fn get_max_roundtrip(&self, _samples: i32) -> Option<f64> {
Some(1.)
}

Expand Down Expand Up @@ -712,20 +713,20 @@ impl<D: Debug + Copy + Clone, N: MeasurementNoiseEstimator<MeasurementDelay = D>
}
}

#[allow(clippy::cast_sign_loss)]
fn snapshot<Index: Copy>(
&self,
index: Index,
config: &AlgorithmConfig,
) -> Option<SourceSnapshot<Index>> {
#[allow(clippy::cast_sign_loss)]
match &self.0 {
SourceStateInner::Initial(InitialSourceFilter {
noise_estimator,
init_offset,
last_measurement: Some(last_measurement),
samples,
}) if *samples > 0 => {
let max_roundtrip = noise_estimator.get_max_roundtrip(samples)?;
let max_roundtrip = noise_estimator.get_max_roundtrip(*samples)?;
Some(SourceSnapshot {
index,
source_uncertainty: last_measurement.root_dispersion,
Expand Down Expand Up @@ -1043,7 +1044,7 @@ mod tests {
D: Debug + Clone + Copy,
N: MeasurementNoiseEstimator<MeasurementDelay = D> + Clone,
>(
noise_estimator: N,
noise_estimator: &N,
delay: D,
) {
let base = NtpTimestamp::from_fixed_int(0);
Expand Down Expand Up @@ -1242,7 +1243,7 @@ mod tests {
#[test]
fn test_offset_steering_and_measurements_normal() {
test_offset_steering_and_measurements(
AveragingBuffer {
&AveragingBuffer {
data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6],
next_idx: 0,
},
Expand All @@ -1252,7 +1253,7 @@ mod tests {

#[test]
fn test_offset_steering_and_measurements_constant_noise_estimate() {
test_offset_steering_and_measurements(1e-9, ());
test_offset_steering_and_measurements(&1e-9, ());
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions ntp-proto/src/ipfilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl IpFilter {
for subnet in subnets {
match subnet.addr {
IpAddr::V4(addr) => ipv4list.push((
(u32::from_be_bytes(addr.octets()) as u128) << 96,
u128::from(u32::from_be_bytes(addr.octets())) << 96,
subnet.mask,
)),
IpAddr::V6(addr) => {
Expand Down Expand Up @@ -214,7 +214,7 @@ impl IpFilter {
/// Panics if `addr` has invalid octets.
fn is_in4(&self, addr: Ipv4Addr) -> bool {
self.ipv4_filter
.lookup((u32::from_be_bytes(addr.octets()) as u128) << 96)
.lookup(u128::from(u32::from_be_bytes(addr.octets())) << 96)
}

fn is_in6(&self, addr: Ipv6Addr) -> bool {
Expand Down
26 changes: 5 additions & 21 deletions ntp-proto/src/keyset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ impl KeySetProvider {
let mut buf = [0; 64];
reader.read_exact(&mut buf[0..20])?;

let time = Self::convert_to_system_time(&buf[0..8])?;
let id_offset = Self::convert_to_u32(&buf[8..12])?;
let primary = Self::convert_to_u32(&buf[12..16])?;
let len = Self::convert_to_u32(&buf[16..20])?;
let time = std::time::SystemTime::UNIX_EPOCH
+ std::time::Duration::from_secs(u64::from_be_bytes(buf[0..8].try_into().unwrap()));
let id_offset = u32::from_be_bytes(buf[8..12].try_into().unwrap());
let primary = u32::from_be_bytes(buf[12..16].try_into().unwrap());
let len = u32::from_be_bytes(buf[16..20].try_into().unwrap());

if primary >= len {
return Err(std::io::Error::new(
Expand Down Expand Up @@ -175,23 +176,6 @@ impl KeySetProvider {
pub fn get(&self) -> Arc<KeySet> {
self.current.clone()
}

fn convert_to_system_time(bytes: &[u8]) -> std::io::Result<std::time::SystemTime> {
let time = u64::from_be_bytes(bytes.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid buffer for SystemTime",
)
})?);
Ok(std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(time))
}

fn convert_to_u32(bytes: &[u8]) -> std::io::Result<u32> {
let value = u32::from_be_bytes(bytes.try_into().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid buffer for u32")
})?);
Ok(value)
}
}

pub struct KeySet {
Expand Down
7 changes: 2 additions & 5 deletions ntp-proto/src/nts_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,6 @@ impl KeyExchangeClient {
self.tls_connection.read_tls(rd)
}

/// # Errors
///
/// Returns error if getting the TLS bool `wants_write` fails.
#[must_use]
pub fn wants_write(&self) -> bool {
self.tls_connection.wants_write()
Expand Down Expand Up @@ -1843,7 +1840,7 @@ impl KeyExchangeServer {

fn decoder_done(
mut self,
data: ServerKeyExchangeData,
data: &ServerKeyExchangeData,
) -> ControlFlow<Result<tls_utils::ServerConnection, KeyExchangeError>, Self> {
let algorithm = data.algorithm;
let protocol = data.protocol;
Expand Down Expand Up @@ -3029,7 +3026,7 @@ mod test {
Certified,
}

fn client_server_pair(client_type: ClientType) -> (KeyExchangeClient, KeyExchangeServer) {
fn client_server_pair(client_type: &ClientType) -> (KeyExchangeClient, KeyExchangeServer) {
#[allow(unused)]
use tls_utils::CloneKeyShim;

Expand Down
40 changes: 22 additions & 18 deletions ntp-proto/src/packet/extension_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ impl<'a> ExtensionField<'a> {

#[must_use]
pub fn into_owned(self) -> ExtensionField<'static> {
#[cfg(feature = "ntpv5")]
use ExtensionField::{
DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse,
InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown,
};

#[cfg(feature = "ntpv5")]
use ExtensionField::{
InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown,
DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse,
};

match self {
Expand Down Expand Up @@ -171,12 +172,13 @@ impl<'a> ExtensionField<'a> {
minimum_size: u16,
version: ExtensionHeaderVersion,
) -> std::io::Result<()> {
#[cfg(feature = "ntpv5")]
use ExtensionField::{
DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse,
InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown,
};

#[cfg(feature = "ntpv5")]
use ExtensionField::{
InvalidNtsEncryptedField, NtsCookie, NtsCookiePlaceholder, UniqueIdentifier, Unknown,
DraftIdentification, Padding, ReferenceIdRequest, ReferenceIdResponse,
};

match self {
Expand Down Expand Up @@ -204,7 +206,9 @@ impl<'a> ExtensionField<'a> {
}
}

#[allow(clippy::missing_errors_doc)]
/// # Errors
///
/// Returns `io::Error` if serialization fails.
#[cfg(feature = "__internal-fuzz")]
pub fn serialize_pub(
&self,
Expand Down Expand Up @@ -469,7 +473,7 @@ impl<'a> ExtensionField<'a> {

/// # Errors
///
/// Returns error if writing to the sink fails.
/// Returns `io::Error` if encoding fails.
#[cfg(feature = "ntpv5")]
pub fn encode_padding_field(
mut w: impl NonBlockingWrite,
Expand Down Expand Up @@ -552,7 +556,7 @@ impl<'a> ExtensionField<'a> {
type EF<'a> = ExtensionField<'a>;
type TypeId = ExtensionFieldTypeId;

let message = raw.message_bytes;
let message = &raw.message_bytes;

match raw.type_id {
TypeId::UniqueIdentifier => EF::decode_unique_identifier(message),
Expand All @@ -563,7 +567,7 @@ impl<'a> ExtensionField<'a> {
EF::decode_draft_identification(message, extension_header_version)
}
#[cfg(feature = "ntpv5")]
TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message).into()),
TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message)?.into()),
#[cfg(feature = "ntpv5")]
TypeId::ReferenceIdResponse => Ok(ReferenceIdResponse::decode(message).into()),
type_id => Ok(EF::decode_unknown(type_id.to_type_id(), message)),
Expand Down Expand Up @@ -669,11 +673,11 @@ impl<'a> ExtensionFieldData<'a> {
RawExtensionField::V4_UNENCRYPTED_MINIMUM_SIZE,
version,
) {
let (offset, field) = field.map_err(super::error::ParsingError::generalize)?;
let (offset, field) = field.map_err(super::ParsingError::generalize)?;
size = offset + field.wire_length(version);
if field.type_id == ExtensionFieldTypeId::NtsEncryptedField {
let encrypted = RawEncryptedField::from_message_bytes(field.message_bytes)
.map_err(super::error::ParsingError::generalize)?;
.map_err(super::ParsingError::generalize)?;

let Some(cipher) = cipher.get(&efdata.untrusted) else {
efdata.untrusted.push(InvalidNtsEncryptedField);
Expand Down Expand Up @@ -711,7 +715,7 @@ impl<'a> ExtensionFieldData<'a> {
efdata.authenticated.append(&mut efdata.untrusted);
} else {
let field = ExtensionField::decode(&field, version)
.map_err(super::error::ParsingError::generalize)?;
.map_err(super::ParsingError::generalize)?;
efdata.untrusted.push(field);
}
}
Expand Down Expand Up @@ -746,23 +750,23 @@ impl<'a> RawEncryptedField<'a> {
fn from_message_bytes(
message_bytes: &'a [u8],
) -> Result<Self, ParsingError<std::convert::Infallible>> {
use ParsingError::IncorrectLength;

let [b0, b1, b2, b3, ref rest @ ..] = message_bytes[..] else {
return Err(IncorrectLength);
return Err(ParsingError::IncorrectLength);
};

let nonce_length = u16::from_be_bytes([b0, b1]) as usize;
let ciphertext_length = u16::from_be_bytes([b2, b3]) as usize;

let nonce = rest.get(..nonce_length).ok_or(IncorrectLength)?;
let nonce = rest
.get(..nonce_length)
.ok_or(ParsingError::IncorrectLength)?;

// skip the lengths and the nonce. pad to a multiple of 4
let ciphertext_start = 4 + next_multiple_of_u16(nonce_length as u16, 4) as usize;

let ciphertext = message_bytes
.get(ciphertext_start..ciphertext_start + ciphertext_length)
.ok_or(IncorrectLength)?;
.ok_or(ParsingError::IncorrectLength)?;

Ok(Self { nonce, ciphertext })
}
Expand Down
1 change: 1 addition & 0 deletions ntp-proto/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ fn check_uid_extensionfield<'a, I: IntoIterator<Item = &'a ExtensionField<'a>>>(

#[cfg(any(test, feature = "__internal-fuzz", feature = "__internal-test"))]
impl NtpPacket<'_> {
#[must_use]
pub fn test() -> Self {
Self::default()
}
Expand Down
8 changes: 5 additions & 3 deletions ntp-proto/src/packet/v5/extension_fields.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::io::NonBlockingWrite;
use crate::packet::error::ParsingError;
use crate::packet::v5::server_reference_id::BloomFilter;
use crate::packet::ExtensionField;
use std::borrow::Cow;
use std::convert::Infallible;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Type {
Expand Down Expand Up @@ -119,7 +121,7 @@ impl ReferenceIdRequest {
Ok(())
}

pub fn decode(msg: &[u8]) -> Self {
pub fn decode(msg: &[u8]) -> Result<Self, ParsingError<Infallible>> {
let payload_len =
u16::try_from(msg.len()).expect("NTP fields can not be longer than u16::MAX");
let offset_bytes: [u8; 2] = msg
Expand All @@ -128,10 +130,10 @@ impl ReferenceIdRequest {
.try_into()
.unwrap();

Self {
Ok(Self {
payload_len,
offset: u16::from_be_bytes(offset_bytes),
}
})
}

pub const fn offset(self) -> u16 {
Expand Down
6 changes: 4 additions & 2 deletions ntp-proto/src/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<Controller: SourceController<MeasurementDelay = ()>> OneWaySource<Controlle
}

pub fn handle_message(&mut self, message: Controller::ControllerMessage) {
self.controller.handle_message(message)
self.controller.handle_message(message);
}
}

Expand Down Expand Up @@ -875,7 +875,9 @@ mod test {
std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH)?;

Ok(NtpTimestamp::from_seconds_nanos_since_ntp_era(
EPOCH_OFFSET.wrapping_add(cur.as_secs() as u32),
EPOCH_OFFSET.wrapping_add(
u32::try_from(cur.as_secs()).expect("Couldn't fit unix epoch inside u32"),
),
cur.subsec_nanos(),
))
}
Expand Down
Loading

0 comments on commit 8c44e3f

Please sign in to comment.