diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4329dff56c..084a3b148f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,7 @@ env: jobs: check: - name: Run checks on ${{ matrix.os }} + name: Lints and doc tests on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -69,8 +69,11 @@ jobs: - name: Perform no_std checks run: cargo check --bin nostd_check --target x86_64-unknown-none --manifest-path ci/nostd-check/Cargo.toml + - name: Run doctests + run: cargo test --doc + test: - name: Run tests on ${{ matrix.os }} + name: Unit tests on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -112,9 +115,6 @@ jobs: if: ${{ matrix.os == 'ubuntu-latest' }} run: cargo nextest run -p zenohd --no-default-features - - name: Run doctests - run: cargo test --doc - valgrind: name: Memory leak checks runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index 737cb62f75..55de3d50f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3190,9 +3190,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.2" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", "ring 0.17.6", @@ -4204,7 +4204,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" dependencies = [ - "rustls 0.22.2", + "rustls 0.22.4", "rustls-pki-types", "tokio", ] @@ -5311,16 +5311,19 @@ name = "zenoh-link-commons" version = "0.11.0-dev" dependencies = [ "async-trait", + "base64 0.21.4", "flume", "futures", - "rustls 0.22.2", + "rustls 0.22.4", "rustls-webpki 0.102.2", "serde", "tokio", "tokio-util", "tracing", + "webpki-roots", "zenoh-buffers", "zenoh-codec", + "zenoh-config", "zenoh-core", "zenoh-protocol", "zenoh-result", @@ -5338,13 +5341,15 @@ dependencies = [ "quinn", "rustls 0.21.7", "rustls-native-certs 0.7.0", - "rustls-pemfile 2.0.0", + "rustls-pemfile 1.0.3", + "rustls-pki-types", "rustls-webpki 0.102.2", "secrecy", "tokio", "tokio-rustls 0.24.1", "tokio-util", "tracing", + "webpki-roots", "zenoh-collections", "zenoh-config", "zenoh-core", @@ -5401,7 +5406,7 @@ dependencies = [ "async-trait", "base64 0.21.4", "futures", - "rustls 0.22.2", + "rustls 0.22.4", "rustls-pemfile 2.0.0", "rustls-pki-types", "rustls-webpki 0.102.2", @@ -5649,6 +5654,7 @@ dependencies = [ "ron", "serde", "tokio", + "tracing", "zenoh-collections", "zenoh-macros", "zenoh-protocol", @@ -5732,6 +5738,7 @@ dependencies = [ "zenoh-core", "zenoh-crypto", "zenoh-link", + "zenoh-link-commons", "zenoh-protocol", "zenoh-result", "zenoh-runtime", @@ -5821,6 +5828,6 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs b/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs index bc8716bb45..364617eb2a 100644 --- a/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs +++ b/ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs @@ -20,15 +20,13 @@ use zenoh::prelude::r#async::*; async fn main() { zenoh_util::init_log_test(); - let _z = zenoh_runtime::ZRuntimePoolGuard; - let queryable_key_expr = KeyExpr::try_from("test/valgrind/data").unwrap(); let get_selector = Selector::try_from("test/valgrind/**").unwrap(); println!("Declaring Queryable on '{queryable_key_expr}'..."); let queryable_session = zenoh::open(Config::default()).res().await.unwrap(); let _queryable = queryable_session - .declare_queryable(queryable_key_expr) + .declare_queryable(queryable_key_expr.clone()) .callback(move |query| { println!(">> Handling query '{}'", query.selector()); let queryable_key_expr = queryable_key_expr.clone(); diff --git a/commons/zenoh-runtime/Cargo.toml b/commons/zenoh-runtime/Cargo.toml index 530140dd7a..e3a08a9de8 100644 --- a/commons/zenoh-runtime/Cargo.toml +++ b/commons/zenoh-runtime/Cargo.toml @@ -31,3 +31,4 @@ tokio = { workspace = true, features = [ "sync", "time", ] } +tracing = { workspace = true } diff --git a/commons/zenoh-runtime/src/lib.rs b/commons/zenoh-runtime/src/lib.rs index 1a9d765420..cb58cac570 100644 --- a/commons/zenoh-runtime/src/lib.rs +++ b/commons/zenoh-runtime/src/lib.rs @@ -184,17 +184,42 @@ impl ZRuntimePool { // If there are any blocking tasks spawned by ZRuntimes, the function will block until they return. impl Drop for ZRuntimePool { fn drop(&mut self) { + std::panic::set_hook(Box::new(|_| { + // To suppress the panic error caught in the following `catch_unwind`. + })); + let handles: Vec<_> = self .0 .drain() .filter_map(|(_name, mut rt)| { - rt.take() - .map(|r| std::thread::spawn(move || r.shutdown_timeout(Duration::from_secs(1)))) + rt.take().map(|r| { + // NOTE: The error of the atexit handler in DLL (static lib is fine) + // failing to spawn a new thread in `cleanup` has been identified. + std::panic::catch_unwind(|| { + std::thread::spawn(move || r.shutdown_timeout(Duration::from_secs(1))) + }) + }) }) .collect(); for hd in handles { - let _ = hd.join(); + match hd { + Ok(handle) => { + if let Err(err) = handle.join() { + tracing::error!( + "The handle failed to join during `ZRuntimePool` drop due to {err:?}" + ); + } + } + Err(err) => { + // WARN: Windows with DLL is expected to panic for the time being. + // Otherwise, report the error. + #[cfg(not(target_os = "windows"))] + tracing::error!("`ZRuntimePool` failed to drop due to {err:?}"); + #[cfg(target_os = "windows")] + tracing::trace!("`ZRuntimePool` failed to drop due to {err:?}"); + } + } } } } diff --git a/io/zenoh-link-commons/Cargo.toml b/io/zenoh-link-commons/Cargo.toml index f2e10616c1..12b70cad6d 100644 --- a/io/zenoh-link-commons/Cargo.toml +++ b/io/zenoh-link-commons/Cargo.toml @@ -12,16 +12,16 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-commons" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-commons" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] @@ -29,18 +29,27 @@ compression = [] [dependencies] async-trait = { workspace = true } +base64 = { workspace = true, optional = true } +flume = { workspace = true } +futures = { workspace = true } rustls = { workspace = true } rustls-webpki = { workspace = true } -flume = { workspace = true } -tracing = {workspace = true} serde = { workspace = true, features = ["default"] } +tokio = { workspace = true, features = [ + "fs", + "io-util", + "net", + "sync", + "time", +] } +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +webpki-roots = { workspace = true, optional = true } zenoh-buffers = { workspace = true } zenoh-codec = { workspace = true } +zenoh-config = { workspace = true } zenoh-core = { workspace = true } zenoh-protocol = { workspace = true } zenoh-result = { workspace = true } -zenoh-util = { workspace = true } zenoh-runtime = { workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync", "time"] } -tokio-util = { workspace = true, features = ["rt"] } -futures = { workspace = true } +zenoh-util = { workspace = true } diff --git a/io/zenoh-links/zenoh-link-quic/Cargo.toml b/io/zenoh-links/zenoh-link-quic/Cargo.toml index d217b65200..63bfc1f839 100644 --- a/io/zenoh-links/zenoh-link-quic/Cargo.toml +++ b/io/zenoh-links/zenoh-link-quic/Cargo.toml @@ -12,40 +12,47 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-quic" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-quic" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] async-trait = { workspace = true } base64 = { workspace = true } futures = { workspace = true } -tracing = {workspace = true} quinn = { workspace = true } rustls-native-certs = { workspace = true } -rustls-pemfile = { workspace = true } +rustls-pki-types = { workspace = true } rustls-webpki = { workspace = true } -secrecy = {workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync", "time"] } +secrecy = { workspace = true } +tokio = { workspace = true, features = [ + "fs", + "io-util", + "net", + "sync", + "time", +] } tokio-util = { workspace = true, features = ["rt"] } zenoh-collections = { workspace = true } +tracing = { workspace = true } +webpki-roots = { workspace = true } zenoh-config = { workspace = true } zenoh-core = { workspace = true } zenoh-link-commons = { workspace = true } zenoh-protocol = { workspace = true } zenoh-result = { workspace = true } +zenoh-runtime = { workspace = true } zenoh-sync = { workspace = true } zenoh-util = { workspace = true } -zenoh-runtime = { workspace = true } - # Lock due to quinn not supporting rustls 0.22 yet rustls = { version = "0.21", features = ["dangerous_configuration", "quic"] } tokio-rustls = "0.24.1" +rustls-pemfile = { version = "1" } diff --git a/io/zenoh-links/zenoh-link-quic/src/lib.rs b/io/zenoh-links/zenoh-link-quic/src/lib.rs index 0c4ae4937b..a60f84c559 100644 --- a/io/zenoh-links/zenoh-link-quic/src/lib.rs +++ b/io/zenoh-links/zenoh-link-quic/src/lib.rs @@ -18,25 +18,17 @@ //! //! [Click here for Zenoh's documentation](../zenoh/index.html) use async_trait::async_trait; -use config::{ - TLS_ROOT_CA_CERTIFICATE_BASE64, TLS_ROOT_CA_CERTIFICATE_FILE, TLS_SERVER_CERTIFICATE_BASE64, - TLS_SERVER_CERTIFICATE_FILE, TLS_SERVER_NAME_VERIFICATION, TLS_SERVER_PRIVATE_KEY_BASE64, - TLS_SERVER_PRIVATE_KEY_FILE, -}; -use secrecy::ExposeSecret; -use std::net::SocketAddr; -use zenoh_config::Config; + use zenoh_core::zconfigurable; -use zenoh_link_commons::{ConfigurationInspector, LocatorInspector}; -use zenoh_protocol::{ - core::{endpoint::Address, Locator, Parameters}, - transport::BatchSize, -}; -use zenoh_result::{bail, zerror, ZResult}; +use zenoh_link_commons::LocatorInspector; +use zenoh_protocol::{core::Locator, transport::BatchSize}; +use zenoh_result::ZResult; mod unicast; +mod utils; mod verify; pub use unicast::*; +pub use utils::TlsConfigurator as QuicConfigurator; // Default ALPN protocol pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; @@ -64,76 +56,6 @@ impl LocatorInspector for QuicLocatorInspector { } } -#[derive(Default, Clone, Copy, Debug)] -pub struct QuicConfigurator; - -impl ConfigurationInspector for QuicConfigurator { - fn inspect_config(&self, config: &Config) -> ZResult { - let mut ps: Vec<(&str, &str)> = vec![]; - - let c = config.transport().link().tls(); - - match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") - } - (Some(ca_certificate), None) => { - ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); - } - (None, Some(ca_certificate)) => { - ps.push(( - TLS_ROOT_CA_CERTIFICATE_BASE64, - ca_certificate.expose_secret(), - )); - } - _ => {} - } - - match (c.server_private_key(), c.server_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") - } - (Some(server_private_key), None) => { - ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); - } - (None, Some(server_private_key)) => { - ps.push(( - TLS_SERVER_PRIVATE_KEY_BASE64, - server_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.server_certificate(), c.server_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") - } - (Some(server_certificate), None) => { - ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); - } - (None, Some(server_certificate)) => { - ps.push(( - TLS_SERVER_CERTIFICATE_BASE64, - server_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(server_name_verification) = c.server_name_verification() { - match server_name_verification { - true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), - false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), - }; - } - - let s = Parameters::from_iter(ps.drain(..)); - - Ok(s) - } -} - zconfigurable! { // Default MTU (QUIC PDU) in bytes. static ref QUIC_DEFAULT_MTU: BatchSize = QUIC_MAX_MTU; @@ -156,25 +78,20 @@ pub mod config { pub const TLS_SERVER_PRIVATE_KEY_RAW: &str = "server_private_key_raw"; pub const TLS_SERVER_PRIVATE_KEY_BASE64: &str = "server_private_key_base64"; - pub const TLS_SERVER_CERTIFICATE_FILE: &str = "tls_server_certificate_file"; - pub const TLS_SERVER_CERTIFICATE_RAW: &str = "tls_server_certificate_raw"; - pub const TLS_SERVER_CERTIFICATE_BASE64: &str = "tls_server_certificate_base64"; + pub const TLS_SERVER_CERTIFICATE_FILE: &str = "server_certificate_file"; + pub const TLS_SERVER_CERTIFICATE_RAW: &str = "server_certificate_raw"; + pub const TLS_SERVER_CERTIFICATE_BASE64: &str = "server_certificate_base64"; - pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; - pub const TLS_SERVER_NAME_VERIFICATION_DEFAULT: &str = "true"; -} + pub const TLS_CLIENT_PRIVATE_KEY_FILE: &str = "client_private_key_file"; + pub const TLS_CLIENT_PRIVATE_KEY_RAW: &str = "client_private_key_raw"; + pub const TLS_CLIENT_PRIVATE_KEY_BASE64: &str = "client_private_key_base64"; -async fn get_quic_addr(address: &Address<'_>) -> ZResult { - match tokio::net::lookup_host(address.as_str()).await?.next() { - Some(addr) => Ok(addr), - None => bail!("Couldn't resolve QUIC locator address: {}", address), - } -} + pub const TLS_CLIENT_CERTIFICATE_FILE: &str = "client_certificate_file"; + pub const TLS_CLIENT_CERTIFICATE_RAW: &str = "client_certificate_raw"; + pub const TLS_CLIENT_CERTIFICATE_BASE64: &str = "client_certificate_base64"; -pub fn base64_decode(data: &str) -> ZResult> { - use base64::engine::general_purpose; - use base64::Engine; - Ok(general_purpose::STANDARD - .decode(data) - .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) + pub const TLS_CLIENT_AUTH: &str = "client_auth"; + + pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; + pub const TLS_SERVER_NAME_VERIFICATION_DEFAULT: &str = "true"; } diff --git a/io/zenoh-links/zenoh-link-quic/src/unicast.rs b/io/zenoh-links/zenoh-link-quic/src/unicast.rs index c200a6f197..05d33dff49 100644 --- a/io/zenoh-links/zenoh-link-quic/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-quic/src/unicast.rs @@ -12,16 +12,13 @@ // ZettaScale Zenoh Team, // -use crate::base64_decode; use crate::{ - config::*, get_quic_addr, verify::WebPkiVerifierAnyServerName, ALPN_QUIC_HTTP, - QUIC_ACCEPT_THROTTLE_TIME, QUIC_DEFAULT_MTU, QUIC_LOCATOR_PREFIX, + config::*, + utils::{get_quic_addr, TlsClientConfig, TlsServerConfig}, + ALPN_QUIC_HTTP, QUIC_ACCEPT_THROTTLE_TIME, QUIC_DEFAULT_MTU, QUIC_LOCATOR_PREFIX, }; use async_trait::async_trait; -use rustls::{Certificate, PrivateKey}; -use rustls_pemfile::Item; use std::fmt; -use std::io::BufReader; use std::net::IpAddr; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -35,7 +32,7 @@ use zenoh_link_commons::{ }; use zenoh_protocol::core::{EndPoint, Locator}; use zenoh_protocol::transport::BatchSize; -use zenoh_result::{bail, zerror, ZError, ZResult}; +use zenoh_result::{bail, zerror, ZResult}; pub struct LinkUnicastQuic { connection: quinn::Connection, @@ -220,55 +217,12 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { } // Initialize the QUIC connection - let mut root_cert_store = rustls::RootCertStore::empty(); - - // Read the certificates - let f = if let Some(value) = epconf.get(TLS_ROOT_CA_CERTIFICATE_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_certificate) = epconf.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { - base64_decode(b64_certificate)? - } else if let Some(value) = epconf.get(TLS_ROOT_CA_CERTIFICATE_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - vec![] - }; - - let certificates = if f.is_empty() { - rustls_native_certs::load_native_certs() - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - .drain(..) - .map(|x| rustls::Certificate(x.to_vec())) - .collect::>() - } else { - rustls_pemfile::certs(&mut BufReader::new(f.as_slice())) - .map(|result| { - result - .map_err(|err| zerror!("Invalid QUIC CA certificate file: {}", err)) - .map(|der| Certificate(der.to_vec())) - }) - .collect::, ZError>>()? - }; - for c in certificates.iter() { - root_cert_store.add(c).map_err(|e| zerror!("{}", e))?; - } - - let client_crypto = rustls::ClientConfig::builder().with_safe_defaults(); - - let mut client_crypto = if server_name_verification { - client_crypto - .with_root_certificates(root_cert_store) - .with_no_client_auth() - } else { - client_crypto - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_no_client_auth() - }; + let mut client_crypto = TlsClientConfig::new(&epconf) + .await + .map_err(|e| zerror!("Cannot create a new QUIC client on {addr}: {e}"))?; - client_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); + client_crypto.client_config.alpn_protocols = + ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); let ip_addr: IpAddr = if addr.is_ipv4() { Ipv4Addr::UNSPECIFIED.into() @@ -277,7 +231,9 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { }; let mut quic_endpoint = quinn::Endpoint::client(SocketAddr::new(ip_addr, 0)) .map_err(|e| zerror!("Can not create a new QUIC link bound to {}: {}", host, e))?; - quic_endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(client_crypto))); + quic_endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new( + client_crypto.client_config, + ))); let src_addr = quic_endpoint .local_addr() @@ -315,61 +271,14 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic { let addr = get_quic_addr(&epaddr).await?; - let f = if let Some(value) = epconf.get(TLS_SERVER_CERTIFICATE_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_certificate) = epconf.get(TLS_SERVER_CERTIFICATE_BASE64) { - base64_decode(b64_certificate)? - } else if let Some(value) = epconf.get(TLS_SERVER_CERTIFICATE_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - bail!("No QUIC CA certificate has been provided."); - }; - let certificates = rustls_pemfile::certs(&mut BufReader::new(f.as_slice())) - .map(|result| { - result - .map_err(|err| zerror!("Invalid QUIC CA certificate file: {}", err)) - .map(|der| Certificate(der.to_vec())) - }) - .collect::, ZError>>()?; - - // Private keys - let f = if let Some(value) = epconf.get(TLS_SERVER_PRIVATE_KEY_RAW) { - value.as_bytes().to_vec() - } else if let Some(b64_key) = epconf.get(TLS_SERVER_PRIVATE_KEY_BASE64) { - base64_decode(b64_key)? - } else if let Some(value) = epconf.get(TLS_SERVER_PRIVATE_KEY_FILE) { - tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))? - } else { - bail!("No QUIC CA private key has been provided."); - }; - let items: Vec = rustls_pemfile::read_all(&mut BufReader::new(f.as_slice())) - .collect::>() - .map_err(|err| zerror!("Invalid QUIC CA private key file: {}", err))?; - - let private_key = items - .into_iter() - .filter_map(|x| match x { - rustls_pemfile::Item::Pkcs1Key(k) => Some(k.secret_pkcs1_der().to_vec()), - rustls_pemfile::Item::Pkcs8Key(k) => Some(k.secret_pkcs8_der().to_vec()), - rustls_pemfile::Item::Sec1Key(k) => Some(k.secret_sec1_der().to_vec()), - _ => None, - }) - .take(1) - .next() - .ok_or_else(|| zerror!("No QUIC CA private key has been provided.")) - .map(PrivateKey)?; - // Server config - let mut server_crypto = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certificates, private_key)?; - server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); - let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + let mut server_crypto = TlsServerConfig::new(&epconf) + .await + .map_err(|e| zerror!("Cannot create a new QUIC listener on {addr}: {e}"))?; + server_crypto.server_config.alpn_protocols = + ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); + let mut server_config = + quinn::ServerConfig::with_crypto(Arc::new(server_crypto.server_config)); // We do not accept unidireactional streams. Arc::get_mut(&mut server_config.transport) diff --git a/io/zenoh-links/zenoh-link-quic/src/utils.rs b/io/zenoh-links/zenoh-link-quic/src/utils.rs new file mode 100644 index 0000000000..e7537bd658 --- /dev/null +++ b/io/zenoh-links/zenoh-link-quic/src/utils.rs @@ -0,0 +1,506 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use crate::config::*; +use crate::verify::WebPkiVerifierAnyServerName; +use rustls::OwnedTrustAnchor; +use rustls::{ + server::AllowAnyAuthenticatedClient, version::TLS13, Certificate, ClientConfig, PrivateKey, + RootCertStore, ServerConfig, +}; +use rustls_pki_types::{CertificateDer, TrustAnchor}; +use secrecy::ExposeSecret; +use zenoh_link_commons::ConfigurationInspector; +// use rustls_pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}; +use std::fs::File; +use std::io; +use std::net::SocketAddr; +use std::{ + io::{BufReader, Cursor}, + sync::Arc, +}; +use webpki::anchor_from_trusted_cert; +use zenoh_config::Config as ZenohConfig; +use zenoh_protocol::core::endpoint::{Address, Config}; +use zenoh_protocol::core::Parameters; +use zenoh_result::{bail, zerror, ZError, ZResult}; + +#[derive(Default, Clone, Copy, Debug)] +pub struct TlsConfigurator; + +impl ConfigurationInspector for TlsConfigurator { + fn inspect_config(&self, config: &ZenohConfig) -> ZResult { + let mut ps: Vec<(&str, &str)> = vec![]; + + let c = config.transport().link().tls(); + + match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") + } + (Some(ca_certificate), None) => { + ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); + } + (None, Some(ca_certificate)) => { + ps.push(( + TLS_ROOT_CA_CERTIFICATE_BASE64, + ca_certificate.expose_secret(), + )); + } + _ => {} + } + + match (c.server_private_key(), c.server_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") + } + (Some(server_private_key), None) => { + ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); + } + (None, Some(server_private_key)) => { + ps.push(( + TLS_SERVER_PRIVATE_KEY_BASE64, + server_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.server_certificate(), c.server_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") + } + (Some(server_certificate), None) => { + ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); + } + (None, Some(server_certificate)) => { + ps.push(( + TLS_SERVER_CERTIFICATE_BASE64, + server_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(client_auth) = c.client_auth() { + match client_auth { + true => ps.push((TLS_CLIENT_AUTH, "true")), + false => ps.push((TLS_CLIENT_AUTH, "false")), + }; + } + + match (c.client_private_key(), c.client_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") + } + (Some(client_private_key), None) => { + ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); + } + (None, Some(client_private_key)) => { + ps.push(( + TLS_CLIENT_PRIVATE_KEY_BASE64, + client_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.client_certificate(), c.client_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") + } + (Some(client_certificate), None) => { + ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); + } + (None, Some(client_certificate)) => { + ps.push(( + TLS_CLIENT_CERTIFICATE_BASE64, + client_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(server_name_verification) = c.server_name_verification() { + match server_name_verification { + true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), + false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), + }; + } + + Ok(Parameters::from_iter(ps.drain(..))) + } +} + +pub(crate) struct TlsServerConfig { + pub(crate) server_config: ServerConfig, +} + +impl TlsServerConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; + let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) + .map_err(|err| zerror!("Error processing server certificate: {err}."))? + .into_iter() + .map(Certificate) + .collect(); + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + + if keys.is_empty() { + keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map_err(|err| zerror!("Error processing server key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + bail!("No private key found for TLS server."); + } + + let sc = if tls_server_client_auth { + let root_cert_store = load_trust_anchors(config)?.map_or_else( + || { + Err(zerror!( + "Missing root certificates while client authentication is enabled." + )) + }, + Ok, + )?; + let client_auth = AllowAnyAuthenticatedClient::new(root_cert_store); + ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])? + .with_client_cert_verifier(Arc::new(client_auth)) + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + } else { + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + }; + Ok(TlsServerConfig { server_config: sc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_SERVER_PRIVATE_KEY_RAW, + TLS_SERVER_PRIVATE_KEY_FILE, + TLS_SERVER_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_SERVER_CERTIFICATE_RAW, + TLS_SERVER_CERTIFICATE_FILE, + TLS_SERVER_CERTIFICATE_BASE64, + ) + .await + } +} + +pub(crate) struct TlsClientConfig { + pub(crate) client_config: ClientConfig, +} + +impl TlsClientConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + + let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { + Some(s) => { + let s: bool = s + .parse() + .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; + if s { + tracing::warn!("Skipping name verification of servers"); + } + s + } + None => false, + }; + + // Allows mixed user-generated CA and webPKI CA + tracing::debug!("Loading default Web PKI certificates."); + let mut root_cert_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS + .iter() + .map(|ta| ta.to_owned()) + .map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.to_vec(), + ta.subject_public_key_info.to_vec(), + ta.name_constraints.map(|nc| nc.to_vec()), + ) + }) + .collect(), + }; + + if let Some(custom_root_cert) = load_trust_anchors(config)? { + tracing::debug!("Loading user-generated certificates."); + root_cert_store.roots.extend(custom_root_cert.roots); + } + + let cc = if tls_client_server_auth { + tracing::debug!("Loading client authentication key and certificate..."); + let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; + let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) + .map_err(|err| zerror!("Error processing client certificate: {err}."))? + .into_iter() + .map(Certificate) + .collect(); + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + + if keys.is_empty() { + keys = + rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map_err(|err| zerror!("Error processing client key: {err}."))? + .into_iter() + .map(PrivateKey) + .collect(); + } + + if keys.is_empty() { + bail!("No private key found for TLS client."); + } + + let builder = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])?; + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_client_auth_cert(certs, keys.remove(0)) + } else { + builder + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_client_auth_cert(certs, keys.remove(0)) + } + .map_err(|e| zerror!("Bad certificate/key: {}", e))? + } else { + let builder = ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&TLS13])?; + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_no_client_auth() + } else { + builder + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_no_client_auth() + } + }; + Ok(TlsClientConfig { client_config: cc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_CLIENT_PRIVATE_KEY_RAW, + TLS_CLIENT_PRIVATE_KEY_FILE, + TLS_CLIENT_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_CLIENT_CERTIFICATE_RAW, + TLS_CLIENT_CERTIFICATE_FILE, + TLS_CLIENT_CERTIFICATE_BASE64, + ) + .await + } +} + +fn process_pem(pem: &mut dyn io::BufRead) -> ZResult> { + let certs: Vec = rustls_pemfile::certs(pem) + .map_err(|err| zerror!("Error processing PEM certificates: {err}."))? + .into_iter() + .map(CertificateDer::from) + .collect(); + + let trust_anchors: Vec = certs + .into_iter() + .map(|cert| { + anchor_from_trusted_cert(&cert) + .map_err(|err| zerror!("Error processing trust anchor: {err}.")) + .map(|trust_anchor| trust_anchor.to_owned()) + }) + .collect::, ZError>>()? + .into_iter() + .map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject.to_vec(), + ta.subject_public_key_info.to_vec(), + ta.name_constraints.map(|nc| nc.to_vec()), + ) + }) + .collect(); + + Ok(trust_anchors) +} + +async fn load_tls_key( + config: &Config<'_>, + tls_private_key_raw_config_key: &str, + tls_private_key_file_config_key: &str, + tls_private_key_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_private_key_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { + return base64_decode(b64_key); + } + + if let Some(value) = config.get(tls_private_key_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) + .and_then(|result| { + if result.is_empty() { + Err(zerror!("Empty TLS key.").into()) + } else { + Ok(result) + } + }); + } + Err(zerror!("Missing TLS private key.").into()) +} + +async fn load_tls_certificate( + config: &Config<'_>, + tls_certificate_raw_config_key: &str, + tls_certificate_file_config_key: &str, + tls_certificate_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_certificate_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { + return base64_decode(b64_certificate); + } + + if let Some(value) = config.get(tls_certificate_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); + } + Err(zerror!("Missing tls certificates.").into()) +} + +fn load_trust_anchors(config: &Config<'_>) -> ZResult> { + let mut root_cert_store = RootCertStore::empty(); + if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { + let mut pem = BufReader::new(value.as_bytes()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { + let certificate_pem = base64_decode(b64_certificate)?; + let mut pem = BufReader::new(certificate_pem.as_slice()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { + let mut pem = BufReader::new(File::open(filename)?); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.roots.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + Ok(None) +} + +pub async fn get_quic_addr(address: &Address<'_>) -> ZResult { + match tokio::net::lookup_host(address.as_str()).await?.next() { + Some(addr) => Ok(addr), + None => bail!("Couldn't resolve QUIC locator address: {}", address), + } +} + +pub fn base64_decode(data: &str) -> ZResult> { + use base64::engine::general_purpose; + use base64::Engine; + Ok(general_purpose::STANDARD + .decode(data) + .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) +} diff --git a/io/zenoh-links/zenoh-link-tls/Cargo.toml b/io/zenoh-links/zenoh-link-tls/Cargo.toml index 0ea59f2753..3025e3d7d7 100644 --- a/io/zenoh-links/zenoh-link-tls/Cargo.toml +++ b/io/zenoh-links/zenoh-link-tls/Cargo.toml @@ -12,31 +12,31 @@ # ZettaScale Zenoh Team, # [package] -rust-version = { workspace = true } -name = "zenoh-link-tls" -version = { workspace = true } -repository = { workspace = true } -homepage = { workspace = true } authors = { workspace = true } -edition = { workspace = true } -license = { workspace = true } categories = { workspace = true } description = "Internal crate for zenoh." +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "zenoh-link-tls" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] async-trait = { workspace = true } base64 = { workspace = true } futures = { workspace = true } -tracing = {workspace = true} rustls = { workspace = true } rustls-pemfile = { workspace = true } rustls-pki-types = { workspace = true } rustls-webpki = { workspace = true } -secrecy = {workspace = true } -tokio = { workspace = true, features = ["io-util", "net", "fs", "sync"] } +secrecy = { workspace = true } +tokio = { workspace = true, features = ["fs", "io-util", "net", "sync"] } tokio-rustls = { workspace = true } tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } webpki-roots = { workspace = true } zenoh-collections = { workspace = true } zenoh-config = { workspace = true } diff --git a/io/zenoh-links/zenoh-link-tls/src/lib.rs b/io/zenoh-links/zenoh-link-tls/src/lib.rs index f6a7968326..9fe6a3ea14 100644 --- a/io/zenoh-links/zenoh-link-tls/src/lib.rs +++ b/io/zenoh-links/zenoh-link-tls/src/lib.rs @@ -18,26 +18,15 @@ //! //! [Click here for Zenoh's documentation](../zenoh/index.html) use async_trait::async_trait; -use config::{ - TLS_CLIENT_AUTH, TLS_CLIENT_CERTIFICATE_BASE64, TLS_CLIENT_CERTIFICATE_FILE, - TLS_CLIENT_PRIVATE_KEY_BASE64, TLS_CLIENT_PRIVATE_KEY_FILE, TLS_ROOT_CA_CERTIFICATE_BASE64, - TLS_ROOT_CA_CERTIFICATE_FILE, TLS_SERVER_CERTIFICATE_BASE64, TLS_SERVER_CERTIFICATE_FILE, - TLS_SERVER_NAME_VERIFICATION, TLS_SERVER_PRIVATE_KEY_BASE_64, TLS_SERVER_PRIVATE_KEY_FILE, -}; -use rustls_pki_types::ServerName; -use secrecy::ExposeSecret; -use std::{convert::TryFrom, net::SocketAddr}; -use zenoh_config::Config; use zenoh_core::zconfigurable; -use zenoh_link_commons::{ConfigurationInspector, LocatorInspector}; -use zenoh_protocol::{ - core::{endpoint::Address, Locator, Parameters}, - transport::BatchSize, -}; -use zenoh_result::{bail, zerror, ZResult}; +use zenoh_link_commons::LocatorInspector; +use zenoh_protocol::{core::Locator, transport::BatchSize}; +use zenoh_result::ZResult; mod unicast; +mod utils; pub use unicast::*; +pub use utils::TlsConfigurator; // Default MTU (TLS PDU) in bytes. // NOTE: Since TLS is a byte-stream oriented transport, theoretically it has @@ -60,114 +49,6 @@ impl LocatorInspector for TlsLocatorInspector { Ok(false) } } -#[derive(Default, Clone, Copy, Debug)] -pub struct TlsConfigurator; - -impl ConfigurationInspector for TlsConfigurator { - fn inspect_config(&self, config: &Config) -> ZResult { - let mut ps: Vec<(&str, &str)> = vec![]; - - let c = config.transport().link().tls(); - - match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") - } - (Some(ca_certificate), None) => { - ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); - } - (None, Some(ca_certificate)) => { - ps.push(( - TLS_ROOT_CA_CERTIFICATE_BASE64, - ca_certificate.expose_secret(), - )); - } - _ => {} - } - - match (c.server_private_key(), c.server_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") - } - (Some(server_private_key), None) => { - ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); - } - (None, Some(server_private_key)) => { - ps.push(( - TLS_SERVER_PRIVATE_KEY_BASE_64, - server_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.server_certificate(), c.server_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") - } - (Some(server_certificate), None) => { - ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); - } - (None, Some(server_certificate)) => { - ps.push(( - TLS_SERVER_CERTIFICATE_BASE64, - server_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(client_auth) = c.client_auth() { - match client_auth { - true => ps.push((TLS_CLIENT_AUTH, "true")), - false => ps.push((TLS_CLIENT_AUTH, "false")), - }; - } - - match (c.client_private_key(), c.client_private_key_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") - } - (Some(client_private_key), None) => { - ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); - } - (None, Some(client_private_key)) => { - ps.push(( - TLS_CLIENT_PRIVATE_KEY_BASE64, - client_private_key.expose_secret(), - )); - } - _ => {} - } - - match (c.client_certificate(), c.client_certificate_base64()) { - (Some(_), Some(_)) => { - bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") - } - (Some(client_certificate), None) => { - ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); - } - (None, Some(client_certificate)) => { - ps.push(( - TLS_CLIENT_CERTIFICATE_BASE64, - client_certificate.expose_secret(), - )); - } - _ => {} - } - - if let Some(server_name_verification) = c.server_name_verification() { - match server_name_verification { - true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), - false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), - }; - } - - let s = Parameters::from_iter(ps.drain(..)); - - Ok(s) - } -} zconfigurable! { // Default MTU (TLS PDU) in bytes. @@ -207,30 +88,3 @@ pub mod config { pub const TLS_SERVER_NAME_VERIFICATION: &str = "server_name_verification"; } - -pub async fn get_tls_addr(address: &Address<'_>) -> ZResult { - match tokio::net::lookup_host(address.as_str()).await?.next() { - Some(addr) => Ok(addr), - None => bail!("Couldn't resolve TLS locator address: {}", address), - } -} - -pub fn get_tls_host<'a>(address: &'a Address<'a>) -> ZResult<&'a str> { - address - .as_str() - .split(':') - .next() - .ok_or_else(|| zerror!("Invalid TLS address").into()) -} - -pub fn get_tls_server_name<'a>(address: &'a Address<'a>) -> ZResult> { - Ok(ServerName::try_from(get_tls_host(address)?).map_err(|e| zerror!(e))?) -} - -pub fn base64_decode(data: &str) -> ZResult> { - use base64::engine::general_purpose; - use base64::Engine; - Ok(general_purpose::STANDARD - .decode(data) - .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) -} diff --git a/io/zenoh-links/zenoh-link-tls/src/unicast.rs b/io/zenoh-links/zenoh-link-tls/src/unicast.rs index 53429ca30f..5cf686cdc5 100644 --- a/io/zenoh-links/zenoh-link-tls/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-tls/src/unicast.rs @@ -12,39 +12,30 @@ // ZettaScale Zenoh Team, // use crate::{ - base64_decode, config::*, get_tls_addr, get_tls_host, get_tls_server_name, + utils::{get_tls_addr, get_tls_host, get_tls_server_name, TlsClientConfig, TlsServerConfig}, TLS_ACCEPT_THROTTLE_TIME, TLS_DEFAULT_MTU, TLS_LINGER_TIMEOUT, TLS_LOCATOR_PREFIX, }; + use async_trait::async_trait; -use rustls::{ - pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}, - server::WebPkiClientVerifier, - version::TLS13, - ClientConfig, RootCertStore, ServerConfig, -}; +use std::cell::UnsafeCell; use std::convert::TryInto; use std::fmt; -use std::fs::File; -use std::io::{BufReader, Cursor}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use std::{cell::UnsafeCell, io}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::Mutex as AsyncMutex; use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream}; use tokio_util::sync::CancellationToken; -use webpki::anchor_from_trusted_cert; use zenoh_core::zasynclock; -use zenoh_link_commons::tls::WebPkiVerifierAnyServerName; use zenoh_link_commons::{ get_ip_interface_names, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, ListenersUnicastIP, NewLinkChannelSender, }; use zenoh_protocol::core::{EndPoint, Locator}; -use zenoh_protocol::{core::endpoint::Config, transport::BatchSize}; -use zenoh_result::{bail, zerror, ZError, ZResult}; +use zenoh_protocol::transport::BatchSize; +use zenoh_result::{zerror, ZResult}; pub struct LinkUnicastTls { // The underlying socket as returned from the async-rustls library @@ -418,311 +409,3 @@ async fn accept_task( Ok(()) } - -struct TlsServerConfig { - server_config: ServerConfig, -} - -impl TlsServerConfig { - pub async fn new(config: &Config<'_>) -> ZResult { - let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { - Some(s) => s - .parse() - .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, - None => false, - }; - let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; - let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; - - let certs: Vec = - rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) - .collect::>() - .map_err(|err| zerror!("Error processing server certificate: {err}."))?; - - let mut keys: Vec = - rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - - if keys.is_empty() { - keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - } - - if keys.is_empty() { - keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing server key: {err}."))?; - } - - if keys.is_empty() { - bail!("No private key found for TLS server."); - } - - let sc = if tls_server_client_auth { - let root_cert_store = load_trust_anchors(config)?.map_or_else( - || { - Err(zerror!( - "Missing root certificates while client authentication is enabled." - )) - }, - Ok, - )?; - let client_auth = WebPkiClientVerifier::builder(root_cert_store.into()).build()?; - ServerConfig::builder_with_protocol_versions(&[&TLS13]) - .with_client_cert_verifier(client_auth) - .with_single_cert(certs, keys.remove(0)) - .map_err(|e| zerror!(e))? - } else { - ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, keys.remove(0)) - .map_err(|e| zerror!(e))? - }; - Ok(TlsServerConfig { server_config: sc }) - } - - async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { - load_tls_key( - config, - TLS_SERVER_PRIVATE_KEY_RAW, - TLS_SERVER_PRIVATE_KEY_FILE, - TLS_SERVER_PRIVATE_KEY_BASE_64, - ) - .await - } - - async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { - load_tls_certificate( - config, - TLS_SERVER_CERTIFICATE_RAW, - TLS_SERVER_CERTIFICATE_FILE, - TLS_SERVER_CERTIFICATE_BASE64, - ) - .await - } -} - -struct TlsClientConfig { - client_config: ClientConfig, -} - -impl TlsClientConfig { - pub async fn new(config: &Config<'_>) -> ZResult { - let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { - Some(s) => s - .parse() - .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, - None => false, - }; - - let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { - Some(s) => { - let s: bool = s - .parse() - .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; - if s { - tracing::warn!("Skipping name verification of servers"); - } - s - } - None => false, - }; - - // Allows mixed user-generated CA and webPKI CA - tracing::debug!("Loading default Web PKI certificates."); - let mut root_cert_store = RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), - }; - - if let Some(custom_root_cert) = load_trust_anchors(config)? { - tracing::debug!("Loading user-generated certificates."); - root_cert_store.extend(custom_root_cert.roots); - } - - let cc = if tls_client_server_auth { - tracing::debug!("Loading client authentication key and certificate..."); - let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; - let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; - - let certs: Vec = - rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) - .collect::>() - .map_err(|err| zerror!("Error processing client certificate: {err}."))?; - - let mut keys: Vec = - rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - - if keys.is_empty() { - keys = - rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - } - - if keys.is_empty() { - keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) - .map(|x| x.map(PrivateKeyDer::from)) - .collect::>() - .map_err(|err| zerror!("Error processing client key: {err}."))?; - } - - if keys.is_empty() { - bail!("No private key found for TLS client."); - } - - let builder = ClientConfig::builder_with_protocol_versions(&[&TLS13]); - - if tls_server_name_verification { - builder - .with_root_certificates(root_cert_store) - .with_client_auth_cert(certs, keys.remove(0)) - } else { - builder - .dangerous() - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_client_auth_cert(certs, keys.remove(0)) - } - .map_err(|e| zerror!("Bad certificate/key: {}", e))? - } else { - let builder = ClientConfig::builder(); - if tls_server_name_verification { - builder - .with_root_certificates(root_cert_store) - .with_no_client_auth() - } else { - builder - .dangerous() - .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( - root_cert_store, - ))) - .with_no_client_auth() - } - }; - Ok(TlsClientConfig { client_config: cc }) - } - - async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { - load_tls_key( - config, - TLS_CLIENT_PRIVATE_KEY_RAW, - TLS_CLIENT_PRIVATE_KEY_FILE, - TLS_CLIENT_PRIVATE_KEY_BASE64, - ) - .await - } - - async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { - load_tls_certificate( - config, - TLS_CLIENT_CERTIFICATE_RAW, - TLS_CLIENT_CERTIFICATE_FILE, - TLS_CLIENT_CERTIFICATE_BASE64, - ) - .await - } -} - -async fn load_tls_key( - config: &Config<'_>, - tls_private_key_raw_config_key: &str, - tls_private_key_file_config_key: &str, - tls_private_key_base64_config_key: &str, -) -> ZResult> { - if let Some(value) = config.get(tls_private_key_raw_config_key) { - return Ok(value.as_bytes().to_vec()); - } - - if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { - return base64_decode(b64_key); - } - - if let Some(value) = config.get(tls_private_key_file_config_key) { - return Ok(tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) - .and_then(|result| { - if result.is_empty() { - Err(zerror!("Empty TLS key.").into()) - } else { - Ok(result) - } - }); - } - Err(zerror!("Missing TLS private key.").into()) -} - -async fn load_tls_certificate( - config: &Config<'_>, - tls_certificate_raw_config_key: &str, - tls_certificate_file_config_key: &str, - tls_certificate_base64_config_key: &str, -) -> ZResult> { - if let Some(value) = config.get(tls_certificate_raw_config_key) { - return Ok(value.as_bytes().to_vec()); - } - - if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { - return base64_decode(b64_certificate); - } - - if let Some(value) = config.get(tls_certificate_file_config_key) { - return Ok(tokio::fs::read(value) - .await - .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); - } - Err(zerror!("Missing tls certificates.").into()) -} - -fn load_trust_anchors(config: &Config<'_>) -> ZResult> { - let mut root_cert_store = RootCertStore::empty(); - if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { - let mut pem = BufReader::new(value.as_bytes()); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - - if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { - let certificate_pem = base64_decode(b64_certificate)?; - let mut pem = BufReader::new(certificate_pem.as_slice()); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - - if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { - let mut pem = BufReader::new(File::open(filename)?); - let trust_anchors = process_pem(&mut pem)?; - root_cert_store.extend(trust_anchors); - return Ok(Some(root_cert_store)); - } - Ok(None) -} - -fn process_pem(pem: &mut dyn io::BufRead) -> ZResult>> { - let certs: Vec = rustls_pemfile::certs(pem) - .map(|result| result.map_err(|err| zerror!("Error processing PEM certificates: {err}."))) - .collect::, ZError>>()?; - - let trust_anchors: Vec = certs - .into_iter() - .map(|cert| { - anchor_from_trusted_cert(&cert) - .map_err(|err| zerror!("Error processing trust anchor: {err}.")) - .map(|trust_anchor| trust_anchor.to_owned()) - }) - .collect::, ZError>>()?; - - Ok(trust_anchors) -} diff --git a/io/zenoh-links/zenoh-link-tls/src/utils.rs b/io/zenoh-links/zenoh-link-tls/src/utils.rs new file mode 100644 index 0000000000..d51a17c694 --- /dev/null +++ b/io/zenoh-links/zenoh-link-tls/src/utils.rs @@ -0,0 +1,477 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use crate::config::*; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, TrustAnchor}, + server::WebPkiClientVerifier, + version::TLS13, + ClientConfig, RootCertStore, ServerConfig, +}; +use rustls_pki_types::ServerName; +use secrecy::ExposeSecret; +use std::fs::File; +use std::io; +use std::{convert::TryFrom, net::SocketAddr}; +use std::{ + io::{BufReader, Cursor}, + sync::Arc, +}; +use webpki::anchor_from_trusted_cert; +use zenoh_config::Config as ZenohConfig; +use zenoh_link_commons::{tls::WebPkiVerifierAnyServerName, ConfigurationInspector}; +use zenoh_protocol::core::endpoint::{Address, Config}; +use zenoh_protocol::core::Parameters; +use zenoh_result::{bail, zerror, ZError, ZResult}; + +#[derive(Default, Clone, Copy, Debug)] +pub struct TlsConfigurator; + +impl ConfigurationInspector for TlsConfigurator { + fn inspect_config(&self, config: &ZenohConfig) -> ZResult { + let mut ps: Vec<(&str, &str)> = vec![]; + + let c = config.transport().link().tls(); + + match (c.root_ca_certificate(), c.root_ca_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'root_ca_certificate' and 'root_ca_certificate_base64' can be present!") + } + (Some(ca_certificate), None) => { + ps.push((TLS_ROOT_CA_CERTIFICATE_FILE, ca_certificate)); + } + (None, Some(ca_certificate)) => { + ps.push(( + TLS_ROOT_CA_CERTIFICATE_BASE64, + ca_certificate.expose_secret(), + )); + } + _ => {} + } + + match (c.server_private_key(), c.server_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_private_key' and 'server_private_key_base64' can be present!") + } + (Some(server_private_key), None) => { + ps.push((TLS_SERVER_PRIVATE_KEY_FILE, server_private_key)); + } + (None, Some(server_private_key)) => { + ps.push(( + TLS_SERVER_PRIVATE_KEY_BASE_64, + server_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.server_certificate(), c.server_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'server_certificate' and 'server_certificate_base64' can be present!") + } + (Some(server_certificate), None) => { + ps.push((TLS_SERVER_CERTIFICATE_FILE, server_certificate)); + } + (None, Some(server_certificate)) => { + ps.push(( + TLS_SERVER_CERTIFICATE_BASE64, + server_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(client_auth) = c.client_auth() { + match client_auth { + true => ps.push((TLS_CLIENT_AUTH, "true")), + false => ps.push((TLS_CLIENT_AUTH, "false")), + }; + } + + match (c.client_private_key(), c.client_private_key_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_private_key' and 'client_private_key_base64' can be present!") + } + (Some(client_private_key), None) => { + ps.push((TLS_CLIENT_PRIVATE_KEY_FILE, client_private_key)); + } + (None, Some(client_private_key)) => { + ps.push(( + TLS_CLIENT_PRIVATE_KEY_BASE64, + client_private_key.expose_secret(), + )); + } + _ => {} + } + + match (c.client_certificate(), c.client_certificate_base64()) { + (Some(_), Some(_)) => { + bail!("Only one between 'client_certificate' and 'client_certificate_base64' can be present!") + } + (Some(client_certificate), None) => { + ps.push((TLS_CLIENT_CERTIFICATE_FILE, client_certificate)); + } + (None, Some(client_certificate)) => { + ps.push(( + TLS_CLIENT_CERTIFICATE_BASE64, + client_certificate.expose_secret(), + )); + } + _ => {} + } + + if let Some(server_name_verification) = c.server_name_verification() { + match server_name_verification { + true => ps.push((TLS_SERVER_NAME_VERIFICATION, "true")), + false => ps.push((TLS_SERVER_NAME_VERIFICATION, "false")), + }; + } + + Ok(Parameters::from_iter(ps.drain(..))) + } +} + +pub(crate) struct TlsServerConfig { + pub(crate) server_config: ServerConfig, +} + +impl TlsServerConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_server_client_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + let tls_server_private_key = TlsServerConfig::load_tls_private_key(config).await?; + let tls_server_certificate = TlsServerConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_server_certificate)) + .collect::>() + .map_err(|err| zerror!("Error processing server certificate: {err}."))?; + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + + if keys.is_empty() { + keys = rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_server_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing server key: {err}."))?; + } + + if keys.is_empty() { + bail!("No private key found for TLS server."); + } + + let sc = if tls_server_client_auth { + let root_cert_store = load_trust_anchors(config)?.map_or_else( + || { + Err(zerror!( + "Missing root certificates while client authentication is enabled." + )) + }, + Ok, + )?; + let client_auth = WebPkiClientVerifier::builder(root_cert_store.into()).build()?; + ServerConfig::builder_with_protocol_versions(&[&TLS13]) + .with_client_cert_verifier(client_auth) + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|e| zerror!(e))? + }; + Ok(TlsServerConfig { server_config: sc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_SERVER_PRIVATE_KEY_RAW, + TLS_SERVER_PRIVATE_KEY_FILE, + TLS_SERVER_PRIVATE_KEY_BASE_64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_SERVER_CERTIFICATE_RAW, + TLS_SERVER_CERTIFICATE_FILE, + TLS_SERVER_CERTIFICATE_BASE64, + ) + .await + } +} + +pub(crate) struct TlsClientConfig { + pub(crate) client_config: ClientConfig, +} + +impl TlsClientConfig { + pub async fn new(config: &Config<'_>) -> ZResult { + let tls_client_server_auth: bool = match config.get(TLS_CLIENT_AUTH) { + Some(s) => s + .parse() + .map_err(|_| zerror!("Unknown client auth argument: {}", s))?, + None => false, + }; + + let tls_server_name_verification: bool = match config.get(TLS_SERVER_NAME_VERIFICATION) { + Some(s) => { + let s: bool = s + .parse() + .map_err(|_| zerror!("Unknown server name verification argument: {}", s))?; + if s { + tracing::warn!("Skipping name verification of servers"); + } + s + } + None => false, + }; + + // Allows mixed user-generated CA and webPKI CA + tracing::debug!("Loading default Web PKI certificates."); + let mut root_cert_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }; + + if let Some(custom_root_cert) = load_trust_anchors(config)? { + tracing::debug!("Loading user-generated certificates."); + root_cert_store.extend(custom_root_cert.roots); + } + + let cc = if tls_client_server_auth { + tracing::debug!("Loading client authentication key and certificate..."); + let tls_client_private_key = TlsClientConfig::load_tls_private_key(config).await?; + let tls_client_certificate = TlsClientConfig::load_tls_certificate(config).await?; + + let certs: Vec = + rustls_pemfile::certs(&mut Cursor::new(&tls_client_certificate)) + .collect::>() + .map_err(|err| zerror!("Error processing client certificate: {err}."))?; + + let mut keys: Vec = + rustls_pemfile::rsa_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + + if keys.is_empty() { + keys = + rustls_pemfile::pkcs8_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + } + + if keys.is_empty() { + keys = rustls_pemfile::ec_private_keys(&mut Cursor::new(&tls_client_private_key)) + .map(|x| x.map(PrivateKeyDer::from)) + .collect::>() + .map_err(|err| zerror!("Error processing client key: {err}."))?; + } + + if keys.is_empty() { + bail!("No private key found for TLS client."); + } + + let builder = ClientConfig::builder_with_protocol_versions(&[&TLS13]); + + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_client_auth_cert(certs, keys.remove(0)) + } else { + builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_client_auth_cert(certs, keys.remove(0)) + } + .map_err(|e| zerror!("Bad certificate/key: {}", e))? + } else { + let builder = ClientConfig::builder(); + if tls_server_name_verification { + builder + .with_root_certificates(root_cert_store) + .with_no_client_auth() + } else { + builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(WebPkiVerifierAnyServerName::new( + root_cert_store, + ))) + .with_no_client_auth() + } + }; + Ok(TlsClientConfig { client_config: cc }) + } + + async fn load_tls_private_key(config: &Config<'_>) -> ZResult> { + load_tls_key( + config, + TLS_CLIENT_PRIVATE_KEY_RAW, + TLS_CLIENT_PRIVATE_KEY_FILE, + TLS_CLIENT_PRIVATE_KEY_BASE64, + ) + .await + } + + async fn load_tls_certificate(config: &Config<'_>) -> ZResult> { + load_tls_certificate( + config, + TLS_CLIENT_CERTIFICATE_RAW, + TLS_CLIENT_CERTIFICATE_FILE, + TLS_CLIENT_CERTIFICATE_BASE64, + ) + .await + } +} + +fn process_pem(pem: &mut dyn io::BufRead) -> ZResult>> { + let certs: Vec = rustls_pemfile::certs(pem) + .map(|result| result.map_err(|err| zerror!("Error processing PEM certificates: {err}."))) + .collect::, ZError>>()?; + + let trust_anchors: Vec = certs + .into_iter() + .map(|cert| { + anchor_from_trusted_cert(&cert) + .map_err(|err| zerror!("Error processing trust anchor: {err}.")) + .map(|trust_anchor| trust_anchor.to_owned()) + }) + .collect::, ZError>>()?; + + Ok(trust_anchors) +} + +async fn load_tls_key( + config: &Config<'_>, + tls_private_key_raw_config_key: &str, + tls_private_key_file_config_key: &str, + tls_private_key_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_private_key_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_key) = config.get(tls_private_key_base64_config_key) { + return base64_decode(b64_key); + } + + if let Some(value) = config.get(tls_private_key_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS private key file: {}", e))?) + .and_then(|result| { + if result.is_empty() { + Err(zerror!("Empty TLS key.").into()) + } else { + Ok(result) + } + }); + } + Err(zerror!("Missing TLS private key.").into()) +} + +async fn load_tls_certificate( + config: &Config<'_>, + tls_certificate_raw_config_key: &str, + tls_certificate_file_config_key: &str, + tls_certificate_base64_config_key: &str, +) -> ZResult> { + if let Some(value) = config.get(tls_certificate_raw_config_key) { + return Ok(value.as_bytes().to_vec()); + } + + if let Some(b64_certificate) = config.get(tls_certificate_base64_config_key) { + return base64_decode(b64_certificate); + } + + if let Some(value) = config.get(tls_certificate_file_config_key) { + return Ok(tokio::fs::read(value) + .await + .map_err(|e| zerror!("Invalid TLS certificate file: {}", e))?); + } + Err(zerror!("Missing tls certificates.").into()) +} + +fn load_trust_anchors(config: &Config<'_>) -> ZResult> { + let mut root_cert_store = RootCertStore::empty(); + if let Some(value) = config.get(TLS_ROOT_CA_CERTIFICATE_RAW) { + let mut pem = BufReader::new(value.as_bytes()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(b64_certificate) = config.get(TLS_ROOT_CA_CERTIFICATE_BASE64) { + let certificate_pem = base64_decode(b64_certificate)?; + let mut pem = BufReader::new(certificate_pem.as_slice()); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + + if let Some(filename) = config.get(TLS_ROOT_CA_CERTIFICATE_FILE) { + let mut pem = BufReader::new(File::open(filename)?); + let trust_anchors = process_pem(&mut pem)?; + root_cert_store.extend(trust_anchors); + return Ok(Some(root_cert_store)); + } + Ok(None) +} + +pub fn base64_decode(data: &str) -> ZResult> { + use base64::engine::general_purpose; + use base64::Engine; + Ok(general_purpose::STANDARD + .decode(data) + .map_err(|e| zerror!("Unable to perform base64 decoding: {e:?}"))?) +} + +pub async fn get_tls_addr(address: &Address<'_>) -> ZResult { + match tokio::net::lookup_host(address.as_str()).await?.next() { + Some(addr) => Ok(addr), + None => bail!("Couldn't resolve TLS locator address: {}", address), + } +} + +pub fn get_tls_host<'a>(address: &'a Address<'a>) -> ZResult<&'a str> { + address + .as_str() + .split(':') + .next() + .ok_or_else(|| zerror!("Invalid TLS address").into()) +} + +pub fn get_tls_server_name<'a>(address: &'a Address<'a>) -> ZResult> { + Ok(ServerName::try_from(get_tls_host(address)?).map_err(|e| zerror!(e))?) +} diff --git a/io/zenoh-transport/Cargo.toml b/io/zenoh-transport/Cargo.toml index 7efaabb719..c1a2c9b8ae 100644 --- a/io/zenoh-transport/Cargo.toml +++ b/io/zenoh-transport/Cargo.toml @@ -93,3 +93,4 @@ futures-util = { workspace = true } zenoh-util = {workspace = true } zenoh-protocol = { workspace = true, features = ["test"] } futures = { workspace = true } +zenoh-link-commons = { workspace = true } diff --git a/io/zenoh-transport/tests/unicast_transport.rs b/io/zenoh-transport/tests/unicast_transport.rs index e3aca1c7b7..4b833bc5e7 100644 --- a/io/zenoh-transport/tests/unicast_transport.rs +++ b/io/zenoh-transport/tests/unicast_transport.rs @@ -69,7 +69,10 @@ use zenoh_transport::{ // the key and certificate brought in by the client. Similarly the server's certificate authority // will validate the key and certificate brought in by the server in front of the client. // -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_KEY: &str = "-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAsfqAuhElN4HnyeqLovSd4Qe+nNv5AwCjSO+HFiF30x3vQ1Hi qRA0UmyFlSqBnFH3TUHm4Jcad40QfrX8f11NKGZdpvKHsMYqYjZnYkRFGS2s4fQy @@ -98,7 +101,10 @@ tYsqC2FtWzY51VOEKNpnfH7zH5n+bjoI9nAEAW63TK9ZKkr2hRGsDhJdGzmLfQ7v F6/CuIw9EsAq6qIB8O88FXQqald+BZOx6AzB8Oedsz/WtMmIEmr/+Q== -----END RSA PRIVATE KEY-----"; -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_CERT: &str = "-----BEGIN CERTIFICATE----- MIIDLjCCAhagAwIBAgIIeUtmIdFQznMwDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE AxMVbWluaWNhIHJvb3QgY2EgMDc4ZGE3MCAXDTIzMDMwNjE2MDMxOFoYDzIxMjMw @@ -120,7 +126,10 @@ p5e60QweRuJsb60aUaCG8HoICevXYK2fFqCQdlb5sIqQqXyN2K6HuKAFywsjsGyJ abY= -----END CERTIFICATE-----"; -#[cfg(all(feature = "transport_tls", target_family = "unix"))] +#[cfg(all( + any(feature = "transport_tls", feature = "transport_quic"), + target_family = "unix" +))] const CLIENT_CA: &str = "-----BEGIN CERTIFICATE----- MIIDSzCCAjOgAwIBAgIIB42n1ZIkOakwDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE AxMVbWluaWNhIHJvb3QgY2EgMDc4ZGE3MCAXDTIzMDMwNjE2MDMwN1oYDzIxMjMw @@ -1289,6 +1298,221 @@ fn transport_unicast_tls_only_mutual_wrong_client_certs_failure() { assert!(result.is_err()); } +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn transport_unicast_quic_only_mutual_success() { + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + let client_auth = "true"; + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10461").parse().unwrap(); + client_endpoint + .config_mut() + .extend_from_iter( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA), + (TLS_CLIENT_CERTIFICATE_RAW, CLIENT_CERT), + (TLS_CLIENT_PRIVATE_KEY_RAW, CLIENT_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .copied(), + ) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10461").parse().unwrap(); + server_endpoint + .config_mut() + .extend_from_iter( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .copied(), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + ) + .await; +} + +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn transport_unicast_quic_only_mutual_no_client_certs_failure() { + use std::vec; + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10462").parse().unwrap(); + client_endpoint + .config_mut() + .extend_from_iter([(TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA)].iter().copied()) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10462").parse().unwrap(); + server_endpoint + .config_mut() + .extend_from_iter( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, "true"), + ] + .iter() + .copied(), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + let result = std::panic::catch_unwind(|| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + )) + }); + assert!(result.is_err()); +} + +#[cfg(all(feature = "transport_quic", target_family = "unix"))] +#[test] +fn transport_unicast_quic_only_mutual_wrong_client_certs_failure() { + use zenoh_link::quic::config::*; + + zenoh_util::try_init_log_from_env(); + + let client_auth = "true"; + + // Define the locator + let mut client_endpoint: EndPoint = ("quic/localhost:10463").parse().unwrap(); + client_endpoint + .config_mut() + .extend_from_iter( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, SERVER_CA), + // Using the SERVER_CERT and SERVER_KEY in the client to simulate the case the client has + // wrong certificates and keys. The SERVER_CA (cetificate authority) will not recognize + // these certificates as it is expecting to receive CLIENT_CERT and CLIENT_KEY from the + // client. + (TLS_CLIENT_CERTIFICATE_RAW, SERVER_CERT), + (TLS_CLIENT_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .copied(), + ) + .unwrap(); + + // Define the locator + let mut server_endpoint: EndPoint = ("quic/localhost:10463").parse().unwrap(); + server_endpoint + .config_mut() + .extend_from_iter( + [ + (TLS_ROOT_CA_CERTIFICATE_RAW, CLIENT_CA), + (TLS_SERVER_CERTIFICATE_RAW, SERVER_CERT), + (TLS_SERVER_PRIVATE_KEY_RAW, SERVER_KEY), + (TLS_CLIENT_AUTH, client_auth), + ] + .iter() + .copied(), + ) + .unwrap(); + // Define the reliability and congestion control + let channel = [ + Channel { + priority: Priority::default(), + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::default(), + reliability: Reliability::BestEffort, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::Reliable, + }, + Channel { + priority: Priority::RealTime, + reliability: Reliability::BestEffort, + }, + ]; + // Run + let client_endpoints = vec![client_endpoint]; + let server_endpoints = vec![server_endpoint]; + let result = std::panic::catch_unwind(|| { + tokio::runtime::Runtime::new() + .unwrap() + .block_on(run_with_universal_transport( + &client_endpoints, + &server_endpoints, + &channel, + &MSG_SIZE_ALL, + )) + }); + assert!(result.is_err()); +} + #[test] fn transport_unicast_qos_and_lowlatency_failure() { struct TestPeer; diff --git a/plugins/zenoh-plugin-example/Cargo.toml b/plugins/zenoh-plugin-example/Cargo.toml index 8e6814590f..ce12dbf18e 100644 --- a/plugins/zenoh-plugin-example/Cargo.toml +++ b/plugins/zenoh-plugin-example/Cargo.toml @@ -20,7 +20,7 @@ edition = { workspace = true } publish = false [features] -default = ["no_mangle", "zenoh/default"] +default = ["no_mangle", "zenoh/default", "zenoh/unstable", "zenoh/plugins"] no_mangle = [] [lib] diff --git a/plugins/zenoh-plugin-rest/Cargo.toml b/plugins/zenoh-plugin-rest/Cargo.toml index 05f010bdb8..fd66cdaedd 100644 --- a/plugins/zenoh-plugin-rest/Cargo.toml +++ b/plugins/zenoh-plugin-rest/Cargo.toml @@ -24,7 +24,7 @@ categories = ["network-programming", "web-programming::http-server"] description = "The zenoh REST plugin" [features] -default = ["no_mangle", "zenoh/default"] +default = ["no_mangle", "zenoh/default", "zenoh/unstable", "zenoh/plugins"] no_mangle = [] [lib] diff --git a/plugins/zenoh-plugin-storage-manager/Cargo.toml b/plugins/zenoh-plugin-storage-manager/Cargo.toml index 9486ab5367..e5a6b033f3 100644 --- a/plugins/zenoh-plugin-storage-manager/Cargo.toml +++ b/plugins/zenoh-plugin-storage-manager/Cargo.toml @@ -24,7 +24,7 @@ categories = { workspace = true } description = "The zenoh storages plugin." [features] -default = ["no_mangle", "zenoh/default"] +default = ["no_mangle", "zenoh/default", "zenoh/unstable", "zenoh/plugins"] no_mangle = [] [lib] diff --git a/zenoh/src/net/routing/hat/client/pubsub.rs b/zenoh/src/net/routing/hat/client/pubsub.rs index c27d8670ac..dd35cf24c8 100644 --- a/zenoh/src/net/routing/hat/client/pubsub.rs +++ b/zenoh/src/net/routing/hat/client/pubsub.rs @@ -17,11 +17,11 @@ use crate::net::routing::dispatcher::face::FaceState; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{Route, RoutingExpr}; -use crate::net::routing::hat::HatPubSubTrait; +use crate::net::routing::hat::{HatPubSubTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::atomic::Ordering; use std::sync::Arc; use zenoh_protocol::core::key_expr::OwnedKeyExpr; @@ -276,11 +276,19 @@ impl HatPubSubTrait for HatCode { forget_client_subscription(tables, face, id) } - fn get_subscriptions(&self, tables: &Tables) -> Vec> { - let mut subs = HashSet::new(); + fn get_subscriptions(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known suscriptions (keys) + let mut subs = HashMap::new(); for src_face in tables.faces.values() { for sub in face_hat!(src_face).remote_subs.values() { - subs.insert(sub.clone()); + // Insert the key in the list of known suscriptions + let srcs = subs.entry(sub.clone()).or_insert_with(Sources::empty); + // Append src_face as a suscription source in the proper list + match src_face.whatami { + WhatAmI::Router => srcs.routers.push(src_face.zid), + WhatAmI::Peer => srcs.peers.push(src_face.zid), + WhatAmI::Client => srcs.clients.push(src_face.zid), + } } } Vec::from_iter(subs) diff --git a/zenoh/src/net/routing/hat/client/queries.rs b/zenoh/src/net/routing/hat/client/queries.rs index d968a53df3..777198ed95 100644 --- a/zenoh/src/net/routing/hat/client/queries.rs +++ b/zenoh/src/net/routing/hat/client/queries.rs @@ -17,12 +17,12 @@ use crate::net::routing::dispatcher::face::FaceState; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{QueryTargetQabl, QueryTargetQablSet, RoutingExpr}; -use crate::net::routing::hat::HatQueriesTrait; +use crate::net::routing::hat::{HatQueriesTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use ordered_float::OrderedFloat; use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::atomic::Ordering; use std::sync::Arc; use zenoh_buffers::ZBuf; @@ -275,11 +275,19 @@ impl HatQueriesTrait for HatCode { forget_client_queryable(tables, face, id) } - fn get_queryables(&self, tables: &Tables) -> Vec> { - let mut qabls = HashSet::new(); + fn get_queryables(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known queryables (keys) + let mut qabls = HashMap::new(); for src_face in tables.faces.values() { for qabl in face_hat!(src_face).remote_qabls.values() { - qabls.insert(qabl.clone()); + // Insert the key in the list of known queryables + let srcs = qabls.entry(qabl.clone()).or_insert_with(Sources::empty); + // Append src_face as a queryable source in the proper list + match src_face.whatami { + WhatAmI::Router => srcs.routers.push(src_face.zid), + WhatAmI::Peer => srcs.peers.push(src_face.zid), + WhatAmI::Client => srcs.clients.push(src_face.zid), + } } } Vec::from_iter(qabls) diff --git a/zenoh/src/net/routing/hat/linkstate_peer/network.rs b/zenoh/src/net/routing/hat/linkstate_peer/network.rs index 33a00d0be4..9c8e0c8860 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/network.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/network.rs @@ -18,6 +18,7 @@ use crate::net::runtime::Runtime; use crate::net::runtime::WeakRuntime; use petgraph::graph::NodeIndex; use petgraph::visit::{VisitMap, Visitable}; +use rand::Rng; use std::convert::TryInto; use vec_map::VecMap; use zenoh_buffers::writer::{DidntWrite, HasWriter}; @@ -486,26 +487,25 @@ impl Network { ); } - if !self.autoconnect.is_empty() { + if !self.autoconnect.is_empty() && self.autoconnect.matches(whatami) { // Connect discovered peers - if zenoh_runtime::ZRuntime::Net - .block_in_place( - strong_runtime.manager().get_transport_unicast(&zid), - ) - .is_none() - && self.autoconnect.matches(whatami) - { - if let Some(locators) = locators { - let runtime = strong_runtime.clone(); - strong_runtime.spawn(async move { + if let Some(locators) = locators { + let runtime = strong_runtime.clone(); + strong_runtime.spawn(async move { + if runtime + .manager() + .get_transport_unicast(&zid) + .await + .is_none() + { // random backoff - tokio::time::sleep(std::time::Duration::from_millis( - rand::random::() % 100, - )) - .await; + let sleep_time = std::time::Duration::from_millis( + rand::thread_rng().gen_range(0..100), + ); + tokio::time::sleep(sleep_time).await; runtime.connect_peer(&zid, &locators).await; - }); - } + } + }); } } } @@ -610,22 +610,25 @@ impl Network { for (_, idx, _) in &link_states { let node = &self.graph[*idx]; if let Some(whatami) = node.whatami { - if zenoh_runtime::ZRuntime::Net - .block_in_place(strong_runtime.manager().get_transport_unicast(&node.zid)) - .is_none() - && self.autoconnect.matches(whatami) - { + if self.autoconnect.matches(whatami) { if let Some(locators) = &node.locators { let runtime = strong_runtime.clone(); let zid = node.zid; let locators = locators.clone(); strong_runtime.spawn(async move { - // random backoff - tokio::time::sleep(std::time::Duration::from_millis( - rand::random::() % 100, - )) - .await; - runtime.connect_peer(&zid, &locators).await; + if runtime + .manager() + .get_transport_unicast(&zid) + .await + .is_none() + { + // random backoff + let sleep_time = std::time::Duration::from_millis( + rand::thread_rng().gen_range(0..100), + ); + tokio::time::sleep(sleep_time).await; + runtime.connect_peer(&zid, &locators).await; + } }); } } diff --git a/zenoh/src/net/routing/hat/linkstate_peer/pubsub.rs b/zenoh/src/net/routing/hat/linkstate_peer/pubsub.rs index 3e9247e2b5..2c1cbb23e7 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/pubsub.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/pubsub.rs @@ -19,7 +19,7 @@ use crate::net::routing::dispatcher::pubsub::*; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{Route, RoutingExpr}; -use crate::net::routing::hat::HatPubSubTrait; +use crate::net::routing::hat::{HatPubSubTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use petgraph::graph::NodeIndex; @@ -605,8 +605,31 @@ impl HatPubSubTrait for HatCode { } } - fn get_subscriptions(&self, tables: &Tables) -> Vec> { - hat!(tables).peer_subs.iter().cloned().collect() + fn get_subscriptions(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known suscriptions (keys) + hat!(tables) + .peer_subs + .iter() + .map(|s| { + ( + s.clone(), + // Compute the list of routers, peers and clients that are known + // sources of those subscriptions + Sources { + routers: vec![], + peers: Vec::from_iter(res_hat!(s).peer_subs.iter().cloned()), + clients: s + .session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Client && f.subs.is_some()) + .then_some(f.face.zid) + }) + .collect(), + }, + ) + }) + .collect() } fn compute_data_route( diff --git a/zenoh/src/net/routing/hat/linkstate_peer/queries.rs b/zenoh/src/net/routing/hat/linkstate_peer/queries.rs index 44a153aa44..a227d845ba 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/queries.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/queries.rs @@ -19,7 +19,7 @@ use crate::net::routing::dispatcher::queries::*; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{QueryTargetQabl, QueryTargetQablSet, RoutingExpr}; -use crate::net::routing::hat::HatQueriesTrait; +use crate::net::routing::hat::{HatQueriesTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use ordered_float::OrderedFloat; @@ -676,8 +676,31 @@ impl HatQueriesTrait for HatCode { } } - fn get_queryables(&self, tables: &Tables) -> Vec> { - hat!(tables).peer_qabls.iter().cloned().collect() + fn get_queryables(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known queryables (keys) + hat!(tables) + .peer_qabls + .iter() + .map(|s| { + ( + s.clone(), + // Compute the list of routers, peers and clients that are known + // sources of those queryables + Sources { + routers: vec![], + peers: Vec::from_iter(res_hat!(s).peer_qabls.keys().cloned()), + clients: s + .session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Client && f.qabl.is_some()) + .then_some(f.face.zid) + }) + .collect(), + }, + ) + }) + .collect() } fn compute_query_route( diff --git a/zenoh/src/net/routing/hat/mod.rs b/zenoh/src/net/routing/hat/mod.rs index 3d1ae0f632..ee6557aac3 100644 --- a/zenoh/src/net/routing/hat/mod.rs +++ b/zenoh/src/net/routing/hat/mod.rs @@ -27,7 +27,7 @@ use super::{ use crate::net::runtime::Runtime; use std::{any::Any, sync::Arc}; use zenoh_buffers::ZBuf; -use zenoh_config::{unwrap_or_default, Config, WhatAmI}; +use zenoh_config::{unwrap_or_default, Config, WhatAmI, ZenohId}; use zenoh_protocol::{ core::WireExpr, network::{ @@ -50,6 +50,23 @@ zconfigurable! { pub static ref TREES_COMPUTATION_DELAY_MS: u64 = 100; } +#[derive(serde::Serialize)] +pub(crate) struct Sources { + routers: Vec, + peers: Vec, + clients: Vec, +} + +impl Sources { + pub(crate) fn empty() -> Self { + Self { + routers: vec![], + peers: vec![], + clients: vec![], + } + } +} + pub(crate) trait HatTrait: HatBaseTrait + HatPubSubTrait + HatQueriesTrait {} pub(crate) trait HatBaseTrait { @@ -134,7 +151,7 @@ pub(crate) trait HatPubSubTrait { node_id: NodeId, ) -> Option>; - fn get_subscriptions(&self, tables: &Tables) -> Vec>; + fn get_subscriptions(&self, tables: &Tables) -> Vec<(Arc, Sources)>; fn compute_data_route( &self, @@ -166,7 +183,7 @@ pub(crate) trait HatQueriesTrait { node_id: NodeId, ) -> Option>; - fn get_queryables(&self, tables: &Tables) -> Vec>; + fn get_queryables(&self, tables: &Tables) -> Vec<(Arc, Sources)>; fn compute_query_route( &self, diff --git a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs index de33f0ac54..df04b396ab 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs @@ -16,6 +16,7 @@ use crate::net::protocol::linkstate::{LinkState, LinkStateList}; use crate::net::runtime::Runtime; use crate::net::runtime::WeakRuntime; use petgraph::graph::NodeIndex; +use rand::Rng; use std::convert::TryInto; use vec_map::VecMap; use zenoh_buffers::writer::{DidntWrite, HasWriter}; @@ -406,24 +407,25 @@ impl Network { ); } - if !self.autoconnect.is_empty() { + if !self.autoconnect.is_empty() && self.autoconnect.matches(whatami) { // Connect discovered peers - if zenoh_runtime::ZRuntime::Acceptor - .block_in_place(strong_runtime.manager().get_transport_unicast(&zid)) - .is_none() - && self.autoconnect.matches(whatami) - { - if let Some(locators) = locators { - let runtime = strong_runtime.clone(); - strong_runtime.spawn(async move { + if let Some(locators) = locators { + let runtime = strong_runtime.clone(); + strong_runtime.spawn(async move { + if runtime + .manager() + .get_transport_unicast(&zid) + .await + .is_none() + { // random backoff - tokio::time::sleep(std::time::Duration::from_millis( - rand::random::() % 100, - )) - .await; + let sleep_time = std::time::Duration::from_millis( + rand::thread_rng().gen_range(0..100), + ); + tokio::time::sleep(sleep_time).await; runtime.connect_peer(&zid, &locators).await; - }); - } + } + }); } } } diff --git a/zenoh/src/net/routing/hat/p2p_peer/pubsub.rs b/zenoh/src/net/routing/hat/p2p_peer/pubsub.rs index d292b77f9f..d57c2ac665 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/pubsub.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/pubsub.rs @@ -17,11 +17,11 @@ use crate::net::routing::dispatcher::face::FaceState; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{Route, RoutingExpr}; -use crate::net::routing::hat::HatPubSubTrait; +use crate::net::routing::hat::{HatPubSubTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::atomic::Ordering; use std::sync::Arc; use zenoh_protocol::core::key_expr::OwnedKeyExpr; @@ -276,11 +276,19 @@ impl HatPubSubTrait for HatCode { forget_client_subscription(tables, face, id) } - fn get_subscriptions(&self, tables: &Tables) -> Vec> { - let mut subs = HashSet::new(); + fn get_subscriptions(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known suscriptions (keys) + let mut subs = HashMap::new(); for src_face in tables.faces.values() { for sub in face_hat!(src_face).remote_subs.values() { - subs.insert(sub.clone()); + // Insert the key in the list of known suscriptions + let srcs = subs.entry(sub.clone()).or_insert_with(Sources::empty); + // Append src_face as a suscription source in the proper list + match src_face.whatami { + WhatAmI::Router => srcs.routers.push(src_face.zid), + WhatAmI::Peer => srcs.peers.push(src_face.zid), + WhatAmI::Client => srcs.clients.push(src_face.zid), + } } } Vec::from_iter(subs) diff --git a/zenoh/src/net/routing/hat/p2p_peer/queries.rs b/zenoh/src/net/routing/hat/p2p_peer/queries.rs index d146886d01..25fed11842 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/queries.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/queries.rs @@ -17,12 +17,12 @@ use crate::net::routing::dispatcher::face::FaceState; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{QueryTargetQabl, QueryTargetQablSet, RoutingExpr}; -use crate::net::routing::hat::HatQueriesTrait; +use crate::net::routing::hat::{HatQueriesTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use ordered_float::OrderedFloat; use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::atomic::Ordering; use std::sync::Arc; use zenoh_buffers::ZBuf; @@ -275,11 +275,19 @@ impl HatQueriesTrait for HatCode { forget_client_queryable(tables, face, id) } - fn get_queryables(&self, tables: &Tables) -> Vec> { - let mut qabls = HashSet::new(); + fn get_queryables(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known queryables (keys) + let mut qabls = HashMap::new(); for src_face in tables.faces.values() { for qabl in face_hat!(src_face).remote_qabls.values() { - qabls.insert(qabl.clone()); + // Insert the key in the list of known queryables + let srcs = qabls.entry(qabl.clone()).or_insert_with(Sources::empty); + // Append src_face as a queryable source in the proper list + match src_face.whatami { + WhatAmI::Router => srcs.routers.push(src_face.zid), + WhatAmI::Peer => srcs.peers.push(src_face.zid), + WhatAmI::Client => srcs.clients.push(src_face.zid), + } } } Vec::from_iter(qabls) diff --git a/zenoh/src/net/routing/hat/router/network.rs b/zenoh/src/net/routing/hat/router/network.rs index 09f6e9df17..3ff59b5ede 100644 --- a/zenoh/src/net/routing/hat/router/network.rs +++ b/zenoh/src/net/routing/hat/router/network.rs @@ -17,6 +17,7 @@ use crate::net::routing::dispatcher::tables::NodeId; use crate::net::runtime::Runtime; use petgraph::graph::NodeIndex; use petgraph::visit::{IntoNodeReferences, VisitMap, Visitable}; +use rand::Rng; use std::convert::TryInto; use vec_map::VecMap; use zenoh_buffers::writer::{DidntWrite, HasWriter}; @@ -489,24 +490,25 @@ impl Network { ); } - if !self.autoconnect.is_empty() { + if !self.autoconnect.is_empty() && self.autoconnect.matches(whatami) { // Connect discovered peers - if zenoh_runtime::ZRuntime::Net - .block_in_place(self.runtime.manager().get_transport_unicast(&zid)) - .is_none() - && self.autoconnect.matches(whatami) - { - if let Some(locators) = locators { - let runtime = self.runtime.clone(); - self.runtime.spawn(async move { + if let Some(locators) = locators { + let runtime = self.runtime.clone(); + self.runtime.spawn(async move { + if runtime + .manager() + .get_transport_unicast(&zid) + .await + .is_none() + { // random backoff - tokio::time::sleep(std::time::Duration::from_millis( - rand::random::() % 100, - )) - .await; + let sleep_time = std::time::Duration::from_millis( + rand::thread_rng().gen_range(0..100), + ); + tokio::time::sleep(sleep_time).await; runtime.connect_peer(&zid, &locators).await; - }); - } + } + }); } } } @@ -611,22 +613,25 @@ impl Network { for (_, idx, _) in &link_states { let node = &self.graph[*idx]; if let Some(whatami) = node.whatami { - if zenoh_runtime::ZRuntime::Net - .block_in_place(self.runtime.manager().get_transport_unicast(&node.zid)) - .is_none() - && self.autoconnect.matches(whatami) - { + if self.autoconnect.matches(whatami) { if let Some(locators) = &node.locators { let runtime = self.runtime.clone(); let zid = node.zid; let locators = locators.clone(); self.runtime.spawn(async move { - // random backoff - tokio::time::sleep(std::time::Duration::from_millis( - rand::random::() % 100, - )) - .await; - runtime.connect_peer(&zid, &locators).await; + if runtime + .manager() + .get_transport_unicast(&zid) + .await + .is_none() + { + // random backoff + let sleep_time = std::time::Duration::from_millis( + rand::thread_rng().gen_range(0..100), + ); + tokio::time::sleep(sleep_time).await; + runtime.connect_peer(&zid, &locators).await; + } }); } } diff --git a/zenoh/src/net/routing/hat/router/pubsub.rs b/zenoh/src/net/routing/hat/router/pubsub.rs index 931911bfe2..99b7eb3c12 100644 --- a/zenoh/src/net/routing/hat/router/pubsub.rs +++ b/zenoh/src/net/routing/hat/router/pubsub.rs @@ -19,7 +19,7 @@ use crate::net::routing::dispatcher::pubsub::*; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{Route, RoutingExpr}; -use crate::net::routing::hat::HatPubSubTrait; +use crate::net::routing::hat::{HatPubSubTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use petgraph::graph::NodeIndex; @@ -924,8 +924,41 @@ impl HatPubSubTrait for HatCode { } } - fn get_subscriptions(&self, tables: &Tables) -> Vec> { - hat!(tables).router_subs.iter().cloned().collect() + fn get_subscriptions(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known suscriptions (keys) + hat!(tables) + .router_subs + .iter() + .map(|s| { + ( + s.clone(), + // Compute the list of routers, peers and clients that are known + // sources of those subscriptions + Sources { + routers: Vec::from_iter(res_hat!(s).router_subs.iter().cloned()), + peers: if hat!(tables).full_net(WhatAmI::Peer) { + Vec::from_iter(res_hat!(s).peer_subs.iter().cloned()) + } else { + s.session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Peer && f.subs.is_some()) + .then_some(f.face.zid) + }) + .collect() + }, + clients: s + .session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Client && f.subs.is_some()) + .then_some(f.face.zid) + }) + .collect(), + }, + ) + }) + .collect() } fn compute_data_route( diff --git a/zenoh/src/net/routing/hat/router/queries.rs b/zenoh/src/net/routing/hat/router/queries.rs index 6ff4509596..dbd7da8629 100644 --- a/zenoh/src/net/routing/hat/router/queries.rs +++ b/zenoh/src/net/routing/hat/router/queries.rs @@ -19,7 +19,7 @@ use crate::net::routing::dispatcher::queries::*; use crate::net::routing::dispatcher::resource::{NodeId, Resource, SessionContext}; use crate::net::routing::dispatcher::tables::Tables; use crate::net::routing::dispatcher::tables::{QueryTargetQabl, QueryTargetQablSet, RoutingExpr}; -use crate::net::routing::hat::HatQueriesTrait; +use crate::net::routing::hat::{HatQueriesTrait, Sources}; use crate::net::routing::router::RoutesIndexes; use crate::net::routing::{RoutingContext, PREFIX_LIVELINESS}; use ordered_float::OrderedFloat; @@ -1080,8 +1080,41 @@ impl HatQueriesTrait for HatCode { } } - fn get_queryables(&self, tables: &Tables) -> Vec> { - hat!(tables).router_qabls.iter().cloned().collect() + fn get_queryables(&self, tables: &Tables) -> Vec<(Arc, Sources)> { + // Compute the list of known queryables (keys) + hat!(tables) + .router_qabls + .iter() + .map(|s| { + ( + s.clone(), + // Compute the list of routers, peers and clients that are known + // sources of those queryables + Sources { + routers: Vec::from_iter(res_hat!(s).router_qabls.keys().cloned()), + peers: if hat!(tables).full_net(WhatAmI::Peer) { + Vec::from_iter(res_hat!(s).peer_qabls.keys().cloned()) + } else { + s.session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Peer && f.qabl.is_some()) + .then_some(f.face.zid) + }) + .collect() + }, + clients: s + .session_ctxs + .values() + .filter_map(|f| { + (f.face.whatami == WhatAmI::Client && f.qabl.is_some()) + .then_some(f.face.zid) + }) + .collect(), + }, + ) + }) + .collect() } fn compute_query_route( diff --git a/zenoh/src/net/runtime/adminspace.rs b/zenoh/src/net/runtime/adminspace.rs index b35d81a81a..ea084c453b 100644 --- a/zenoh/src/net/runtime/adminspace.rs +++ b/zenoh/src/net/runtime/adminspace.rs @@ -710,11 +710,17 @@ fn subscribers_data(context: &AdminContext, query: Query) { "@/{}/{}/subscriber/{}", context.runtime.state.whatami, context.runtime.state.zid, - sub.expr() + sub.0.expr() )) .unwrap(); if query.key_expr().intersects(&key) { - if let Err(e) = query.reply(key, ZBytes::empty()).res() { + let payload = + ZBytes::from(serde_json::to_string(&sub.1).unwrap_or_else(|_| "{}".to_string())); + if let Err(e) = query + .reply(key, payload) + .encoding(Encoding::APPLICATION_JSON) + .res_sync() + { tracing::error!("Error sending AdminSpace reply: {:?}", e); } } @@ -728,11 +734,17 @@ fn queryables_data(context: &AdminContext, query: Query) { "@/{}/{}/queryable/{}", context.runtime.state.whatami, context.runtime.state.zid, - qabl.expr() + qabl.0.expr() )) .unwrap(); if query.key_expr().intersects(&key) { - if let Err(e) = query.reply(key, ZBytes::empty()).res() { + let payload = + ZBytes::from(serde_json::to_string(&qabl.1).unwrap_or_else(|_| "{}".to_string())); + if let Err(e) = query + .reply(key, payload) + .encoding(Encoding::APPLICATION_JSON) + .res_sync() + { tracing::error!("Error sending AdminSpace reply: {:?}", e); } }