Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't fail when finish is called on a stopped stream #1699

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions quinn-proto/src/connection/streams/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,36 @@ impl Send {
matches!(self.state, SendState::ResetSent { .. })
}

/// Initiate a graceful close of the stream, ensuring all sent data has been acknowledged by the remote.
pub(super) fn finish(&mut self) -> Result<(), FinishError> {
if let Some(error_code) = self.stop_reason {
Err(FinishError::Stopped(error_code))
} else if self.state == SendState::Ready {
tracing::debug!(%error_code, "Stream is already stopped");

if !self.pending.is_fully_acked() {
// Remote stopped the stream before ack-ing all sent data.
// Return error to indicate that a graceful close failed.
return Err(FinishError::Stopped(error_code));
}

// Remote stopped the stream but acked all sent data before doing so.
// We pretend that the remote acked our `FIN`.
// Actually trying to send it would fail because the remote has stopped the stream already.

self.state = SendState::DataSent { finish_acked: true };

return Ok(());
}

if self.state == SendState::Ready {
self.state = SendState::DataSent {
finish_acked: false,
};
self.fin_pending = true;
Ok(())
} else {
Err(FinishError::UnknownStream)

return Ok(());
}

Err(FinishError::UnknownStream)
}

pub(super) fn write<S: BytesSource>(
Expand Down
21 changes: 14 additions & 7 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ fn stop_stream() {
);
assert_matches!(
pair.client_send(client_ch, s).finish(),
Err(FinishError::Stopped(ERROR))
Ok(()) // No unacknowledged data, hence `finish` succeeds.
);
}

Expand Down Expand Up @@ -1519,18 +1519,25 @@ fn stop_before_finish() {
let (client_ch, server_ch) = pair.connect();

let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap();
const MSG: &[u8] = b"hello";
pair.client_send(client_ch, s).write(MSG).unwrap();
pair.drive();

info!("stopping stream");
const ERROR: VarInt = VarInt(42);

// 1. Stop the stream (but don't send anything yet)
pair.server_recv(server_ch, s).stop(ERROR).unwrap();
pair.drive();

// 2. Queue data to be sent (client has no idea server has stopped the stream already)
pair.client_send(client_ch, s).write(b"hello1").unwrap();

// 3. Transmit `STOP_SENDING` to client
pair.drive_server();

// 4. Actually send data to server, now waiting for ACKs
pair.drive_client();

// 5. Finish the stream
assert_matches!(
pair.client_send(client_ch, s).finish(),
Err(FinishError::Stopped(ERROR))
Err(FinishError::Stopped(ERROR)) // Fails because we have unacknowledged data.
);
}

Expand Down
47 changes: 46 additions & 1 deletion quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ async fn echo((mut send, mut recv): (SendStream, RecvStream)) {
loop {
// These are 32 buffers, for reading approximately 32kB at once
#[rustfmt::skip]
let mut bufs = [
let mut bufs = [
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Expand Down Expand Up @@ -744,3 +744,48 @@ async fn two_datagram_readers() {
assert!(*a == *b"one" || *b == *b"one");
assert!(*a == *b"two" || *b == *b"two");
}

#[tokio::test]
async fn finish_finished_stream_no_error() {
let _guard = subscribe();
let endpoint = endpoint();

let (client, server) = tokio::join!(
async {
endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.await
.unwrap()
},
async { endpoint.accept().await.unwrap().await.unwrap() }
);

let client = async {
let (mut send_stream, mut recv_stream) = client.open_bi().await.unwrap();
send_stream.write_all(b"request").await.unwrap();

let mut buf = [0u8; 8];
recv_stream.read_exact(&mut buf).await.unwrap();

assert_eq!(&buf, b"response");

tokio::time::sleep(Duration::from_millis(100)).await; // Simulate some more processing of the response.

send_stream.finish().await.unwrap(); // Be a good citizen and close stream instead of dropping.
};

let server = async {
let (mut send_stream, mut recv_stream) = server.accept_bi().await.unwrap();

let mut buf = [0u8; 7];
recv_stream.read_exact(&mut buf).await.unwrap();

assert_eq!(&buf, b"request");

send_stream.write_all(b"response").await.unwrap();
send_stream.finish().await.unwrap();
};

tokio::join!(client, server);
}
Loading