From b4c20d44b5f338cbf72aac363755bcb47a3deb88 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Wed, 22 May 2024 16:35:04 -0700 Subject: [PATCH 1/9] Require CID generators to also be CID parsers --- quinn-proto/src/cid_generator.rs | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index 6173e96f9..494ac1f09 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -1,12 +1,13 @@ use std::{hash::Hasher, time::Duration}; +use bytes::Buf; use rand::{Rng, RngCore}; use crate::shared::ConnectionId; -use crate::MAX_CID_SIZE; +use crate::{ConnectionIdParser, PacketDecodeError, MAX_CID_SIZE}; /// Generates connection IDs for incoming connections -pub trait ConnectionIdGenerator: Send + Sync { +pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { /// Generates a new CID /// /// Connection IDs MUST NOT contain any information that can be used by @@ -73,6 +74,14 @@ impl RandomConnectionIdGenerator { } } +impl ConnectionIdParser for RandomConnectionIdGenerator { + fn parse(&self, buffer: &mut dyn Buf) -> Result { + (buffer.remaining() >= self.cid_len) + .then(|| ConnectionId::from_buf(buffer, self.cid_len)) + .ok_or(PacketDecodeError::InvalidHeader("packet too small")) + } +} + impl ConnectionIdGenerator for RandomConnectionIdGenerator { fn generate_cid(&mut self) -> ConnectionId { let mut bytes_arr = [0; MAX_CID_SIZE]; @@ -131,9 +140,17 @@ impl Default for HashedConnectionIdGenerator { } } +impl ConnectionIdParser for HashedConnectionIdGenerator { + fn parse(&self, buffer: &mut dyn Buf) -> Result { + (buffer.remaining() >= HASHED_CID_LEN) + .then(|| ConnectionId::from_buf(buffer, HASHED_CID_LEN)) + .ok_or(PacketDecodeError::InvalidHeader("packet too small")) + } +} + impl ConnectionIdGenerator for HashedConnectionIdGenerator { fn generate_cid(&mut self) -> ConnectionId { - let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN]; + let mut bytes_arr = [0; HASHED_CID_LEN]; rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]); let mut hasher = rustc_hash::FxHasher::default(); hasher.write_u64(self.key); @@ -155,7 +172,7 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { } fn cid_len(&self) -> usize { - NONCE_LEN + SIGNATURE_LEN + HASHED_CID_LEN } fn cid_lifetime(&self) -> Option { @@ -165,6 +182,7 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { const NONCE_LEN: usize = 3; // Good for more than 16 million connections const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length +const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN; #[cfg(test)] mod tests { From b897e057f7b929e3c30b0489979c4cce4a184275 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Wed, 22 May 2024 16:42:18 -0700 Subject: [PATCH 2/9] Use CID generator as parser in endpoint --- quinn-proto/src/cid_generator.rs | 4 +++- quinn-proto/src/endpoint.rs | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index 494ac1f09..1cb489de9 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -19,7 +19,9 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { /// Quickly determine whether `cid` could have been generated by this generator /// - /// False positives are permitted, but increase the cost of handling invalid packets. + /// False positives are permitted, but increase the cost of handling invalid packets. The input + /// CID is guaranteed to have been obtained from a successful call to the generator's + /// implementation of [`ConnectionIdParser::parse`]. fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> { Ok(()) } diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index e4df5787d..6901e9490 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -23,8 +23,8 @@ use crate::{ crypto::{self, Keys, UnsupportedVersion}, frame, packet::{ - FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, Packet, - PacketDecodeError, PacketNumber, PartialDecode, ProtectedInitialHeader, + Header, InitialHeader, InitialPacket, Packet, PacketDecodeError, PacketNumber, + PartialDecode, ProtectedInitialHeader, }, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, @@ -144,7 +144,7 @@ impl Endpoint { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, - &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()), + &*self.local_cid_generator, &self.config.supported_versions, self.config.grease_quic_bit, ) { From b8f425ef28b2554781cefd98b70c41cbbb98dc3d Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Wed, 22 May 2024 16:48:47 -0700 Subject: [PATCH 3/9] Allow generating CIDs from a shared CID generator Allows generators to be shared with connections, and removes an obstacle to parallelizing endpoint work in the future. --- quinn-proto/src/cid_generator.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index 1cb489de9..fc465fc1e 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -15,7 +15,7 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { /// issuer) to correlate them with other connection IDs for the same /// connection. They MUST have high entropy, e.g. due to encrypted data /// or cryptographic-grade random data. - fn generate_cid(&mut self) -> ConnectionId; + fn generate_cid(&self) -> ConnectionId; /// Quickly determine whether `cid` could have been generated by this generator /// @@ -85,7 +85,7 @@ impl ConnectionIdParser for RandomConnectionIdGenerator { } impl ConnectionIdGenerator for RandomConnectionIdGenerator { - fn generate_cid(&mut self) -> ConnectionId { + fn generate_cid(&self) -> ConnectionId { let mut bytes_arr = [0; MAX_CID_SIZE]; rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]); @@ -151,7 +151,7 @@ impl ConnectionIdParser for HashedConnectionIdGenerator { } impl ConnectionIdGenerator for HashedConnectionIdGenerator { - fn generate_cid(&mut self) -> ConnectionId { + fn generate_cid(&self) -> ConnectionId { let mut bytes_arr = [0; HASHED_CID_LEN]; rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]); let mut hasher = rustc_hash::FxHasher::default(); @@ -193,7 +193,7 @@ mod tests { #[test] #[cfg(feature = "ring")] fn validate_keyed_cid() { - let mut generator = HashedConnectionIdGenerator::new(); + let generator = HashedConnectionIdGenerator::new(); let cid = generator.generate_cid(); generator.validate(&cid).unwrap(); } From 97e14bfba169da4b6292204eaa57517af902ba43 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Wed, 22 May 2024 16:55:49 -0700 Subject: [PATCH 4/9] Use CID generator as parser in connections --- quinn-proto/src/config.rs | 6 +++--- quinn-proto/src/connection/mod.rs | 10 ++++++---- quinn-proto/src/endpoint.rs | 4 ++-- quinn-proto/src/tests/mod.rs | 18 +++++++++--------- quinn/src/tests.rs | 2 +- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 6d53fd590..7238388c0 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -620,7 +620,7 @@ pub struct EndpointConfig { /// /// Create a cid generator for local cid in Endpoint struct pub(crate) connection_id_generator_factory: - Arc Box + Send + Sync>, + Arc Arc + Send + Sync>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets @@ -631,7 +631,7 @@ impl EndpointConfig { /// Create a default config with a particular `reset_key` pub fn new(reset_key: Arc) -> Self { let cid_factory = - || -> Box { Box::::default() }; + || -> Arc { Arc::::default() }; Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers @@ -650,7 +650,7 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator Box + Send + Sync + 'static>( + pub fn cid_generator Arc + Send + Sync + 'static>( &mut self, factory: F, ) -> &mut Self { diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index a0ed44574..a5d498c57 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -23,8 +23,8 @@ use crate::{ frame, frame::{Close, Datagram, FrameStruct}, packet::{ - FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, - PacketNumber, PartialDecode, SpaceId, + Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, + SpaceId, }, range_set::ArrayRangeSet, shared::{ @@ -197,6 +197,7 @@ pub struct Connection { retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, + cid_gen: Arc, // // Queued non-retransmittable 1-RTT data @@ -252,7 +253,7 @@ impl Connection { remote: SocketAddr, local_ip: Option, crypto: Box, - cid_gen: &dyn ConnectionIdGenerator, + cid_gen: Arc, now: Instant, version: u32, allow_mtud: bool, @@ -329,6 +330,7 @@ impl Connection { }, #[cfg(not(test))] packet_number_filter: PacketNumberFilter::new(&mut rng), + cid_gen, path_responses: PathResponses::default(), close: false, @@ -2101,7 +2103,7 @@ impl Connection { while let Some(data) = remaining { match PartialDecode::new( data, - &FixedLengthConnectionIdParser::new(self.local_cid_state.cid_len()), + &*self.cid_gen, &[self.version], self.endpoint_config.grease_quic_bit, ) { diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 6901e9490..1b9892f7a 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -44,7 +44,7 @@ pub struct Endpoint { rng: StdRng, index: ConnectionIndex, connections: Slab, - local_cid_generator: Box, + local_cid_generator: Arc, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets @@ -833,7 +833,7 @@ impl Endpoint { addresses.remote, addresses.local_ip, tls, - self.local_cid_generator.as_ref(), + self.local_cid_generator.clone(), now, version, self.allow_mtud, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 272bf3854..1e9c30008 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -66,8 +66,8 @@ fn version_negotiate_client() { let server_addr = "[::2]:7890".parse().unwrap(); // Configure client to use empty CIDs so we can easily hardcode a server version negotiation // packet - let cid_generator_factory: fn() -> Box = - || Box::new(RandomConnectionIdGenerator::new(0)); + let cid_generator_factory: fn() -> Arc = + || Arc::new(RandomConnectionIdGenerator::new(0)); let mut client = Endpoint::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), @@ -185,7 +185,7 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(move || Arc::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -215,7 +215,7 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(move || Arc::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -244,7 +244,7 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(move || Box::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(move || Arc::new(RandomConnectionIdGenerator::new(8))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1470,8 +1470,8 @@ fn implicit_open() { #[test] fn zero_length_cid() { let _guard = subscribe(); - let cid_generator_factory: fn() -> Box = - || Box::new(RandomConnectionIdGenerator::new(0)); + let cid_generator_factory: fn() -> Arc = + || Arc::new(RandomConnectionIdGenerator::new(0)); let mut pair = Pair::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), @@ -1528,8 +1528,8 @@ fn cid_rotation() { let _guard = subscribe(); const CID_TIMEOUT: Duration = Duration::from_secs(2); - let cid_generator_factory: fn() -> Box = - || Box::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); + let cid_generator_factory: fn() -> Arc = + || Arc::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index c9c7f768f..d7cb09025 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -819,7 +819,7 @@ async fn multiple_conns_with_zero_length_cids() { let mut factory = EndpointFactory::new(); factory .endpoint_config - .cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0))); + .cid_generator(|| Arc::new(RandomConnectionIdGenerator::new(0))); let server = { let _guard = error_span!("server").entered(); factory.endpoint() From ea6480166e9c42ec5ead095bdbb350a01568475d Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 25 May 2024 10:33:24 -0700 Subject: [PATCH 5/9] Configure connection ID generators directly Now that CID generators can have shared ownership, there's no need to indirect through a factory function in the config. --- quinn-proto/src/config.rs | 17 ++++------------- quinn-proto/src/endpoint.rs | 2 +- quinn-proto/src/tests/mod.rs | 27 ++++++++++----------------- quinn/src/tests.rs | 2 +- 4 files changed, 16 insertions(+), 32 deletions(-) diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 7238388c0..36a281b5f 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -616,11 +616,7 @@ impl Default for MtuDiscoveryConfig { pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, - /// CID generator factory - /// - /// Create a cid generator for local cid in Endpoint struct - pub(crate) connection_id_generator_factory: - Arc Arc + Send + Sync>, + pub(crate) connection_id_generator: Arc, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets @@ -630,12 +626,10 @@ pub struct EndpointConfig { impl EndpointConfig { /// Create a default config with a particular `reset_key` pub fn new(reset_key: Arc) -> Self { - let cid_factory = - || -> Arc { Arc::::default() }; Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers - connection_id_generator_factory: Arc::new(cid_factory), + connection_id_generator: Arc::::default(), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, min_reset_interval: Duration::from_millis(20), @@ -650,11 +644,8 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator Arc + Send + Sync + 'static>( - &mut self, - factory: F, - ) -> &mut Self { - self.connection_id_generator_factory = Arc::new(factory); + pub fn cid_generator(&mut self, generator: Arc) -> &mut Self { + self.connection_id_generator = generator; self } diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 1b9892f7a..6e05f91a5 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -72,7 +72,7 @@ impl Endpoint { rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed), index: ConnectionIndex::default(), connections: Slab::new(), - local_cid_generator: (config.connection_id_generator_factory.as_ref())(), + local_cid_generator: config.connection_id_generator.clone(), config, server_config, allow_mtud, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 1e9c30008..da117e730 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -20,10 +20,8 @@ use tracing::info; use super::*; use crate::{ - cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, - crypto::rustls::QuicServerConfig, - frame::FrameStruct, - transport_parameters::TransportParameters, + cid_generator::RandomConnectionIdGenerator, crypto::rustls::QuicServerConfig, + frame::FrameStruct, transport_parameters::TransportParameters, }; mod util; use util::*; @@ -66,11 +64,9 @@ fn version_negotiate_client() { let server_addr = "[::2]:7890".parse().unwrap(); // Configure client to use empty CIDs so we can easily hardcode a server version negotiation // packet - let cid_generator_factory: fn() -> Arc = - || Arc::new(RandomConnectionIdGenerator::new(0)); let mut client = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), ..Default::default() }), None, @@ -185,7 +181,7 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -215,7 +211,7 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -244,7 +240,7 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(move || Arc::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(Arc::new(RandomConnectionIdGenerator::new(8))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1470,11 +1466,9 @@ fn implicit_open() { #[test] fn zero_length_cid() { let _guard = subscribe(); - let cid_generator_factory: fn() -> Arc = - || Arc::new(RandomConnectionIdGenerator::new(0)); let mut pair = Pair::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), ..EndpointConfig::default() }), server_config(), @@ -1528,13 +1522,12 @@ fn cid_rotation() { let _guard = subscribe(); const CID_TIMEOUT: Duration = Duration::from_secs(2); - let cid_generator_factory: fn() -> Arc = - || Arc::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); - // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator_factory: Arc::new(cid_generator_factory), + connection_id_generator: Arc::new( + *RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT), + ), ..EndpointConfig::default() }), Some(Arc::new(server_config())), diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index d7cb09025..f1fd1b49f 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -819,7 +819,7 @@ async fn multiple_conns_with_zero_length_cids() { let mut factory = EndpointFactory::new(); factory .endpoint_config - .cid_generator(|| Arc::new(RandomConnectionIdGenerator::new(0))); + .cid_generator(Arc::new(RandomConnectionIdGenerator::new(0))); let server = { let _guard = error_span!("server").entered(); factory.endpoint() From 943a27119e35f8d3d9ccdde0410335fb88deb62e Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Wed, 22 May 2024 16:57:31 -0700 Subject: [PATCH 6/9] Remove superfluous freestanding connection ID parser --- fuzz/fuzz_targets/packet.rs | 4 ++-- quinn-proto/src/lib.rs | 4 ++-- quinn-proto/src/packet.rs | 24 ++---------------------- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index a8320a87a..c4879d1d2 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -5,7 +5,7 @@ extern crate proto; use libfuzzer_sys::fuzz_target; use proto::{ fuzzing::{PacketParams, PartialDecode}, - FixedLengthConnectionIdParser, DEFAULT_SUPPORTED_VERSIONS, + RandomConnectionIdGenerator, DEFAULT_SUPPORTED_VERSIONS, }; fuzz_target!(|data: PacketParams| { @@ -13,7 +13,7 @@ fuzz_target!(|data: PacketParams| { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); if let Ok(decoded) = PartialDecode::new( data.buf, - &FixedLengthConnectionIdParser::new(data.local_cid_len), + &RandomConnectionIdGenerator::new(data.local_cid_len), &supported_versions, data.grease_quic_bit, ) { diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index bb39101b7..5a84bdcf5 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -66,8 +66,8 @@ pub use crate::endpoint::{ mod packet; pub use packet::{ - ConnectionIdParser, FixedLengthConnectionIdParser, LongType, PacketDecodeError, PartialDecode, - ProtectedHeader, ProtectedInitialHeader, + ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader, + ProtectedInitialHeader, }; mod shared; diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index be132a74f..37ed0b06f 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -767,26 +767,6 @@ impl PacketNumber { } } -/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length -pub struct FixedLengthConnectionIdParser { - expected_len: usize, -} - -impl FixedLengthConnectionIdParser { - /// Create a new instance of `FixedLengthConnectionIdParser` - pub fn new(expected_len: usize) -> Self { - Self { expected_len } - } -} - -impl ConnectionIdParser for FixedLengthConnectionIdParser { - fn parse(&self, buffer: &mut dyn Buf) -> Result { - (buffer.remaining() >= self.expected_len) - .then(|| ConnectionId::from_buf(buffer, self.expected_len)) - .ok_or(PacketDecodeError::InvalidHeader("packet too small")) - } -} - /// Parse connection id in short header packet pub trait ConnectionIdParser { /// Parse a connection id from given buffer @@ -928,7 +908,7 @@ mod tests { #[test] fn header_encoding() { use crate::crypto::rustls::{initial_keys, initial_suite_from_provider}; - use crate::Side; + use crate::{RandomConnectionIdGenerator, Side}; use rustls::crypto::ring::default_provider; use rustls::quic::Version; @@ -970,7 +950,7 @@ mod tests { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new( buf.as_slice().into(), - &FixedLengthConnectionIdParser::new(0), + &RandomConnectionIdGenerator::new(0), &supported_versions, false, ) From b34657fa192e27124023809f4a5b8c00af556c3d Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 25 May 2024 11:02:24 -0700 Subject: [PATCH 7/9] Remove CID exhaustion check We can't track this if the size of the CID space isn't readily known, and we already implicitly rely on `new_cid` being infallible. --- quinn-proto/src/endpoint.rs | 38 +------------------------------------ 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 6e05f91a5..cdd5888fa 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -385,9 +385,6 @@ impl Endpoint { remote: SocketAddr, server_name: &str, ) -> Result<(ConnectionHandle, Connection), ConnectError> { - if self.cids_exhausted() { - return Err(ConnectError::CidsExhausted); - } if remote.port() == 0 || remote.ip().is_unspecified() { return Err(ConnectError::InvalidRemoteAddress(remote)); } @@ -565,22 +562,6 @@ impl Endpoint { .. } = incoming.packet.header; - if self.cids_exhausted() { - debug!("refusing connection"); - self.index.remove_initial(incoming.orig_dst_cid); - return Err(AcceptError { - cause: ConnectionError::CidsExhausted, - response: Some(self.initial_close( - version, - incoming.addresses, - &incoming.crypto, - &src_cid, - TransportError::CONNECTION_REFUSED(""), - buf, - )), - }); - } - let server_config = server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone()); @@ -691,7 +672,7 @@ impl Endpoint { header: &ProtectedInitialHeader, ) -> Result<(), TransportError> { let config = &self.server_config.as_ref().unwrap(); - if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming { + if self.incoming_buffers.len() >= config.max_incoming { return Err(TransportError::CONNECTION_REFUSED("")); } @@ -930,18 +911,6 @@ impl Endpoint { pub(crate) fn known_cids(&self) -> usize { self.index.connection_ids.len() } - - /// Whether we've used up 3/4 of the available CID space - /// - /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't - /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot. - fn cids_exhausted(&self) -> bool { - self.local_cid_generator.cid_len() <= 4 - && self.local_cid_generator.cid_len() != 0 - && (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8) - - self.index.connection_ids.len()) - < 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2) - } } impl fmt::Debug for Endpoint { @@ -1229,11 +1198,6 @@ pub enum ConnectError { /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled. #[error("endpoint stopping")] EndpointStopping, - /// The connection could not be created because not enough of the CID space is available - /// - /// Try using longer connection IDs - #[error("CIDs exhausted")] - CidsExhausted, /// The given server name was malformed #[error("invalid server name: {0}")] InvalidServerName(String), From 61dbea64da386ec24138a3b53fb362cfe0385ff0 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 25 May 2024 11:16:08 -0700 Subject: [PATCH 8/9] Only populate local_cid_state when CIDs are in use --- quinn-proto/src/connection/cid_state.rs | 20 +----- quinn-proto/src/connection/mod.rs | 90 ++++++++++++++----------- 2 files changed, 50 insertions(+), 60 deletions(-) diff --git a/quinn-proto/src/connection/cid_state.rs b/quinn-proto/src/connection/cid_state.rs index abf577ae7..46267ddd7 100644 --- a/quinn-proto/src/connection/cid_state.rs +++ b/quinn-proto/src/connection/cid_state.rs @@ -21,19 +21,12 @@ pub(super) struct CidState { prev_retire_seq: u64, /// Sequence number to set in retire_prior_to field in NEW_CONNECTION_ID frame retire_seq: u64, - /// cid length used to decode short packet - cid_len: usize, //// cid lifetime cid_lifetime: Option, } impl CidState { - pub(crate) fn new( - cid_len: usize, - cid_lifetime: Option, - now: Instant, - issued: u64, - ) -> Self { + pub(crate) fn new(cid_lifetime: Option, now: Instant, issued: u64) -> Self { let mut active_seq = FxHashSet::default(); // Add sequence number of CIDs used in handshaking into tracking set for seq in 0..issued { @@ -45,7 +38,6 @@ impl CidState { active_seq, prev_retire_seq: 0, retire_seq: 0, - cid_len, cid_lifetime, }; // Track lifetime of CIDs used in handshaking @@ -158,11 +150,6 @@ impl CidState { sequence: u64, limit: u64, ) -> Result { - if self.cid_len == 0 { - return Err(TransportError::PROTOCOL_VIOLATION( - "RETIRE_CONNECTION_ID when CIDs aren't in use", - )); - } if sequence > self.issued { debug!( sequence, @@ -181,11 +168,6 @@ impl CidState { Ok(limit > self.active_seq.len() as u64) } - /// Length of local Connection IDs - pub(crate) fn cid_len(&self) -> usize { - self.cid_len - } - /// The value for `retire_prior_to` field in `NEW_CONNECTION_ID` frame pub(crate) fn retire_prior_to(&self) -> u64 { self.retire_seq diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index a5d498c57..aaf8bb5d3 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -231,8 +231,8 @@ pub struct Connection { streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, - // Attributes of CIDs generated by local peer - local_cid_state: CidState, + /// Attributes of CIDs generated by local peer, if in use + local_cid_state: Option, /// State of the unreliable datagram extension datagrams: DatagramState, /// Connection level statistics @@ -281,12 +281,14 @@ impl Connection { crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, - local_cid_state: CidState::new( - cid_gen.cid_len(), - cid_gen.cid_lifetime(), - now, - if pref_addr_cid.is_some() { 2 } else { 1 }, - ), + local_cid_state: match cid_gen.cid_len() { + 0 => None, + _ => Some(CidState::new( + cid_gen.cid_lifetime(), + now, + if pref_addr_cid.is_some() { 2 } else { 1 }, + )), + }, path: PathData::new(remote, allow_mtud, None, now, path_validated, &config), allow_mtud, local_ip, @@ -1088,7 +1090,8 @@ impl Connection { } } NewIdentifiers(ids, now) => { - self.local_cid_state.new_cids(&ids, now); + let cid_state = self.local_cid_state.as_mut().unwrap(); + cid_state.new_cids(&ids, now); ids.into_iter().rev().for_each(|frame| { self.spaces[SpaceId::Data].pending.new_cids.push(frame); }); @@ -1098,7 +1101,9 @@ impl Connection { .get(Timer::PushNewCid) .map_or(true, |x| x <= now) { - self.reset_cid_retirement(); + if let Some(t) = cid_state.next_timeout() { + self.timers.set(Timer::PushNewCid, t); + } } } } @@ -1149,12 +1154,13 @@ impl Connection { } Timer::Pacing => trace!("pacing timer expired"), Timer::PushNewCid => { + let cid_state = self.local_cid_state.as_mut().unwrap(); // Update `retire_prior_to` field in NEW_CONNECTION_ID frame - let num_new_cid = self.local_cid_state.on_cid_timeout().into(); + let num_new_cid = cid_state.on_cid_timeout().into(); if !self.state.is_closed() { trace!( "push a new cid to peer RETIRE_PRIOR_TO field {}", - self.local_cid_state.retire_prior_to() + cid_state.retire_prior_to() ); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid)); @@ -1860,12 +1866,6 @@ impl Connection { self.timers.set(Timer::KeepAlive, now + interval); } - fn reset_cid_retirement(&mut self) { - if let Some(t) = self.local_cid_state.next_timeout() { - self.timers.set(Timer::PushNewCid, t); - } - } - /// Handle the already-decrypted first packet from the client /// /// Decrypting the first packet in the `Endpoint` allows stateless packet handling to be more @@ -2756,8 +2756,12 @@ impl Connection { self.streams.received_stop_sending(id, error_code); } Frame::RetireConnectionId { sequence } => { - let allow_more_cids = self - .local_cid_state + let cid_state = self.local_cid_state.as_mut().ok_or_else(|| { + TransportError::PROTOCOL_VIOLATION( + "RETIRE_CONNECTION_ID when CIDs aren't in use", + ) + })?; + let allow_more_cids = cid_state .on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?; self.endpoint_events .push_back(EndpointEventInner::RetireConnectionId( @@ -2999,7 +3003,7 @@ impl Connection { /// Issue an initial set of connection IDs to the peer upon connection fn issue_first_cids(&mut self, now: Instant) { - if self.local_cid_state.cid_len() == 0 { + if self.local_cid_state.is_none() { return; } @@ -3169,25 +3173,27 @@ impl Connection { } // NEW_CONNECTION_ID - while buf.len() + 44 < max_size { - let issued = match space.pending.new_cids.pop() { - Some(x) => x, - None => break, - }; - trace!( - sequence = issued.sequence, - id = %issued.id, - "NEW_CONNECTION_ID" - ); - frame::NewConnectionId { - sequence: issued.sequence, - retire_prior_to: self.local_cid_state.retire_prior_to(), - id: issued.id, - reset_token: issued.reset_token, + if let Some(cid_state) = self.local_cid_state.as_ref() { + while buf.len() + 44 < max_size { + let issued = match space.pending.new_cids.pop() { + Some(x) => x, + None => break, + }; + trace!( + sequence = issued.sequence, + id = %issued.id, + "NEW_CONNECTION_ID" + ); + frame::NewConnectionId { + sequence: issued.sequence, + retire_prior_to: cid_state.retire_prior_to(), + id: issued.id, + reset_token: issued.reset_token, + } + .encode(buf); + sent.retransmits.get_or_create().new_cids.push(issued); + self.stats.frame_tx.new_connection_id += 1; } - .encode(buf); - sent.retransmits.get_or_create().new_cids.push(issued); - self.stats.frame_tx.new_connection_id += 1; } // RETIRE_CONNECTION_ID @@ -3481,14 +3487,16 @@ impl Connection { #[cfg(test)] pub(crate) fn active_local_cid_seq(&self) -> (u64, u64) { - self.local_cid_state.active_seq() + self.local_cid_state + .as_ref() + .map_or((u64::MAX, u64::MIN), |state| state.active_seq()) } /// Instruct the peer to replace previously issued CIDs by sending a NEW_CONNECTION_ID frame /// with updated `retire_prior_to` field set to `v` #[cfg(test)] pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) { - let n = self.local_cid_state.assign_retire_seq(v); + let n = self.local_cid_state.as_mut().unwrap().assign_retire_seq(v); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } From 615a1edd76786d1146564ba6a763bd6507a144ad Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 25 May 2024 11:23:58 -0700 Subject: [PATCH 9/9] Represent zero-length CIDs by specifying no CID generator --- fuzz/fuzz_targets/packet.rs | 12 +++++-- quinn-proto/src/cid_generator.rs | 40 ++++++++++++++++------- quinn-proto/src/config.rs | 9 ++++-- quinn-proto/src/connection/mod.rs | 25 +++++++------- quinn-proto/src/endpoint.rs | 43 ++++++++++++++----------- quinn-proto/src/lib.rs | 2 +- quinn-proto/src/packet.rs | 14 ++++++-- quinn-proto/src/shared.rs | 6 ++++ quinn-proto/src/tests/mod.rs | 14 ++++---- quinn-proto/src/transport_parameters.rs | 5 ++- quinn/src/tests.rs | 6 ++-- 11 files changed, 113 insertions(+), 63 deletions(-) diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index c4879d1d2..ce6890574 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -5,15 +5,23 @@ extern crate proto; use libfuzzer_sys::fuzz_target; use proto::{ fuzzing::{PacketParams, PartialDecode}, - RandomConnectionIdGenerator, DEFAULT_SUPPORTED_VERSIONS, + ConnectionIdParser, RandomConnectionIdGenerator, ZeroLengthConnectionIdParser, + DEFAULT_SUPPORTED_VERSIONS, }; fuzz_target!(|data: PacketParams| { let len = data.buf.len(); let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); + let cid_gen; if let Ok(decoded) = PartialDecode::new( data.buf, - &RandomConnectionIdGenerator::new(data.local_cid_len), + match data.local_cid_len { + 0 => &ZeroLengthConnectionIdParser as &dyn ConnectionIdParser, + _ => { + cid_gen = RandomConnectionIdGenerator::new(data.local_cid_len); + &cid_gen as &dyn ConnectionIdParser + } + }, &supported_versions, data.grease_quic_bit, ) { diff --git a/quinn-proto/src/cid_generator.rs b/quinn-proto/src/cid_generator.rs index fc465fc1e..73fe14d40 100644 --- a/quinn-proto/src/cid_generator.rs +++ b/quinn-proto/src/cid_generator.rs @@ -26,8 +26,6 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser { Ok(()) } - /// Returns the length of a CID for connections created by this generator - fn cid_len(&self) -> usize; /// Returns the lifetime of generated Connection IDs /// /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant. @@ -63,6 +61,10 @@ impl RandomConnectionIdGenerator { /// The given length must be less than or equal to MAX_CID_SIZE. pub fn new(cid_len: usize) -> Self { debug_assert!(cid_len <= MAX_CID_SIZE); + assert!( + cid_len > 0, + "connection ID generators must produce non-empty IDs" + ); Self { cid_len, ..Self::default() @@ -92,11 +94,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator { ConnectionId::new(&bytes_arr[..self.cid_len]) } - /// Provide the length of dst_cid in short header packet - fn cid_len(&self) -> usize { - self.cid_len - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -173,10 +170,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator { } } - fn cid_len(&self) -> usize { - HASHED_CID_LEN - } - fn cid_lifetime(&self) -> Option { self.lifetime } @@ -186,6 +179,31 @@ const NONCE_LEN: usize = 3; // Good for more than 16 million connections const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN; +/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable +/// +/// CID generators should produce nonempty CIDs. We should be able to use +/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require +/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`. +/// +/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991 +pub(crate) struct ZeroLengthConnectionIdGenerator; + +impl ConnectionIdParser for ZeroLengthConnectionIdGenerator { + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + +impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator { + fn generate_cid(&self) -> ConnectionId { + unreachable!() + } + + fn cid_lifetime(&self) -> Option { + None + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 36a281b5f..650d6e2cf 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -616,7 +616,7 @@ impl Default for MtuDiscoveryConfig { pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, - pub(crate) connection_id_generator: Arc, + pub(crate) connection_id_generator: Option>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets @@ -629,7 +629,7 @@ impl EndpointConfig { Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers - connection_id_generator: Arc::::default(), + connection_id_generator: Some(Arc::::default()), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, min_reset_interval: Duration::from_millis(20), @@ -644,7 +644,10 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator(&mut self, generator: Arc) -> &mut Self { + pub fn cid_generator( + &mut self, + generator: Option>, + ) -> &mut Self { self.connection_id_generator = generator; self } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index aaf8bb5d3..69ebbba58 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -15,13 +15,12 @@ use thiserror::Error; use tracing::{debug, error, trace, trace_span, warn}; use crate::{ - cid_generator::ConnectionIdGenerator, + cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator}, cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame, - frame::{Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct}, packet::{ Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -197,7 +196,7 @@ pub struct Connection { retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, - cid_gen: Arc, + cid_gen: Option>, // // Queued non-retransmittable 1-RTT data @@ -253,7 +252,7 @@ impl Connection { remote: SocketAddr, local_ip: Option, crypto: Box, - cid_gen: Arc, + cid_gen: Option>, now: Instant, version: u32, allow_mtud: bool, @@ -281,14 +280,13 @@ impl Connection { crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, - local_cid_state: match cid_gen.cid_len() { - 0 => None, - _ => Some(CidState::new( - cid_gen.cid_lifetime(), + local_cid_state: cid_gen.as_ref().map(|gen| { + CidState::new( + gen.cid_lifetime(), now, if pref_addr_cid.is_some() { 2 } else { 1 }, - )), - }, + ) + }), path: PathData::new(remote, allow_mtud, None, now, path_validated, &config), allow_mtud, local_ip, @@ -2103,7 +2101,10 @@ impl Connection { while let Some(data) = remaining { match PartialDecode::new( data, - &*self.cid_gen, + self.cid_gen.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &[self.version], self.endpoint_config.grease_quic_bit, ) { diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index cdd5888fa..d1e483dcf 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -16,7 +16,9 @@ use thiserror::Error; use tracing::{debug, error, trace, warn}; use crate::{ - cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, + cid_generator::{ + ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator, + }, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, connection::{Connection, ConnectionError}, @@ -44,7 +46,7 @@ pub struct Endpoint { rng: StdRng, index: ConnectionIndex, connections: Slab, - local_cid_generator: Arc, + local_cid_generator: Option>, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets @@ -144,7 +146,10 @@ impl Endpoint { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, - &*self.local_cid_generator, + self.local_cid_generator.as_ref().map_or( + &ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator, + |x| &**x, + ), &self.config.supported_versions, self.config.grease_quic_bit, ) { @@ -302,8 +307,8 @@ impl Endpoint { if !first_decode.is_initial() && self .local_cid_generator - .validate(first_decode.dst_cid()) - .is_err() + .as_ref() + .map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err()) { debug!("dropping packet with invalid CID"); return None; @@ -400,7 +405,7 @@ impl Endpoint { let params = TransportParameters::new( &config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, None, ); @@ -453,12 +458,11 @@ impl Endpoint { /// Generate a connection ID for `ch` fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId { loop { - let cid = self.local_cid_generator.generate_cid(); - if cid.len() == 0 { + let Some(cid_generator) = self.local_cid_generator.as_ref() else { // Zero-length CID; nothing to track - debug_assert_eq!(self.local_cid_generator.cid_len(), 0); - return cid; - } + return ConnectionId::EMPTY; + }; + let cid = cid_generator.generate_cid(); if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) { e.insert(ch); break cid; @@ -589,7 +593,7 @@ impl Endpoint { let mut params = TransportParameters::new( &server_config.transport, &self.config, - self.local_cid_generator.as_ref(), + self.local_cid_generator.is_some(), loc_cid, Some(&server_config), ); @@ -680,10 +684,7 @@ impl Endpoint { // bytes. If this is a Retry packet, then the length must instead match our usual CID // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll // also need to validate CID length for those after decoding the token. - if header.dst_cid.len() < 8 - && (!header.token_pos.is_empty() - && header.dst_cid.len() != self.local_cid_generator.cid_len()) - { + if header.dst_cid.len() < 8 && !header.token_pos.is_empty() { debug!( "rejecting connection due to invalid DCID length {}", header.dst_cid.len() @@ -730,7 +731,10 @@ impl Endpoint { // with established connections. In the unlikely event that a collision occurs // between two connections in the initial phase, both will fail fast and may be // retried by the application layer. - let loc_cid = self.local_cid_generator.generate_cid(); + let loc_cid = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let token = RetryToken { orig_dst_cid: incoming.packet.header.dst_cid, @@ -860,7 +864,10 @@ impl Endpoint { // We don't need to worry about CID collisions in initial closes because the peer // shouldn't respond, and if it does, and the CID collides, we'll just drop the // unexpected response. - let local_id = self.local_cid_generator.generate_cid(); + let local_id = self + .local_cid_generator + .as_ref() + .map_or(ConnectionId::EMPTY, |gen| gen.generate_cid()); let number = PacketNumber::U8(0); let header = Header::Initial(InitialHeader { dst_cid: *remote_id, diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 5a84bdcf5..ae5a4d2c7 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -67,7 +67,7 @@ pub use crate::endpoint::{ mod packet; pub use packet::{ ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader, - ProtectedInitialHeader, + ProtectedInitialHeader, ZeroLengthConnectionIdParser, }; mod shared; diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index 37ed0b06f..328988e0e 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -773,6 +773,16 @@ pub trait ConnectionIdParser { fn parse(&self, buf: &mut dyn Buf) -> Result; } +/// Trivial parser for zero-length connection IDs +pub struct ZeroLengthConnectionIdParser; + +impl ConnectionIdParser for ZeroLengthConnectionIdParser { + #[inline] + fn parse(&self, _: &mut dyn Buf) -> Result { + Ok(ConnectionId::new(&[])) + } +} + /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { @@ -908,7 +918,7 @@ mod tests { #[test] fn header_encoding() { use crate::crypto::rustls::{initial_keys, initial_suite_from_provider}; - use crate::{RandomConnectionIdGenerator, Side}; + use crate::Side; use rustls::crypto::ring::default_provider; use rustls::quic::Version; @@ -950,7 +960,7 @@ mod tests { let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new( buf.as_slice().into(), - &RandomConnectionIdGenerator::new(0), + &ZeroLengthConnectionIdParser, &supported_versions, false, ) diff --git a/quinn-proto/src/shared.rs b/quinn-proto/src/shared.rs index 05ffe8caf..f6dac3652 100644 --- a/quinn-proto/src/shared.rs +++ b/quinn-proto/src/shared.rs @@ -72,6 +72,12 @@ pub struct ConnectionId { } impl ConnectionId { + /// The zero-length connection ID + pub const EMPTY: Self = Self { + len: 0, + bytes: [0; MAX_CID_SIZE], + }; + /// Construct cid from byte array pub fn new(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index da117e730..8b4cc0180 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -66,7 +66,7 @@ fn version_negotiate_client() { // packet let mut client = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), + connection_id_generator: None, ..Default::default() }), None, @@ -181,7 +181,7 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -211,7 +211,7 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0)))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -240,7 +240,7 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(Arc::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(Some(Arc::new(RandomConnectionIdGenerator::new(8)))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1468,7 +1468,7 @@ fn zero_length_cid() { let _guard = subscribe(); let mut pair = Pair::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)), + connection_id_generator: None, ..EndpointConfig::default() }), server_config(), @@ -1525,9 +1525,9 @@ fn cid_rotation() { // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { - connection_id_generator: Arc::new( + connection_id_generator: Some(Arc::new( *RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT), - ), + )), ..EndpointConfig::default() }), Some(Arc::new(server_config())), diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index 3a1a443a2..2dc0a916d 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -15,7 +15,6 @@ use bytes::{Buf, BufMut}; use thiserror::Error; use crate::{ - cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, config::{EndpointConfig, ServerConfig, TransportConfig}, @@ -132,7 +131,7 @@ impl TransportParameters { pub(crate) fn new( config: &TransportConfig, endpoint_config: &EndpointConfig, - cid_gen: &dyn ConnectionIdGenerator, + use_cids: bool, initial_src_cid: ConnectionId, server_config: Option<&ServerConfig>, ) -> Self { @@ -147,7 +146,7 @@ impl TransportParameters { max_udp_payload_size: endpoint_config.max_udp_payload_size, max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)), disable_active_migration: server_config.map_or(false, |c| !c.migration), - active_connection_id_limit: if cid_gen.cid_len() == 0 { + active_connection_id_limit: if !use_cids { 2 // i.e. default, i.e. unsent } else { CidQueue::LEN as u32 diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index f1fd1b49f..13b84fea5 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -10,7 +10,7 @@ use std::{ use crate::runtime::TokioRuntime; use bytes::Bytes; -use proto::{crypto::rustls::QuicClientConfig, RandomConnectionIdGenerator}; +use proto::crypto::rustls::QuicClientConfig; use rand::{rngs::StdRng, RngCore, SeedableRng}; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, @@ -817,9 +817,7 @@ async fn two_datagram_readers() { async fn multiple_conns_with_zero_length_cids() { let _guard = subscribe(); let mut factory = EndpointFactory::new(); - factory - .endpoint_config - .cid_generator(Arc::new(RandomConnectionIdGenerator::new(0))); + factory.endpoint_config.cid_generator(None); let server = { let _guard = error_span!("server").entered(); factory.endpoint()