Skip to content

Commit

Permalink
Merge pull request #27 from mstyura/quinn-0.11
Browse files Browse the repository at this point in the history
Upgrade crate to support quinn 0.11.
  • Loading branch information
nmittler authored Sep 18, 2024
2 parents 0cd56ae + 354a57a commit 8aeaa43
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 115 deletions.
37 changes: 19 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,38 @@ rust-version = "1.70"
maintenance = { status = "passively-maintained" }

[features]
default = ["runtime-tokio"]
fips = ["boring/fips", "boring-sys/fips"]
runtime-tokio = ["quinn/runtime-tokio"]

[dependencies]
boring = "3.0.2"
boring-sys = "3.0.2"
boring = "4"
boring-sys = "4"
bytes = "1"
foreign-types-shared = "0.3.1"
lru = "0.11.0"
once_cell = "1.17"
quinn = { version = "0.10.1", default_features = false, features = ["native-certs", "runtime-tokio"] }
quinn-proto = { version = "0.10.1", default-features = false }
foreign-types-shared = "0.3"
lru = "0.12"
once_cell = "1"
quinn = { version = "0.11", default-features = false }
quinn-proto = { version = "0.11", default-features = false }
rand = "0.8"
tracing = "0.1"

[dev-dependencies]
anyhow = "1.0.22"
assert_hex = "0.2.2"
assert_matches = "1.1"
clap = { version = "4.3", features = ["derive"] }
anyhow = "1.0"
assert_matches = "1"
clap = { version = "4", features = ["derive"] }
directories-next = "2"
hex-literal = "0.4.1"
ring = "0.16.7"
rcgen = "0.11.1"
rustls-pemfile = "1.0.0"
tokio = { version = "1.0.1", features = ["rt", "rt-multi-thread", "time", "macros", "sync"] }
tracing-futures = { version = "0.2.0", default-features = false, features = ["std-future"] }
tracing-subscriber = { version = "0.3.0", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] }
hex-literal = "0.4"
rcgen = "0.13"
rustls-pemfile = "2"
tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "macros", "sync"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] }
url = "2"

[[example]]
name = "server"
required-features = ["runtime-tokio"]

[[example]]
name = "client"
required-features = ["runtime-tokio"]
50 changes: 31 additions & 19 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use std::{
fs,
io::{self, Write},
net::ToSocketAddrs,
net::{SocketAddr, ToSocketAddrs},
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
Expand All @@ -22,20 +22,27 @@ use url::Url;
#[derive(Parser, Debug)]
#[clap(name = "client")]
struct Opt {
#[arg(default_value = "https://localhost:4433/Cargo.toml")]
/// Perform NSS-compatible TLS key logging to the file specified in `SSLKEYLOGFILE`.
#[clap(long = "keylog")]
keylog: bool,

url: Url,

/// Override hostname used for certificate verification
#[arg(long = "host")]
#[clap(long = "host")]
host: Option<String>,

/// Custom certificate authority to trust, in DER format
#[arg(long = "ca")]
#[clap(long = "ca")]
ca: Option<PathBuf>,

/// Simulate NAT rebinding after connecting
#[arg(long = "rebind")]
#[clap(long = "rebind")]
rebind: bool,

/// Address to bind on
#[clap(long = "bind", default_value = "[::]:0")]
bind: SocketAddr,
}

fn main() {
Expand All @@ -48,19 +55,20 @@ fn main() {
let opt = Opt::parse();
let code = {
if let Err(e) = run(opt) {
eprintln!("ERROR: {}", e);
eprintln!("ERROR: {e}");
1
} else {
0
}
};
std::process::exit(code);
::std::process::exit(code);
}

#[tokio::main]
async fn run(options: Opt) -> Result<()> {
let url = options.url;
let remote = (url.host_str().unwrap(), url.port().unwrap_or(4433))
let url_host = strip_ipv6_brackets(url.host_str().unwrap());
let remote = (url_host, url.port().unwrap_or(4433))
.to_socket_addrs()?
.next()
.ok_or_else(|| anyhow!("couldn't resolve to an address"))?;
Expand Down Expand Up @@ -91,19 +99,15 @@ async fn run(options: Opt) -> Result<()> {
}
}

let mut endpoint = quinn_boring::helpers::client_endpoint("[::]:0".parse().unwrap())?;
let mut endpoint = quinn_boring::helpers::client_endpoint(options.bind)?;
endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(client_crypto)));

let request = format!("GET {}\r\n", url.path());
let start = Instant::now();
let rebind = options.rebind;
let host = options
.host
.as_ref()
.map_or_else(|| url.host_str(), |x| Some(x))
.ok_or_else(|| anyhow!("no hostname specified"))?;
let host = options.host.as_deref().unwrap_or(url_host);

