diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 8f4ab14dbe..3be913f722 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -54,7 +54,7 @@ mod packet_builder; use packet_builder::PacketBuilder; mod packet_crypto; -use packet_crypto::ZeroRttCrypto; +use packet_crypto::{PrevCrypto, ZeroRttCrypto}; mod paths; use paths::PathData; @@ -3230,78 +3230,33 @@ impl Connection { now: Instant, packet: &mut Packet, ) -> Result, Option> { - if !packet.header.is_protected() { - // Unprotected packets also don't have packet numbers - return Ok(None); - } - let space = packet.header.space(); - let rx_packet = self.spaces[space].rx_packet; - let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); - let key_phase = packet.header.key_phase(); - - let mut crypto_update = false; - let crypto = if packet.header.is_0rtt() { - &self.zero_rtt_crypto.as_ref().unwrap().packet - } else if key_phase == self.key_phase || space != SpaceId::Data { - &self.spaces[space].crypto.as_mut().unwrap().packet.remote - } else if let Some(prev) = self.prev_crypto.as_ref().and_then(|crypto| { - // If this packet comes prior to acknowledgment of the key update by the peer, - if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { - // use the previous keys. - Some(crypto) - } else { - // Otherwise, this must be a remotely-initiated key update, so fall through to the - // final case. - None - } - }) { - &prev.crypto.remote - } else { - // We're in the Data space with a key phase mismatch and either there is no locally - // initiated key update or the locally initiated key update was acknowledged by a - // lower-numbered packet. The key phase mismatch must therefore represent a new - // remotely-initiated key update. - crypto_update = true; - &self.next_crypto.as_ref().unwrap().remote - }; - - crypto - .decrypt(number, &packet.header_data, &mut packet.payload) - .map_err(|_| { - trace!("decryption failed with packet number {}", number); - None - })?; + let result = packet_crypto::decrypt_packet_body( + packet, + &self.spaces, + self.zero_rtt_crypto.as_ref(), + self.key_phase, + self.prev_crypto.as_ref(), + self.next_crypto.as_ref(), + )?; - if let Some(ref mut prev) = self.prev_crypto { - if prev.end_packet.is_none() && key_phase == self.key_phase { - // Outgoing key update newly acknowledged - prev.end_packet = Some((number, now)); - self.set_key_discard_timer(now, space); + let pn = result.map(|d| { + if d.outgoing_key_update_acked { + if let Some(prev) = self.prev_crypto.as_mut() { + prev.end_packet = Some((d.number, now)); + self.set_key_discard_timer(now, packet.header.space()); + } } - } - if !packet.reserved_bits_valid() { - return Err(Some(TransportError::PROTOCOL_VIOLATION( - "reserved bits set", - ))); - } - - if crypto_update { - // Validate and commit incoming key update - if number <= rx_packet - || self - .prev_crypto - .as_ref() - .map_or(false, |x| x.update_unacked) - { - return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); + if d.incoming_key_update { + trace!("key update authenticated"); + self.update_keys(Some((d.number, now)), true); + self.set_key_discard_timer(now, packet.header.space()); } - trace!("key update authenticated"); - self.update_keys(Some((number, now)), true); - self.set_key_discard_timer(now, space); - } - Ok(Some(number)) + d.number + }); + + Ok(pn) } fn update_keys(&mut self, end_packet: Option<(u64, Instant)>, remote: bool) { @@ -3565,21 +3520,6 @@ mod state { } } -struct PrevCrypto { - /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by - /// the peer prior to its own key update. - crypto: KeyPair>, - /// The incoming packet that ends the interval for which these keys are applicable, and the time - /// of its receipt. - /// - /// Incoming packets should be decrypted using these keys iff this is `None` or their packet - /// number is lower. `None` indicates that we have not yet received a packet using newer keys, - /// which implies that the update was locally initiated. - end_packet: Option<(u64, Instant)>, - /// Whether the following key phase is from a remotely initiated update that we haven't acked - update_unacked: bool, -} - struct InFlight { /// Sum of the sizes of all sent packets considered "in flight" by congestion control /// diff --git a/quinn-proto/src/connection/packet_crypto.rs b/quinn-proto/src/connection/packet_crypto.rs index bb8ea384d0..7c248f1ea7 100644 --- a/quinn-proto/src/connection/packet_crypto.rs +++ b/quinn-proto/src/connection/packet_crypto.rs @@ -1,10 +1,11 @@ +use std::time::Instant; use tracing::{debug, trace}; use crate::connection::spaces::PacketSpace; -use crate::crypto::{HeaderKey, PacketKey}; -use crate::packet::{Packet, PartialDecode}; +use crate::crypto::{HeaderKey, KeyPair, PacketKey}; +use crate::packet::{Packet, PartialDecode, SpaceId}; use crate::token::ResetToken; -use crate::RESET_TOKEN_SIZE; +use crate::{TransportError, RESET_TOKEN_SIZE}; pub(super) struct UnprotectHeaderResult { /// The packet with the now unprotected header (`None` in the case of stateless reset packets @@ -64,6 +65,108 @@ pub(super) fn unprotect_header( } } +pub(super) struct DecryptPacketResult { + /// The packet number + pub(super) number: u64, + /// Whether a locally initiated key update has been acknowledged by the peer + pub(super) outgoing_key_update_acked: bool, + /// Whether the peer has initiated a key update + pub(super) incoming_key_update: bool, +} + +/// Decrypts a packet's body in-place +pub(super) fn decrypt_packet_body( + packet: &mut Packet, + spaces: &[PacketSpace; 3], + zero_rtt_crypto: Option<&ZeroRttCrypto>, + conn_key_phase: bool, + prev_crypto: Option<&PrevCrypto>, + next_crypto: Option<&KeyPair>>, +) -> Result, Option> { + if !packet.header.is_protected() { + // Unprotected packets also don't have packet numbers + return Ok(None); + } + let space = packet.header.space(); + let rx_packet = spaces[space].rx_packet; + let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); + let packet_key_phase = packet.header.key_phase(); + + let mut crypto_update = false; + let crypto = if packet.header.is_0rtt() { + &zero_rtt_crypto.unwrap().packet + } else if packet_key_phase == conn_key_phase || space != SpaceId::Data { + &spaces[space].crypto.as_ref().unwrap().packet.remote + } else if let Some(prev) = prev_crypto.and_then(|crypto| { + // If this packet comes prior to acknowledgment of the key update by the peer, + if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { + // use the previous keys. + Some(crypto) + } else { + // Otherwise, this must be a remotely-initiated key update, so fall through to the + // final case. + None + } + }) { + &prev.crypto.remote + } else { + // We're in the Data space with a key phase mismatch and either there is no locally + // initiated key update or the locally initiated key update was acknowledged by a + // lower-numbered packet. The key phase mismatch must therefore represent a new + // remotely-initiated key update. + crypto_update = true; + &next_crypto.unwrap().remote + }; + + crypto + .decrypt(number, &packet.header_data, &mut packet.payload) + .map_err(|_| { + trace!("decryption failed with packet number {}", number); + None + })?; + + if !packet.reserved_bits_valid() { + return Err(Some(TransportError::PROTOCOL_VIOLATION( + "reserved bits set", + ))); + } + + let mut outgoing_key_update_acked = false; + if let Some(prev) = prev_crypto { + if prev.end_packet.is_none() && packet_key_phase == conn_key_phase { + outgoing_key_update_acked = true; + } + } + + if crypto_update { + // Validate incoming key update + if number <= rx_packet || prev_crypto.map_or(false, |x| x.update_unacked) { + return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); + } + } + + Ok(Some(DecryptPacketResult { + number, + outgoing_key_update_acked, + incoming_key_update: crypto_update, + })) +} + +pub(super) struct PrevCrypto { + /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by + /// the peer prior to its own key update. + pub(super) crypto: KeyPair>, + /// The incoming packet that ends the interval for which these keys are applicable, and the time + /// of its receipt. + /// + /// Incoming packets should be decrypted using these keys iff this is `None` or their packet + /// number is lower. `None` indicates that we have not yet received a packet using newer keys, + /// which implies that the update was locally initiated. + pub(super) end_packet: Option<(u64, Instant)>, + /// Whether the following key phase is from a remotely initiated update that we haven't acked + pub(super) update_unacked: bool, +} + pub(super) struct ZeroRttCrypto { pub(super) header: Box, pub(super) packet: Box,