diff --git a/quinn-proto/src/connection/streams/mod.rs b/quinn-proto/src/connection/streams/mod.rs index 69d6b8abc..3cc898ecd 100644 --- a/quinn-proto/src/connection/streams/mod.rs +++ b/quinn-proto/src/connection/streams/mod.rs @@ -1,6 +1,6 @@ use std::{ cell::RefCell, - collections::{hash_map, BinaryHeap, VecDeque}, + collections::{BinaryHeap, VecDeque}, }; use bytes::Bytes; @@ -129,11 +129,10 @@ impl<'a> RecvStream<'a> { /// Discards unread data and notifies the peer to stop transmitting. Once stopped, further /// attempts to operate on a stream will yield `UnknownStream` errors. pub fn stop(&mut self, error_code: VarInt) -> Result<(), UnknownStream> { - let mut entry = match self.state.recv.entry(self.id) { - hash_map::Entry::Occupied(s) => s, - hash_map::Entry::Vacant(_) => return Err(UnknownStream { _private: () }), + let stream = match self.state.recv.get_mut(&self.id).and_then(|s| s.as_mut()) { + Some(s) => s, + None => return Err(UnknownStream { _private: () }), }; - let stream = entry.get_mut(); let (read_credits, stop_sending) = stream.stop()?; if stop_sending.should_transmit() { @@ -147,7 +146,7 @@ impl<'a> RecvStream<'a> { // connection-level flow control to account for discarded data. Otherwise, we can discard // state immediately. if !stream.receiving_unknown_size() { - entry.remove(); + self.state.recv.remove(&self.id); self.state.stream_freed(self.id, StreamHalf::Recv); } @@ -211,6 +210,7 @@ impl<'a> SendStream<'a> { .state .send .get_mut(&self.id) + .and_then(|s| s.as_mut()) .ok_or(WriteError::UnknownStream)?; if limit == 0 { trace!( @@ -237,7 +237,7 @@ impl<'a> SendStream<'a> { /// Check if this stream was stopped, get the reason if it was pub fn stopped(&mut self) -> Result, UnknownStream> { - match self.state.send.get(&self.id) { + match self.state.send.get(&self.id).and_then(|s| s.as_ref()) { Some(s) => Ok(s.stop_reason), None => Err(UnknownStream { _private: () }), } @@ -253,6 +253,7 @@ impl<'a> SendStream<'a> { .state .send .get_mut(&self.id) + .and_then(|s| s.as_mut()) .ok_or(FinishError::UnknownStream)?; let was_pending = stream.is_pending(); @@ -273,6 +274,7 @@ impl<'a> SendStream<'a> { .state .send .get_mut(&self.id) + .and_then(|s| s.as_mut()) .ok_or(UnknownStream { _private: () })?; if matches!(stream.state, SendState::ResetSent) { @@ -300,6 +302,7 @@ impl<'a> SendStream<'a> { .state .send .get_mut(&self.id) + .and_then(|s| s.as_mut()) .ok_or(UnknownStream { _private: () })?; stream.priority = priority; @@ -315,6 +318,7 @@ impl<'a> SendStream<'a> { .state .send .get(&self.id) + .and_then(|s| s.as_ref()) .ok_or(UnknownStream { _private: () })?; Ok(stream.priority) diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index 627666b83..0ab114de4 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -220,9 +220,9 @@ impl<'a> Chunks<'a> { Entry::Vacant(_) => return Err(ReadableError::UnknownStream), }; - let mut recv = match entry.get().stopped { + let mut recv = match entry.get().as_ref().map(|s| s.stopped).unwrap_or(true) { true => return Err(ReadableError::UnknownStream), - false => entry.remove(), + false => entry.remove().unwrap(), // this can't fail at this point }; recv.assembler.ensure_ordering(ordered)?; @@ -313,7 +313,7 @@ impl<'a> Chunks<'a> { self.pending.max_stream_data.insert(self.id); } // Return the stream to storage for future use - self.streams.recv.insert(self.id, rs); + self.streams.recv.insert(self.id, Some(rs)); } // Issue connection-level flow control credit for any data we read regardless of state @@ -331,7 +331,7 @@ impl<'a> Drop for Chunks<'a> { } enum ChunksState { - Readable(Recv), + Readable(Box), Reset(VarInt), Finished, Finalized, diff --git a/quinn-proto/src/connection/streams/state.rs b/quinn-proto/src/connection/streams/state.rs index 207778f00..c6790e8fc 100644 --- a/quinn-proto/src/connection/streams/state.rs +++ b/quinn-proto/src/connection/streams/state.rs @@ -24,8 +24,8 @@ use crate::{ pub struct StreamsState { pub(super) side: Side, // Set of streams that are currently open, or could be immediately opened by the peer - pub(super) send: FxHashMap, - pub(super) recv: FxHashMap, + pub(super) send: FxHashMap>>, + pub(super) recv: FxHashMap>>, pub(super) next: [u64; 2], /// Maximum number of locally-initiated streams that may be opened over the lifetime of the /// connection so far, per direction @@ -152,8 +152,9 @@ impl StreamsState { self.received_max_data(params.initial_max_data); for i in 0..self.max_remote[Dir::Bi as usize] { let id = StreamId::new(!self.side, Dir::Bi, i); - self.send.get_mut(&id).unwrap().max_data = - params.initial_max_stream_data_bidi_local.into(); + if let Some(s) = self.send.get_mut(&id).and_then(|s| s.as_mut()) { + s.max_data = params.initial_max_stream_data_bidi_local.into(); + } } } @@ -205,13 +206,17 @@ impl StreamsState { frame: frame::Stream, payload_len: usize, ) -> Result { - let stream = frame.id; - self.validate_receive_id(stream).map_err(|e| { + let id = frame.id; + self.validate_receive_id(id).map_err(|e| { debug!("received illegal STREAM frame"); e })?; - let rs = match self.recv.get_mut(&stream) { + let rs = match self + .recv + .get_mut(&id) + .map(|s| s.get_or_insert_with(|| Box::new(Recv::new(self.stream_receive_window)))) + { Some(rs) => rs, None => { trace!("dropping frame for closed stream"); @@ -229,14 +234,14 @@ impl StreamsState { self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { - self.on_stream_frame(true, stream); + self.on_stream_frame(true, id); return Ok(ShouldTransmit(false)); } // Stopped streams become closed instantly on FIN, so check whether we need to clean up if closed { - self.recv.remove(&stream); - self.stream_freed(stream, StreamHalf::Recv); + self.recv.remove(&id); + self.stream_freed(id, StreamHalf::Recv); } // We don't buffer data on stopped streams, so issue flow control credit immediately @@ -261,7 +266,11 @@ impl StreamsState { e })?; - let rs = match self.recv.get_mut(&id) { + let rs = match self + .recv + .get_mut(&id) + .map(|s| s.get_or_insert_with(|| Box::new(Recv::new(self.stream_receive_window)))) + { Some(stream) => stream, None => { trace!("received RESET_STREAM on closed stream"); @@ -304,7 +313,7 @@ impl StreamsState { /// Process incoming `STOP_SENDING` frame #[allow(unreachable_pub)] // fuzzing only pub fn received_stop_sending(&mut self, id: StreamId, error_code: VarInt) { - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(ss) => ss, None => return, }; @@ -320,7 +329,7 @@ impl StreamsState { match self.send.entry(id) { hash_map::Entry::Vacant(_) => {} hash_map::Entry::Occupied(e) => { - if let SendState::ResetSent = e.get().state { + if let Some(SendState::ResetSent) = e.get().as_ref().map(|s| s.state) { e.remove_entry(); self.stream_freed(id, StreamHalf::Send); } @@ -332,11 +341,12 @@ impl StreamsState { pub(crate) fn can_send_stream_data(&self) -> bool { // Reset streams may linger in the pending stream list, but will never produce stream frames self.pending.iter().any(|level| { - level - .queue - .borrow() - .iter() - .any(|id| self.send.get(id).map_or(false, |s| !s.is_reset())) + level.queue.borrow().iter().any(|id| { + self.send + .get(id) + .and_then(|s| s.as_ref()) + .map_or(false, |s| !s.is_reset()) + }) }) } @@ -344,6 +354,7 @@ impl StreamsState { pub(crate) fn can_send_flow_control(&self, id: StreamId) -> bool { self.recv .get(&id) + .and_then(|s| s.as_ref()) .map_or(false, |s| s.receiving_unknown_size()) } @@ -361,7 +372,7 @@ impl StreamsState { Some(x) => x, None => break, }; - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; @@ -428,7 +439,7 @@ impl StreamsState { None => break, }; pending.max_stream_data.remove(&id); - let rs = match self.recv.get_mut(&id) { + let rs = match self.recv.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; @@ -507,7 +518,7 @@ impl StreamsState { break; } }; - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(s) => s, // Stream was reset with pending data and the reset was acknowledged None => continue, @@ -589,11 +600,11 @@ impl StreamsState { } pub(crate) fn received_ack_of(&mut self, frame: frame::StreamMeta) { - let mut entry = match self.send.entry(frame.id) { - hash_map::Entry::Vacant(_) => return, - hash_map::Entry::Occupied(e) => e, + let stream = match self.send.get_mut(&frame.id).and_then(|s| s.as_mut()) { + None => return, + Some(s) => s, }; - let stream = entry.get_mut(); + if stream.is_reset() { // We account for outstanding data on reset streams at time of reset return; @@ -605,13 +616,13 @@ impl StreamsState { return; } - entry.remove_entry(); + self.send.remove(&id); self.stream_freed(id, StreamHalf::Send); self.events.push_back(StreamEvent::Finished { id }); } pub(crate) fn retransmit(&mut self, frame: frame::StreamMeta) { - let stream = match self.send.get_mut(&frame.id) { + let stream = match self.send.get_mut(&frame.id).and_then(|s| s.as_mut()) { // Loss of data on a closed stream is a noop None => return, Some(x) => x, @@ -627,7 +638,7 @@ impl StreamsState { for dir in Dir::iter() { for index in 0..self.next[dir as usize] { let id = StreamId::new(Side::Client, dir, index); - let stream = self.send.get_mut(&id).unwrap(); + let stream = self.send.get_mut(&id).and_then(|s| s.as_mut()).unwrap(); if stream.pending.is_fully_acked() && !stream.fin_pending { // Stream data can't be acked in 0-RTT, so we must not have sent anything on // this stream @@ -679,7 +690,7 @@ impl StreamsState { } let write_limit = self.write_limit(); - if let Some(ss) = self.send.get_mut(&id) { + if let Some(ss) = self.send.get_mut(&id).and_then(|s| s.as_mut()) { if ss.increase_max_data(offset) { if write_limit > 0 { self.events.push_back(StreamEvent::Writable { id }); @@ -716,7 +727,7 @@ impl StreamsState { if self.write_limit() > 0 { while let Some(id) = self.connection_blocked.pop() { - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { None => continue, Some(s) => s, }; @@ -799,22 +810,32 @@ impl StreamsState { pub(super) fn insert(&mut self, remote: bool, id: StreamId) { let bi = id.dir() == Dir::Bi; + // bidirectional OR (unidirectional AND NOT remote) if bi || !remote { - let max_data = match id.dir() { - Dir::Uni => self.initial_max_stream_data_uni, - // Remote/local appear reversed here because the transport parameters are named from - // the perspective of the peer. - Dir::Bi if remote => self.initial_max_stream_data_bidi_local, - Dir::Bi => self.initial_max_stream_data_bidi_remote, - }; - let stream = Send::new(max_data); - assert!(self.send.insert(id, stream).is_none()); + if remote { + assert!(self.send.insert(id, None).is_none()); + } else { + let max_data = match id.dir() { + Dir::Uni => self.initial_max_stream_data_uni, + // Remote/local appear reversed here because the transport parameters are named from + // the perspective of the peer. + Dir::Bi if remote => self.initial_max_stream_data_bidi_local, + Dir::Bi => self.initial_max_stream_data_bidi_remote, + }; + let stream = Send::new(max_data); + assert!(self.send.insert(id, Some(Box::new(stream))).is_none()); + } } + // bidirectional OR (unidirectional AND remote) if bi || remote { - assert!(self - .recv - .insert(id, Recv::new(self.stream_receive_window)) - .is_none()); + if remote { + assert!(self.recv.insert(id, None).is_none()); + } else { + assert!(self + .recv + .insert(id, Some(Box::new(Recv::new(self.stream_receive_window)))) + .is_none()); + } } }