Skip to content

Commit

Permalink
Flush writers before potentially expecting a response
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdedude authored and cpu committed Oct 24, 2023
1 parent 53adb9d commit ecc6cde
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 26 deletions.
2 changes: 2 additions & 0 deletions rustls/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ impl<Data> ConnectionCommon<Data> {
while self.wants_write() {
wrlen += self.write_tls(io)?;
}
io.flush()?;

if !until_handshaked && wrlen > 0 {
return Ok((rdlen, wrlen));
Expand Down Expand Up @@ -411,6 +412,7 @@ impl<Data> ConnectionCommon<Data> {
// try a last-gasp write -- but don't predate the primary
// error.
let _ignored = self.write_tls(io);
let _ignored = io.flush();

return Err(io::Error::new(io::ErrorKind::InvalidData, e));
}
Expand Down
121 changes: 95 additions & 26 deletions rustls/tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,8 @@ where
fail_ok: bool,
pub short_writes: bool,
pub last_error: Option<rustls::Error>,
pub buffered: bool,
buffer: Vec<Vec<u8>>,
}

impl<'a, C, S> OtherSession<'a, C, S>
Expand All @@ -1339,41 +1341,24 @@ where
fail_ok: false,
short_writes: false,
last_error: None,
buffered: false,
buffer: vec![],
}
}

fn new_fails(sess: &'a mut C) -> OtherSession<'a, C, S> {
fn new_buffered(sess: &'a mut C) -> OtherSession<'a, C, S> {
let mut os = OtherSession::new(sess);
os.fail_ok = true;
os.buffered = true;
os
}
}

impl<'a, C, S> io::Read for OtherSession<'a, C, S>
where
C: DerefMut + Deref<Target = ConnectionCommon<S>>,
S: SideData,
{
fn read(&mut self, mut b: &mut [u8]) -> io::Result<usize> {
self.reads += 1;
self.sess.write_tls(b.by_ref())
}
}

impl<'a, C, S> io::Write for OtherSession<'a, C, S>
where
C: DerefMut + Deref<Target = ConnectionCommon<S>>,
S: SideData,
{
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
unreachable!()
}

fn flush(&mut self) -> io::Result<()> {
Ok(())
fn new_fails(sess: &'a mut C) -> OtherSession<'a, C, S> {
let mut os = OtherSession::new(sess);
os.fail_ok = true;
os
}

fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
fn flush_vectored(&mut self, b: &[io::IoSlice<'_>]) -> io::Result<usize> {
let mut total = 0;
let mut lengths = vec![];
for bytes in b {
Expand Down Expand Up @@ -1409,6 +1394,48 @@ where
}
}

impl<'a, C, S> io::Read for OtherSession<'a, C, S>
where
C: DerefMut + Deref<Target = ConnectionCommon<S>>,
S: SideData,
{
fn read(&mut self, mut b: &mut [u8]) -> io::Result<usize> {
self.reads += 1;
self.sess.write_tls(b.by_ref())
}
}

impl<'a, C, S> io::Write for OtherSession<'a, C, S>
where
C: DerefMut + Deref<Target = ConnectionCommon<S>>,
S: SideData,
{
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
unreachable!()
}

fn flush(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
let buffer = mem::take(&mut self.buffer);
let slices = buffer
.iter()
.map(|b| io::IoSlice::new(b))
.collect::<Vec<_>>();
self.flush_vectored(&slices)?;
}
Ok(())
}

fn write_vectored(&mut self, b: &[io::IoSlice<'_>]) -> io::Result<usize> {
if self.buffered {
self.buffer
.extend(b.iter().map(|s| s.to_vec()));
return Ok(b.iter().map(|s| s.len()).sum());
}
self.flush_vectored(b)
}
}

#[test]
fn server_read_returns_wouldblock_when_no_data() {
let (_, mut server) = make_pair(KeyType::Rsa);
Expand Down Expand Up @@ -1456,6 +1483,19 @@ fn client_complete_io_for_handshake() {
assert!(!client.wants_write());
}

#[test]
fn buffered_client_complete_io_for_handshake() {
let (mut client, mut server) = make_pair(KeyType::Rsa);

assert!(client.is_handshaking());
let (rdlen, wrlen) = client
.complete_io(&mut OtherSession::new_buffered(&mut server))
.unwrap();
assert!(rdlen > 0 && wrlen > 0);
assert!(!client.is_handshaking());
assert!(!client.wants_write());
}

#[test]
fn client_complete_io_for_handshake_eof() {
let (mut client, _) = make_pair(KeyType::Rsa);
Expand Down Expand Up @@ -1497,6 +1537,35 @@ fn client_complete_io_for_write() {
}
}

#[test]
fn buffered_client_complete_io_for_write() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt);

do_handshake(&mut client, &mut server);

client
.writer()
.write_all(b"01234567890123456789")
.unwrap();
client
.writer()
.write_all(b"01234567890123456789")
.unwrap();
{
let mut pipe = OtherSession::new_buffered(&mut server);
let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
assert!(rdlen == 0 && wrlen > 0);
println!("{:?}", pipe.writevs);
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
}
check_read(
&mut server.reader(),
b"0123456789012345678901234567890123456789",
);
}
}

#[test]
fn client_complete_io_for_read() {
for kt in ALL_KEY_TYPES.iter() {
Expand Down

0 comments on commit ecc6cde

Please sign in to comment.