diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index c4879d1d2..ce6890574 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -5,15 +5,23 @@ extern crate proto; use libfuzzer_sys::fuzz_target; use proto::{ fuzzing::{PacketParams, PartialDecode}, - RandomConnectionIdGenerator, DEFAULT_SUPPORTED_VERSIONS, + ConnectionIdParser, RandomConnectionIdGenerator, ZeroLengthConnectionIdParser, + DEFAULT_SUPPORTED_VERSIONS, }; fuzz_target!(|data: PacketParams| { let len = data.buf.len(); let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); + let cid_gen; if let Ok(decoded) = PartialDecode::new( data.buf, - &RandomConnectionIdGenerator::new(data.local_cid_len), + match data.local_cid_len { + 0 => &ZeroLengthConnectionIdParser as &dyn ConnectionIdParser, + _ => { + cid_gen = RandomConnectionIdGenerator::new(data.local_cid_len); + &cid_gen as &dyn ConnectionIdParser + } + }, &supported_versions, data.grease_quic_bit, ) { diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index fc465fc1e..73fe14d40 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -26,8 +26,6 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { Ok(()) } - /// Returns the length of a CID for connections created by this generator - fn cid_len(&self) -> usize; /// Returns the lifetime of generated Connection IDs /// /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant. @@ -63,6 +61,10 @@ impl RandomConnectionIdGenerator { /// The given length must be less than or equal to MAX_CID_SIZE. pub fn new(cid_len: usize) -> Self { debug_assert!(cid_len <= MAX_CID_SIZE); + assert!( + cid_len > 0, + "connection ID generators must produce non-empty IDs" + ); Self { cid_len, ..Self::default() @@ -92,11 +94,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator { ConnectionId::new(&bytes_arr[..self.cid_len]) } - /// Provide the length of dst_cid in short header packet - fn cid_len(&self) -> usize { - self.cid_len - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -173,10 +170,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { } } - fn cid_len(&self) -> usize { - HASHED_CID_LEN - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -186,6 +179,31 @@ const NONCE_LEN: usize = 3; // Good for more than 16 million connections const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN; +/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable +/// +/// CID generators should produce nonempty CIDs. We should be able to use +/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require +/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`. +/// +/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991 +pub(crate) struct ZeroLengthConnectionIdGenerator; + +impl ConnectionIdParser for ZeroLengthConnectionIdGenerator { + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + +impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator { + fn generate_cid(&self) -> ConnectionId { + unreachable!() + } + + fn cid_lifetime(&self) -> Option { + None + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 36a281b5f..650d6e2cf 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -616,7 +616,7 @@ impl Default for MtuDiscoveryConfig { pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, - pub(crate) connection_id_generator: Arc, + pub(crate) connection_id_generator: Option>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets @@ -629,7 +629,7 @@ impl EndpointConfig { Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers - connection_id_generator: Arc::::default(), + connection_id_generator: Some(Arc::::default()), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, min_reset_interval: Duration::from_millis(20), @@ -644,7 +644,10 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator(&mut self, generator: Arc) -> &mut Self { + pub fn cid_generator( + &mut self, + generator: Option>, + ) -> &mut Self { self.connection_id_generator = generator; self } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index aaf8bb5d3..69ebbba58 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -15,13 +15,12 @@ use thiserror::Error; use tracing::{debug, error, trace, trace_span, warn}; use crate::{ - cid_generator::ConnectionIdGenerator, + cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator}, cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame, - frame::{Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct}, packet::{ Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -197,7 +196,7 @@ pub struct Connection { retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, - cid_gen: Arc, + cid_gen: Option>, // // Queued non-retransmittable 1-RTT data @@ -253,7 +252,7 @@ impl Connection { remote: SocketAddr, local_ip: Option, crypto: Box, - cid_gen: Arc, + cid_gen: Option>, now: Instant, version: u32, allow_mtud: bool, @@ -281,14 +280,13 @@ impl Connection { crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, - local_cid_state: match cid_gen.cid_len() { - 0 => None, - _ => Some(CidState::new( - cid_gen.cid_lifetime(), + local_cid_state: cid_gen.as_ref().map(|gen| { + CidState::new( + gen.cid_lifetime(), now, if pref_addr_cid.is_some() { 2 } else { 1 }, - )), - }, + ) + }), path: PathData::new(remote, allow_mtud, None, now, path_validated, &config), allow_mtud, local_ip, @@ -2103,7 +2101,10 @@ impl Connection { while let Some(data) = remaining { match PartialDecode::new( data, - &*self.cid_gen, + self.cid_gen.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &[self.version], self.endpoint_config.grease_quic_bit, ) { diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index cdd5888fa..d1e483dcf 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -16,7 +16,9 @@ use thiserror::Error; use tracing::{debug, error, trace, warn}; use crate::{ - cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, + cid_generator::{ + ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator, + }, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, connection::{Connection, ConnectionError}, @@ -44,7 +46,7 @@ pub struct Endpoint { rng: StdRng, index: ConnectionIndex, connections: Slab, - local_cid_generator: Arc, + local_cid_generator: Option>, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets @@ -144,7 +146,10 @@ impl Endpoint { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, - &*self.local_cid_generator, + self.local_cid_generator.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &self.config.supported_versions, self.config.grease_quic_bit, ) { @@ -302,8 +307,8 @@ impl Endpoint { if !first_decode.is_initial() && self .local_cid_generator - .validate(first_decode.dst_cid()) - .is_err() + .as_ref() + .map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err()) { debug!("dropping packet with invalid CID"); return None; @@ -400,7 +405,7 @@ impl Endpoint { let params = TransportParameters::new( &config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, None, ); @@ -453,12 +458,11 @@ impl Endpoint { /// Generate a connection ID for `ch` fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId { loop { - let cid = self.local_cid_generator.generate_cid(); - if cid.len() == 0 { + let Some(cid_generator) = self.local_cid_generator.as_ref() else { // Zero-length CID; nothing to track - debug_assert_eq!(self.local_cid_generator.cid_len(), 0); - return cid; - } + return ConnectionId::EMPTY; + }; + let cid = cid_generator.generate_cid(); if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) { e.insert(ch); break cid; @@ -589,7 +593,7 @@ impl Endpoint { let mut params = TransportParameters::new( &server_config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, Some(&server_config), ); @@ -680,10 +684,7 @@ impl Endpoint { // bytes. If this is a Retry packet, then the length must instead match our usual CID // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll // also need to validate CID length for those after decoding the token. - if header.dst_cid.len() < 8 - && (!header.token_pos.is_empty() - && header.dst_cid.len() != self.local_cid_generator.cid_len()) - { + if header.dst_cid.len() < 8 && !header.token_pos.is_empty() { debug!( "rejecting connection due to invalid DCID length {}", header.dst_cid.len() @@ -730,7 +731,10 @@ impl Endpoint { // with established connections. In the unlikely event that a collision occurs // between two connections in the initial phase, both will fail fast and may be // retried by the application layer. - let loc_cid = self.local_cid_generator.generate_cid(); + let loc_cid = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let token = RetryToken { orig_dst_cid: incoming.packet.header.dst_cid, @@ -860,7 +864,10 @@ impl Endpoint { // We don't need to worry about CID collisions in initial closes because the peer // shouldn't respond, and if it does, and the CID collides, we'll just drop the // unexpected response. - let local_id = self.local_cid_generator.generate_cid(); + let local_id = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let number = PacketNumber::U8(0); let header = Header::Initial(InitialHeader { dst_cid: *remote_id, diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 5a84bdcf5..ae5a4d2c7 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -67,7 +67,7 @@ pub use crate::endpoint::{ mod packet; pub use packet::{ ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader, - ProtectedInitialHeader, + ProtectedInitialHeader, ZeroLengthConnectionIdParser, }; mod shared; diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index 37ed0b06f..328988e0e 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -773,6 +773,16 @@ pub trait ConnectionIdParser { fn parse(&self, buf: &mut dyn Buf) -> Result; } +/// Trivial parser for zero-length connection IDs +pub struct ZeroLengthConnectionIdParser; + +impl ConnectionIdParser for ZeroLengthConnectionIdParser { + #[inline] + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { @@ -908,7 +918,7 @@ mod tests { #[test] fn header_encoding() { use crate::crypto::rustls::{initial_keys, initial_suite_from_provider}; - use crate::{RandomConnectionIdGenerator, Side}; + use crate::Side; use rustls::crypto::ring::default_provider; use rustls::quic::Version; @@ -950,7 +960,7 @@ mod tests { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new( buf.as_slice().into(), - &RandomConnectionIdGenerator::new(0), + &ZeroLengthConnectionIdParser, &supported_versions, false, ) diff --git a/quinn-proto/src/shared.rs b/quinn-proto/src/shared.rs index 05ffe8caf..f6dac3652 100644 --- a/quinn-proto/src/shared.rs +++ b/quinn-proto/src/shared.rs @@ -72,6 +72,12 @@ pub struct ConnectionId { } impl ConnectionId { + /// The zero-length connection ID + pub const EMPTY: Self = Self { + len: 0, + bytes: [0; MAX_CID_SIZE], + }; + /// Construct cid from byte array pub fn new(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index da117e730..8b4cc0180 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -66,7 +66,7 @@ fn version_negotiate_client() { // packet let mut client = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), + connection_id_generator: None, ..Default::default() }), None, @@ -181,7 +181,7 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -211,7 +211,7 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -240,7 +240,7 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(Arc::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(Some(Arc::new(RandomConnectionIdGenerator::new(8)))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1468,7 +1468,7 @@ fn zero_length_cid() { let _guard = subscribe(); let mut pair = Pair::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), + connection_id_generator: None, ..EndpointConfig::default() }), server_config(), @@ -1525,9 +1525,9 @@ fn cid_rotation() { // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new( + connection_id_generator: Some(Arc::new( *RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT), - ), + )), ..EndpointConfig::default() }), Some(Arc::new(server_config())), diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index 3a1a443a2..2dc0a916d 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -15,7 +15,6 @@ use bytes::{Buf, BufMut}; use thiserror::Error; use crate::{ - cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, config::{EndpointConfig, ServerConfig, TransportConfig}, @@ -132,7 +131,7 @@ impl TransportParameters { pub(crate) fn new( config: &TransportConfig, endpoint_config: &EndpointConfig, - cid_gen: &dyn ConnectionIdGenerator, + use_cids: bool, initial_src_cid: ConnectionId, server_config: Option<&ServerConfig>, ) -> Self { @@ -147,7 +146,7 @@ impl TransportParameters { max_udp_payload_size: endpoint_config.max_udp_payload_size, max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)), disable_active_migration: server_config.map_or(false, |c| !c.migration), - active_connection_id_limit: if cid_gen.cid_len() == 0 { + active_connection_id_limit: if !use_cids { 2 // i.e. default, i.e. unsent } else { CidQueue::LEN as u32 diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index f1fd1b49f..13b84fea5 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -10,7 +10,7 @@ use std::{ use crate::runtime::TokioRuntime; use bytes::Bytes; -use proto::{crypto::rustls::QuicClientConfig, RandomConnectionIdGenerator}; +use proto::crypto::rustls::QuicClientConfig; use rand::{rngs::StdRng, RngCore, SeedableRng}; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, @@ -817,9 +817,7 @@ async fn two_datagram_readers() { async fn multiple_conns_with_zero_length_cids() { let _guard = subscribe(); let mut factory = EndpointFactory::new(); - factory - .endpoint_config - .cid_generator(Arc::new(RandomConnectionIdGenerator::new(0))); + factory.endpoint_config.cid_generator(None); let server = { let _guard = error_span!("server").entered(); factory.endpoint()