From 9a9832a407300763af6e30652ac33bcaab2c94e4 Mon Sep 17 00:00:00 2001 From: Gabriele Baldoni Date: Wed, 24 Apr 2024 14:23:49 +0000 Subject: [PATCH] fix(771): Adding mTLS support in QUIC (#899) * refactor(tls-quic): moving shared code into zenoh-link-commons::tls Signed-off-by: gabrik * fix(mtls-quic): adding support for mTLS in QUIC [no ci] - broken Signed-off-by: gabrik * fix(mtls-quic): using current release of quinn at the cost of some duplicated code Signed-off-by: gabrik * test(quic-mlts): added tests for QUIC with mTLS, using rustls 0.22 to workaround the default CryptoProvider panic Signed-off-by: gabrik * chore: addressing comments Signed-off-by: gabrik * Apply suggestions from code review --------- Signed-off-by: gabrik Co-authored-by: Luca Cominardi --- Cargo.lock | 22 +- io/zenoh-link-commons/Cargo.toml | 35 +- io/zenoh-links/zenoh-link-quic/Cargo.toml | 33 +- io/zenoh-links/zenoh-link-quic/src/lib.rs | 122 +---- io/zenoh-links/zenoh-link-quic/src/unicast.rs | 129 +---- io/zenoh-links/zenoh-link-quic/src/utils.rs | 509 ++++++++++++++++++ io/zenoh-links/zenoh-link-tls/Cargo.toml | 20 +- io/zenoh-links/zenoh-link-tls/src/lib.rs | 157 +----- io/zenoh-links/zenoh-link-tls/src/unicast.rs | 326 +---------- io/zenoh-links/zenoh-link-tls/src/utils.rs | 480 +++++++++++++++++ io/zenoh-transport/Cargo.toml | 1 + io/zenoh-transport/tests/unicast_transport.rs | 234 +++++++- 12 files changed, 1334 insertions(+), 734 deletions(-) create mode 100644 io/zenoh-links/zenoh-link-quic/src/utils.rs create mode 100644 io/zenoh-links/zenoh-link-tls/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 16f7b4d1a0..36078d0238 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3118,9 +3118,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", "ring 0.17.6", @@ -4041,7 +4041,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" dependencies = [ - "rustls 0.22.2", + "rustls 0.22.4", "rustls-pki-types", "tokio", ] @@ -5109,16 +5109,19 @@ name = "zenoh-link-commons" version = "0.11.0-dev" dependencies = [ "async-trait", + "base64 0.21.4", "flume", "futures", - "rustls 0.22.2", + "rustls 0.22.4", "rustls-webpki 0.102.2", "serde", "tokio", "tokio-util", "tracing", + "webpki-roots", "zenoh-buffers", "zenoh-codec", + "zenoh-config", "zenoh-core", "zenoh-protocol", "zenoh-result", @@ -5136,13 +5139,15 @@ dependencies = [ "quinn", "rustls 0.21.7", "rustls-native-certs 0.7.0", - "rustls-pemfile 2.0.0", + "rustls-pemfile 1.0.3", + "rustls-pki-types", "rustls-webpki 0.102.2", "secrecy", "tokio", "tokio-rustls 0.24.1", "tokio-util", "tracing", + "webpki-roots", "zenoh-config", "zenoh-core", "zenoh-link-commons", @@ -5198,7 +5203,7 @@ dependencies = [ "async-trait", "base64 0.21.4", "futures", - "rustls 0.22.2", + "rustls 0.22.4", "rustls-pemfile 2.0.0", "rustls-pki-types", "rustls-webpki 0.102.2", @@ -5516,6 +5521,7 @@ dependencies = [ "zenoh-core", "zenoh-crypto", "zenoh-link", + "zenoh-link-commons", "zenoh-protocol", "zenoh-result", "zenoh-runtime", @@ -5605,6 +5611,6 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/io/zenoh-link-commons/Cargo.toml b/io/zenoh-link-commons/Cargo.toml index f2e10616c1..12b70cad6d 100644 --- a/io/zenoh-link-commons/Cargo.toml +++ b/io/zenoh-link-commons/Cargo.toml @@ -12,16 +12,16 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-commons" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-commons" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] @@ -29,18 +29,27 @@ compression = [] [dependencies] async-trait = { workspace = true } +base64 = { workspace = true, optional = true } +flume = { workspace = true } +futures = { workspace = true } rustls = { workspace = true } rustls-webpki = { workspace = true } -flume = { workspace = true } -tracing = {workspace = true} serde = { workspace = true, features = ["default"] } +tokio = { workspace = true, features = [ + "fs", + "io-util", + "net", + "sync", + "time", +] } +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +webpki-roots = { workspace = true, optional = true } zenoh-buffers = { workspace = true } zenoh-codec = { workspace = true } +zenoh-config = { workspace = true } zenoh-core = { workspace = true } zenoh-protocol = { workspace = true } zenoh-result = { workspace = true } -zenoh-util = { workspace = true } zenoh-runtime = { workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync", "time"] } -tokio-util = { workspace = true, features = ["rt"] } -futures = { workspace = true } +zenoh-util = { workspace = true } diff --git a/io/zenoh-links/zenoh-link-quic/Cargo.toml b/io/zenoh-links/zenoh-link-quic/Cargo.toml index a10e18fd43..0e1c720d78 100644 --- a/io/zenoh-links/zenoh-link-quic/Cargo.toml +++ b/io/zenoh-links/zenoh-link-quic/Cargo.toml @@ -12,39 +12,46 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-quic" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-quic" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] async-trait = { workspace = true } base64 = { workspace = true } futures = { workspace = true } -tracing = {workspace = true} quinn = { workspace = true } rustls-native-certs = { workspace = true } -rustls-pemfile = { workspace = true } +rustls-pki-types = { workspace = true } rustls-webpki = { workspace = true } -secrecy = {workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync", "time"] } +secrecy = { workspace = true } +tokio = { workspace = true, features = [ + "fs", + "io-util", + "net", + "sync", + "time", +] } tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +webpki-roots = { workspace = true } zenoh-config = { workspace = true } zenoh-core = { workspace = true } zenoh-link-commons = { workspace = true } zenoh-protocol = { workspace = true } zenoh-result = { workspace = true } +zenoh-runtime = { workspace = true } zenoh-sync = { workspace = true } zenoh-util = { workspace = true } -zenoh-runtime = { workspace = true } - # Lock due to quinn not supporting rustls 0.22 yet rustls = { version = "0.21", features = ["dangerous_configuration", "quic"] } tokio-rustls = "0.24.1" +rustls-pemfile = { version = "1" } diff --git a/io/zenoh-links/zenoh-link-quic/src/lib.rs b/io/zenoh-links/zenoh-link-quic/src/lib.rs index c6d7e16087..0c9bc7365e 100644 --- a/io/zenoh-links/zenoh-link-quic/src/lib.rs +++ b/io/zenoh-links/zenoh-link-quic/src/lib.rs @@ -18,25 +18,17 @@ //! //! [Click here for Zenoh's documentation](../zenoh/index.html) use async_trait::async_trait; -use config::{ - TLS_ROOT_CA_CERTIFICATE_BASE64, TLS_ROOT_CA_CERTIFICATE_FILE, TLS_SERVER_CERTIFICATE_BASE64, - TLS_SERVER_CERTIFICATE_FILE, TLS_SERVER_NAME_VERIFICATION, TLS_SERVER_PRIVATE_KEY_BASE64, - TLS_SERVER_PRIVATE_KEY_FILE, -}; -use secrecy::ExposeSecret; -use std::net::SocketAddr; -use zenoh_config::Config; + use zenoh_core::zconfigurable; -use zenoh_link_commons::{ConfigurationInspector, LocatorInspector}; -use zenoh_protocol::core::{ - endpoint::{Address, Parameters}, - Locator, -}; -use zenoh_result::{bail, zerror, ZResult}; +use zenoh_link_commons::LocatorInspector; +use zenoh_protocol::core::Locator; +use zenoh_result::ZResult; mod unicast; +mod utils; mod verify; pub use unicast::*; +pub use utils::TlsConfigurator as QuicConfigurator; // Default ALPN protocol pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; @@ -64,77 +56,6 @@ impl LocatorInspector for QuicLocatorInspector { } } -#[derive(Default, Clone, Copy, Debug)] -pub struct QuicConfigurator; - -impl ConfigurationInspector for QuicConfigurator { - fn inspect_config(&self, config: &Config) -> ZResult { - let mut ps: Vec<(&str, &str)> = vec![]; - - let c = config.transport().link().tls(); - - match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") - } - (Some(ca_certificate), None) => { - ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); - } - (None, Some(ca_certificate)) => { - ps.push(( - TLS_ROOT_CA_CERTIFICATE_BASE64, - ca_certificate.expose_secret(), - )); - } - _ => {} - } - - match (c.server_private_key(), c.server_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") - } - (Some(server_private_key), None) => { - ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); - } - (None, Some(server_private_key)) => { - ps.push(( - TLS_SERVER_PRIVATE_KEY_BASE64, - server_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.server_certificate(), c.server_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") - } - (Some(server_certificate), None) => { - ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); - } - (None, Some(server_certificate)) => { - ps.push(( - TLS_SERVER_CERTIFICATE_BASE64, - server_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(server_name_verification) = c.server_name_verification() { - match server_name_verification { - true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), - false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), - }; - } - - let mut s = String::new(); - Parameters::extend(ps.drain(..), &mut s); - - Ok(s) - } -} - zconfigurable! { // Default MTU (QUIC PDU) in bytes. static ref QUIC_DEFAULT_MTU: u16 = QUIC_MAX_MTU; @@ -157,25 +78,20 @@ pub mod config { pub const TLS_SERVER_PRIVATE_KEY_RAW: &str = "server_private_key_raw"; pub const TLS_SERVER_PRIVATE_KEY_BASE64: &str = "server_private_key_base64"; - pub const TLS_SERVER_CERTIFICATE_FILE: &str = "tls_server_certificate_file"; - pub const TLS_SERVER_CERTIFICATE_RAW: &str = "tls_server_certificate_raw"; - pub const TLS_SERVER_CERTIFICATE_BASE64: &str = "tls_server_certificate_base64"; + pub const TLS_SERVER_CERTIFICATE_FILE: &str = "server_certificate_file"; + pub const TLS_SERVER_CERTIFICATE_RAW: &str = "server_certificate_raw"; + pub const TLS_SERVER_CERTIFICATE_BASE64: &str = "server_certificate_base64"; - pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; - pub const TLS_SERVER_NAME_VERIFICATION_DEFAULT: &str = "true"; -} + pub const TLS_CLIENT_PRIVATE_KEY_FILE: &str = "client_private_key_file"; + pub const TLS_CLIENT_PRIVATE_KEY_RAW: &str = "client_private_key_raw"; + pub const TLS_CLIENT_PRIVATE_KEY_BASE64: &str = "client_private_key_base64"; -async fn get_quic_addr(address: &Address<'_>) -> ZResult { - match tokio::net::lookup_host(address.as_str()).await?.next() { - Some(addr) => Ok(addr), - None => bail!("Couldn't resolve QUIC locator address: {}", address), - } -} + pub const TLS_CLIENT_CERTIFICATE_FILE: &str = "client_certificate_file"; + pub const TLS_CLIENT_CERTIFICATE_RAW: &str = "client_certificate_raw"; + pub const TLS_CLIENT_CERTIFICATE_BASE64: &str = "client_certificate_base64"; -pub fn base64_decode(data: &str) -> ZResult> { - use base64::engine::general_purpose; - use base64::Engine; - Ok(general_purpose::STANDARD - .decode(data) - .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) + pub const TLS_CLIENT_AUTH: &str = "client_auth"; + + pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; + pub const TLS_SERVER_NAME_VERIFICATION_DEFAULT: &str = "true"; } diff --git a/io/zenoh-links/zenoh-link-quic/src/unicast.rs b/io/zenoh-links/zenoh-link-quic/src/unicast.rs index 8fd7777137..452fd8a122 100644 --- a/io/zenoh-links/zenoh-link-quic/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-quic/src/unicast.rs @@ -12,16 +12,13 @@ // ZettaScale Zenoh Team, // -use crate::base64_decode; use crate::{ - config::*, get_quic_addr, verify::WebPkiVerifierAnyServerName, ALPN_QUIC_HTTP, - QUIC_ACCEPT_THROTTLE_TIME, QUIC_DEFAULT_MTU, QUIC_LOCATOR_PREFIX, + config::*, + utils::{get_quic_addr, TlsClientConfig, TlsServerConfig}, + ALPN_QUIC_HTTP, QUIC_ACCEPT_THROTTLE_TIME, QUIC_DEFAULT_MTU, QUIC_LOCATOR_PREFIX, }; use async_trait::async_trait; -use rustls::{Certificate, PrivateKey}; -use rustls_pemfile::Item; use std::fmt; -use std::io::BufReader; use std::net::IpAddr; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -34,7 +31,7 @@ use zenoh_link_commons::{ ListenersUnicastIP, NewLinkChannelSender, }; use zenoh_protocol::core::{EndPoint, Locator}; -use zenoh_result::{bail, zerror, ZError, ZResult}; +use zenoh_result::{bail, zerror, ZResult}; pub struct LinkUnicastQuic { connection: quinn::Connection, @@ -219,55 +216,12 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { } // Initialize the QUIC connection - let mut root_cert_store = rustls::RootCertStore::empty(); - - // Read the certificates - let f = if let Some(value) = epconf.get(TLS_ROOT_CA_CERTIFICATE_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_certificate) = epconf.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { - base64_decode(b64_certificate)? - } else if let Some(value) = epconf.get(TLS_ROOT_CA_CERTIFICATE_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - vec![] - }; - - let certificates = if f.is_empty() { - rustls_native_certs::load_native_certs() - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - .drain(..) - .map(|x| rustls::Certificate(x.to_vec())) - .collect::>() - } else { - rustls_pemfile::certs(&mut BufReader::new(f.as_slice())) - .map(|result| { - result - .map_err(|err| zerror!("Invalid QUIC CA certificate file: {}", err)) - .map(|der| Certificate(der.to_vec())) - }) - .collect::, ZError>>()? - }; - for c in certificates.iter() { - root_cert_store.add(c).map_err(|e| zerror!("{}", e))?; - } - - let client_crypto = rustls::ClientConfig::builder().with_safe_defaults(); - - let mut client_crypto = if server_name_verification { - client_crypto - .with_root_certificates(root_cert_store) - .with_no_client_auth() - } else { - client_crypto - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_no_client_auth() - }; + let mut client_crypto = TlsClientConfig::new(&epconf) + .await + .map_err(|e| zerror!("Cannot create a new QUIC client on {addr}: {e}"))?; - client_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); + client_crypto.client_config.alpn_protocols = + ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); let ip_addr: IpAddr = if addr.is_ipv4() { Ipv4Addr::UNSPECIFIED.into() @@ -276,7 +230,9 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { }; let mut quic_endpoint = quinn::Endpoint::client(SocketAddr::new(ip_addr, 0)) .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))?; - quic_endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(client_crypto))); + quic_endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new( + client_crypto.client_config, + ))); let src_addr = quic_endpoint .local_addr() @@ -314,61 +270,14 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { let addr = get_quic_addr(&epaddr).await?; - let f = if let Some(value) = epconf.get(TLS_SERVER_CERTIFICATE_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_certificate) = epconf.get(TLS_SERVER_CERTIFICATE_BASE64) { - base64_decode(b64_certificate)? - } else if let Some(value) = epconf.get(TLS_SERVER_CERTIFICATE_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - bail!("No QUIC CA certificate has been provided."); - }; - let certificates = rustls_pemfile::certs(&mut BufReader::new(f.as_slice())) - .map(|result| { - result - .map_err(|err| zerror!("Invalid QUIC CA certificate file: {}", err)) - .map(|der| Certificate(der.to_vec())) - }) - .collect::, ZError>>()?; - - // Private keys - let f = if let Some(value) = epconf.get(TLS_SERVER_PRIVATE_KEY_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_key) = epconf.get(TLS_SERVER_PRIVATE_KEY_BASE64) { - base64_decode(b64_key)? - } else if let Some(value) = epconf.get(TLS_SERVER_PRIVATE_KEY_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - bail!("No QUIC CA private key has been provided."); - }; - let items: Vec = rustls_pemfile::read_all(&mut BufReader::new(f.as_slice())) - .collect::>() - .map_err(|err| zerror!("Invalid QUIC CA private key file: {}", err))?; - - let private_key = items - .into_iter() - .filter_map(|x| match x { - rustls_pemfile::Item::Pkcs1Key(k) => Some(k.secret_pkcs1_der().to_vec()), - rustls_pemfile::Item::Pkcs8Key(k) => Some(k.secret_pkcs8_der().to_vec()), - rustls_pemfile::Item::Sec1Key(k) => Some(k.secret_sec1_der().to_vec()), - _ => None, - }) - .take(1) - .next() - .ok_or_else(|| zerror!("No QUIC CA private key has been provided.")) - .map(PrivateKey)?; - // Server config - let mut server_crypto = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certificates, private_key)?; - server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); - let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + let mut server_crypto = TlsServerConfig::new(&epconf) + .await + .map_err(|e| zerror!("Cannot create a new QUIC listener on {addr}: {e}"))?; + server_crypto.server_config.alpn_protocols = + ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); + let mut server_config = + quinn::ServerConfig::with_crypto(Arc::new(server_crypto.server_config)); // We do not accept unidireactional streams. Arc::get_mut(&mut server_config.transport) diff --git a/io/zenoh-links/zenoh-link-quic/src/utils.rs b/io/zenoh-links/zenoh-link-quic/src/utils.rs new file mode 100644 index 0000000000..40367599cb --- /dev/null +++ b/io/zenoh-links/zenoh-link-quic/src/utils.rs @@ -0,0 +1,509 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use crate::config::*; +use crate::verify::WebPkiVerifierAnyServerName; +use rustls::OwnedTrustAnchor; +use rustls::{ + server::AllowAnyAuthenticatedClient, version::TLS13, Certificate, ClientConfig, PrivateKey, + RootCertStore, ServerConfig, +}; +use rustls_pki_types::{CertificateDer, TrustAnchor}; +use secrecy::ExposeSecret; +use zenoh_link_commons::ConfigurationInspector; +// use rustls_pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}; +use std::fs::File; +use std::io; +use std::net::SocketAddr; +use std::{ + io::{BufReader, Cursor}, + sync::Arc, +}; +use webpki::anchor_from_trusted_cert; +use zenoh_config::Config as ZenohConfig; +use zenoh_protocol::core::endpoint::Config; +use zenoh_protocol::core::endpoint::{self, Address}; +use zenoh_result::{bail, zerror, ZError, ZResult}; + +#[derive(Default, Clone, Copy, Debug)] +pub struct TlsConfigurator; + +impl ConfigurationInspector for TlsConfigurator { + fn inspect_config(&self, config: &ZenohConfig) -> ZResult { + let mut ps: Vec<(&str, &str)> = vec![]; + + let c = config.transport().link().tls(); + + match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") + } + (Some(ca_certificate), None) => { + ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); + } + (None, Some(ca_certificate)) => { + ps.push(( + TLS_ROOT_CA_CERTIFICATE_BASE64, + ca_certificate.expose_secret(), + )); + } + _ => {} + } + + match (c.server_private_key(), c.server_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") + } + (Some(server_private_key), None) => { + ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); + } + (None, Some(server_private_key)) => { + ps.push(( + TLS_SERVER_PRIVATE_KEY_BASE64, + server_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.server_certificate(), c.server_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") + } + (Some(server_certificate), None) => { + ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); + } + (None, Some(server_certificate)) => { + ps.push(( + TLS_SERVER_CERTIFICATE_BASE64, + server_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(client_auth) = c.client_auth() { + match client_auth { + true => ps.push((TLS_CLIENT_AUTH, "true")), + false => ps.push((TLS_CLIENT_AUTH, "false")), + }; + } + + match (c.client_private_key(), c.client_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") + } + (Some(client_private_key), None) => { + ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); + } + (None, Some(client_private_key)) => { + ps.push(( + TLS_CLIENT_PRIVATE_KEY_BASE64, + client_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.client_certificate(), c.client_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") + } + (Some(client_certificate), None) => { + ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); + } + (None, Some(client_certificate)) => { + ps.push(( + TLS_CLIENT_CERTIFICATE_BASE64, + client_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(server_name_verification) = c.server_name_verification() { + match server_name_verification { + true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), + false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), + }; + } + + let mut s = String::new(); + endpoint::Parameters::extend(ps.drain(..), &mut s); + + Ok(s) + } +} + +pub(crate) struct TlsServerConfig { + pub(crate) server_config: ServerConfig, +} + +impl TlsServerConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; + let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) + .map_err(|err| zerror!("Error processing server certificate: {err}."))? + .into_iter() + .map(Certificate) + .collect(); + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + + if keys.is_empty() { + keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + bail!("No private key found for TLS server."); + } + + let sc = if tls_server_client_auth { + let root_cert_store = load_trust_anchors(config)?.map_or_else( + || { + Err(zerror!( + "Missing root certificates while client authentication is enabled." + )) + }, + Ok, + )?; + let client_auth = AllowAnyAuthenticatedClient::new(root_cert_store); + ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])? + .with_client_cert_verifier(Arc::new(client_auth)) + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + } else { + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + }; + Ok(TlsServerConfig { server_config: sc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_SERVER_PRIVATE_KEY_RAW, + TLS_SERVER_PRIVATE_KEY_FILE, + TLS_SERVER_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_SERVER_CERTIFICATE_RAW, + TLS_SERVER_CERTIFICATE_FILE, + TLS_SERVER_CERTIFICATE_BASE64, + ) + .await + } +} + +pub(crate) struct TlsClientConfig { + pub(crate) client_config: ClientConfig, +} + +impl TlsClientConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + + let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { + Some(s) => { + let s: bool = s + .parse() + .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; + if s { + tracing::warn!("Skipping name verification of servers"); + } + s + } + None => false, + }; + + // Allows mixed user-generated CA and webPKI CA + tracing::debug!("Loading default Web PKI certificates."); + let mut root_cert_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS + .iter() + .map(|ta| ta.to_owned()) + .map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.to_vec(), + ta.subject_public_key_info.to_vec(), + ta.name_constraints.map(|nc| nc.to_vec()), + ) + }) + .collect(), + }; + + if let Some(custom_root_cert) = load_trust_anchors(config)? { + tracing::debug!("Loading user-generated certificates."); + root_cert_store.roots.extend(custom_root_cert.roots); + } + + let cc = if tls_client_server_auth { + tracing::debug!("Loading client authentication key and certificate..."); + let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; + let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) + .map_err(|err| zerror!("Error processing client certificate: {err}."))? + .into_iter() + .map(Certificate) + .collect(); + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + + if keys.is_empty() { + keys = + rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + bail!("No private key found for TLS client."); + } + + let builder = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])?; + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_client_auth_cert(certs, keys.remove(0)) + } else { + builder + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_client_auth_cert(certs, keys.remove(0)) + } + .map_err(|e| zerror!("Bad certificate/key: {}", e))? + } else { + let builder = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])?; + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_no_client_auth() + } else { + builder + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_no_client_auth() + } + }; + Ok(TlsClientConfig { client_config: cc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_CLIENT_PRIVATE_KEY_RAW, + TLS_CLIENT_PRIVATE_KEY_FILE, + TLS_CLIENT_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_CLIENT_CERTIFICATE_RAW, + TLS_CLIENT_CERTIFICATE_FILE, + TLS_CLIENT_CERTIFICATE_BASE64, + ) + .await + } +} + +fn process_pem(pem: &mut dyn io::BufRead) -> ZResult> { + let certs: Vec = rustls_pemfile::certs(pem) + .map_err(|err| zerror!("Error processing PEM certificates: {err}."))? + .into_iter() + .map(CertificateDer::from) + .collect(); + + let trust_anchors: Vec = certs + .into_iter() + .map(|cert| { + anchor_from_trusted_cert(&cert) + .map_err(|err| zerror!("Error processing trust anchor: {err}.")) + .map(|trust_anchor| trust_anchor.to_owned()) + }) + .collect::, ZError>>()? + .into_iter() + .map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.to_vec(), + ta.subject_public_key_info.to_vec(), + ta.name_constraints.map(|nc| nc.to_vec()), + ) + }) + .collect(); + + Ok(trust_anchors) +} + +async fn load_tls_key( + config: &Config<'_>, + tls_private_key_raw_config_key: &str, + tls_private_key_file_config_key: &str, + tls_private_key_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_private_key_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { + return base64_decode(b64_key); + } + + if let Some(value) = config.get(tls_private_key_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) + .and_then(|result| { + if result.is_empty() { + Err(zerror!("Empty TLS key.").into()) + } else { + Ok(result) + } + }); + } + Err(zerror!("Missing TLS private key.").into()) +} + +async fn load_tls_certificate( + config: &Config<'_>, + tls_certificate_raw_config_key: &str, + tls_certificate_file_config_key: &str, + tls_certificate_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_certificate_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { + return base64_decode(b64_certificate); + } + + if let Some(value) = config.get(tls_certificate_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); + } + Err(zerror!("Missing tls certificates.").into()) +} + +fn load_trust_anchors(config: &Config<'_>) -> ZResult> { + let mut root_cert_store = RootCertStore::empty(); + if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { + let mut pem = BufReader::new(value.as_bytes()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { + let certificate_pem = base64_decode(b64_certificate)?; + let mut pem = BufReader::new(certificate_pem.as_slice()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { + let mut pem = BufReader::new(File::open(filename)?); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + Ok(None) +} + +pub async fn get_quic_addr(address: &Address<'_>) -> ZResult { + match tokio::net::lookup_host(address.as_str()).await?.next() { + Some(addr) => Ok(addr), + None => bail!("Couldn't resolve QUIC locator address: {}", address), + } +} + +pub fn base64_decode(data: &str) -> ZResult> { + use base64::engine::general_purpose; + use base64::Engine; + Ok(general_purpose::STANDARD + .decode(data) + .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) +} diff --git a/io/zenoh-links/zenoh-link-tls/Cargo.toml b/io/zenoh-links/zenoh-link-tls/Cargo.toml index 11d00d96d8..91fb72787e 100644 --- a/io/zenoh-links/zenoh-link-tls/Cargo.toml +++ b/io/zenoh-links/zenoh-link-tls/Cargo.toml @@ -12,31 +12,31 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-tls" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-tls" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] async-trait = { workspace = true } base64 = { workspace = true } futures = { workspace = true } -tracing = {workspace = true} rustls = { workspace = true } rustls-pemfile = { workspace = true } rustls-pki-types = { workspace = true } rustls-webpki = { workspace = true } -secrecy = {workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync"] } +secrecy = { workspace = true } +tokio = { workspace = true, features = ["fs", "io-util", "net", "sync"] } tokio-rustls = { workspace = true } tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } webpki-roots = { workspace = true } zenoh-config = { workspace = true } zenoh-core = { workspace = true } diff --git a/io/zenoh-links/zenoh-link-tls/src/lib.rs b/io/zenoh-links/zenoh-link-tls/src/lib.rs index 95d59104b4..b9002cc397 100644 --- a/io/zenoh-links/zenoh-link-tls/src/lib.rs +++ b/io/zenoh-links/zenoh-link-tls/src/lib.rs @@ -18,26 +18,15 @@ //! //! [Click here for Zenoh's documentation](../zenoh/index.html) use async_trait::async_trait; -use config::{ - TLS_CLIENT_AUTH, TLS_CLIENT_CERTIFICATE_BASE64, TLS_CLIENT_CERTIFICATE_FILE, - TLS_CLIENT_PRIVATE_KEY_BASE64, TLS_CLIENT_PRIVATE_KEY_FILE, TLS_ROOT_CA_CERTIFICATE_BASE64, - TLS_ROOT_CA_CERTIFICATE_FILE, TLS_SERVER_CERTIFICATE_BASE64, TLS_SERVER_CERTIFICATE_FILE, - TLS_SERVER_NAME_VERIFICATION, TLS_SERVER_PRIVATE_KEY_BASE_64, TLS_SERVER_PRIVATE_KEY_FILE, -}; -use rustls_pki_types::ServerName; -use secrecy::ExposeSecret; -use std::{convert::TryFrom, net::SocketAddr}; -use zenoh_config::Config; use zenoh_core::zconfigurable; -use zenoh_link_commons::{ConfigurationInspector, LocatorInspector}; -use zenoh_protocol::core::{ - endpoint::{self, Address}, - Locator, -}; -use zenoh_result::{bail, zerror, ZResult}; +use zenoh_link_commons::LocatorInspector; +use zenoh_protocol::core::Locator; +use zenoh_result::ZResult; mod unicast; +mod utils; pub use unicast::*; +pub use utils::TlsConfigurator; // Default MTU (TLS PDU) in bytes. // NOTE: Since TLS is a byte-stream oriented transport, theoretically it has @@ -60,115 +49,6 @@ impl LocatorInspector for TlsLocatorInspector { Ok(false) } } -#[derive(Default, Clone, Copy, Debug)] -pub struct TlsConfigurator; - -impl ConfigurationInspector for TlsConfigurator { - fn inspect_config(&self, config: &Config) -> ZResult { - let mut ps: Vec<(&str, &str)> = vec![]; - - let c = config.transport().link().tls(); - - match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") - } - (Some(ca_certificate), None) => { - ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); - } - (None, Some(ca_certificate)) => { - ps.push(( - TLS_ROOT_CA_CERTIFICATE_BASE64, - ca_certificate.expose_secret(), - )); - } - _ => {} - } - - match (c.server_private_key(), c.server_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") - } - (Some(server_private_key), None) => { - ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); - } - (None, Some(server_private_key)) => { - ps.push(( - TLS_SERVER_PRIVATE_KEY_BASE_64, - server_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.server_certificate(), c.server_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") - } - (Some(server_certificate), None) => { - ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); - } - (None, Some(server_certificate)) => { - ps.push(( - TLS_SERVER_CERTIFICATE_BASE64, - server_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(client_auth) = c.client_auth() { - match client_auth { - true => ps.push((TLS_CLIENT_AUTH, "true")), - false => ps.push((TLS_CLIENT_AUTH, "false")), - }; - } - - match (c.client_private_key(), c.client_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") - } - (Some(client_private_key), None) => { - ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); - } - (None, Some(client_private_key)) => { - ps.push(( - TLS_CLIENT_PRIVATE_KEY_BASE64, - client_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.client_certificate(), c.client_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") - } - (Some(client_certificate), None) => { - ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); - } - (None, Some(client_certificate)) => { - ps.push(( - TLS_CLIENT_CERTIFICATE_BASE64, - client_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(server_name_verification) = c.server_name_verification() { - match server_name_verification { - true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), - false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), - }; - } - - let mut s = String::new(); - endpoint::Parameters::extend(ps.drain(..), &mut s); - - Ok(s) - } -} zconfigurable! { // Default MTU (TLS PDU) in bytes. @@ -208,30 +88,3 @@ pub mod config { pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; } - -pub async fn get_tls_addr(address: &Address<'_>) -> ZResult { - match tokio::net::lookup_host(address.as_str()).await?.next() { - Some(addr) => Ok(addr), - None => bail!("Couldn't resolve TLS locator address: {}", address), - } -} - -pub fn get_tls_host<'a>(address: &'a Address<'a>) -> ZResult<&'a str> { - address - .as_str() - .split(':') - .next() - .ok_or_else(|| zerror!("Invalid TLS address").into()) -} - -pub fn get_tls_server_name<'a>(address: &'a Address<'a>) -> ZResult> { - Ok(ServerName::try_from(get_tls_host(address)?).map_err(|e| zerror!(e))?) -} - -pub fn base64_decode(data: &str) -> ZResult> { - use base64::engine::general_purpose; - use base64::Engine; - Ok(general_purpose::STANDARD - .decode(data) - .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) -} diff --git a/io/zenoh-links/zenoh-link-tls/src/unicast.rs b/io/zenoh-links/zenoh-link-tls/src/unicast.rs index 9eec2feb2a..b12608354e 100644 --- a/io/zenoh-links/zenoh-link-tls/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-tls/src/unicast.rs @@ -12,39 +12,29 @@ // ZettaScale Zenoh Team, // use crate::{ - base64_decode, config::*, get_tls_addr, get_tls_host, get_tls_server_name, + utils::{get_tls_addr, get_tls_host, get_tls_server_name, TlsClientConfig, TlsServerConfig}, TLS_ACCEPT_THROTTLE_TIME, TLS_DEFAULT_MTU, TLS_LINGER_TIMEOUT, TLS_LOCATOR_PREFIX, }; + use async_trait::async_trait; -use rustls::{ - pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}, - server::WebPkiClientVerifier, - version::TLS13, - ClientConfig, RootCertStore, ServerConfig, -}; +use std::cell::UnsafeCell; use std::convert::TryInto; use std::fmt; -use std::fs::File; -use std::io::{BufReader, Cursor}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use std::{cell::UnsafeCell, io}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::Mutex as AsyncMutex; use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream}; use tokio_util::sync::CancellationToken; -use webpki::anchor_from_trusted_cert; use zenoh_core::zasynclock; -use zenoh_link_commons::tls::WebPkiVerifierAnyServerName; use zenoh_link_commons::{ get_ip_interface_names, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender, }; -use zenoh_protocol::core::endpoint::Config; use zenoh_protocol::core::{EndPoint, Locator}; -use zenoh_result::{bail, zerror, ZError, ZResult}; +use zenoh_result::{zerror, ZResult}; pub struct LinkUnicastTls { // The underlying socket as returned from the async-rustls library @@ -418,311 +408,3 @@ async fn accept_task( Ok(()) } - -struct TlsServerConfig { - server_config: ServerConfig, -} - -impl TlsServerConfig { - pub async fn new(config: &Config<'_>) -> ZResult { - let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { - Some(s) => s - .parse() - .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, - None => false, - }; - let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; - let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; - - let certs: Vec = - rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) - .collect::>() - .map_err(|err| zerror!("Error processing server certificate: {err}."))?; - - let mut keys: Vec = - rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - - if keys.is_empty() { - keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - } - - if keys.is_empty() { - keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - } - - if keys.is_empty() { - bail!("No private key found for TLS server."); - } - - let sc = if tls_server_client_auth { - let root_cert_store = load_trust_anchors(config)?.map_or_else( - || { - Err(zerror!( - "Missing root certificates while client authentication is enabled." - )) - }, - Ok, - )?; - let client_auth = WebPkiClientVerifier::builder(root_cert_store.into()).build()?; - ServerConfig::builder_with_protocol_versions(&[&TLS13]) - .with_client_cert_verifier(client_auth) - .with_single_cert(certs, keys.remove(0)) - .map_err(|e| zerror!(e))? - } else { - ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, keys.remove(0)) - .map_err(|e| zerror!(e))? - }; - Ok(TlsServerConfig { server_config: sc }) - } - - async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { - load_tls_key( - config, - TLS_SERVER_PRIVATE_KEY_RAW, - TLS_SERVER_PRIVATE_KEY_FILE, - TLS_SERVER_PRIVATE_KEY_BASE_64, - ) - .await - } - - async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { - load_tls_certificate( - config, - TLS_SERVER_CERTIFICATE_RAW, - TLS_SERVER_CERTIFICATE_FILE, - TLS_SERVER_CERTIFICATE_BASE64, - ) - .await - } -} - -struct TlsClientConfig { - client_config: ClientConfig, -} - -impl TlsClientConfig { - pub async fn new(config: &Config<'_>) -> ZResult { - let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { - Some(s) => s - .parse() - .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, - None => false, - }; - - let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { - Some(s) => { - let s: bool = s - .parse() - .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; - if s { - tracing::warn!("Skipping name verification of servers"); - } - s - } - None => false, - }; - - // Allows mixed user-generated CA and webPKI CA - tracing::debug!("Loading default Web PKI certificates."); - let mut root_cert_store = RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), - }; - - if let Some(custom_root_cert) = load_trust_anchors(config)? { - tracing::debug!("Loading user-generated certificates."); - root_cert_store.extend(custom_root_cert.roots); - } - - let cc = if tls_client_server_auth { - tracing::debug!("Loading client authentication key and certificate..."); - let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; - let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; - - let certs: Vec = - rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) - .collect::>() - .map_err(|err| zerror!("Error processing client certificate: {err}."))?; - - let mut keys: Vec = - rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - - if keys.is_empty() { - keys = - rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - } - - if keys.is_empty() { - keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - } - - if keys.is_empty() { - bail!("No private key found for TLS client."); - } - - let builder = ClientConfig::builder_with_protocol_versions(&[&TLS13]); - - if tls_server_name_verification { - builder - .with_root_certificates(root_cert_store) - .with_client_auth_cert(certs, keys.remove(0)) - } else { - builder - .dangerous() - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_client_auth_cert(certs, keys.remove(0)) - } - .map_err(|e| zerror!("Bad certificate/key: {}", e))? - } else { - let builder = ClientConfig::builder(); - if tls_server_name_verification { - builder - .with_root_certificates(root_cert_store) - .with_no_client_auth() - } else { - builder - .dangerous() - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_no_client_auth() - } - }; - Ok(TlsClientConfig { client_config: cc }) - } - - async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { - load_tls_key( - config, - TLS_CLIENT_PRIVATE_KEY_RAW, - TLS_CLIENT_PRIVATE_KEY_FILE, - TLS_CLIENT_PRIVATE_KEY_BASE64, - ) - .await - } - - async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { - load_tls_certificate( - config, - TLS_CLIENT_CERTIFICATE_RAW, - TLS_CLIENT_CERTIFICATE_FILE, - TLS_CLIENT_CERTIFICATE_BASE64, - ) - .await - } -} - -async fn load_tls_key( - config: &Config<'_>, - tls_private_key_raw_config_key: &str, - tls_private_key_file_config_key: &str, - tls_private_key_base64_config_key: &str, -) -> ZResult> { - if let Some(value) = config.get(tls_private_key_raw_config_key) { - return Ok(value.as_bytes().to_vec()); - } - - if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { - return base64_decode(b64_key); - } - - if let Some(value) = config.get(tls_private_key_file_config_key) { - return Ok(tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) - .and_then(|result| { - if result.is_empty() { - Err(zerror!("Empty TLS key.").into()) - } else { - Ok(result) - } - }); - } - Err(zerror!("Missing TLS private key.").into()) -} - -async fn load_tls_certificate( - config: &Config<'_>, - tls_certificate_raw_config_key: &str, - tls_certificate_file_config_key: &str, - tls_certificate_base64_config_key: &str, -) -> ZResult> { - if let Some(value) = config.get(tls_certificate_raw_config_key) { - return Ok(value.as_bytes().to_vec()); - } - - if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { - return base64_decode(b64_certificate); - } - - if let Some(value) = config.get(tls_certificate_file_config_key) { - return Ok(tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); - } - Err(zerror!("Missing tls certificates.").into()) -} - -fn load_trust_anchors(config: &Config<'_>) -> ZResult> { - let mut root_cert_store = RootCertStore::empty(); - if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { - let mut pem = BufReader::new(value.as_bytes()); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - - if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { - let certificate_pem = base64_decode(b64_certificate)?; - let mut pem = BufReader::new(certificate_pem.as_slice()); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - - if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { - let mut pem = BufReader::new(File::open(filename)?); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - Ok(None) -} - -fn process_pem(pem: &mut dyn io::BufRead) -> ZResult>> { - let certs: Vec = rustls_pemfile::certs(pem) - .map(|result| result.map_err(|err| zerror!("Error processing PEM certificates: {err}."))) - .collect::, ZError>>()?; - - let trust_anchors: Vec = certs - .into_iter() - .map(|cert| { - anchor_from_trusted_cert(&cert) - .map_err(|err| zerror!("Error processing trust anchor: {err}.")) - .map(|trust_anchor| trust_anchor.to_owned()) - }) - .collect::, ZError>>()?; - - Ok(trust_anchors) -} diff --git a/io/zenoh-links/zenoh-link-tls/src/utils.rs b/io/zenoh-links/zenoh-link-tls/src/utils.rs new file mode 100644 index 0000000000..f62757523c --- /dev/null +++ b/io/zenoh-links/zenoh-link-tls/src/utils.rs @@ -0,0 +1,480 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use crate::config::*; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}, + server::WebPkiClientVerifier, + version::TLS13, + ClientConfig, RootCertStore, ServerConfig, +}; +use rustls_pki_types::ServerName; +use secrecy::ExposeSecret; +use std::fs::File; +use std::io; +use std::{convert::TryFrom, net::SocketAddr}; +use std::{ + io::{BufReader, Cursor}, + sync::Arc, +}; +use webpki::anchor_from_trusted_cert; +use zenoh_config::Config as ZenohConfig; +use zenoh_link_commons::{tls::WebPkiVerifierAnyServerName, ConfigurationInspector}; +use zenoh_protocol::core::endpoint::Config; +use zenoh_protocol::core::endpoint::{self, Address}; +use zenoh_result::{bail, zerror, ZError, ZResult}; + +#[derive(Default, Clone, Copy, Debug)] +pub struct TlsConfigurator; + +impl ConfigurationInspector for TlsConfigurator { + fn inspect_config(&self, config: &ZenohConfig) -> ZResult { + let mut ps: Vec<(&str, &str)> = vec![]; + + let c = config.transport().link().tls(); + + match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") + } + (Some(ca_certificate), None) => { + ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); + } + (None, Some(ca_certificate)) => { + ps.push(( + TLS_ROOT_CA_CERTIFICATE_BASE64, + ca_certificate.expose_secret(), + )); + } + _ => {} + } + + match (c.server_private_key(), c.server_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") + } + (Some(server_private_key), None) => { + ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); + } + (None, Some(server_private_key)) => { + ps.push(( + TLS_SERVER_PRIVATE_KEY_BASE_64, + server_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.server_certificate(), c.server_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") + } + (Some(server_certificate), None) => { + ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); + } + (None, Some(server_certificate)) => { + ps.push(( + TLS_SERVER_CERTIFICATE_BASE64, + server_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(client_auth) = c.client_auth() { + match client_auth { + true => ps.push((TLS_CLIENT_AUTH, "true")), + false => ps.push((TLS_CLIENT_AUTH, "false")), + }; + } + + match (c.client_private_key(), c.client_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") + } + (Some(client_private_key), None) => { + ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); + } + (None, Some(client_private_key)) => { + ps.push(( + TLS_CLIENT_PRIVATE_KEY_BASE64, + client_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.client_certificate(), c.client_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") + } + (Some(client_certificate), None) => { + ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); + } + (None, Some(client_certificate)) => { + ps.push(( + TLS_CLIENT_CERTIFICATE_BASE64, + client_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(server_name_verification) = c.server_name_verification() { + match server_name_verification { + true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), + false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), + }; + } + + let mut s = String::new(); + endpoint::Parameters::extend(ps.drain(..), &mut s); + + Ok(s) + } +} + +pub(crate) struct TlsServerConfig { + pub(crate) server_config: ServerConfig, +} + +impl TlsServerConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; + let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) + .collect::>() + .map_err(|err| zerror!("Error processing server certificate: {err}."))?; + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + + if keys.is_empty() { + keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + } + + if keys.is_empty() { + bail!("No private key found for TLS server."); + } + + let sc = if tls_server_client_auth { + let root_cert_store = load_trust_anchors(config)?.map_or_else( + || { + Err(zerror!( + "Missing root certificates while client authentication is enabled." + )) + }, + Ok, + )?; + let client_auth = WebPkiClientVerifier::builder(root_cert_store.into()).build()?; + ServerConfig::builder_with_protocol_versions(&[&TLS13]) + .with_client_cert_verifier(client_auth) + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + }; + Ok(TlsServerConfig { server_config: sc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_SERVER_PRIVATE_KEY_RAW, + TLS_SERVER_PRIVATE_KEY_FILE, + TLS_SERVER_PRIVATE_KEY_BASE_64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_SERVER_CERTIFICATE_RAW, + TLS_SERVER_CERTIFICATE_FILE, + TLS_SERVER_CERTIFICATE_BASE64, + ) + .await + } +} + +pub(crate) struct TlsClientConfig { + pub(crate) client_config: ClientConfig, +} + +impl TlsClientConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + + let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { + Some(s) => { + let s: bool = s + .parse() + .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; + if s { + tracing::warn!("Skipping name verification of servers"); + } + s + } + None => false, + }; + + // Allows mixed user-generated CA and webPKI CA + tracing::debug!("Loading default Web PKI certificates."); + let mut root_cert_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }; + + if let Some(custom_root_cert) = load_trust_anchors(config)? { + tracing::debug!("Loading user-generated certificates."); + root_cert_store.extend(custom_root_cert.roots); + } + + let cc = if tls_client_server_auth { + tracing::debug!("Loading client authentication key and certificate..."); + let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; + let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) + .collect::>() + .map_err(|err| zerror!("Error processing client certificate: {err}."))?; + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + + if keys.is_empty() { + keys = + rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + } + + if keys.is_empty() { + bail!("No private key found for TLS client."); + } + + let builder = ClientConfig::builder_with_protocol_versions(&[&TLS13]); + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_client_auth_cert(certs, keys.remove(0)) + } else { + builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_client_auth_cert(certs, keys.remove(0)) + } + .map_err(|e| zerror!("Bad certificate/key: {}", e))? + } else { + let builder = ClientConfig::builder(); + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_no_client_auth() + } else { + builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_no_client_auth() + } + }; + Ok(TlsClientConfig { client_config: cc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_CLIENT_PRIVATE_KEY_RAW, + TLS_CLIENT_PRIVATE_KEY_FILE, + TLS_CLIENT_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_CLIENT_CERTIFICATE_RAW, + TLS_CLIENT_CERTIFICATE_FILE, + TLS_CLIENT_CERTIFICATE_BASE64, + ) + .await + } +} + +fn process_pem(pem: &mut dyn io::BufRead) -> ZResult>> { + let certs: Vec = rustls_pemfile::certs(pem) + .map(|result| result.map_err(|err| zerror!("Error processing PEM certificates: {err}."))) + .collect::, ZError>>()?; + + let trust_anchors: Vec = certs + .into_iter() + .map(|cert| { + anchor_from_trusted_cert(&cert) + .map_err(|err| zerror!("Error processing trust anchor: {err}.")) + .map(|trust_anchor| trust_anchor.to_owned()) + }) + .collect::, ZError>>()?; + + Ok(trust_anchors) +} + +async fn load_tls_key( + config: &Config<'_>, + tls_private_key_raw_config_key: &str, + tls_private_key_file_config_key: &str, + tls_private_key_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_private_key_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { + return base64_decode(b64_key); + } + + if let Some(value) = config.get(tls_private_key_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) + .and_then(|result| { + if result.is_empty() { + Err(zerror!("Empty TLS key.").into()) + } else { + Ok(result) + } + }); + } + Err(zerror!("Missing TLS private key.").into()) +} + +async fn load_tls_certificate( + config: &Config<'_>, + tls_certificate_raw_config_key: &str, + tls_certificate_file_config_key: &str, + tls_certificate_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_certificate_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { + return base64_decode(b64_certificate); + } + + if let Some(value) = config.get(tls_certificate_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); + } + Err(zerror!("Missing tls certificates.").into()) +} + +fn load_trust_anchors(config: &Config<'_>) -> ZResult> { + let mut root_cert_store = RootCertStore::empty(); + if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { + let mut pem = BufReader::new(value.as_bytes()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { + let certificate_pem = base64_decode(b64_certificate)?; + let mut pem = BufReader::new(certificate_pem.as_slice()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { + let mut pem = BufReader::new(File::open(filename)?); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + Ok(None) +} + +pub fn base64_decode(data: &str) -> ZResult> { + use base64::engine::general_purpose; + use base64::Engine; + Ok(general_purpose::STANDARD + .decode(data) + .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) +} + +pub async fn get_tls_addr(address: &Address<'_>) -> ZResult { + match tokio::net::lookup_host(address.as_str()).await?.next() { + Some(addr) => Ok(addr), + None => bail!("Couldn't resolve TLS locator address: {}", address), + } +} + +pub fn get_tls_host<'a>(address: &'a Address<'a>) -> ZResult<&'a str> { + address + .as_str() + .split(':') + .next() + .ok_or_else(|| zerror!("Invalid TLS address").into()) +} + +pub fn get_tls_server_name<'a>(address: &'a Address<'a>) -> ZResult> { + Ok(ServerName::try_from(get_tls_host(address)?).map_err(|e| zerror!(e))?) +} diff --git a/io/zenoh-transport/Cargo.toml b/io/zenoh-transport/Cargo.toml index b3a299e8be..9f6594761e 100644 --- a/io/zenoh-transport/Cargo.toml +++ b/io/zenoh-transport/Cargo.toml @@ -92,3 +92,4 @@ futures-util = { workspace = true } zenoh-util = {workspace = true } zenoh-protocol = { workspace = true, features = ["test"] } futures = { workspace = true } +zenoh-link-commons = { workspace = true } diff --git a/io/zenoh-transport/tests/unicast_transport.rs b/io/zenoh-transport/tests/unicast_transport.rs index af1dedfbce..33cfbceb17 100644 --- a/io/zenoh-transport/tests/unicast_transport.rs +++ b/io/zenoh-transport/tests/unicast_transport.rs @@ -69,7 +69,10 @@ use zenoh_transport::{ // the key and certificate brought in by the client. Similarly the server's certificate authority // will validate the key and certificate brought in by the server in front of the client. // -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_KEY: &str = "-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAsfqAuhElN4HnyeqLovSd4Qe+nNv5AwCjSO+HFiF30x3vQ1Hi qRA0UmyFlSqBnFH3TUHm4Jcad40QfrX8f11NKGZdpvKHsMYqYjZnYkRFGS2s4fQy @@ -98,7 +101,10 @@ tYsqC2FtWzY51VOEKNpnfH7zH5n+bjoI9nAEAW63TK9ZKkr2hRGsDhJdGzmLfQ7v F6/CuIw9EsAq6qIB8O88FXQqald+BZOx6AzB8Oedsz/WtMmIEmr/+Q== -----END RSA PRIVATE KEY-----"; -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_CERT: &str = "-----BEGIN CERTIFICATE----- MIIDLjCCAhagAwIBAgIIeUtmIdFQznMwDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE AxMVbWluaWNhIHJvb3QgY2EgMDc4ZGE3MCAXDTIzMDMwNjE2MDMxOFoYDzIxMjMw @@ -120,7 +126,10 @@ p5e60QweRuJsb60aUaCG8HoICevXYK2fFqCQdlb5sIqQqXyN2K6HuKAFywsjsGyJ abY= -----END CERTIFICATE-----"; -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_CA: &str = "-----BEGIN CERTIFICATE----- MIIDSzCCAjOgAwIBAgIIB42n1ZIkOakwDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE AxMVbWluaWNhIHJvb3QgY2EgMDc4ZGE3MCAXDTIzMDMwNjE2MDMwN1oYDzIxMjMw @@ -1298,6 +1307,225 @@ fn transport_unicast_tls_only_mutual_wrong_client_certs_failure() { assert!(result.is_err()); } +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn transport_unicast_quic_only_mutual_success() { + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + let client_auth = "true"; + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10461").parse().unwrap(); + client_endpoint + .config_mut() + .extend( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA), + (TLS_CLIENT_CERTIFICATE_RAW, CLIENT_CERT), + (TLS_CLIENT_PRIVATE_KEY_RAW, CLIENT_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10461").parse().unwrap(); + server_endpoint + .config_mut() + .extend( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + ) + .await; +} + +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn transport_unicast_quic_only_mutual_no_client_certs_failure() { + use std::vec; + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10462").parse().unwrap(); + client_endpoint + .config_mut() + .extend( + [(TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA)] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10462").parse().unwrap(); + server_endpoint + .config_mut() + .extend( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, "true"), + ] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + let result = std::panic::catch_unwind(|| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + )) + }); + assert!(result.is_err()); +} + +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[test] +fn transport_unicast_quic_only_mutual_wrong_client_certs_failure() { + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + let client_auth = "true"; + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10463").parse().unwrap(); + client_endpoint + .config_mut() + .extend( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA), + // Using the SERVER_CERT and SERVER_KEY in the client to simulate the case the client has + // wrong certificates and keys. The SERVER_CA (cetificate authority) will not recognize + // these certificates as it is expecting to receive CLIENT_CERT and CLIENT_KEY from the + // client. + (TLS_CLIENT_CERTIFICATE_RAW, SERVER_CERT), + (TLS_CLIENT_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10463").parse().unwrap(); + server_endpoint + .config_mut() + .extend( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .map(|(k, v)| ((*k).to_owned(), (*v).to_owned())), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + let result = std::panic::catch_unwind(|| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + )) + }); + assert!(result.is_err()); +} + #[test] fn transport_unicast_qos_and_lowlatency_failure() { struct TestPeer;