diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index e4a5eabc6..53d579ba0 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -12,7 +12,7 @@ use std::{ }; use bytes::{Buf, BufMut}; -use rand::{Rng as _, RngCore}; +use rand::{seq::SliceRandom as _, Rng as _, RngCore}; use thiserror::Error; use crate::{ @@ -104,6 +104,12 @@ macro_rules! make_struct { /// of transport parameter extensions. /// When present, it is included during serialization but ignored during deserialization. pub(crate) grease_transport_parameter: Option, + + /// Defines the order in which transport parameters are serialized. + /// + /// This field is initialized only for outgoing `TransportParameters` instances and + /// is set to `None` for `TransportParameters` received from a peer. + pub(crate) write_order: Option<[u8; TransportParameterId::SUPPORTED.len()]>, } // We deliberately don't implement the `Default` trait, since that would be public, and @@ -126,6 +132,7 @@ macro_rules! make_struct { stateless_reset_token: None, preferred_address: None, grease_transport_parameter: None, + write_order: None, } } } @@ -168,6 +175,11 @@ impl TransportParameters { VarInt::from_u64(u64::try_from(TIMER_GRANULARITY.as_micros()).unwrap()).unwrap(), ), grease_transport_parameter: Some(ReservedTransportParameter::random(rng)), + write_order: Some({ + let mut order = std::array::from_fn(|i| i as u8); + order.shuffle(rng); + order + }), ..Self::default() } } @@ -295,68 +307,100 @@ impl From for Error { impl TransportParameters { /// Encode `TransportParameters` into buffer pub fn write(&self, w: &mut W) { - macro_rules! write_params { - {$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => { - $( - if self.$name.0 != $default { - w.write_var(TransportParameterId::$id as u64); - w.write(VarInt::try_from(self.$name.size()).unwrap()); - w.write(self.$name); + for idx in self + .write_order + .as_ref() + .unwrap_or(&std::array::from_fn(|i| i as u8)) + { + let id = TransportParameterId::SUPPORTED[*idx as usize]; + match id { + TransportParameterId::ReservedTransportParameter => { + if let Some(param) = self.grease_transport_parameter { + param.write(w); } - )* - } - } - apply_params!(write_params); - - if let Some(param) = self.grease_transport_parameter { - param.write(w); - } - - if let Some(ref x) = self.stateless_reset_token { - w.write_var(0x02); - w.write_var(16); - w.put_slice(x); - } - - if self.disable_active_migration { - w.write_var(0x0c); - w.write_var(0); - } - - if let Some(x) = self.max_datagram_frame_size { - w.write_var(0x20); - w.write_var(x.size() as u64); - w.write(x); - } - - if let Some(ref x) = self.preferred_address { - w.write_var(0x000d); - w.write_var(x.wire_size() as u64); - x.write(w); - } - - for &(tag, cid) in &[ - (0x00, &self.original_dst_cid), - (0x0f, &self.initial_src_cid), - (0x10, &self.retry_src_cid), - ] { - if let Some(ref cid) = *cid { - w.write_var(tag); - w.write_var(cid.len() as u64); - w.put_slice(cid); + } + TransportParameterId::StatelessResetToken => { + if let Some(ref x) = self.stateless_reset_token { + w.write_var(id as u64); + w.write_var(16); + w.put_slice(x); + } + } + TransportParameterId::DisableActiveMigration => { + if self.disable_active_migration { + w.write_var(id as u64); + w.write_var(0); + } + } + TransportParameterId::MaxDatagramFrameSize => { + if let Some(x) = self.max_datagram_frame_size { + w.write_var(id as u64); + w.write_var(x.size() as u64); + w.write(x); + } + } + TransportParameterId::PreferredAddress => { + if let Some(ref x) = self.preferred_address { + w.write_var(id as u64); + w.write_var(x.wire_size() as u64); + x.write(w); + } + } + TransportParameterId::OriginalDestinationConnectionId => { + if let Some(ref cid) = self.original_dst_cid { + w.write_var(id as u64); + w.write_var(cid.len() as u64); + w.put_slice(cid); + } + } + TransportParameterId::InitialSourceConnectionId => { + if let Some(ref cid) = self.initial_src_cid { + w.write_var(id as u64); + w.write_var(cid.len() as u64); + w.put_slice(cid); + } + } + TransportParameterId::RetrySourceConnectionId => { + if let Some(ref cid) = self.retry_src_cid { + w.write_var(id as u64); + w.write_var(cid.len() as u64); + w.put_slice(cid); + } + } + TransportParameterId::GreaseQuicBit => { + if self.grease_quic_bit { + w.write_var(id as u64); + w.write_var(0); + } + } + TransportParameterId::MinAckDelayDraft07 => { + if let Some(x) = self.min_ack_delay { + w.write_var(id as u64); + w.write_var(x.size() as u64); + w.write(x); + } + } + id => { + macro_rules! write_params { + {$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => { + match id { + $(TransportParameterId::$id => { + if self.$name.0 != $default { + w.write_var(id as u64); + w.write(VarInt::try_from(self.$name.size()).unwrap()); + w.write(self.$name); + } + })*, + _ => { + unimplemented!("Missing implementation of write for transport parameter with code {id:?}"); + } + } + } + } + apply_params!(write_params); + } } } - - if self.grease_quic_bit { - w.write_var(0x2ab2); - w.write_var(0); - } - - if let Some(x) = self.min_ack_delay { - w.write_var(0xff04de1b); - w.write_var(x.size() as u64); - w.write(x); - } } /// Decode `TransportParameters` from buffer @@ -385,12 +429,17 @@ impl TransportParameters { return Err(Error::Malformed); } let len = len as usize; + let Ok(id) = TransportParameterId::try_from(id) else { + // unknown transport parameters are ignored + r.advance(len as usize); + continue; + }; match id { - id if TransportParameterId::OriginalDestinationConnectionId == id => { + TransportParameterId::OriginalDestinationConnectionId => { decode_cid(len, &mut params.original_dst_cid, r)? } - id if TransportParameterId::StatelessResetToken == id => { + TransportParameterId::StatelessResetToken => { if len != 16 || params.stateless_reset_token.is_some() { return Err(Error::Malformed); } @@ -398,42 +447,42 @@ impl TransportParameters { r.copy_to_slice(&mut tok); params.stateless_reset_token = Some(tok.into()); } - id if TransportParameterId::DisableActiveMigration == id => { + TransportParameterId::DisableActiveMigration => { if len != 0 || params.disable_active_migration { return Err(Error::Malformed); } params.disable_active_migration = true; } - id if TransportParameterId::PreferredAddress == id => { + TransportParameterId::PreferredAddress => { if params.preferred_address.is_some() { return Err(Error::Malformed); } params.preferred_address = Some(PreferredAddress::read(&mut r.take(len))?); } - id if TransportParameterId::InitialSourceConnectionId == id => { + TransportParameterId::InitialSourceConnectionId => { decode_cid(len, &mut params.initial_src_cid, r)? } - id if TransportParameterId::RetrySourceConnectionId == id => { + TransportParameterId::RetrySourceConnectionId => { decode_cid(len, &mut params.retry_src_cid, r)? } - id if TransportParameterId::MaxDatagramFrameSize == id => { + TransportParameterId::MaxDatagramFrameSize => { if len > 8 || params.max_datagram_frame_size.is_some() { return Err(Error::Malformed); } params.max_datagram_frame_size = Some(r.get().unwrap()); } - id if TransportParameterId::GreaseQuicBit == id => match len { + TransportParameterId::GreaseQuicBit => match len { 0 => params.grease_quic_bit = true, _ => return Err(Error::Malformed), }, - id if TransportParameterId::MinAckDelayDraft07 == id => { + TransportParameterId::MinAckDelayDraft07 => { params.min_ack_delay = Some(r.get().unwrap()) } _ => { macro_rules! parse { {$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => { match id { - $(id if TransportParameterId::$id == id => { + $(TransportParameterId::$id => { let value = r.get::()?; if len != value.size() || got.$name { return Err(Error::Malformed); } params.$name = value.into(); @@ -593,12 +642,75 @@ pub(crate) enum TransportParameterId { MinAckDelayDraft07 = 0xFF04DE1B, } +impl TransportParameterId { + /// Array with all supported transport parameter IDs + const SUPPORTED: [Self; 21] = [ + Self::MaxIdleTimeout, + Self::MaxUdpPayloadSize, + Self::InitialMaxData, + Self::InitialMaxStreamDataBidiLocal, + Self::InitialMaxStreamDataBidiRemote, + Self::InitialMaxStreamDataUni, + Self::InitialMaxStreamsBidi, + Self::InitialMaxStreamsUni, + Self::AckDelayExponent, + Self::MaxAckDelay, + Self::ActiveConnectionIdLimit, + Self::ReservedTransportParameter, + Self::StatelessResetToken, + Self::DisableActiveMigration, + Self::MaxDatagramFrameSize, + Self::PreferredAddress, + Self::OriginalDestinationConnectionId, + Self::InitialSourceConnectionId, + Self::RetrySourceConnectionId, + Self::GreaseQuicBit, + Self::MinAckDelayDraft07, + ]; +} + impl std::cmp::PartialEq for TransportParameterId { fn eq(&self, other: &u64) -> bool { *other == (*self as u64) } } +impl TryFrom for TransportParameterId { + type Error = (); + + fn try_from(value: u64) -> Result { + let param = match value { + id if Self::MaxIdleTimeout == id => Self::MaxIdleTimeout, + id if Self::MaxUdpPayloadSize == id => Self::MaxUdpPayloadSize, + id if Self::InitialMaxData == id => Self::InitialMaxData, + id if Self::InitialMaxStreamDataBidiLocal == id => Self::InitialMaxStreamDataBidiLocal, + id if Self::InitialMaxStreamDataBidiRemote == id => { + Self::InitialMaxStreamDataBidiRemote + } + id if Self::InitialMaxStreamDataUni == id => Self::InitialMaxStreamDataUni, + id if Self::InitialMaxStreamsBidi == id => Self::InitialMaxStreamsBidi, + id if Self::InitialMaxStreamsUni == id => Self::InitialMaxStreamsUni, + id if Self::AckDelayExponent == id => Self::AckDelayExponent, + id if Self::MaxAckDelay == id => Self::MaxAckDelay, + id if Self::ActiveConnectionIdLimit == id => Self::ActiveConnectionIdLimit, + id if Self::ReservedTransportParameter == id => Self::ReservedTransportParameter, + id if Self::StatelessResetToken == id => Self::StatelessResetToken, + id if Self::DisableActiveMigration == id => Self::DisableActiveMigration, + id if Self::MaxDatagramFrameSize == id => Self::MaxDatagramFrameSize, + id if Self::PreferredAddress == id => Self::PreferredAddress, + id if Self::OriginalDestinationConnectionId == id => { + Self::OriginalDestinationConnectionId + } + id if Self::InitialSourceConnectionId == id => Self::InitialSourceConnectionId, + id if Self::RetrySourceConnectionId == id => Self::RetrySourceConnectionId, + id if Self::GreaseQuicBit == id => Self::GreaseQuicBit, + id if Self::MinAckDelayDraft07 == id => Self::MinAckDelayDraft07, + _ => return Err(()), + }; + Ok(param) + } +} + fn decode_cid(len: usize, value: &mut Option, r: &mut impl Buf) -> Result<(), Error> { if len > MAX_CID_SIZE || value.is_some() || r.remaining() < len { return Err(Error::Malformed);