diff --git a/Cargo.toml b/Cargo.toml index 376a225..2cfe391 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,37 +13,38 @@ rust-version = "1.70" maintenance = { status = "passively-maintained" } [features] +default = ["runtime-tokio"] fips = ["boring/fips", "boring-sys/fips"] +runtime-tokio = ["quinn/runtime-tokio"] [dependencies] -boring = "3.0.2" -boring-sys = "3.0.2" +boring = "4" +boring-sys = "4" bytes = "1" -foreign-types-shared = "0.3.1" -lru = "0.11.0" -once_cell = "1.17" -quinn = { version = "0.10.1", default_features = false, features = ["native-certs", "runtime-tokio"] } -quinn-proto = { version = "0.10.1", default-features = false } +foreign-types-shared = "0.3" +lru = "0.12" +once_cell = "1" +quinn = { version = "0.11", default-features = false } +quinn-proto = { version = "0.11", default-features = false } rand = "0.8" tracing = "0.1" [dev-dependencies] -anyhow = "1.0.22" -assert_hex = "0.2.2" -assert_matches = "1.1" -clap = { version = "4.3", features = ["derive"] } +anyhow = "1.0" +assert_matches = "1" +clap = { version = "4", features = ["derive"] } directories-next = "2" -hex-literal = "0.4.1" -ring = "0.16.7" -rcgen = "0.11.1" -rustls-pemfile = "1.0.0" -tokio = { version = "1.0.1", features = ["rt", "rt-multi-thread", "time", "macros", "sync"] } -tracing-futures = { version = "0.2.0", default-features = false, features = ["std-future"] } -tracing-subscriber = { version = "0.3.0", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] } +hex-literal = "0.4" +rcgen = "0.13" +rustls-pemfile = "2" +tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "macros", "sync"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] } url = "2" [[example]] name = "server" +required-features = ["runtime-tokio"] [[example]] name = "client" +required-features = ["runtime-tokio"] diff --git a/examples/client.rs b/examples/client.rs index fc6789f..5be9e76 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -5,7 +5,7 @@ use std::{ fs, io::{self, Write}, - net::ToSocketAddrs, + net::{SocketAddr, ToSocketAddrs}, path::PathBuf, sync::Arc, time::{Duration, Instant}, @@ -22,20 +22,27 @@ use url::Url; #[derive(Parser, Debug)] #[clap(name = "client")] struct Opt { - #[arg(default_value = "https://localhost:4433/Cargo.toml")] + /// Perform NSS-compatible TLS key logging to the file specified in `SSLKEYLOGFILE`. + #[clap(long = "keylog")] + keylog: bool, + url: Url, /// Override hostname used for certificate verification - #[arg(long = "host")] + #[clap(long = "host")] host: Option, /// Custom certificate authority to trust, in DER format - #[arg(long = "ca")] + #[clap(long = "ca")] ca: Option, /// Simulate NAT rebinding after connecting - #[arg(long = "rebind")] + #[clap(long = "rebind")] rebind: bool, + + /// Address to bind on + #[clap(long = "bind", default_value = "[::]:0")] + bind: SocketAddr, } fn main() { @@ -48,19 +55,20 @@ fn main() { let opt = Opt::parse(); let code = { if let Err(e) = run(opt) { - eprintln!("ERROR: {}", e); + eprintln!("ERROR: {e}"); 1 } else { 0 } }; - std::process::exit(code); + ::std::process::exit(code); } #[tokio::main] async fn run(options: Opt) -> Result<()> { let url = options.url; - let remote = (url.host_str().unwrap(), url.port().unwrap_or(4433)) + let url_host = strip_ipv6_brackets(url.host_str().unwrap()); + let remote = (url_host, url.port().unwrap_or(4433)) .to_socket_addrs()? .next() .ok_or_else(|| anyhow!("couldn't resolve to an address"))?; @@ -91,19 +99,15 @@ async fn run(options: Opt) -> Result<()> { } } - let mut endpoint = quinn_boring::helpers::client_endpoint("[::]:0".parse().unwrap())?; + let mut endpoint = quinn_boring::helpers::client_endpoint(options.bind)?; endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(client_crypto))); let request = format!("GET {}\r\n", url.path()); let start = Instant::now(); let rebind = options.rebind; - let host = options - .host - .as_ref() - .map_or_else(|| url.host_str(), |x| Some(x)) - .ok_or_else(|| anyhow!("no hostname specified"))?; + let host = options.host.as_deref().unwrap_or(url_host); - eprintln!("connecting to {} at {}", host, remote); + eprintln!("connecting to {host} at {remote}"); let conn = endpoint .connect(remote, host)? .await @@ -116,16 +120,14 @@ async fn run(options: Opt) -> Result<()> { if rebind { let socket = std::net::UdpSocket::bind("[::]:0").unwrap(); let addr = socket.local_addr().unwrap(); - eprintln!("rebinding to {}", addr); + eprintln!("rebinding to {addr}"); endpoint.rebind(socket).expect("rebind failed"); } send.write_all(request.as_bytes()) .await .map_err(|e| anyhow!("failed to send request: {}", e))?; - send.finish() - .await - .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; + send.finish().unwrap(); let response_start = Instant::now(); eprintln!("request sent at {:?}", response_start - start); let resp = recv @@ -148,6 +150,16 @@ async fn run(options: Opt) -> Result<()> { Ok(()) } +fn strip_ipv6_brackets(host: &str) -> &str { + // An ipv6 url looks like eg https://[::1]:4433/Cargo.toml, wherein the host [::1] is the + // ipv6 address ::1 wrapped in brackets, per RFC 2732. This strips those. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } +} + fn duration_secs(x: &Duration) -> f32 { x.as_secs() as f32 + x.subsec_nanos() as f32 * 1e-9 } diff --git a/examples/server.rs b/examples/server.rs index ff3cd9a..a1927e3 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -15,27 +15,34 @@ use boring::pkey::PKey; use boring::x509::X509; use clap::Parser; use quinn_boring::QuicSslContext; -use tracing::{error, info, info_span}; -use tracing_futures::Instrument as _; +use tracing::{error, info, info_span, Instrument as _}; #[derive(Parser, Debug)] #[clap(name = "server")] struct Opt { + /// file to log TLS keys to for debugging + #[clap(long = "keylog")] + keylog: bool, /// directory to serve files from - #[arg(default_value = "./")] root: PathBuf, /// TLS private key in PEM format - #[arg(short = 'k', long = "key", requires = "cert")] + #[clap(short = 'k', long = "key", requires = "cert")] key: Option, /// TLS certificate in PEM format - #[arg(short = 'c', long = "cert", requires = "key")] + #[clap(short = 'c', long = "cert", requires = "key")] cert: Option, /// Enable stateless retries - #[arg(long = "stateless-retry")] + #[clap(long = "stateless-retry")] stateless_retry: bool, /// Address to listen on - #[arg(long = "listen", default_value = "127.0.0.1:4433")] + #[clap(long = "listen", default_value = "[::1]:4433")] listen: SocketAddr, + /// Client address to block + #[clap(long = "block")] + block: Option, + /// Maximum number of concurrent connections to allow + #[clap(long = "connection-limit")] + connection_limit: Option, } fn main() { @@ -48,7 +55,7 @@ fn main() { let opt = Opt::parse(); let code = { if let Err(e) = run(opt) { - eprintln!("ERROR: {}", e); + eprintln!("ERROR: {e}"); 1 } else { 0 @@ -64,31 +71,22 @@ async fn run(options: Opt) -> Result<()> { let key = if key_path.extension().map_or(false, |x| x == "der") { PKey::private_key_from_der(&key)? } else { - let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*key) - .context("malformed PKCS #8 private key")?; - match pkcs8.into_iter().next() { - Some(x) => PKey::private_key_from_der(&x)?, - None => { - let rsa = rustls_pemfile::rsa_private_keys(&mut &*key) - .context("malformed PKCS #1 private key")?; - match rsa.into_iter().next() { - Some(x) => PKey::private_key_from_der(&x)?, - None => { - bail!("no private keys found"); - } - } - } - } + let key = rustls_pemfile::private_key(&mut &*key) + .context("malformed PKCS #1 private key")? + .ok_or_else(|| anyhow::Error::msg("no private keys found"))?; + PKey::private_key_from_der(key.secret_der())? }; let cert_chain = fs::read(cert_path).context("failed to read certificate chain")?; let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") { vec![X509::from_der(&cert_chain)?] } else { - rustls_pemfile::certs(&mut &*cert_chain) - .context("invalid PEM-encoded certificate")? - .into_iter() - .map(|x| X509::from_der(&x).unwrap()) - .collect() + let mut certs = Vec::new(); + for cert in rustls_pemfile::certs(&mut &*cert_chain) { + let cert = cert.context("invalid PEM-encoded certificate")?; + let x509 = X509::from_der(&cert).context("invalid X509 cert")?; + certs.push(x509); + } + certs }; (cert_chain, key) @@ -100,12 +98,12 @@ async fn run(options: Opt) -> Result<()> { let key_path = path.join("key.der"); let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) { - Ok(x) => x, + Ok((cert, key)) => (cert, key), Err(ref e) if e.kind() == io::ErrorKind::NotFound => { info!("generating self-signed certificate"); let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let key = cert.serialize_private_key_der(); - let cert = cert.serialize_der().unwrap(); + let key = cert.key_pair.serialize_der(); + let cert = cert.cert.der().to_vec(); fs::create_dir_all(path).context("failed to create certificate directory")?; fs::write(&cert_path, &cert).context("failed to write certificate")?; fs::write(&key_path, &key).context("failed to write private key")?; @@ -137,12 +135,8 @@ async fn run(options: Opt) -> Result<()> { ctx.check_private_key()?; let mut server_config = quinn_boring::helpers::server_config(Arc::new(server_crypto))?; - Arc::get_mut(&mut server_config.transport) - .unwrap() - .max_concurrent_uni_streams(0_u8.into()); - if options.stateless_retry { - server_config.use_retry(true); - } + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(0_u8.into()); let root = Arc::::from(options.root.clone()); if !root.exists() { @@ -153,19 +147,33 @@ async fn run(options: Opt) -> Result<()> { eprintln!("listening on {}", endpoint.local_addr()?); while let Some(conn) = endpoint.accept().await { - info!("connection incoming"); - let fut = handle_connection(root.clone(), conn); - tokio::spawn(async move { - if let Err(e) = fut.await { - error!("connection failed: {reason}", reason = e.to_string()) - } - }); + if options + .connection_limit + .map_or(false, |n| endpoint.open_connections() >= n) + { + info!("refusing due to open connection limit"); + conn.refuse(); + } else if Some(conn.remote_address()) == options.block { + info!("refusing blocked client IP address"); + conn.refuse(); + } else if options.stateless_retry && !conn.remote_address_validated() { + info!("requiring connection to validate its address"); + conn.retry().unwrap(); + } else { + info!("accepting connection"); + let fut = handle_connection(root.clone(), conn); + tokio::spawn(async move { + if let Err(e) = fut.await { + error!("connection failed: {reason}", reason = e.to_string()) + } + }); + } } Ok(()) } -async fn handle_connection(root: Arc, conn: quinn::Connecting) -> Result<()> { +async fn handle_connection(root: Arc, conn: quinn::Incoming) -> Result<()> { let connection = conn.await?; let span = info_span!( "connection", @@ -226,16 +234,14 @@ async fn handle_request( // Execute the request let resp = process_get(&root, &req).unwrap_or_else(|e| { error!("failed: {}", e); - format!("failed to process request: {}\n", e).into_bytes() + format!("failed to process request: {e}\n").into_bytes() }); // Write the response send.write_all(&resp) .await .map_err(|e| anyhow!("failed to send response: {}", e))?; // Gracefully terminate the stream - send.finish() - .await - .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; + send.finish().unwrap(); info!("complete"); Ok(()) } diff --git a/src/client.rs b/src/client.rs index 291910d..2a27d48 100644 --- a/src/client.rs +++ b/src/client.rs @@ -171,12 +171,12 @@ impl Session { // Configure verification for the server hostname. ssl.set_verify_hostname(server_name) - .map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?; + .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?; // Set the SNI hostname. // TODO: should we validate the hostname? ssl.set_hostname(server_name) - .map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?; + .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?; // Set the transport parameters. ssl.set_quic_transport_params(&encode_params(params)) diff --git a/src/key.rs b/src/key.rs index 424743b..5f93abb 100644 --- a/src/key.rs +++ b/src/key.rs @@ -439,6 +439,10 @@ impl Debug for AeadKey { unsafe impl Send for AeadKey {} +// EVP_AEAD_CTX_seal & EVP_AEAD_CTX_open allowed to be called concurrently on the same instance of EVP_AEAD_CTX +// https://github.com/google/boringssl/blob/master/include/openssl/aead.h#L278 +unsafe impl Sync for AeadKey {} + impl AeadKey { #[inline] pub(crate) fn new(suite: &'static CipherSuite, key: Key) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 25858ea..fbed4de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,10 +45,7 @@ pub struct HandshakeData { pub mod helpers { use super::*; - use quinn::TokioRuntime; use quinn_proto::crypto; - use std::io; - use std::net::SocketAddr; use std::sync::Arc; /// Create a server config with the given [`crypto::ServerConfig`] @@ -78,13 +75,14 @@ pub mod helpers { /// IPv6 address on Windows will not by default be able to communicate with IPv4 /// addresses. Portable applications should bind an address that matches the family they wish to /// communicate within. - pub fn client_endpoint(addr: SocketAddr) -> io::Result { + #[cfg(feature = "runtime-tokio")] + pub fn client_endpoint(addr: std::net::SocketAddr) -> std::io::Result { let socket = std::net::UdpSocket::bind(addr)?; quinn::Endpoint::new( default_endpoint_config(), None, socket, - Arc::new(TokioRuntime), + Arc::new(quinn::TokioRuntime), ) } @@ -94,16 +92,17 @@ pub mod helpers { /// IPv6 address on Windows will not by default be able to communicate with IPv4 /// addresses. Portable applications should bind an address that matches the family they wish to /// communicate within. + #[cfg(feature = "runtime-tokio")] pub fn server_endpoint( config: quinn::ServerConfig, - addr: SocketAddr, - ) -> io::Result { + addr: std::net::SocketAddr, + ) -> std::io::Result { let socket = std::net::UdpSocket::bind(addr)?; quinn::Endpoint::new( default_endpoint_config(), Some(config), socket, - Arc::new(TokioRuntime), + Arc::new(quinn::TokioRuntime), ) } } diff --git a/src/server.rs b/src/server.rs index 0138c49..77427c4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -108,11 +108,10 @@ impl crypto::ServerConfig for Config { fn initial_keys( &self, version: u32, - dcid: &ConnectionId, - side: Side, + dst_cid: &ConnectionId, ) -> StdResult { let version = QuicVersion::parse(version)?; - let secrets = Secrets::initial(version, dcid, side).unwrap(); + let secrets = Secrets::initial(version, dst_cid, Side::Server).unwrap(); Ok(secrets.keys().unwrap().as_crypto().unwrap()) } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 860211c..2f28d17 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -116,9 +116,8 @@ async fn stateless_retry() -> Result<()> { let _guard = subscribe(); // Run the server. - let mut server_config = server_config(server_crypto())?; - server_config.use_retry(true); - let server = Server::run(server_config)?; + let server_config = server_config(server_crypto())?; + let server = Server::run_with_retry(server_config, true)?; // Connect the client. let client_config = client_config(client_crypto()); @@ -258,7 +257,8 @@ async fn zero_rtt_rejected() -> Result<()> { // Hack to allow us to sleep between creating the stream and sending the message. async fn send_ping(mut send: SendStream) -> std::result::Result<(), WriteError> { send.write_all(PING_MSG).await?; - send.finish().await?; + send.finish()?; + send.stopped().await?; Ok(()) } @@ -325,7 +325,8 @@ impl Client { async fn send_ping(&self) -> std::result::Result<(), WriteError> { let mut send = self.conn.open_uni().await?; send.write_all(PING_MSG).await?; - send.finish().await?; + send.finish()?; + send.stopped().await?; Ok(()) } @@ -344,6 +345,9 @@ struct Server { impl Server { fn run(server_config: quinn::ServerConfig) -> Result> { + Self::run_with_retry(server_config, false) + } + fn run_with_retry(server_config: quinn::ServerConfig, use_retry: bool) -> Result> { let endpoint = quinn_boring::helpers::server_endpoint(server_config, local_address())?; let addr = endpoint.local_addr()?; @@ -354,8 +358,27 @@ impl Server { let server2 = server.clone(); tokio::spawn(async move { - while let Some(conn) = endpoint.accept().await { - let server = server2.clone(); + while let Some(incoming) = endpoint.accept().await { + let server: Arc = server2.clone(); + if use_retry && !incoming.remote_address_validated() { + if let Err(e) = incoming.retry() { + error!( + "server: connection retry failed: {reason}", + reason = e.to_string() + ) + } + continue; + } + let conn = match incoming.accept() { + Ok(conn) => conn, + Err(e) => { + error!( + "server: connection accept failed: {reason}", + reason = e.to_string() + ); + continue; + } + }; tokio::spawn(async move { let fut = server.handle_connection(conn); if let Err(e) = fut.await { @@ -462,8 +485,10 @@ impl Server { .map_err(|e| anyhow!("failed to send response: {}", e))?; // Gracefully terminate the stream send.finish() - .await .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; + send.stopped() + .await + .map_err(|e| anyhow!("failed to stop stream: {}", e))?; Ok(()) } } @@ -523,24 +548,38 @@ fn server_crypto() -> ServerConfig { } /// Certificate Authority utility that can create new leaf certs. -struct Ca(rcgen::Certificate); +struct Ca(rcgen::CertifiedKey); impl Ca { /// Creates a new CA. fn new() -> Self { - let mut params = CertificateParams::new(&[] as &[String]); + let key_pair = rcgen::KeyPair::generate().expect("key pair generated"); + + let mut params = CertificateParams::default(); params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - let cert = rcgen::Certificate::from_params(params).unwrap(); - Self(cert) + params.key_usages = vec![ + rcgen::KeyUsagePurpose::DigitalSignature, + rcgen::KeyUsagePurpose::KeyEncipherment, + rcgen::KeyUsagePurpose::ContentCommitment, + ]; + + Self(rcgen::CertifiedKey { + cert: params.self_signed(&key_pair).unwrap(), + key_pair, + }) } /// Creates a new leaf cert signed by this CA. fn new_leaf(&self, subject_alt_names: impl Into>) -> Leaf { - let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap(); - let private_key = cert.serialize_private_key_der(); - let cert = cert.serialize_der_with_signer(&self.0).unwrap(); - let ca_cert = self.0.serialize_der().unwrap(); + let key_pair = rcgen::KeyPair::generate().unwrap(); + let certificate = CertificateParams::new(subject_alt_names) + .unwrap() + .signed_by(&key_pair, &self.0.cert, &self.0.key_pair) + .unwrap(); + let private_key = key_pair.serialize_der(); + let cert = certificate.der().to_vec(); + let ca_cert = self.0.cert.der().to_vec(); Leaf { private_key, cert,