eprintln!("connecting to {} at {}", host, remote);
eprintln!("connecting to {host} at {remote}");
let conn = endpoint
.connect(remote, host)?
.await
Expand All @@ -116,16 +120,14 @@ async fn run(options: Opt) -> Result<()> {
if rebind {
let socket = std::net::UdpSocket::bind("[::]:0").unwrap();
let addr = socket.local_addr().unwrap();
eprintln!("rebinding to {}", addr);
eprintln!("rebinding to {addr}");
endpoint.rebind(socket).expect("rebind failed");
}

send.write_all(request.as_bytes())
.await
.map_err(|e| anyhow!("failed to send request: {}", e))?;
send.finish()
.await
.map_err(|e| anyhow!("failed to shutdown stream: {}", e))?;
send.finish().unwrap();
let response_start = Instant::now();
eprintln!("request sent at {:?}", response_start - start);
let resp = recv
Expand All @@ -148,6 +150,16 @@ async fn run(options: Opt) -> Result<()> {
Ok(())
}

fn strip_ipv6_brackets(host: &str) -> &str {
// An ipv6 url looks like eg https://[::1]:4433/Cargo.toml, wherein the host [::1] is the
// ipv6 address ::1 wrapped in brackets, per RFC 2732. This strips those.
if host.starts_with('[') && host.ends_with(']') {
&host[1..host.len() - 1]
} else {
host
}
}

fn duration_secs(x: &Duration) -> f32 {
x.as_secs() as f32 + x.subsec_nanos() as f32 * 1e-9
}
104 changes: 55 additions & 49 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,34 @@ use boring::pkey::PKey;
use boring::x509::X509;
use clap::Parser;
use quinn_boring::QuicSslContext;
use tracing::{error, info, info_span};
use tracing_futures::Instrument as _;
use tracing::{error, info, info_span, Instrument as _};

#[derive(Parser, Debug)]
#[clap(name = "server")]
struct Opt {
/// file to log TLS keys to for debugging
#[clap(long = "keylog")]
keylog: bool,
/// directory to serve files from
#[arg(default_value = "./")]
root: PathBuf,
/// TLS private key in PEM format
#[arg(short = 'k', long = "key", requires = "cert")]
#[clap(short = 'k', long = "key", requires = "cert")]
key: Option<PathBuf>,
/// TLS certificate in PEM format
#[arg(short = 'c', long = "cert", requires = "key")]
#[clap(short = 'c', long = "cert", requires = "key")]
cert: Option<PathBuf>,
/// Enable stateless retries
#[arg(long = "stateless-retry")]
#[clap(long = "stateless-retry")]
stateless_retry: bool,
/// Address to listen on
#[arg(long = "listen", default_value = "127.0.0.1:4433")]
#[clap(long = "listen", default_value = "[::1]:4433")]
listen: SocketAddr,
/// Client address to block
#[clap(long = "block")]
block: Option<SocketAddr>,
/// Maximum number of concurrent connections to allow
#[clap(long = "connection-limit")]
connection_limit: Option<usize>,
}

fn main() {
Expand All @@ -48,7 +55,7 @@ fn main() {
let opt = Opt::parse();
let code = {
if let Err(e) = run(opt) {
eprintln!("ERROR: {}", e);
eprintln!("ERROR: {e}");
1
} else {
0
Expand All @@ -64,31 +71,22 @@ async fn run(options: Opt) -> Result<()> {
let key = if key_path.extension().map_or(false, |x| x == "der") {
PKey::private_key_from_der(&key)?
} else {
let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*key)
.context("malformed PKCS #8 private key")?;
match pkcs8.into_iter().next() {
Some(x) => PKey::private_key_from_der(&x)?,
None => {
let rsa = rustls_pemfile::rsa_private_keys(&mut &*key)
.context("malformed PKCS #1 private key")?;
match rsa.into_iter().next() {
Some(x) => PKey::private_key_from_der(&x)?,
None => {
bail!("no private keys found");
}
}
}
}
let key = rustls_pemfile::private_key(&mut &*key)
.context("malformed PKCS #1 private key")?
.ok_or_else(|| anyhow::Error::msg("no private keys found"))?;
PKey::private_key_from_der(key.secret_der())?
};
let cert_chain = fs::read(cert_path).context("failed to read certificate chain")?;
let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") {
vec![X509::from_der(&cert_chain)?]
} else {
rustls_pemfile::certs(&mut &*cert_chain)
.context("invalid PEM-encoded certificate")?
.into_iter()
.map(|x| X509::from_der(&x).unwrap())
.collect()
let mut certs = Vec::new();
for cert in rustls_pemfile::certs(&mut &*cert_chain) {
let cert = cert.context("invalid PEM-encoded certificate")?;
let x509 = X509::from_der(&cert).context("invalid X509 cert")?;
certs.push(x509);
}
certs
};

(cert_chain, key)
Expand All @@ -100,12 +98,12 @@ async fn run(options: Opt) -> Result<()> {
let key_path = path.join("key.der");

let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
Ok(x) => x,
Ok((cert, key)) => (cert, key),
Err(ref e) if e.kind() == io::ErrorKind::NotFound => {
info!("generating self-signed certificate");
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = cert.serialize_private_key_der();
let cert = cert.serialize_der().unwrap();
let key = cert.key_pair.serialize_der();
let cert = cert.cert.der().to_vec();
fs::create_dir_all(path).context("failed to create certificate directory")?;
fs::write(&cert_path, &cert).context("failed to write certificate")?;
fs::write(&key_path, &key).context("failed to write private key")?;
Expand Down Expand Up @@ -137,12 +135,8 @@ async fn run(options: Opt) -> Result<()> {
ctx.check_private_key()?;

let mut server_config = quinn_boring::helpers::server_config(Arc::new(server_crypto))?;
Arc::get_mut(&mut server_config.transport)
.unwrap()
.max_concurrent_uni_streams(0_u8.into());
if options.stateless_retry {
server_config.use_retry(true);
}
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.max_concurrent_uni_streams(0_u8.into());

let root = Arc::<Path>::from(options.root.clone());
if !root.exists() {
Expand All @@ -153,19 +147,33 @@ async fn run(options: Opt) -> Result<()> {
eprintln!("listening on {}", endpoint.local_addr()?);

while let Some(conn) = endpoint.accept().await {
info!("connection incoming");
let fut = handle_connection(root.clone(), conn);
tokio::spawn(async move {
if let Err(e) = fut.await {
error!("connection failed: {reason}", reason = e.to_string())
}
});
if options
.connection_limit
.map_or(false, |n| endpoint.open_connections() >= n)
{
info!("refusing due to open connection limit");
conn.refuse();
} else if Some(conn.remote_address()) == options.block {
info!("refusing blocked client IP address");
conn.refuse();
} else if options.stateless_retry && !conn.remote_address_validated() {
info!("requiring connection to validate its address");
conn.retry().unwrap();
} else {
info!("accepting connection");
let fut = handle_connection(root.clone(), conn);
tokio::spawn(async move {
if let Err(e) = fut.await {
error!("connection failed: {reason}", reason = e.to_string())
}
});
}
}

Ok(())
}

async fn handle_connection(root: Arc<Path>, conn: quinn::Connecting) -> Result<()> {
async fn handle_connection(root: Arc<Path>, conn: quinn::Incoming) -> Result<()> {
let connection = conn.await?;
let span = info_span!(
"connection",
Expand Down Expand Up @@ -226,16 +234,14 @@ async fn handle_request(
// Execute the request
let resp = process_get(&root, &req).unwrap_or_else(|e| {
error!("failed: {}", e);
format!("failed to process request: {}\n", e).into_bytes()
format!("failed to process request: {e}\n").into_bytes()
});
// Write the response
send.write_all(&resp)
.await
.map_err(|e| anyhow!("failed to send response: {}", e))?;
// Gracefully terminate the stream
send.finish()
.await
.map_err(|e| anyhow!("failed to shutdown stream: {}", e))?;
send.finish().unwrap();
info!("complete");
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ impl Session {

// Configure verification for the server hostname.
ssl.set_verify_hostname(server_name)
.map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?;
.map_err(|_| ConnectError::InvalidServerName(server_name.into()))?;

// Set the SNI hostname.
// TODO: should we validate the hostname?
ssl.set_hostname(server_name)
.map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?;
.map_err(|_| ConnectError::InvalidServerName(server_name.into()))?;

// Set the transport parameters.
ssl.set_quic_transport_params(&encode_params(params))
Expand Down
4 changes: 4 additions & 0 deletions src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ impl Debug for AeadKey {

unsafe impl Send for AeadKey {}

// EVP_AEAD_CTX_seal & EVP_AEAD_CTX_open allowed to be called concurrently on the same instance of EVP_AEAD_CTX
// https://github.com/google/boringssl/blob/master/include/openssl/aead.h#L278
unsafe impl Sync for AeadKey {}

impl AeadKey {
#[inline]
pub(crate) fn new(suite: &'static CipherSuite, key: Key) -> Result<Self> {
Expand Down
Loading

0 comments on commit 8aeaa43

Please sign in to comment.