Skip to content

Commit

Permalink
quinn-rs#2057: Write transport parameters in random order.
Browse files Browse the repository at this point in the history
  • Loading branch information
mstyura committed Nov 27, 2024
1 parent d658b53 commit ff96fda
Showing 1 changed file with 182 additions and 70 deletions.
252 changes: 182 additions & 70 deletions quinn-proto/src/transport_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
};

use bytes::{Buf, BufMut};
use rand::{Rng as _, RngCore};
use rand::{seq::SliceRandom as _, Rng as _, RngCore};
use thiserror::Error;

use crate::{
Expand Down Expand Up @@ -104,6 +104,12 @@ macro_rules! make_struct {
/// of transport parameter extensions.
/// When present, it is included during serialization but ignored during deserialization.
pub(crate) grease_transport_parameter: Option<ReservedTransportParameter>,

/// Defines the order in which transport parameters are serialized.
///
/// This field is initialized only for outgoing `TransportParameters` instances and
/// is set to `None` for `TransportParameters` received from a peer.
pub(crate) write_order: Option<[u8; TransportParameterId::SUPPORTED.len()]>,
}

// We deliberately don't implement the `Default` trait, since that would be public, and
Expand All @@ -126,6 +132,7 @@ macro_rules! make_struct {
stateless_reset_token: None,
preferred_address: None,
grease_transport_parameter: None,
write_order: None,
}
}
}
Expand Down Expand Up @@ -168,6 +175,11 @@ impl TransportParameters {
VarInt::from_u64(u64::try_from(TIMER_GRANULARITY.as_micros()).unwrap()).unwrap(),
),
grease_transport_parameter: Some(ReservedTransportParameter::random(rng)),
write_order: Some({
let mut order = std::array::from_fn(|i| i as u8);
order.shuffle(rng);
order
}),
..Self::default()
}
}
Expand Down Expand Up @@ -295,68 +307,100 @@ impl From<UnexpectedEnd> for Error {
impl TransportParameters {
/// Encode `TransportParameters` into buffer
pub fn write<W: BufMut>(&self, w: &mut W) {
macro_rules! write_params {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
$(
if self.$name.0 != $default {
w.write_var(TransportParameterId::$id as u64);
w.write(VarInt::try_from(self.$name.size()).unwrap());
w.write(self.$name);
for idx in self
.write_order
.as_ref()
.unwrap_or(&std::array::from_fn(|i| i as u8))
{
let id = TransportParameterId::SUPPORTED[*idx as usize];
match id {
TransportParameterId::ReservedTransportParameter => {
if let Some(param) = self.grease_transport_parameter {
param.write(w);
}
)*
}
}
apply_params!(write_params);

if let Some(param) = self.grease_transport_parameter {
param.write(w);
}

if let Some(ref x) = self.stateless_reset_token {
w.write_var(0x02);
w.write_var(16);
w.put_slice(x);
}

if self.disable_active_migration {
w.write_var(0x0c);
w.write_var(0);
}

if let Some(x) = self.max_datagram_frame_size {
w.write_var(0x20);
w.write_var(x.size() as u64);
w.write(x);
}

if let Some(ref x) = self.preferred_address {
w.write_var(0x000d);
w.write_var(x.wire_size() as u64);
x.write(w);
}

for &(tag, cid) in &[
(0x00, &self.original_dst_cid),
(0x0f, &self.initial_src_cid),
(0x10, &self.retry_src_cid),
] {
if let Some(ref cid) = *cid {
w.write_var(tag);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
TransportParameterId::StatelessResetToken => {
if let Some(ref x) = self.stateless_reset_token {
w.write_var(id as u64);
w.write_var(16);
w.put_slice(x);
}
}
TransportParameterId::DisableActiveMigration => {
if self.disable_active_migration {
w.write_var(id as u64);
w.write_var(0);
}
}
TransportParameterId::MaxDatagramFrameSize => {
if let Some(x) = self.max_datagram_frame_size {
w.write_var(id as u64);
w.write_var(x.size() as u64);
w.write(x);
}
}
TransportParameterId::PreferredAddress => {
if let Some(ref x) = self.preferred_address {
w.write_var(id as u64);
w.write_var(x.wire_size() as u64);
x.write(w);
}
}
TransportParameterId::OriginalDestinationConnectionId => {
if let Some(ref cid) = self.original_dst_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::InitialSourceConnectionId => {
if let Some(ref cid) = self.initial_src_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::RetrySourceConnectionId => {
if let Some(ref cid) = self.retry_src_cid {
w.write_var(id as u64);
w.write_var(cid.len() as u64);
w.put_slice(cid);
}
}
TransportParameterId::GreaseQuicBit => {
if self.grease_quic_bit {
w.write_var(id as u64);
w.write_var(0);
}
}
TransportParameterId::MinAckDelayDraft07 => {
if let Some(x) = self.min_ack_delay {
w.write_var(id as u64);
w.write_var(x.size() as u64);
w.write(x);
}
}
id => {
macro_rules! write_params {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
match id {
$(TransportParameterId::$id => {
if self.$name.0 != $default {
w.write_var(id as u64);
w.write(VarInt::try_from(self.$name.size()).unwrap());
w.write(self.$name);
}
})*,
_ => {
unimplemented!("Missing implementation of write for transport parameter with code {id:?}");
}
}
}
}
apply_params!(write_params);
}
}
}

if self.grease_quic_bit {
w.write_var(0x2ab2);
w.write_var(0);
}

if let Some(x) = self.min_ack_delay {
w.write_var(0xff04de1b);
w.write_var(x.size() as u64);
w.write(x);
}
}

/// Decode `TransportParameters` from buffer
Expand Down Expand Up @@ -385,55 +429,60 @@ impl TransportParameters {
return Err(Error::Malformed);
}
let len = len as usize;
let Ok(id) = TransportParameterId::try_from(id) else {
// unknown transport parameters are ignored
r.advance(len as usize);
continue;
};

match id {
id if TransportParameterId::OriginalDestinationConnectionId == id => {
TransportParameterId::OriginalDestinationConnectionId => {
decode_cid(len, &mut params.original_dst_cid, r)?
}
id if TransportParameterId::StatelessResetToken == id => {
TransportParameterId::StatelessResetToken => {
if len != 16 || params.stateless_reset_token.is_some() {
return Err(Error::Malformed);
}
let mut tok = [0; RESET_TOKEN_SIZE];
r.copy_to_slice(&mut tok);
params.stateless_reset_token = Some(tok.into());
}
id if TransportParameterId::DisableActiveMigration == id => {
TransportParameterId::DisableActiveMigration => {
if len != 0 || params.disable_active_migration {
return Err(Error::Malformed);
}
params.disable_active_migration = true;
}
id if TransportParameterId::PreferredAddress == id => {
TransportParameterId::PreferredAddress => {
if params.preferred_address.is_some() {
return Err(Error::Malformed);
}
params.preferred_address = Some(PreferredAddress::read(&mut r.take(len))?);
}
id if TransportParameterId::InitialSourceConnectionId == id => {
TransportParameterId::InitialSourceConnectionId => {
decode_cid(len, &mut params.initial_src_cid, r)?
}
id if TransportParameterId::RetrySourceConnectionId == id => {
TransportParameterId::RetrySourceConnectionId => {
decode_cid(len, &mut params.retry_src_cid, r)?
}
id if TransportParameterId::MaxDatagramFrameSize == id => {
TransportParameterId::MaxDatagramFrameSize => {
if len > 8 || params.max_datagram_frame_size.is_some() {
return Err(Error::Malformed);
}
params.max_datagram_frame_size = Some(r.get().unwrap());
}
id if TransportParameterId::GreaseQuicBit == id => match len {
TransportParameterId::GreaseQuicBit => match len {
0 => params.grease_quic_bit = true,
_ => return Err(Error::Malformed),
},
id if TransportParameterId::MinAckDelayDraft07 == id => {
TransportParameterId::MinAckDelayDraft07 => {
params.min_ack_delay = Some(r.get().unwrap())
}
_ => {
macro_rules! parse {
{$($(#[$doc:meta])* $name:ident ($id:ident) = $default:expr,)*} => {
match id {
$(id if TransportParameterId::$id == id => {
$(TransportParameterId::$id => {
let value = r.get::<VarInt>()?;
if len != value.size() || got.$name { return Err(Error::Malformed); }
params.$name = value.into();
Expand Down Expand Up @@ -593,12 +642,75 @@ pub(crate) enum TransportParameterId {
MinAckDelayDraft07 = 0xFF04DE1B,
}

impl TransportParameterId {
/// Array with all supported transport parameter IDs
const SUPPORTED: [Self; 21] = [
Self::MaxIdleTimeout,
Self::MaxUdpPayloadSize,
Self::InitialMaxData,
Self::InitialMaxStreamDataBidiLocal,
Self::InitialMaxStreamDataBidiRemote,
Self::InitialMaxStreamDataUni,
Self::InitialMaxStreamsBidi,
Self::InitialMaxStreamsUni,
Self::AckDelayExponent,
Self::MaxAckDelay,
Self::ActiveConnectionIdLimit,
Self::ReservedTransportParameter,
Self::StatelessResetToken,
Self::DisableActiveMigration,
Self::MaxDatagramFrameSize,
Self::PreferredAddress,
Self::OriginalDestinationConnectionId,
Self::InitialSourceConnectionId,
Self::RetrySourceConnectionId,
Self::GreaseQuicBit,
Self::MinAckDelayDraft07,
];
}

impl std::cmp::PartialEq<u64> for TransportParameterId {
fn eq(&self, other: &u64) -> bool {
*other == (*self as u64)
}
}

impl TryFrom<u64> for TransportParameterId {
type Error = ();

fn try_from(value: u64) -> Result<Self, Self::Error> {
let param = match value {
id if Self::MaxIdleTimeout == id => Self::MaxIdleTimeout,
id if Self::MaxUdpPayloadSize == id => Self::MaxUdpPayloadSize,
id if Self::InitialMaxData == id => Self::InitialMaxData,
id if Self::InitialMaxStreamDataBidiLocal == id => Self::InitialMaxStreamDataBidiLocal,
id if Self::InitialMaxStreamDataBidiRemote == id => {
Self::InitialMaxStreamDataBidiRemote
}
id if Self::InitialMaxStreamDataUni == id => Self::InitialMaxStreamDataUni,
id if Self::InitialMaxStreamsBidi == id => Self::InitialMaxStreamsBidi,
id if Self::InitialMaxStreamsUni == id => Self::InitialMaxStreamsUni,
id if Self::AckDelayExponent == id => Self::AckDelayExponent,
id if Self::MaxAckDelay == id => Self::MaxAckDelay,
id if Self::ActiveConnectionIdLimit == id => Self::ActiveConnectionIdLimit,
id if Self::ReservedTransportParameter == id => Self::ReservedTransportParameter,
id if Self::StatelessResetToken == id => Self::StatelessResetToken,
id if Self::DisableActiveMigration == id => Self::DisableActiveMigration,
id if Self::MaxDatagramFrameSize == id => Self::MaxDatagramFrameSize,
id if Self::PreferredAddress == id => Self::PreferredAddress,
id if Self::OriginalDestinationConnectionId == id => {
Self::OriginalDestinationConnectionId
}
id if Self::InitialSourceConnectionId == id => Self::InitialSourceConnectionId,
id if Self::RetrySourceConnectionId == id => Self::RetrySourceConnectionId,
id if Self::GreaseQuicBit == id => Self::GreaseQuicBit,
id if Self::MinAckDelayDraft07 == id => Self::MinAckDelayDraft07,
_ => return Err(()),
};
Ok(param)
}
}

fn decode_cid(len: usize, value: &mut Option<ConnectionId>, r: &mut impl Buf) -> Result<(), Error> {
if len > MAX_CID_SIZE || value.is_some() || r.remaining() < len {
return Err(Error::Malformed);
Expand Down

0 comments on commit ff96fda

Please sign in to comment.