From e1e1a0000a3c112a118a3235f3c216eb9e741530 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 20 Sep 2023 16:32:30 +0200 Subject: [PATCH] More principled error handling for invalid frames --- quinn-proto/src/connection/mod.rs | 23 ++++++++++------------ quinn-proto/src/connection/stats.rs | 1 - quinn-proto/src/frame.rs | 30 +++++++++++++++++++++-------- quinn-proto/src/tests/mod.rs | 5 +++-- 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index c42cbebbd8..d785eb4661 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -2192,7 +2192,12 @@ impl Connection { return Ok(()); } State::Closed(_) => { - for frame in frame::Iter::new(packet.payload.freeze()) { + for result in frame::Iter::new(packet.payload.freeze()) { + let Ok(frame) = result else { + trace!("invalid frame"); + continue; + }; + if let Frame::Padding = frame { continue; }; @@ -2433,7 +2438,8 @@ impl Connection { debug_assert_ne!(packet.header.space(), SpaceId::Data); let payload_len = packet.payload.len(); let mut ack_eliciting = false; - for frame in frame::Iter::new(packet.payload.freeze()) { + for result in frame::Iter::new(packet.payload.freeze()) { + let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), @@ -2458,11 +2464,6 @@ impl Connection { self.state = State::Draining; return Ok(()); } - Frame::Invalid { ty, reason } => { - let mut err = TransportError::FRAME_ENCODING_ERROR(reason); - err.frame = Some(ty); - return Err(err); - } _ => { let mut err = TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake"); @@ -2495,7 +2496,8 @@ impl Connection { let mut close = None; let payload_len = payload.len(); let mut ack_eliciting = false; - for frame in frame::Iter::new(payload) { + for result in frame::Iter::new(payload) { + let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), @@ -2543,11 +2545,6 @@ impl Connection { } } match frame { - Frame::Invalid { ty, reason } => { - let mut err = TransportError::FRAME_ENCODING_ERROR(reason); - err.frame = Some(ty); - return Err(err); - } Frame::Crypto(frame) => { self.read_crypto(SpaceId::Data, &frame, payload_len)?; } diff --git a/quinn-proto/src/connection/stats.rs b/quinn-proto/src/connection/stats.rs index c04b7dbe65..8ec23e80e4 100644 --- a/quinn-proto/src/connection/stats.rs +++ b/quinn-proto/src/connection/stats.rs @@ -86,7 +86,6 @@ impl FrameStats { Frame::AckFrequency(_) => self.ack_frequency += 1, Frame::ImmediateAck => self.immediate_ack += 1, Frame::HandshakeDone => self.handshake_done += 1, - Frame::Invalid { .. } => {} } } } diff --git a/quinn-proto/src/frame.rs b/quinn-proto/src/frame.rs index 8c742032f2..cb8d6ac1ee 100644 --- a/quinn-proto/src/frame.rs +++ b/quinn-proto/src/frame.rs @@ -162,7 +162,6 @@ pub(crate) enum Frame { Datagram(Datagram), AckFrequency(AckFrequency), ImmediateAck, - Invalid { ty: Type, reason: &'static str }, HandshakeDone, } @@ -204,7 +203,6 @@ impl Frame { Datagram(_) => Type(*DATAGRAM_TYS.start()), AckFrequency(_) => Type::ACK_FREQUENCY, ImmediateAck => Type::IMMEDIATE_ACK, - Invalid { ty, .. } => ty, HandshakeDone => Type::HANDSHAKE_DONE, } } @@ -734,25 +732,39 @@ impl Iter { } impl Iterator for Iter { - type Item = Frame; + type Item = Result; fn next(&mut self) -> Option { if !self.bytes.has_remaining() { return None; } match self.try_next() { - Ok(x) => Some(x), + Ok(x) => Some(Ok(x)), Err(e) => { // Corrupt frame, skip it and everything that follows self.bytes = io::Cursor::new(Bytes::new()); - Some(Frame::Invalid { - ty: self.last_ty.unwrap(), + Some(Err(InvalidFrame { + ty: self.last_ty, reason: e.reason(), - }) + })) } } } } +#[derive(Debug)] +pub(crate) struct InvalidFrame { + pub(crate) ty: Option, + pub(crate) reason: &'static str, +} + +impl From for TransportError { + fn from(err: InvalidFrame) -> Self { + let mut te = TransportError::FRAME_ENCODING_ERROR(err.reason); + te.frame = err.ty; + te + } +} + fn scan_ack_blocks(buf: &mut io::Cursor, largest: u64, n: usize) -> Result<(), IterErr> { let first_block = buf.get_var()?; let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?; @@ -910,7 +922,9 @@ mod test { use assert_matches::assert_matches; fn frames(buf: Vec) -> Vec { - Iter::new(Bytes::from(buf)).collect::>() + Iter::new(Bytes::from(buf)) + .collect::, _>>() + .unwrap() } #[test] diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index ee963d633a..6f69e3303c 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -2256,8 +2256,9 @@ fn single_ack_eliciting_packet_triggers_ack_after_delay() { // The ACK delay is properly calculated assert_eq!(pair.client.captured_packets.len(), 1); - let mut frames = - frame::Iter::new(pair.client.captured_packets.remove(0).into()).collect::>(); + let mut frames = frame::Iter::new(pair.client.captured_packets.remove(0).into()) + .collect::, _>>() + .unwrap(); assert_eq!(frames.len(), 1); if let Frame::Ack(ack) = frames.remove(0) { let ack_delay_exp = TransportParameters::default().ack_delay_exponent;