Skip to content

Commit

Permalink
add sending server denial records
Browse files Browse the repository at this point in the history
  • Loading branch information
squell committed Dec 7, 2023
1 parent e1af5e6 commit d497795
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
4 changes: 2 additions & 2 deletions ntp-proto/src/nts_pool_ke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ impl ClientToPoolDecoder {
}

#[cfg(feature = "nts-pool")]
NtpServerDeny { _denied } => {
state.denied_servers.push(_denied);
NtpServerDeny { denied } => {
state.denied_servers.push(denied);
Continue(state)
}

Expand Down
59 changes: 36 additions & 23 deletions ntp-proto/src/nts_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub enum NtsRecord {
},
#[cfg(feature = "nts-pool")]
NtpServerDeny {
_denied: String,
denied: String,
},
}

Expand Down Expand Up @@ -202,9 +202,9 @@ impl NtsRecord {
pub const BAD_REQUEST: u16 = 1;
pub const INTERNAL_SERVER_ERROR: u16 = 2;

pub fn client_key_exchange_records() -> [NtsRecord; if cfg!(feature = "ntpv5") { 4 } else { 3 }]
{
[
#[cfg_attr(not(feature = "nts-pool"), allow(unused_variables))]
pub fn client_key_exchange_records(denied_servers: &[&str]) -> Box<[NtsRecord]> {
let mut base = vec![
#[cfg(feature = "ntpv5")]
NtsRecord::DraftId {
data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(),
Expand All @@ -223,8 +223,20 @@ impl NtsRecord {
.map(|algorithm| *algorithm as u16)
.collect(),
},
NtsRecord::EndOfMessage,
]
];

#[cfg(feature = "nts-pool")]
base.extend(
denied_servers
.iter()
.map(|server| NtsRecord::NtpServerDeny {
denied: server.to_string(),
}),
);

base.push(NtsRecord::EndOfMessage);

base.into_boxed_slice()
}

#[cfg(feature = "nts-pool")]
Expand Down Expand Up @@ -385,7 +397,7 @@ impl NtsRecord {
// NOTE: the string data should be ascii (not utf8) but we don't enforce that here
let str_data = read_bytes_exact(reader, record_len)?;
match String::from_utf8(str_data) {
Ok(_denied) => NtsRecord::NtpServerDeny { _denied },
Ok(denied) => NtsRecord::NtpServerDeny { denied },
Err(e) => NtsRecord::Unknown {
record_type,
critical,
Expand Down Expand Up @@ -496,7 +508,7 @@ impl NtsRecord {
writer.write_all(s2c)?;
}
#[cfg(feature = "nts-pool")]
NtsRecord::NtpServerDeny { _denied: name } => {
NtsRecord::NtpServerDeny { denied: name } => {
// NOTE: the server name should be ascii
#[cfg(not(feature = "__internal-fuzz"))]
debug_assert!(name.is_ascii());
Expand Down Expand Up @@ -1149,14 +1161,15 @@ impl KeyExchangeClient {
pub fn new(
server_name: String,
tls_config: rustls::ClientConfig,
denied_servers: &[&str],
) -> Result<Self, KeyExchangeError> {
let mut client = Self::new_without_tls_write(server_name, tls_config)?;

// Make the request immediately (note, this will only go out to the wire via the write functions above)
// We use an intermediary buffer to ensure that all records are sent at once.
// This should not be needed, but works around issues in some NTS-ke server implementations
let mut buffer = Vec::with_capacity(1024);
for record in NtsRecord::client_key_exchange_records() {
for record in NtsRecord::client_key_exchange_records(denied_servers).iter() {
record.write(&mut buffer)?;
}
client.tls_connection.writer().write_all(&buffer)?;
Expand Down Expand Up @@ -1403,7 +1416,7 @@ impl KeyExchangeServerDecoder {
Continue(state)
}
#[cfg(feature = "nts-pool")]
NtpServerDeny { _denied } => {
NtpServerDeny { denied: _ } => {
// we are not a NTS pool server, so we ignore this record
Continue(state)
}
Expand Down Expand Up @@ -1753,7 +1766,7 @@ mod test {
#[test]
fn test_client_key_exchange_records() {
let mut buffer = Vec::with_capacity(1024);
for record in NtsRecord::client_key_exchange_records() {
for record in NtsRecord::client_key_exchange_records(&[]).iter() {
record.write(&mut buffer).unwrap();
}

Expand All @@ -1777,7 +1790,7 @@ mod test {
decoder.step().unwrap().unwrap(),
decoder.step().unwrap().unwrap(),
],
NtsRecord::client_key_exchange_records()
NtsRecord::client_key_exchange_records(&[]).as_ref()
);

assert!(decoder.step().unwrap().is_none());
Expand Down Expand Up @@ -1901,7 +1914,7 @@ mod test {
let mut buffer = Vec::new();

let record = NtsRecord::NtpServerDeny {
_denied: String::from("a string"),
denied: String::from("a string"),
};

record.write(&mut buffer).unwrap();
Expand Down Expand Up @@ -2506,14 +2519,14 @@ mod test {

#[test]
fn server_decoder_finds_algorithm() {
let result = server_decode_records(&NtsRecord::client_key_exchange_records()).unwrap();
let result = server_decode_records(&NtsRecord::client_key_exchange_records(&[])).unwrap();

assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
}

#[test]
fn server_decoder_ignores_new_cookie() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(
0,
NtsRecord::NewCookie {
Expand All @@ -2527,7 +2540,7 @@ mod test {

#[test]
fn server_decoder_ignores_server_and_port_preference() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(
0,
NtsRecord::Server {
Expand All @@ -2550,7 +2563,7 @@ mod test {

#[test]
fn server_decoder_ignores_warn() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(0, NtsRecord::Warning { warningcode: 42 });

let result = server_decode_records(&records).unwrap();
Expand All @@ -2559,7 +2572,7 @@ mod test {

#[test]
fn server_decoder_ignores_unknown_not_critical() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(
0,
NtsRecord::Unknown {
Expand All @@ -2575,7 +2588,7 @@ mod test {

#[test]
fn server_decoder_reports_unknown_critical() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(
0,
NtsRecord::Unknown {
Expand All @@ -2594,7 +2607,7 @@ mod test {

#[test]
fn server_decoder_reports_error() {
let mut records = NtsRecord::client_key_exchange_records().to_vec();
let mut records = NtsRecord::client_key_exchange_records(&[]).to_vec();
records.insert(0, NtsRecord::Error { errorcode: 2 });

let error = server_decode_records(&records).unwrap_err();
Expand Down Expand Up @@ -2760,7 +2773,7 @@ mod test {
.with_no_client_auth();

let mut server = rustls::ServerConnection::new(Arc::new(serverconfig)).unwrap();
let mut client = KeyExchangeClient::new("localhost".into(), clientconfig).unwrap();
let mut client = KeyExchangeClient::new("localhost".into(), clientconfig, &[]).unwrap();

server.writer().write_all(NTS_TIME_NL_RESPONSE).unwrap();

Expand Down Expand Up @@ -2942,7 +2955,7 @@ mod test {
let (mut client, server) = client_server_pair(ClientType::Uncertified);

let mut buffer = Vec::with_capacity(1024);
for record in NtsRecord::client_key_exchange_records() {
for record in NtsRecord::client_key_exchange_records(&[]).iter() {
record.write(&mut buffer).unwrap();
}
client.tls_connection.writer().write_all(&buffer).unwrap();
Expand Down Expand Up @@ -3040,7 +3053,7 @@ mod test {
#[test]
fn test_keyexchange_invalid_input() {
let mut buffer = Vec::with_capacity(1024);
for record in NtsRecord::client_key_exchange_records() {
for record in NtsRecord::client_key_exchange_records(&[]).iter() {
record.write(&mut buffer).unwrap();
}

Expand Down
22 changes: 18 additions & 4 deletions ntpd/src/daemon/keyexchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,20 @@ pub(crate) async fn key_exchange_client(
let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?;
let config = build_client_config(extra_certificates)?;

BoundKeyExchangeClient::new(socket, server_name, config)?.await
BoundKeyExchangeClient::new(socket, server_name, config, &[])?.await
}

#[cfg(feature = "unstable_nts-pool")]
pub(crate) async fn key_exchange_client_with_denied_servers(
server_name: String,
port: u16,
extra_certificates: &[Certificate],
denied_servers: &[&str],
) -> Result<KeyExchangeResult, KeyExchangeError> {
let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?;
let config = build_client_config(extra_certificates)?;

BoundKeyExchangeClient::new(socket, server_name, config, denied_servers)?.await
}

pub fn spawn(
Expand Down Expand Up @@ -205,11 +218,12 @@ where
io: IO,
server_name: String,
config: rustls::ClientConfig,
denied_servers: &[&str],
) -> Result<Self, KeyExchangeError> {
Ok(Self {
inner: Some(BoundKeyExchangeClientData {
io,
client: KeyExchangeClient::new(server_name, config)?,
client: KeyExchangeClient::new(server_name, config, denied_servers)?,
need_flush: false,
}),
})
Expand Down Expand Up @@ -668,7 +682,7 @@ mod tests {

fn client_key_exchange_message_length() -> usize {
let mut buffer = Vec::with_capacity(1024);
for record in ntp_proto::NtsRecord::client_key_exchange_records() {
for record in ntp_proto::NtsRecord::client_key_exchange_records(&[]).iter() {
record.write(&mut buffer).unwrap();
}

Expand Down Expand Up @@ -868,7 +882,7 @@ mod tests {

#[tokio::test]
async fn server_expected_client_records() {
let records = NtsRecord::client_key_exchange_records().to_vec();
let records = NtsRecord::client_key_exchange_records(&[]).to_vec();
let result = send_records_to_server(records).await;

assert!(result.is_ok());
Expand Down
11 changes: 9 additions & 2 deletions ntpd/src/daemon/spawn/nts_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use thiserror::Error;
use tokio::sync::mpsc;
use tracing::warn;

use super::super::{config::NtsPoolPeerConfig, keyexchange::key_exchange_client};
use super::super::{
config::NtsPoolPeerConfig, keyexchange::key_exchange_client_with_denied_servers,
};

use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId};

Expand Down Expand Up @@ -90,10 +92,15 @@ impl NtsPoolSpawner {
loop {
// Try and add peers to our pool
while self.current_peers.len() < self.config.max_peers {
match key_exchange_client(
match key_exchange_client_with_denied_servers(
self.config.addr.server_name.clone(),
self.config.addr.port,
&self.config.certificate_authorities,
&self
.current_peers
.iter()
.map(|peer| peer.remote.as_str())
.collect::<Vec<_>>(),
)
.await
{
Expand Down

0 comments on commit d497795

Please sign in to comment.