diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 98cbeb37d5..d272b4103e 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -3278,6 +3278,43 @@ impl Connection { self.spaces[self.highest_space].immediate_ack_pending = true; } + /// Decodes a packet, returning its decrypted payload, so it can be inspected in tests + #[cfg(test)] + pub(crate) fn decode_packet(&self, event: &ConnectionEvent) -> Option> { + let (first_decode, remaining) = match &event.0 { + ConnectionEventInner::Datagram { + first_decode, + remaining, + .. + } => (first_decode, remaining), + _ => return None, + }; + + if remaining.is_some() { + panic!("Packets should never be coalesced in tests"); + } + + let decrypted_header = packet_crypto::unprotect_header( + first_decode.clone(), + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.peer_params.stateless_reset_token, + )?; + + let mut packet = decrypted_header.packet?; + packet_crypto::decrypt_packet_body( + &mut packet, + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.key_phase, + self.prev_crypto.as_ref(), + self.next_crypto.as_ref(), + ) + .ok()?; + + Some(packet.payload.to_vec()) + } + /// The number of bytes of packets containing retransmittable frames that have not been /// acknowledged or declared lost. #[cfg(test)] diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index f501e8593c..b70eb98f78 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -19,6 +19,7 @@ use crate::{ // to inspect the version and packet type (which depends on the version). // This information allows us to fully decode and decrypt the packet. #[allow(unreachable_pub)] // fuzzing only +#[cfg_attr(test, derive(Clone))] #[derive(Debug)] pub struct PartialDecode { plain_header: PlainHeader, @@ -234,7 +235,8 @@ impl Packet { } } -#[derive(Debug, Clone)] +#[cfg_attr(test, derive(Clone))] +#[derive(Debug)] pub(crate) enum Header { Initial { dst_cid: ConnectionId, @@ -477,7 +479,7 @@ impl PartialEncode { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) enum PlainHeader { Initial { dst_cid: ConnectionId, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 040bdf61e9..c10a526f30 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -18,6 +18,7 @@ use super::*; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, frame::FrameStruct, + transport_parameters::TransportParameters, }; mod util; use util::*; @@ -2236,6 +2237,7 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() { 0 ); + pair.client.capture_inbound_packets = true; pair.drive(); let stats_after_drive = pair.client_conn_mut(client_ch).stats(); assert_eq!( @@ -2244,7 +2246,24 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() { ); // The time is start + max_ack_delay - assert_eq!(pair.time, start + Duration::from_millis(25)); + let default_max_ack_delay_ms = TransportParameters::default().max_ack_delay.into_inner(); + assert_eq!( + pair.time, + start + Duration::from_millis(default_max_ack_delay_ms) + ); + + // The ACK delay is properly calculated + assert_eq!(pair.client.captured_packets.len(), 1); + let mut frames = + frame::Iter::new(pair.client.captured_packets.remove(0).into()).collect::>(); + assert_eq!(frames.len(), 1); + if let Frame::Ack(ack) = frames.remove(0) { + let ack_delay_exp = TransportParameters::default().ack_delay_exponent; + let delay = ack.delay << ack_delay_exp.into_inner(); + assert_eq!(delay, default_max_ack_delay_ms * 1_000); + } else { + panic!("Expected ACK frame"); + } // Sanity check: no loss probe was sent, because the delayed ACK was received on time assert_eq!( diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 9bf824417a..0cc77d119e 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -276,6 +276,8 @@ pub(super) struct TestEndpoint { accepted: Option, pub(super) connections: HashMap, conn_events: HashMap>, + pub(super) captured_packets: Vec>, + pub(super) capture_inbound_packets: bool, } impl TestEndpoint { @@ -300,6 +302,8 @@ impl TestEndpoint { accepted: None, connections: HashMap::default(), conn_events: HashMap::default(), + captured_packets: Vec::new(), + capture_inbound_packets: false, } } @@ -322,6 +326,11 @@ impl TestEndpoint { self.accepted = Some(ch); } DatagramEvent::ConnectionEvent(ch, event) => { + if self.capture_inbound_packets { + let packet = self.connections[&ch].decode_packet(&event); + self.captured_packets.extend(packet); + } + self.conn_events .entry(ch) .or_insert_with(VecDeque::new)