Skip to content

Commit

Permalink
Represent zero-length CIDs by specifying no CID generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed Jun 2, 2024
1 parent 61dbea6 commit 615a1ed
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 63 deletions.
12 changes: 10 additions & 2 deletions fuzz/fuzz_targets/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@ extern crate proto;
use libfuzzer_sys::fuzz_target;
use proto::{
fuzzing::{PacketParams, PartialDecode},
RandomConnectionIdGenerator, DEFAULT_SUPPORTED_VERSIONS,
ConnectionIdParser, RandomConnectionIdGenerator, ZeroLengthConnectionIdParser,
DEFAULT_SUPPORTED_VERSIONS,
};

fuzz_target!(|data: PacketParams| {
let len = data.buf.len();
let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec();
let cid_gen;
if let Ok(decoded) = PartialDecode::new(
data.buf,
&RandomConnectionIdGenerator::new(data.local_cid_len),
match data.local_cid_len {
0 => &ZeroLengthConnectionIdParser as &dyn ConnectionIdParser,
_ => {
cid_gen = RandomConnectionIdGenerator::new(data.local_cid_len);
&cid_gen as &dyn ConnectionIdParser
}
},
&supported_versions,
data.grease_quic_bit,
) {
Expand Down
40 changes: 29 additions & 11 deletions quinn-proto/src/cid_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser {
Ok(())
}

/// Returns the length of a CID for connections created by this generator
fn cid_len(&self) -> usize;
/// Returns the lifetime of generated Connection IDs
///
/// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
Expand Down Expand Up @@ -63,6 +61,10 @@ impl RandomConnectionIdGenerator {
/// The given length must be less than or equal to MAX_CID_SIZE.
pub fn new(cid_len: usize) -> Self {
debug_assert!(cid_len <= MAX_CID_SIZE);
assert!(
cid_len > 0,
"connection ID generators must produce non-empty IDs"
);
Self {
cid_len,
..Self::default()
Expand Down Expand Up @@ -92,11 +94,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator {
ConnectionId::new(&bytes_arr[..self.cid_len])
}

/// Provide the length of dst_cid in short header packet
fn cid_len(&self) -> usize {
self.cid_len
}

fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
Expand Down Expand Up @@ -173,10 +170,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator {
}
}

fn cid_len(&self) -> usize {
HASHED_CID_LEN
}

fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
Expand All @@ -186,6 +179,31 @@ const NONCE_LEN: usize = 3; // Good for more than 16 million connections
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN;

/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable
///
/// CID generators should produce nonempty CIDs. We should be able to use
/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require
/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`.
///
/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991
pub(crate) struct ZeroLengthConnectionIdGenerator;

impl ConnectionIdParser for ZeroLengthConnectionIdGenerator {
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
Ok(ConnectionId::new(&[]))
}
}

impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator {
fn generate_cid(&self) -> ConnectionId {
unreachable!()
}

fn cid_lifetime(&self) -> Option<Duration> {
None
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
9 changes: 6 additions & 3 deletions quinn-proto/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ impl Default for MtuDiscoveryConfig {
pub struct EndpointConfig {
pub(crate) reset_key: Arc<dyn HmacKey>,
pub(crate) max_udp_payload_size: VarInt,
pub(crate) connection_id_generator: Arc<dyn ConnectionIdGenerator>,
pub(crate) connection_id_generator: Option<Arc<dyn ConnectionIdGenerator>>,
pub(crate) supported_versions: Vec<u32>,
pub(crate) grease_quic_bit: bool,
/// Minimum interval between outgoing stateless reset packets
Expand All @@ -629,7 +629,7 @@ impl EndpointConfig {
Self {
reset_key,
max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers
connection_id_generator: Arc::<HashedConnectionIdGenerator>::default(),
connection_id_generator: Some(Arc::<HashedConnectionIdGenerator>::default()),
supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(),
grease_quic_bit: true,
min_reset_interval: Duration::from_millis(20),
Expand All @@ -644,7 +644,10 @@ impl EndpointConfig {
/// information in local connection IDs, e.g. to support stateless packet-level load balancers.
///
/// Defaults to [`HashedConnectionIdGenerator`].
pub fn cid_generator(&mut self, generator: Arc<dyn ConnectionIdGenerator>) -> &mut Self {
pub fn cid_generator(
&mut self,
generator: Option<Arc<dyn ConnectionIdGenerator>>,
) -> &mut Self {
self.connection_id_generator = generator;
self
}
Expand Down
25 changes: 13 additions & 12 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ use thiserror::Error;
use tracing::{debug, error, trace, trace_span, warn};

use crate::{
cid_generator::ConnectionIdGenerator,
cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator},
cid_queue::CidQueue,
coding::BufMutExt,
config::{ServerConfig, TransportConfig},
crypto::{self, KeyPair, Keys, PacketKey},
frame,
frame::{Close, Datagram, FrameStruct},
frame::{self, Close, Datagram, FrameStruct},
packet::{
Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode,
SpaceId,
Expand Down Expand Up @@ -197,7 +196,7 @@ pub struct Connection {
retry_token: Bytes,
/// Identifies Data-space packet numbers to skip. Not used in earlier spaces.
packet_number_filter: PacketNumberFilter,
cid_gen: Arc<dyn ConnectionIdGenerator>,
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,

//
// Queued non-retransmittable 1-RTT data
Expand Down Expand Up @@ -253,7 +252,7 @@ impl Connection {
remote: SocketAddr,
local_ip: Option<IpAddr>,
crypto: Box<dyn crypto::Session>,
cid_gen: Arc<dyn ConnectionIdGenerator>,
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,
now: Instant,
version: u32,
allow_mtud: bool,
Expand Down Expand Up @@ -281,14 +280,13 @@ impl Connection {
crypto,
handshake_cid: loc_cid,
rem_handshake_cid: rem_cid,
local_cid_state: match cid_gen.cid_len() {
0 => None,
_ => Some(CidState::new(
cid_gen.cid_lifetime(),
local_cid_state: cid_gen.as_ref().map(|gen| {
CidState::new(
gen.cid_lifetime(),
now,
if pref_addr_cid.is_some() { 2 } else { 1 },
)),
},
)
}),
path: PathData::new(remote, allow_mtud, None, now, path_validated, &config),
allow_mtud,
local_ip,
Expand Down Expand Up @@ -2103,7 +2101,10 @@ impl Connection {
while let Some(data) = remaining {
match PartialDecode::new(
data,
&*self.cid_gen,
self.cid_gen.as_ref().map_or(
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
|x| &**x,
),
&[self.version],
self.endpoint_config.grease_quic_bit,
) {
Expand Down
43 changes: 25 additions & 18 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use thiserror::Error;
use tracing::{debug, error, trace, warn};

use crate::{
cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator},
cid_generator::{
ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator,
},
coding::BufMutExt,
config::{ClientConfig, EndpointConfig, ServerConfig},
connection::{Connection, ConnectionError},
Expand Down Expand Up @@ -44,7 +46,7 @@ pub struct Endpoint {
rng: StdRng,
index: ConnectionIndex,
connections: Slab<ConnectionMeta>,
local_cid_generator: Arc<dyn ConnectionIdGenerator>,
local_cid_generator: Option<Arc<dyn ConnectionIdGenerator>>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
/// Whether the underlying UDP socket promises not to fragment packets
Expand Down Expand Up @@ -144,7 +146,10 @@ impl Endpoint {
let datagram_len = data.len();
let (first_decode, remaining) = match PartialDecode::new(
data,
&*self.local_cid_generator,
self.local_cid_generator.as_ref().map_or(
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
|x| &**x,
),
&self.config.supported_versions,
self.config.grease_quic_bit,
) {
Expand Down Expand Up @@ -302,8 +307,8 @@ impl Endpoint {
if !first_decode.is_initial()
&& self
.local_cid_generator
.validate(first_decode.dst_cid())
.is_err()
.as_ref()
.map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err())
{
debug!("dropping packet with invalid CID");
return None;
Expand Down Expand Up @@ -400,7 +405,7 @@ impl Endpoint {
let params = TransportParameters::new(
&config.transport,
&self.config,
self.local_cid_generator.as_ref(),
self.local_cid_generator.is_some(),
loc_cid,
None,
);
Expand Down Expand Up @@ -453,12 +458,11 @@ impl Endpoint {
/// Generate a connection ID for `ch`
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
loop {
let cid = self.local_cid_generator.generate_cid();
if cid.len() == 0 {
let Some(cid_generator) = self.local_cid_generator.as_ref() else {
// Zero-length CID; nothing to track
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
return cid;
}
return ConnectionId::EMPTY;
};
let cid = cid_generator.generate_cid();
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
e.insert(ch);
break cid;
Expand Down Expand Up @@ -589,7 +593,7 @@ impl Endpoint {
let mut params = TransportParameters::new(
&server_config.transport,
&self.config,
self.local_cid_generator.as_ref(),
self.local_cid_generator.is_some(),
loc_cid,
Some(&server_config),
);
Expand Down Expand Up @@ -680,10 +684,7 @@ impl Endpoint {
// bytes. If this is a Retry packet, then the length must instead match our usual CID
// length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
// also need to validate CID length for those after decoding the token.
if header.dst_cid.len() < 8
&& (!header.token_pos.is_empty()
&& header.dst_cid.len() != self.local_cid_generator.cid_len())
{
if header.dst_cid.len() < 8 && !header.token_pos.is_empty() {
debug!(
"rejecting connection due to invalid DCID length {}",
header.dst_cid.len()
Expand Down Expand Up @@ -730,7 +731,10 @@ impl Endpoint {
// with established connections. In the unlikely event that a collision occurs
// between two connections in the initial phase, both will fail fast and may be
// retried by the application layer.
let loc_cid = self.local_cid_generator.generate_cid();
let loc_cid = self
.local_cid_generator
.as_ref()
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());

let token = RetryToken {
orig_dst_cid: incoming.packet.header.dst_cid,
Expand Down Expand Up @@ -860,7 +864,10 @@ impl Endpoint {
// We don't need to worry about CID collisions in initial closes because the peer
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
// unexpected response.
let local_id = self.local_cid_generator.generate_cid();
let local_id = self
.local_cid_generator
.as_ref()
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());
let number = PacketNumber::U8(0);
let header = Header::Initial(InitialHeader {
dst_cid: *remote_id,
Expand Down
2 changes: 1 addition & 1 deletion quinn-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub use crate::endpoint::{
mod packet;
pub use packet::{
ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader,
ProtectedInitialHeader,
ProtectedInitialHeader, ZeroLengthConnectionIdParser,
};

mod shared;
Expand Down
14 changes: 12 additions & 2 deletions quinn-proto/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,16 @@ pub trait ConnectionIdParser {
fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
}

/// Trivial parser for zero-length connection IDs
pub struct ZeroLengthConnectionIdParser;

impl ConnectionIdParser for ZeroLengthConnectionIdParser {
#[inline]
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
Ok(ConnectionId::new(&[]))
}
}

/// Long packet type including non-uniform cases
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum LongHeaderType {
Expand Down Expand Up @@ -908,7 +918,7 @@ mod tests {
#[test]
fn header_encoding() {
use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
use crate::{RandomConnectionIdGenerator, Side};
use crate::Side;
use rustls::crypto::ring::default_provider;
use rustls::quic::Version;

Expand Down Expand Up @@ -950,7 +960,7 @@ mod tests {
let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec();
let decode = PartialDecode::new(
buf.as_slice().into(),
&RandomConnectionIdGenerator::new(0),
&ZeroLengthConnectionIdParser,
&supported_versions,
false,
)
Expand Down
6 changes: 6 additions & 0 deletions quinn-proto/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ pub struct ConnectionId {
}

impl ConnectionId {
/// The zero-length connection ID
pub const EMPTY: Self = Self {
len: 0,
bytes: [0; MAX_CID_SIZE],
};

/// Construct cid from byte array
pub fn new(bytes: &[u8]) -> Self {
debug_assert!(bytes.len() <= MAX_CID_SIZE);
Expand Down
Loading

0 comments on commit 615a1ed

Please sign in to comment.