diff --git a/Cargo.lock b/Cargo.lock index bb21721f7..2da97d482 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -402,11 +402,14 @@ dependencies = [ name = "nts-pool-ke" version = "1.1.0-alpha.20231123" dependencies = [ + "ntp-proto", "rustls", + "rustls-native-certs", "rustls-pemfile", "serde", "thiserror", "tokio", + "tokio-rustls", "toml", "tracing", "tracing-subscriber", diff --git a/ntp-proto/src/lib.rs b/ntp-proto/src/lib.rs index 45b050410..b5678b7a0 100644 --- a/ntp-proto/src/lib.rs +++ b/ntp-proto/src/lib.rs @@ -20,6 +20,8 @@ mod peer; mod system; mod time_types; +#[cfg(feature = "nts-pool")] +mod nts_pool_ke; #[cfg(feature = "nts-pool")] pub mod tls_utils; @@ -77,6 +79,12 @@ mod exports { pub mod v5 { pub use crate::packet::v5::server_reference_id::{BloomFilter, ServerId}; } + + #[cfg(feature = "nts-pool")] + pub use super::nts_pool_ke::{ + ClientToPoolData, ClientToPoolDecoder, PoolToServerData, PoolToServerDecoder, + SupportedAlgorithmsDecoder, + }; } #[cfg(feature = "__internal-api")] diff --git a/ntp-proto/src/nts_pool_ke.rs b/ntp-proto/src/nts_pool_ke.rs new file mode 100644 index 000000000..f4526ac93 --- /dev/null +++ b/ntp-proto/src/nts_pool_ke.rs @@ -0,0 +1,351 @@ +use std::ops::ControlFlow; + +use crate::{ + nts_record::{AeadAlgorithm, NtsKeys, ProtocolId}, + KeyExchangeError, NtsRecord, NtsRecordDecoder, +}; + +/// Pool KE decoding records reserved from an NTS KE +#[derive(Debug, Default)] +pub struct SupportedAlgorithmsDecoder { + decoder: NtsRecordDecoder, + supported_algorithms: Vec<(u16, u16)>, +} + +impl SupportedAlgorithmsDecoder { + pub fn step_with_slice( + mut self, + bytes: &[u8], + ) -> ControlFlow, KeyExchangeError>, Self> { + self.decoder.extend(bytes.iter().copied()); + + loop { + match self.decoder.step() { + Err(e) => return ControlFlow::Break(Err(e.into())), + Ok(Some(record)) => self = self.step_with_record(record)?, + Ok(None) => return ControlFlow::Continue(self), + } + } + } + + #[inline(always)] + fn step_with_record( + self, + record: NtsRecord, + ) -> ControlFlow, KeyExchangeError>, Self> { + use ControlFlow::{Break, Continue}; + use NtsRecord::*; + + let mut state = self; + + match record { + EndOfMessage => Break(Ok(state.supported_algorithms)), + Error { errorcode } => Break(Err(KeyExchangeError::from_error_code(errorcode))), + Warning { warningcode } => { + tracing::warn!(warningcode, "Received key exchange warning code"); + + Continue(state) + } + #[cfg(feature = "nts-pool")] + SupportedAlgorithmList { + supported_algorithms, + } => { + state.supported_algorithms = supported_algorithms; + + Continue(state) + } + + _ => Continue(state), + } + } +} + +/// Pool KE decoding records fron the client +#[derive(Debug, Default)] +pub struct ClientToPoolDecoder { + decoder: NtsRecordDecoder, + /// AEAD algorithm that the client is able to use and that we support + /// it may be that the server and client supported algorithms have no + /// intersection! + algorithm: AeadAlgorithm, + /// Protocol (NTP version) that is supported by both client and server + protocol: ProtocolId, + + records: Vec, + denied_servers: Vec, + + #[cfg(feature = "ntpv5")] + allow_v5: bool, +} + +#[derive(Debug)] +pub struct ClientToPoolData { + pub algorithm: AeadAlgorithm, + pub protocol: ProtocolId, + pub records: Vec, + pub denied_servers: Vec, +} + +impl ClientToPoolData { + pub fn extract_nts_keys( + &self, + stream: &rustls::ConnectionCommon, + ) -> Result { + self.algorithm + .extract_nts_keys(self.protocol, stream) + .map_err(KeyExchangeError::Tls) + } +} + +impl ClientToPoolDecoder { + pub fn step_with_slice( + mut self, + bytes: &[u8], + ) -> ControlFlow, Self> { + self.decoder.extend(bytes.iter().copied()); + + loop { + match self.decoder.step() { + Err(e) => return ControlFlow::Break(Err(e.into())), + Ok(Some(record)) => self = self.step_with_record(record)?, + Ok(None) => return ControlFlow::Continue(self), + } + } + } + + #[inline(always)] + fn step_with_record( + self, + record: NtsRecord, + ) -> ControlFlow, Self> { + use self::AeadAlgorithm as Algorithm; + use ControlFlow::{Break, Continue}; + use KeyExchangeError::*; + use NtsRecord::*; + + let mut state = self; + + match record { + EndOfMessage => { + // NOTE EndOfMessage not pushed onto the vector + + let result = ClientToPoolData { + algorithm: state.algorithm, + protocol: state.protocol, + records: state.records, + denied_servers: state.denied_servers, + }; + + Break(Ok(result)) + } + Error { errorcode } => { + // + Break(Err(KeyExchangeError::from_error_code(errorcode))) + } + Warning { warningcode } => { + tracing::debug!(warningcode, "Received key exchange warning code"); + + state.records.push(record); + Continue(state) + } + #[cfg(feature = "ntpv5")] + DraftId { data } => { + if data == crate::packet::v5::DRAFT_VERSION.as_bytes() { + state.allow_v5 = true; + } + Continue(state) + } + NextProtocol { protocol_ids } => { + #[cfg(feature = "ntpv5")] + let selected = if state.allow_v5 { + protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize_v5) + } else { + protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize) + }; + + #[cfg(not(feature = "ntpv5"))] + let selected = protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize); + + match selected { + None => Break(Err(NoValidProtocol)), + Some(protocol) => { + state.protocol = protocol; + Continue(state) + } + } + } + AeadAlgorithm { algorithm_ids, .. } => { + let selected = algorithm_ids + .iter() + .copied() + .find_map(Algorithm::try_deserialize); + + match selected { + None => Break(Err(NoValidAlgorithm)), + Some(algorithm) => { + state.algorithm = algorithm; + Continue(state) + } + } + } + + #[cfg(feature = "nts-pool")] + NtpServerDeny { _denied } => { + state.denied_servers.push(_denied); + Continue(state) + } + + other => { + // just forward other records blindly + state.records.push(other); + Continue(state) + } + } + } +} + +/// Pool KE decoding records from the NTS KE +#[derive(Debug, Default)] +pub struct PoolToServerDecoder { + decoder: NtsRecordDecoder, + /// AEAD algorithm that the client is able to use and that we support + /// it may be that the server and client supported algorithms have no + /// intersection! + algorithm: AeadAlgorithm, + /// Protocol (NTP version) that is supported by both client and server + protocol: ProtocolId, + + records: Vec, + + #[cfg(feature = "ntpv5")] + allow_v5: bool, +} + +#[derive(Debug)] +pub struct PoolToServerData { + pub algorithm: AeadAlgorithm, + pub protocol: ProtocolId, + pub records: Vec, +} + +impl PoolToServerDecoder { + pub fn step_with_slice( + mut self, + bytes: &[u8], + ) -> ControlFlow, Self> { + self.decoder.extend(bytes.iter().copied()); + + loop { + match self.decoder.step() { + Err(e) => return ControlFlow::Break(Err(e.into())), + Ok(Some(record)) => self = self.step_with_record(record)?, + Ok(None) => return ControlFlow::Continue(self), + } + } + } + + #[inline(always)] + fn step_with_record( + self, + record: NtsRecord, + ) -> ControlFlow, Self> { + use self::AeadAlgorithm as Algorithm; + use ControlFlow::{Break, Continue}; + use KeyExchangeError::*; + use NtsRecord::*; + + let mut state = self; + + match &record { + EndOfMessage => { + state.records.push(EndOfMessage); + + let result = PoolToServerData { + algorithm: state.algorithm, + protocol: state.protocol, + records: state.records, + }; + + Break(Ok(result)) + } + Error { errorcode } => { + // + Break(Err(KeyExchangeError::from_error_code(*errorcode))) + } + Warning { warningcode } => { + tracing::debug!(warningcode, "Received key exchange warning code"); + + state.records.push(record); + Continue(state) + } + #[cfg(feature = "ntpv5")] + DraftId { data } => { + if data == crate::packet::v5::DRAFT_VERSION.as_bytes() { + state.allow_v5 = true; + } + Continue(state) + } + NextProtocol { protocol_ids } => { + #[cfg(feature = "ntpv5")] + let selected = if state.allow_v5 { + protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize_v5) + } else { + protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize) + }; + + #[cfg(not(feature = "ntpv5"))] + let selected = protocol_ids + .iter() + .copied() + .find_map(ProtocolId::try_deserialize); + + state.records.push(record); + + match selected { + None => Break(Err(NoValidProtocol)), + Some(protocol) => { + state.protocol = protocol; + Continue(state) + } + } + } + AeadAlgorithm { algorithm_ids, .. } => { + let selected = algorithm_ids + .iter() + .copied() + .find_map(Algorithm::try_deserialize); + + state.records.push(record); + + match selected { + None => Break(Err(NoValidAlgorithm)), + Some(algorithm) => { + state.algorithm = algorithm; + Continue(state) + } + } + } + + _other => { + // just forward other records blindly + state.records.push(record); + Continue(state) + } + } + } +} diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index 9dd2514a7..930cf3e2d 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -642,7 +642,7 @@ pub enum KeyExchangeError { } impl KeyExchangeError { - fn from_error_code(error_code: u16) -> Self { + pub(crate) fn from_error_code(error_code: u16) -> Self { match error_code { 0 => Self::UnrecognizedCriticalRecord, 1 => Self::BadRequest, @@ -732,7 +732,7 @@ impl AeadAlgorithm { const IN_ORDER_OF_PREFERENCE: &'static [Self] = &[Self::AeadAesSivCmac512, Self::AeadAesSivCmac256]; - fn extract_nts_keys( + pub(crate) fn extract_nts_keys( &self, protocol: ProtocolId, tls_connection: &rustls::ConnectionCommon, @@ -760,17 +760,17 @@ impl AeadAlgorithm { } #[cfg(feature = "nts-pool")] - fn try_into_nts_keys(&self, RequestedKeys { c2s, s2c }: RequestedKeys) -> Option { + fn try_into_nts_keys(&self, RequestedKeys { c2s, s2c }: &RequestedKeys) -> Option { match self { AeadAlgorithm::AeadAesSivCmac256 => { - let c2s = Box::new(AesSivCmac256::from_key_bytes(&c2s).ok()?); - let s2c = Box::new(AesSivCmac256::from_key_bytes(&s2c).ok()?); + let c2s = Box::new(AesSivCmac256::from_key_bytes(c2s).ok()?); + let s2c = Box::new(AesSivCmac256::from_key_bytes(s2c).ok()?); Some(NtsKeys { c2s, s2c }) } AeadAlgorithm::AeadAesSivCmac512 => { - let c2s = Box::new(AesSivCmac512::from_key_bytes(&c2s).ok()?); - let s2c = Box::new(AesSivCmac512::from_key_bytes(&s2c).ok()?); + let c2s = Box::new(AesSivCmac512::from_key_bytes(c2s).ok()?); + let s2c = Box::new(AesSivCmac512::from_key_bytes(s2c).ok()?); Some(NtsKeys { c2s, s2c }) } @@ -791,6 +791,25 @@ pub struct NtsKeys { s2c: Box, } +impl NtsKeys { + #[cfg(feature = "nts-pool")] + pub fn as_fixed_key_request(&self) -> NtsRecord { + NtsRecord::FixedKeyRequest { + c2s: self.c2s.key_bytes().to_vec(), + s2c: self.s2c.key_bytes().to_vec(), + } + } +} + +impl std::fmt::Debug for NtsKeys { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NtsKeys") + .field("c2s", &"") + .field("s2c", &"") + .finish() + } +} + fn extract_nts_key, ConnectionData>( tls_connection: &rustls::ConnectionCommon, context: [u8; 5], @@ -806,7 +825,7 @@ fn extract_nts_key, ConnectionData>( } #[derive(Debug, PartialEq, Eq)] -struct PartialKeyExchangeData { +pub struct PartialKeyExchangeData { remote: Option, port: Option, protocol: ProtocolId, @@ -817,7 +836,7 @@ struct PartialKeyExchangeData { } #[derive(Debug, Default)] -struct KeyExchangeResultDecoder { +pub struct KeyExchangeResultDecoder { decoder: NtsRecordDecoder, remote: Option, port: Option, @@ -1056,8 +1075,8 @@ impl KeyExchangeClient { if let Err(e) = self.tls_connection.process_new_packets() { return ControlFlow::Break(Err(e.into())); } - let read_result = self.tls_connection.reader().read(&mut buf); - match read_result { + + match self.tls_connection.reader().read(&mut buf) { Ok(0) => return ControlFlow::Break(Err(KeyExchangeError::IncompleteResponse)), Ok(n) => { self.decoder = match self.decoder.step_with_slice(&buf[..n]) { @@ -1491,7 +1510,7 @@ impl KeyExchangeServer { return ControlFlow::Break(Err(e.into())); } - let mut buf = [0; 128]; + let mut buf = [0; 512]; match self.tls_connection.reader().read(&mut buf) { Ok(0) => { // the connection was closed cleanly by the client @@ -1504,6 +1523,16 @@ impl KeyExchangeServer { ControlFlow::Continue(decoder) => { // more bytes are needed self.state = State::Active { decoder }; + + // recursively invoke the progress function. This is very unlikely! + // + // Normally, all records are written with a single write call, and + // received as one unit. Using many write calls does not really make + // sense for a client. + // + // So then, the other reason we could end up here is if the buffer is + // full. But 512 bytes is a lot of space for this interaction, and + // should be sufficient in most cases. ControlFlow::Continue(self) } ControlFlow::Break(Ok(data)) => { @@ -1570,8 +1599,8 @@ impl KeyExchangeServer { } #[cfg(feature = "nts-pool")] - fn extract_nts_keys(&self, data: ServerKeyExchangeData) -> Result { - if let Some(keys) = data.fixed_keys { + fn extract_nts_keys(&self, data: &ServerKeyExchangeData) -> Result { + if let Some(keys) = &data.fixed_keys { if self.privileged_connection() { tracing::debug!("using fixed keys for AEAD algorithm"); data.algorithm @@ -1587,13 +1616,13 @@ impl KeyExchangeServer { } #[cfg(not(feature = "nts-pool"))] - fn extract_nts_keys(&self, data: ServerKeyExchangeData) -> Result { + fn extract_nts_keys(&self, data: &ServerKeyExchangeData) -> Result { self.extract_nts_keys_tls(data) } fn extract_nts_keys_tls( &self, - data: ServerKeyExchangeData, + data: &ServerKeyExchangeData, ) -> Result { tracing::debug!("using AEAD keys extracted from TLS connection"); @@ -1614,7 +1643,7 @@ impl KeyExchangeServer { tracing::debug!(?protocol, ?algorithm, "selected AEAD algorithm"); - match self.extract_nts_keys(data) { + match self.extract_nts_keys(&data) { Ok(keys) => { let records = NtsRecord::server_key_exchange_records( protocol, diff --git a/ntp-proto/src/packet/extension_fields.rs b/ntp-proto/src/packet/extension_fields.rs index 8bd5ef717..8c2e2bf53 100644 --- a/ntp-proto/src/packet/extension_fields.rs +++ b/ntp-proto/src/packet/extension_fields.rs @@ -772,7 +772,7 @@ impl<'a> RawEncryptedField<'a> { Err(_) => { return Err(ParsingError::DecryptError( ExtensionField::InvalidNtsEncryptedField, - )) + )); } }; diff --git a/ntpd/Cargo.toml b/ntpd/Cargo.toml index 6ea94975e..31ecf6a54 100644 --- a/ntpd/Cargo.toml +++ b/ntpd/Cargo.toml @@ -37,6 +37,7 @@ ntp-proto = { workspace = true, features = ["__internal-test",] } tokio-rustls.workspace = true [features] +default = [] hardware-timestamping = [] __internal-fuzz = [] unstable_ntpv5 = ["ntp-proto/ntpv5"] diff --git a/nts-pool-ke/Cargo.toml b/nts-pool-ke/Cargo.toml index 18db28d46..a8775b554 100644 --- a/nts-pool-ke/Cargo.toml +++ b/nts-pool-ke/Cargo.toml @@ -17,8 +17,11 @@ tracing.workspace = true tracing-subscriber = { version = "0.3.0", default-features = false, features = ["std", "fmt", "ansi"] } rustls.workspace = true rustls-pemfile.workspace = true +rustls-native-certs.workspace = true serde.workspace = true thiserror.workspace = true +ntp-proto = { workspace = true, features = ["nts-pool"] } +tokio-rustls = "0.24.1" [[bin]] name = "nts-pool-ke" diff --git a/nts-pool-ke/src/config.rs b/nts-pool-ke/src/config.rs index b752df07f..53d343210 100644 --- a/nts-pool-ke/src/config.rs +++ b/nts-pool-ke/src/config.rs @@ -62,6 +62,7 @@ pub struct ObservabilityConfig { #[derive(Debug, PartialEq, Eq, Clone, Deserialize)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub struct NtsPoolKeConfig { + pub certificate_authority_path: PathBuf, pub certificate_chain_path: PathBuf, pub private_key_path: PathBuf, #[serde(default = "default_nts_ke_timeout")] @@ -83,18 +84,22 @@ mod tests { r#" [nts-pool-ke-server] listen = "0.0.0.0:4460" + certificate-authority-path = "/foo/bar/ca.pem" certificate-chain-path = "/foo/bar/baz.pem" private-key-path = "spam.der" "#, ) .unwrap(); - let pem = PathBuf::from("/foo/bar/baz.pem"); - assert_eq!(test.nts_pool_ke_server.certificate_chain_path, pem); - assert_eq!( - test.nts_pool_ke_server.private_key_path, - PathBuf::from("spam.der") - ); + let ca = PathBuf::from("/foo/bar/ca.pem"); + assert_eq!(test.nts_pool_ke_server.certificate_authority_path, ca); + + let chain = PathBuf::from("/foo/bar/baz.pem"); + assert_eq!(test.nts_pool_ke_server.certificate_chain_path, chain); + + let private_key = PathBuf::from("spam.der"); + assert_eq!(test.nts_pool_ke_server.private_key_path, private_key); + assert_eq!(test.nts_pool_ke_server.key_exchange_timeout_ms, 1000,); assert_eq!( test.nts_pool_ke_server.listen, diff --git a/nts-pool-ke/src/lib.rs b/nts-pool-ke/src/lib.rs index 7db6b5c81..4b9eee22e 100644 --- a/nts-pool-ke/src/lib.rs +++ b/nts-pool-ke/src/lib.rs @@ -3,11 +3,20 @@ mod config; mod tracing; -use std::{io::BufRead, path::PathBuf, sync::Arc}; +use std::{io::BufRead, ops::ControlFlow, path::PathBuf, sync::Arc}; +use ::tracing::info; use cli::NtsPoolKeOptions; use config::{Config, NtsPoolKeConfig}; -use tokio::net::{TcpListener, ToSocketAddrs}; +use ntp_proto::{ + ClientToPoolData, KeyExchangeError, NtsRecord, PoolToServerData, PoolToServerDecoder, + SupportedAlgorithmsDecoder, +}; +use rustls::Certificate; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, ToSocketAddrs}, +}; use crate::tracing as daemon_tracing; use daemon_tracing::LogLevel; @@ -109,6 +118,14 @@ async fn run(options: NtsPoolKeOptions) -> Result<(), Box } async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result<()> { + let certificate_authority_file = + std::fs::File::open(&nts_pool_ke_config.certificate_authority_path).map_err(|e| { + io_error(&format!( + "error reading certificate_authority_path at `{:?}`: {:?}", + nts_pool_ke_config.certificate_authority_path, e + )) + })?; + let certificate_chain_file = std::fs::File::open(&nts_pool_ke_config.certificate_chain_path) .map_err(|e| { io_error(&format!( @@ -125,7 +142,13 @@ async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result )) })?; - let cert_chain: Vec = + let certificate_authority: Arc<[rustls::Certificate]> = + rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_authority_file))? + .into_iter() + .map(rustls::Certificate) + .collect(); + + let certificate_chain: Vec = rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file))? .into_iter() .map(rustls::Certificate) @@ -136,7 +159,8 @@ async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result pool_key_exchange_server( nts_pool_ke_config.listen, - cert_chain, + certificate_authority, + certificate_chain, private_key, nts_pool_ke_config.key_exchange_timeout_ms, ) @@ -149,6 +173,7 @@ fn io_error(msg: &str) -> std::io::Error { async fn pool_key_exchange_server( address: impl ToSocketAddrs, + certificate_authority: Arc<[rustls::Certificate]>, certificate_chain: Vec, private_key: rustls::PrivateKey, timeout_ms: u64, @@ -167,19 +192,11 @@ async fn pool_key_exchange_server( let config = Arc::new(config); loop { - let (stream, peer_address) = listener.accept().await?; - - let config = config.clone(); + let (client_stream, peer_address) = listener.accept().await?; + let client_to_pool_config = config.clone(); - let fut = async move { - // BoundKeyExchangeServer::run(stream, config) - // .await - // .map_err(|ke_error| std::io::Error::new(std::io::ErrorKind::Other, ke_error)) - let _ = stream; - let _ = config; - - std::io::Result::Ok(()) - }; + let certificate_authority = certificate_authority.clone(); + let fut = handle_client(client_stream, client_to_pool_config, certificate_authority); tokio::spawn(async move { let timeout = std::time::Duration::from_millis(timeout_ms); @@ -192,6 +209,231 @@ async fn pool_key_exchange_server( } } +async fn handle_client( + client_stream: tokio::net::TcpStream, + config: Arc, + certificate_authority: Arc<[rustls::Certificate]>, +) -> Result<(), KeyExchangeError> { + // handle the initial client to pool + let acceptor = tokio_rustls::TlsAcceptor::from(config); + let mut client_stream = acceptor.accept(client_stream).await?; + + // read all records from the client + let client_data = client_to_pool_request(&mut client_stream).await?; + + info!("received records from the client",); + + // next we should pick a server that satisfies the algorithm used and is not denied by the + // client. But this server hardcoded for now. + let server_name = String::from("localhost"); + let port = 8080; + let domain = rustls::ServerName::try_from(server_name.as_str()) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname"))?; + + let connector = pool_to_server_connector(&certificate_authority)?; + let server_stream = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; + let mut server_stream = connector.connect(domain.clone(), server_stream).await?; + + info!("established connection to the server"); + + let supported_algorithms = supported_algorithms_request(&mut server_stream).await?; + + info!("received supported algorithms from the NTS KE server"); + + if !supported_algorithms + .iter() + .any(|(algorithm_id, _)| *algorithm_id == client_data.algorithm as u16) + { + // for now, just send back to the client that its algorithms were invalid + // AeadAlgorithm::AeadAesSivCmac256 should always be supported by servers and clients + + let records = [ + NtsRecord::NextProtocol { + protocol_ids: vec![0], + }, + NtsRecord::Error { + errorcode: NtsRecord::BAD_REQUEST, + }, + NtsRecord::EndOfMessage, + ]; + + // now we just forward the response + let mut buffer = Vec::with_capacity(1024); + for record in records { + record.write(&mut buffer)?; + } + + client_stream.write_all(&buffer).await?; + client_stream.shutdown().await?; + + info!("wrote NoValidAlgorithm to client"); + + return Ok(()); + } + + // this is inefficient of course, but spec-compliant: the TLS connection is closed when the server + // receives a EndOfMessage record, so we have to establish a new one. re-using the TCP + // connection runs into issues (seems to leave the server in an invalid state). + let server_stream = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; + let server_stream = connector.connect(domain, server_stream).await?; + + // get the cookies from the NTS KE server + let records_for_server = prepare_records_for_server(&client_stream, client_data)?; + let records_for_client = cookie_request(server_stream, &records_for_server).await?; + + info!("received cookies from the NTS KE server"); + + // now we just forward the response + let mut buffer = Vec::with_capacity(1024); + for record in records_for_client { + record.write(&mut buffer)?; + } + + client_stream.write_all(&buffer).await?; + client_stream.shutdown().await?; + + info!("wrote records for client"); + + Ok(()) +} + +fn prepare_records_for_server( + client_stream: &tokio_rustls::server::TlsStream, + client_data: ClientToPoolData, +) -> Result, KeyExchangeError> { + let nts_keys = client_data.extract_nts_keys(client_stream.get_ref().1)?; + + let mut records_for_server = client_data.records; + records_for_server.extend([ + NtsRecord::NextProtocol { + protocol_ids: vec![0], + }, + NtsRecord::AeadAlgorithm { + critical: false, + algorithm_ids: vec![client_data.algorithm as u16], + }, + nts_keys.as_fixed_key_request(), + NtsRecord::EndOfMessage, + ]); + + Ok(records_for_server) +} + +fn pool_to_server_connector( + extra_certificates: &[Certificate], +) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs()? { + let cert = rustls::Certificate(cert.0); + roots.add(&cert).map_err(KeyExchangeError::Certificate)?; + } + + for cert in extra_certificates { + roots.add(cert).map_err(KeyExchangeError::Certificate)?; + } + + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + + // already has the FixedKeyRequest record + Ok(tokio_rustls::TlsConnector::from(Arc::new(config))) +} + +async fn client_to_pool_request( + stream: &mut tokio_rustls::server::TlsStream, +) -> Result { + let mut decoder = ntp_proto::ClientToPoolDecoder::default(); + + let mut buf = [0; 1024]; + + loop { + let n = stream.read(&mut buf).await?; + + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); + } + + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(done) => break done, + }; + } +} + +async fn cookie_request( + mut stream: tokio_rustls::client::TlsStream, + nts_records: &[NtsRecord], +) -> Result, KeyExchangeError> { + // now we just forward the response + let mut buf = Vec::with_capacity(1024); + for record in nts_records { + record.write(&mut buf)?; + } + + stream.write_all(&buf).await?; + + let mut buf = [0; 1024]; + let mut decoder = PoolToServerDecoder::default(); + + loop { + let n = stream.read(&mut buf).await?; + + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); + } + + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(Ok(PoolToServerData { + records, + algorithm: _, + protocol: _, + })) => { + stream.shutdown().await?; + break Ok(records); + } + ControlFlow::Break(Err(error)) => break Err(error), + }; + } +} + +async fn supported_algorithms_request( + stream: &mut tokio_rustls::client::TlsStream, +) -> Result, KeyExchangeError> { + let nts_records = [ + NtsRecord::SupportedAlgorithmList { + supported_algorithms: vec![], + }, + NtsRecord::EndOfMessage, + ]; + + // now we just forward the response + let mut buf = Vec::with_capacity(1024); + for record in nts_records { + record.write(&mut buf)?; + } + + stream.write_all(&buf).await?; + + let mut buf = [0; 1024]; + let mut decoder = SupportedAlgorithmsDecoder::default(); + + loop { + let n = stream.read(&mut buf).await?; + + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); + } + + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(result) => break result, + }; + } +} + fn private_key_from_bufread( mut reader: impl BufRead, ) -> std::io::Result> { diff --git a/nts-pool-ke/unsafe.nts.client.toml b/nts-pool-ke/unsafe.nts.client.toml new file mode 100644 index 000000000..fd01fdfd0 --- /dev/null +++ b/nts-pool-ke/unsafe.nts.client.toml @@ -0,0 +1,19 @@ +# part of the test setup for the NTS pool KE. Do not use in production! +# (the private key of the certificate is public!) + +[observability] +# Other values include trace, debug, warn and error +log-level = "info" +observation-path = "/var/run/ntpd-rs/observe" + +# See https://docs.ntpd-rs.pendulum-project.org/man/ntp.toml.5/ on how to set up certificates +[[source]] +mode = "nts" +address = "localhost:4460" +certificate-authority = "test-keys/testca.pem" + +# System parameters used in filtering and steering the clock: +[synchronization] +minimum-agreeing-sources = 1 +single-step-panic-threshold = 10 +startup-step-panic-threshold = { forward = "inf", backward = 86400 } diff --git a/nts-pool-ke/unsafe.nts.server.toml b/nts-pool-ke/unsafe.nts.server.toml new file mode 100644 index 000000000..1c7e6456a --- /dev/null +++ b/nts-pool-ke/unsafe.nts.server.toml @@ -0,0 +1,31 @@ +# part of the test setup for the NTS pool KE. Do not use in production! +# (the private key of the certificate is public!) + +[observability] +# Other values include trace, debug, warn and error +log-level = "info" +observation-path = "/var/run/ntpd-rs/observe" + +# the server will get its time from the NTP pool +[[source]] +mode = "pool" +address = "pool.ntp.org" +count = 4 + +[[server]] +listen = "0.0.0.0:123" + +# System parameters used in filtering and steering the clock: +[synchronization] +minimum-agreeing-sources = 1 +single-step-panic-threshold = 10 +startup-step-panic-threshold = { forward = 0, backward = 86400 } + +# to function as an NTS server, we must also provide key exchange +# uses an unsecure certificate chain! +[[nts-ke-server]] +listen = "0.0.0.0:8080" +certificate-chain-path = "test-keys/end.fullchain.pem" +private-key-path = "test-keys/end.key" +authorized-pool-server-certificates = ["test-keys/testca.pem"] +key-exchange-timeout-ms = 1000 diff --git a/nts-pool-ke/unsafe.pool.toml b/nts-pool-ke/unsafe.pool.toml new file mode 100644 index 000000000..4ec7a1f71 --- /dev/null +++ b/nts-pool-ke/unsafe.pool.toml @@ -0,0 +1,7 @@ +# part of the test setup for the NTS pool KE. Do not use in production! +# (the private key of the certificate is public!) +[nts-pool-ke-server] +listen = "0.0.0.0:4460" +certificate-authority-path = "../test-keys/testca.pem" +certificate-chain-path = "../test-keys/end.fullchain.pem" +private-key-path = "../test-keys/end.key"