Skip to content

Commit

Permalink
Port link-quic to tokio
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanYuYuan committed Jan 10, 2024
1 parent fb208f6 commit 69d20e3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 102 deletions.
14 changes: 9 additions & 5 deletions io/zenoh-links/zenoh-link-quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,26 @@ description = "Internal crate for zenoh."
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-rustls = { workspace = true }
async-std = { workspace = true, features = ["unstable", "tokio1"] }
async-trait = { workspace = true }
base64 = { workspace = true }
futures = { workspace = true }
log = { workspace = true }
quinn = { workspace = true }
rustls = { workspace = true }
rustls-native-certs = { workspace = true }
rustls-pemfile = { workspace = true }
rustls-webpki = { workspace = true }
secrecy = {workspace = true }
tokio = { workspace = true, features = ["io-util", "net", "fs", "sync", "time"] }
tokio-util = { workspace = true, features = ["rt"] }
zenoh-config = { workspace = true }
zenoh-core = { workspace = true }
zenoh-link-commons = { workspace = true }
zenoh-protocol = { workspace = true }
zenoh-result = { workspace = true }
zenoh-sync = { workspace = true }
zenoh-util = { workspace = true }
base64 = { workspace = true }
secrecy = {workspace = true }
zenoh-runtime = { workspace = true }

