diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 7a9867c62b..513b3ae05a 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -19,7 +19,7 @@ use crate::{ cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, - crypto::{self, HeaderKey, KeyPair, Keys, PacketKey}, + crypto::{self, KeyPair, Keys, PacketKey}, frame, frame::{Close, Datagram, FrameStruct}, packet::{Header, LongType, Packet, PartialDecode, SpaceId}, @@ -31,7 +31,7 @@ use crate::{ token::ResetToken, transport_parameters::TransportParameters, Dir, EndpointConfig, Frame, Side, StreamId, Transmit, TransportError, TransportErrorCode, - VarInt, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, TIMER_GRANULARITY, + VarInt, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; @@ -53,6 +53,9 @@ mod pacing; mod packet_builder; use packet_builder::PacketBuilder; +mod packet_crypto; +use packet_crypto::{PrevCrypto, ZeroRttCrypto}; + mod paths; use paths::PathData; pub use paths::RttEstimator; @@ -2015,40 +2018,13 @@ impl Connection { ecn: Option, partial_decode: PartialDecode, ) { - let header_crypto = if partial_decode.is_0rtt() { - if let Some(ref crypto) = self.zero_rtt_crypto { - Some(&*crypto.header) - } else { - debug!("dropping unexpected 0-RTT packet"); - return; - } - } else if let Some(space) = partial_decode.space() { - if let Some(ref crypto) = self.spaces[space].crypto { - Some(&*crypto.header.remote) - } else { - debug!( - "discarding unexpected {:?} packet ({} bytes)", - space, - partial_decode.len(), - ); - return; - } - } else { - // Unprotected packet - None - }; - - let packet = partial_decode.data(); - let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5 - && self.peer_params.stateless_reset_token.as_deref() - == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]); - - match partial_decode.finish(header_crypto) { - Ok(packet) => self.handle_packet(now, remote, ecn, Some(packet), stateless_reset), - Err(_) if stateless_reset => self.handle_packet(now, remote, ecn, None, true), - Err(e) => { - trace!("unable to complete packet decoding: {}", e); - } + if let Some(decoded) = packet_crypto::unprotect_header( + partial_decode, + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.peer_params.stateless_reset_token, + ) { + self.handle_packet(now, remote, ecn, decoded.packet, decoded.stateless_reset); } } @@ -3252,78 +3228,33 @@ impl Connection { now: Instant, packet: &mut Packet, ) -> Result, Option> { - if !packet.header.is_protected() { - // Unprotected packets also don't have packet numbers - return Ok(None); - } - let space = packet.header.space(); - let rx_packet = self.spaces[space].rx_packet; - let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); - let key_phase = packet.header.key_phase(); - - let mut crypto_update = false; - let crypto = if packet.header.is_0rtt() { - &self.zero_rtt_crypto.as_ref().unwrap().packet - } else if key_phase == self.key_phase || space != SpaceId::Data { - &self.spaces[space].crypto.as_mut().unwrap().packet.remote - } else if let Some(prev) = self.prev_crypto.as_ref().and_then(|crypto| { - // If this packet comes prior to acknowledgment of the key update by the peer, - if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { - // use the previous keys. - Some(crypto) - } else { - // Otherwise, this must be a remotely-initiated key update, so fall through to the - // final case. - None - } - }) { - &prev.crypto.remote - } else { - // We're in the Data space with a key phase mismatch and either there is no locally - // initiated key update or the locally initiated key update was acknowledged by a - // lower-numbered packet. The key phase mismatch must therefore represent a new - // remotely-initiated key update. - crypto_update = true; - &self.next_crypto.as_ref().unwrap().remote - }; - - crypto - .decrypt(number, &packet.header_data, &mut packet.payload) - .map_err(|_| { - trace!("decryption failed with packet number {}", number); - None - })?; + let result = packet_crypto::decrypt_packet_body( + packet, + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.key_phase, + self.prev_crypto.as_ref(), + self.next_crypto.as_ref(), + )?; - if let Some(ref mut prev) = self.prev_crypto { - if prev.end_packet.is_none() && key_phase == self.key_phase { - // Outgoing key update newly acknowledged - prev.end_packet = Some((number, now)); - self.set_key_discard_timer(now, space); + let pn = result.map(|d| { + if d.outgoing_key_update_acked { + if let Some(prev) = self.prev_crypto.as_mut() { + prev.end_packet = Some((d.number, now)); + self.set_key_discard_timer(now, packet.header.space()); + } } - } - if !packet.reserved_bits_valid() { - return Err(Some(TransportError::PROTOCOL_VIOLATION( - "reserved bits set", - ))); - } - - if crypto_update { - // Validate and commit incoming key update - if number <= rx_packet - || self - .prev_crypto - .as_ref() - .map_or(false, |x| x.update_unacked) - { - return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); + if d.incoming_key_update { + trace!("key update authenticated"); + self.update_keys(Some((d.number, now)), true); + self.set_key_discard_timer(now, packet.header.space()); } - trace!("key update authenticated"); - self.update_keys(Some((number, now)), true); - self.set_key_discard_timer(now, space); - } - Ok(Some(number)) + d.number + }); + + Ok(pn) } fn update_keys(&mut self, end_packet: Option<(u64, Instant)>, remote: bool) { @@ -3355,6 +3286,43 @@ impl Connection { self.peer_params.min_ack_delay.is_some() } + /// Decodes a packet, returning its decrypted payload, so it can be inspected in tests + #[cfg(test)] + pub(crate) fn decode_packet(&self, event: &ConnectionEvent) -> Option> { + if let ConnectionEventInner::Datagram { + first_decode, + remaining, + .. + } = &event.0 + { + if remaining.is_some() { + panic!("Packets should never be coalesced in tests"); + } + + let decrypted_header = packet_crypto::unprotect_header( + first_decode.clone(), + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.peer_params.stateless_reset_token, + )?; + + let mut packet = decrypted_header.packet?; + packet_crypto::decrypt_packet_body( + &mut packet, + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.key_phase, + self.prev_crypto.as_ref(), + self.next_crypto.as_ref(), + ) + .ok()?; + + return Some(packet.payload.to_vec()); + } + + None + } + /// The number of bytes of packets containing retransmittable frames that have not been /// acknowledged or declared lost. #[cfg(test)] @@ -3587,21 +3555,6 @@ mod state { } } -struct PrevCrypto { - /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by - /// the peer prior to its own key update. - crypto: KeyPair>, - /// The incoming packet that ends the interval for which these keys are applicable, and the time - /// of its receipt. - /// - /// Incoming packets should be decrypted using these keys iff this is `None` or their packet - /// number is lower. `None` indicates that we have not yet received a packet using newer keys, - /// which implies that the update was locally initiated. - end_packet: Option<(u64, Instant)>, - /// Whether the following key phase is from a remotely initiated update that we haven't acked - update_unacked: bool, -} - struct InFlight { /// Sum of the sizes of all sent packets considered "in flight" by congestion control /// @@ -3679,11 +3632,6 @@ const MIN_PACKET_SPACE: usize = 40; /// that numbers around 10 are a good compromise. const MAX_TRANSMIT_SEGMENTS: usize = 10; -struct ZeroRttCrypto { - header: Box, - packet: Box, -} - #[derive(Default)] struct SentFrames { retransmits: ThinRetransmits, diff --git a/quinn-proto/src/connection/packet_crypto.rs b/quinn-proto/src/connection/packet_crypto.rs new file mode 100644 index 0000000000..7c248f1ea7 --- /dev/null +++ b/quinn-proto/src/connection/packet_crypto.rs @@ -0,0 +1,173 @@ +use std::time::Instant; +use tracing::{debug, trace}; + +use crate::connection::spaces::PacketSpace; +use crate::crypto::{HeaderKey, KeyPair, PacketKey}; +use crate::packet::{Packet, PartialDecode, SpaceId}; +use crate::token::ResetToken; +use crate::{TransportError, RESET_TOKEN_SIZE}; + +pub(super) struct UnprotectHeaderResult { + /// The packet with the now unprotected header (`None` in the case of stateless reset packets + /// that fail to be decoded) + pub(super) packet: Option, + /// Whether the packet was a stateless reset packet + pub(super) stateless_reset: bool, +} + +/// Removes header protection of a packet, or returns `None` if the packet was dropped +pub(super) fn unprotect_header( + partial_decode: PartialDecode, + spaces: &[PacketSpace; 3], + zero_rtt_crypto: Option<&ZeroRttCrypto>, + stateless_reset_token: Option, +) -> Option { + let header_crypto = if partial_decode.is_0rtt() { + if let Some(crypto) = zero_rtt_crypto { + Some(&*crypto.header) + } else { + debug!("dropping unexpected 0-RTT packet"); + return None; + } + } else if let Some(space) = partial_decode.space() { + if let Some(ref crypto) = spaces[space].crypto { + Some(&*crypto.header.remote) + } else { + debug!( + "discarding unexpected {:?} packet ({} bytes)", + space, + partial_decode.len(), + ); + return None; + } + } else { + // Unprotected packet + None + }; + + let packet = partial_decode.data(); + let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5 + && stateless_reset_token.as_deref() == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]); + + match partial_decode.finish(header_crypto) { + Ok(packet) => Some(UnprotectHeaderResult { + packet: Some(packet), + stateless_reset, + }), + Err(_) if stateless_reset => Some(UnprotectHeaderResult { + packet: None, + stateless_reset: true, + }), + Err(e) => { + trace!("unable to complete packet decoding: {}", e); + None + } + } +} + +pub(super) struct DecryptPacketResult { + /// The packet number + pub(super) number: u64, + /// Whether a locally initiated key update has been acknowledged by the peer + pub(super) outgoing_key_update_acked: bool, + /// Whether the peer has initiated a key update + pub(super) incoming_key_update: bool, +} + +/// Decrypts a packet's body in-place +pub(super) fn decrypt_packet_body( + packet: &mut Packet, + spaces: &[PacketSpace; 3], + zero_rtt_crypto: Option<&ZeroRttCrypto>, + conn_key_phase: bool, + prev_crypto: Option<&PrevCrypto>, + next_crypto: Option<&KeyPair>>, +) -> Result, Option> { + if !packet.header.is_protected() { + // Unprotected packets also don't have packet numbers + return Ok(None); + } + let space = packet.header.space(); + let rx_packet = spaces[space].rx_packet; + let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); + let packet_key_phase = packet.header.key_phase(); + + let mut crypto_update = false; + let crypto = if packet.header.is_0rtt() { + &zero_rtt_crypto.unwrap().packet + } else if packet_key_phase == conn_key_phase || space != SpaceId::Data { + &spaces[space].crypto.as_ref().unwrap().packet.remote + } else if let Some(prev) = prev_crypto.and_then(|crypto| { + // If this packet comes prior to acknowledgment of the key update by the peer, + if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { + // use the previous keys. + Some(crypto) + } else { + // Otherwise, this must be a remotely-initiated key update, so fall through to the + // final case. + None + } + }) { + &prev.crypto.remote + } else { + // We're in the Data space with a key phase mismatch and either there is no locally + // initiated key update or the locally initiated key update was acknowledged by a + // lower-numbered packet. The key phase mismatch must therefore represent a new + // remotely-initiated key update. + crypto_update = true; + &next_crypto.unwrap().remote + }; + + crypto + .decrypt(number, &packet.header_data, &mut packet.payload) + .map_err(|_| { + trace!("decryption failed with packet number {}", number); + None + })?; + + if !packet.reserved_bits_valid() { + return Err(Some(TransportError::PROTOCOL_VIOLATION( + "reserved bits set", + ))); + } + + let mut outgoing_key_update_acked = false; + if let Some(prev) = prev_crypto { + if prev.end_packet.is_none() && packet_key_phase == conn_key_phase { + outgoing_key_update_acked = true; + } + } + + if crypto_update { + // Validate incoming key update + if number <= rx_packet || prev_crypto.map_or(false, |x| x.update_unacked) { + return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); + } + } + + Ok(Some(DecryptPacketResult { + number, + outgoing_key_update_acked, + incoming_key_update: crypto_update, + })) +} + +pub(super) struct PrevCrypto { + /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by + /// the peer prior to its own key update. + pub(super) crypto: KeyPair>, + /// The incoming packet that ends the interval for which these keys are applicable, and the time + /// of its receipt. + /// + /// Incoming packets should be decrypted using these keys iff this is `None` or their packet + /// number is lower. `None` indicates that we have not yet received a packet using newer keys, + /// which implies that the update was locally initiated. + pub(super) end_packet: Option<(u64, Instant)>, + /// Whether the following key phase is from a remotely initiated update that we haven't acked + pub(super) update_unacked: bool, +} + +pub(super) struct ZeroRttCrypto { + pub(super) header: Box, + pub(super) packet: Box, +} diff --git a/quinn-proto/src/connection/spaces.rs b/quinn-proto/src/connection/spaces.rs index d65d1c3255..789c5ffa68 100644 --- a/quinn-proto/src/connection/spaces.rs +++ b/quinn-proto/src/connection/spaces.rs @@ -14,7 +14,7 @@ use crate::{ shared::IssuedCid, Dir, StreamId, VarInt, }; -pub(super) struct PacketSpace { +pub(crate) struct PacketSpace { pub(super) crypto: Option, pub(super) dedup: Dedup, /// Highest received packet number diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index 576f2b93ae..b6d9ecf0fb 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -19,7 +19,7 @@ use crate::{ // to inspect the version and packet type (which depends on the version). // This information allows us to fully decode and decrypt the packet. #[allow(unreachable_pub)] // fuzzing only -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct PartialDecode { plain_header: PlainHeader, buf: io::Cursor, @@ -477,7 +477,7 @@ impl PartialEncode { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) enum PlainHeader { Initial { dst_cid: ConnectionId, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 8ac2e16c0c..55e31a7436 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -18,6 +18,7 @@ use super::*; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, frame::FrameStruct, + transport_parameters::TransportParameters, }; mod util; use util::*; @@ -2236,6 +2237,7 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() { 0 ); + pair.client.capture_inbound_packets = true; pair.drive(); let stats_after_drive = pair.client_conn_mut(client_ch).stats(); assert_eq!( @@ -2246,6 +2248,19 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() { // The time is start + max_ack_delay assert_eq!(pair.time, start + Duration::from_millis(25)); + // The ACK delay is properly calculated + assert_eq!(pair.client.captured_packets.len(), 1); + let mut frames = + frame::Iter::new(pair.client.captured_packets.remove(0).into()).collect::>(); + assert_eq!(frames.len(), 1); + if let Frame::Ack(ack) = frames.remove(0) { + let ack_delay_exp = TransportParameters::default().ack_delay_exponent; + let delay = ack.delay << ack_delay_exp.into_inner(); + assert_eq!(delay, 25_000); + } else { + panic!("Expected ACK frame"); + } + // Sanity check: no loss probe was sent, because the delayed ACK was received on time assert_eq!( stats_after_drive.frame_tx.ping - stats_after_connect.frame_tx.ping, diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 00cc951a7b..965c13af48 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -271,6 +271,8 @@ pub(super) struct TestEndpoint { accepted: Option, pub(super) connections: HashMap, conn_events: HashMap>, + pub(super) captured_packets: Vec>, + pub(super) capture_inbound_packets: bool, } impl TestEndpoint { @@ -295,6 +297,8 @@ impl TestEndpoint { accepted: None, connections: HashMap::default(), conn_events: HashMap::default(), + captured_packets: Vec::new(), + capture_inbound_packets: false, } } @@ -320,6 +324,11 @@ impl TestEndpoint { self.accepted = Some(ch); } DatagramEvent::ConnectionEvent(event) => { + if self.capture_inbound_packets { + let packet = self.connections[&ch].decode_packet(&event); + self.captured_packets.extend(packet); + } + self.conn_events .entry(ch) .or_insert_with(VecDeque::new)