# Lock due to quinn not supporting rustls 0.22 yet
rustls = { version = "0.21", features = ["dangerous_configuration", "quic"] }
tokio-rustls = "0.24.1"
3 changes: 1 addition & 2 deletions io/zenoh-links/zenoh-link-quic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
//! This crate is intended for Zenoh's internal use.
//!
//! [Click here for Zenoh's documentation](../zenoh/index.html)
use async_std::net::ToSocketAddrs;
use async_trait::async_trait;
use config::{
TLS_ROOT_CA_CERTIFICATE_BASE64, TLS_ROOT_CA_CERTIFICATE_FILE, TLS_SERVER_CERTIFICATE_BASE64,
Expand Down Expand Up @@ -167,7 +166,7 @@ pub mod config {
}

async fn get_quic_addr(address: &Address<'_>) -> ZResult<SocketAddr> {
match address.as_str().to_socket_addrs().await?.next() {
match tokio::net::lookup_host(address.as_str()).await?.next() {
Some(addr) => Ok(addr),
None => bail!("Couldn't resolve QUIC locator address: {}", address),
}
Expand Down
168 changes: 78 additions & 90 deletions io/zenoh-links/zenoh-link-quic/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,24 @@ use crate::{
config::*, get_quic_addr, verify::WebPkiVerifierAnyServerName, ALPN_QUIC_HTTP,
QUIC_ACCEPT_THROTTLE_TIME, QUIC_DEFAULT_MTU, QUIC_LOCATOR_PREFIX,
};
use async_std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use async_std::prelude::FutureExt;
use async_std::sync::Mutex as AsyncMutex;
use async_std::task;
use async_std::task::JoinHandle;
use async_trait::async_trait;
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::Item;
use std::collections::HashMap;
use std::fmt;
use std::io::BufReader;
use std::net::IpAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::Mutex as AsyncMutex;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use zenoh_core::{zasynclock, zread, zwrite};
use zenoh_link_commons::{
LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, NewLinkChannelSender,
};
use zenoh_protocol::core::{EndPoint, Locator};
use zenoh_result::{bail, zerror, ZError, ZResult};
use zenoh_sync::Signal;

pub struct LinkUnicastQuic {
connection: quinn::Connection,
Expand Down Expand Up @@ -186,25 +182,28 @@ impl fmt::Debug for LinkUnicastQuic {
/*************************************/
struct ListenerUnicastQuic {
endpoint: EndPoint,
active: Arc<AtomicBool>,
signal: Signal,
handle: JoinHandle<ZResult<()>>,
token: CancellationToken,
tracker: TaskTracker,
}

impl ListenerUnicastQuic {
fn new(
endpoint: EndPoint,
active: Arc<AtomicBool>,
signal: Signal,
handle: JoinHandle<ZResult<()>>,
token: CancellationToken,
tracker: TaskTracker,
) -> ListenerUnicastQuic {
ListenerUnicastQuic {
endpoint,
active,
signal,
handle,
token,
tracker,
}
}

async fn stop(&self) {
self.token.cancel();
self.tracker.close();
self.tracker.wait().await;
}
}

pub struct LinkManagerUnicastQuic {
Expand Down Expand Up @@ -252,7 +251,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {
} else if let Some(b64_certificate) = epconf.get(TLS_ROOT_CA_CERTIFICATE_BASE64) {
base64_decode(b64_certificate)?
} else if let Some(value) = epconf.get(TLS_ROOT_CA_CERTIFICATE_FILE) {
async_std::fs::read(value)
tokio::fs::read(value)
.await
.map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))?
} else {
Expand Down Expand Up @@ -344,7 +343,7 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {
} else if let Some(b64_certificate) = epconf.get(TLS_SERVER_CERTIFICATE_BASE64) {
base64_decode(b64_certificate)?
} else if let Some(value) = epconf.get(TLS_SERVER_CERTIFICATE_FILE) {
async_std::fs::read(value)
tokio::fs::read(value)
.await
.map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))?
} else {
Expand All @@ -364,17 +363,15 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {
} else if let Some(b64_key) = epconf.get(TLS_SERVER_PRIVATE_KEY_BASE64) {
base64_decode(b64_key)?
} else if let Some(value) = epconf.get(TLS_SERVER_PRIVATE_KEY_FILE) {
async_std::fs::read(value)
tokio::fs::read(value)
.await
.map_err(|e| zerror!("Invalid QUIC CA certificate file: {}", e))?
} else {
bail!("No QUIC CA private key has been provided.");
};
let items: Vec<Item> = rustls_pemfile::read_all(&mut BufReader::new(f.as_slice()))
.map(|result| {
result.map_err(|err| zerror!("Invalid QUIC CA private key file: {}", err))
})
.collect::<Result<Vec<Item>, ZError>>()?;
.collect::<Result<_, _>>()
.map_err(|err| zerror!("Invalid QUIC CA private key file: {}", err))?;

let private_key = items
.into_iter()
Expand Down Expand Up @@ -423,25 +420,26 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {
)?;

// Spawn the accept loop for the listener
let active = Arc::new(AtomicBool::new(true));
let signal = Signal::new();
let token = CancellationToken::new();
let c_token = token.clone();
let mut listeners = zwrite!(self.listeners);

let c_active = active.clone();
let c_signal = signal.clone();
let c_manager = self.manager.clone();
let c_listeners = self.listeners.clone();
let c_addr = local_addr;
let handle = task::spawn(async move {

let tracker = TaskTracker::new();
let task = async move {
// Wait for the accept loop to terminate
let res = accept_task(quic_endpoint, c_active, c_signal, c_manager).await;
let res = accept_task(quic_endpoint, c_token, c_manager).await;
zwrite!(c_listeners).remove(&c_addr);
res
});
};
tracker.spawn_on(task, &zenoh_runtime::ZRuntime::TX);

// Initialize the QuicAcceptor
let locator = endpoint.to_locator();
let listener = ListenerUnicastQuic::new(endpoint, active, signal, handle);
let listener = ListenerUnicastQuic::new(endpoint, token, tracker);
// Update the list of active listeners on the manager
listeners.insert(local_addr, listener);

Expand All @@ -464,9 +462,8 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {
})?;

// Send the stop signal
listener.active.store(false, Ordering::Release);
listener.signal.trigger();
listener.handle.await
listener.stop().await;
Ok(())
}

fn get_listeners(&self) -> Vec<EndPoint> {
Expand Down Expand Up @@ -509,16 +506,10 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastQuic {

async fn accept_task(
endpoint: quinn::Endpoint,
active: Arc<AtomicBool>,
signal: Signal,
token: CancellationToken,
manager: NewLinkChannelSender,
) -> ZResult<()> {
enum Action {
Accept(quinn::Connection),
Stop,
}

async fn accept(acceptor: quinn::Accept<'_>) -> ZResult<Action> {
async fn accept(acceptor: quinn::Accept<'_>) -> ZResult<quinn::Connection> {
let qc = acceptor
.await
.ok_or_else(|| zerror!("Can not accept QUIC connections: acceptor closed"))?;
Expand All @@ -529,12 +520,7 @@ async fn accept_task(
e
})?;

Ok(Action::Accept(conn))
}

async fn stop(signal: Signal) -> ZResult<Action> {
signal.wait().await;
Ok(Action::Stop)
Ok(conn)
}

let src_addr = endpoint
Expand All @@ -543,51 +529,53 @@ async fn accept_task(

// The accept future
log::trace!("Ready to accept QUIC connections on: {:?}", src_addr);
while active.load(Ordering::Acquire) {
// Wait for incoming connections
let quic_conn = match accept(endpoint.accept()).race(stop(signal.clone())).await {
Ok(action) => match action {
Action::Accept(qc) => qc,
Action::Stop => break,
},
Err(e) => {
log::warn!("{} Hint: increase the system open file limit.", e);
// Throttle the accept loop upon an error
// NOTE: This might be due to various factors. However, the most common case is that
// the process has reached the maximum number of open files in the system. On
// Linux systems this limit can be changed by using the "ulimit" command line
// tool. In case of systemd-based systems, this can be changed by using the
// "sysctl" command line tool.
task::sleep(Duration::from_micros(*QUIC_ACCEPT_THROTTLE_TIME)).await;
continue;
}
};

// Get the bideractional streams. Note that we don't allow unidirectional streams.
let (send, recv) = match quic_conn.accept_bi().await {
Ok(stream) => stream,
Err(e) => {
log::warn!("QUIC connection has no streams: {:?}", e);
continue;
loop {
tokio::select! {
_ = token.cancelled() => break,

res = accept(endpoint.accept()) => {
match res {
Ok(quic_conn) => {
// Get the bideractional streams. Note that we don't allow unidirectional streams.
let (send, recv) = match quic_conn.accept_bi().await {
Ok(stream) => stream,
Err(e) => {
log::warn!("QUIC connection has no streams: {:?}", e);
continue;
}
};

let dst_addr = quic_conn.remote_address();
log::debug!("Accepted QUIC connection on {:?}: {:?}", src_addr, dst_addr);
// Create the new link object
let link = Arc::new(LinkUnicastQuic::new(
quic_conn,
src_addr,
Locator::new(QUIC_LOCATOR_PREFIX, dst_addr.to_string(), "")?,
send,
recv,
));

// Communicate the new link to the initial transport manager
if let Err(e) = manager.send_async(LinkUnicast(link)).await {
log::error!("{}-{}: {}", file!(), line!(), e)
}

}
Err(e) => {
log::warn!("{} Hint: increase the system open file limit.", e);
// Throttle the accept loop upon an error
// NOTE: This might be due to various factors. However, the most common case is that
// the process has reached the maximum number of open files in the system. On
// Linux systems this limit can be changed by using the "ulimit" command line
// tool. In case of systemd-based systems, this can be changed by using the
// "sysctl" command line tool.
tokio::time::sleep(Duration::from_micros(*QUIC_ACCEPT_THROTTLE_TIME)).await;
}
}
}
};

let dst_addr = quic_conn.remote_address();
log::debug!("Accepted QUIC connection on {:?}: {:?}", src_addr, dst_addr);
// Create the new link object
let link = Arc::new(LinkUnicastQuic::new(
quic_conn,
src_addr,
Locator::new(QUIC_LOCATOR_PREFIX, dst_addr.to_string(), "")?,
send,
recv,
));

// Communicate the new link to the initial transport manager
if let Err(e) = manager.send_async(LinkUnicast(link)).await {
log::error!("{}-{}: {}", file!(), line!(), e)
}
}

Ok(())
}
10 changes: 5 additions & 5 deletions io/zenoh-links/zenoh-link-quic/src/verify.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use async_rustls::rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, RootCertStore, ServerName,
};
use rustls::client::verify_server_cert_signed_by_trust_anchor;
use rustls::server::ParsedCertificate;
use std::time::SystemTime;
use tokio_rustls::rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, RootCertStore, ServerName,
};

impl ServerCertVerifier for WebPkiVerifierAnyServerName {
/// Will verify the certificate is valid in the following ways:
Expand All @@ -18,7 +18,7 @@ impl ServerCertVerifier for WebPkiVerifierAnyServerName {
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, async_rustls::rustls::Error> {
) -> Result<ServerCertVerified, tokio_rustls::rustls::Error> {
let cert = ParsedCertificate::try_from(end_entity)?;
verify_server_cert_signed_by_trust_anchor(&cert, &self.roots, intermediates, now)?;
Ok(ServerCertVerified::assertion())
Expand Down

0 comments on commit 69d20e3

Please sign in to comment.