diff --git a/wtx-fuzz/web_socket.rs b/wtx-fuzz/web_socket.rs index 0173ede0..8a5c8632 100644 --- a/wtx-fuzz/web_socket.rs +++ b/wtx-fuzz/web_socket.rs @@ -13,6 +13,7 @@ libfuzzer_sys::fuzz_target!(|data: (OpCode, Vec)| { Builder::new_current_thread().enable_all().build().unwrap().block_on(async move { let Ok(mut ws) = WebSocketServerOwned::new( (), + false, Xorshift64::from(simple_seed()), BytesStream::default(), WebSocketBuffer::default(), diff --git a/wtx-instances/http2-examples/http2-server.rs b/wtx-instances/http2-examples/http2-server.rs index b12c758c..f6d2b2ac 100644 --- a/wtx-instances/http2-examples/http2-server.rs +++ b/wtx-instances/http2-examples/http2-server.rs @@ -74,5 +74,6 @@ async fn manual( } wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?; } + wos.close().await?; Ok(()) } diff --git a/wtx-instances/http2-examples/http2-web-socket.rs b/wtx-instances/http2-examples/http2-web-socket.rs index 3460f79b..4af185e0 100644 --- a/wtx-instances/http2-examples/http2-web-socket.rs +++ b/wtx-instances/http2-examples/http2-web-socket.rs @@ -56,6 +56,7 @@ async fn main() -> wtx::Result<()> { } wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?; } + wos.close().await?; stream.common().clear(false).await?; Ok(()) } diff --git a/wtx/src/http2/web_socket_over_stream.rs b/wtx/src/http2/web_socket_over_stream.rs index 2085dde2..06773a99 100644 --- a/wtx/src/http2/web_socket_over_stream.rs +++ b/wtx/src/http2/web_socket_over_stream.rs @@ -2,7 +2,7 @@ use crate::{ http::{Headers, KnownHeaderName, Method, Protocol, StatusCode}, - http2::{Http2Buffer, Http2Data, Http2RecvStatus, SendDataMode, ServerStream}, + http2::{Http2Buffer, Http2Data, Http2ErrorCode, Http2RecvStatus, SendDataMode, ServerStream}, misc::{ ConnectionState, LeaseMut, Lock, RefCounter, SingleTypeStorage, StreamWriter, Vector, Xorshift64, @@ -14,7 +14,8 @@ use crate::{ manage_text_of_first_continuation_frame, manage_text_of_recurrent_continuation_frames, unmask_nb, }, - Frame, FrameMut, ReadFrameInfo, + web_socket_writer::manage_normal_frame, + Frame, FrameMut, OpCode, ReadFrameInfo, }, }; @@ -63,6 +64,14 @@ where Ok(Self { connection_state: ConnectionState::Open, no_masking, rng, stream }) } + /// Closes the stream as well as the WebSocket connection. + #[inline] + pub async fn close(&mut self) -> crate::Result<()> { + self.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await?; + self.stream.lease_mut().common().send_reset(Http2ErrorCode::NoError).await; + Ok(()) + } + /// Reads a frame from the stream. /// /// If a frame is made up of other sub-frames or continuations, then everything is collected @@ -141,6 +150,12 @@ where where P: LeaseMut<[u8]>, { + manage_normal_frame::<_, _, false>( + &mut self.connection_state, + frame, + self.no_masking, + &mut self.rng, + ); let (header, payload) = frame.header_and_payload(); let hss = self .stream diff --git a/wtx/src/web_socket.rs b/wtx/src/web_socket.rs index 24d7d176..69c896e7 100644 --- a/wtx/src/web_socket.rs +++ b/wtx/src/web_socket.rs @@ -15,7 +15,7 @@ mod web_socket_buffer; mod web_socket_error; mod web_socket_parts; pub(crate) mod web_socket_reader; -mod web_socket_writer; +pub(crate) mod web_socket_writer; use crate::{ misc::{ConnectionState, LeaseMut, Stream, Xorshift64}, @@ -36,8 +36,15 @@ pub use web_socket_buffer::WebSocketBuffer; pub use web_socket_error::WebSocketError; pub use web_socket_parts::{WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart}; +const FIN_MASK: u8 = 0b1000_0000; +const MASK_MASK: u8 = 0b1000_0000; const MAX_CONTROL_PAYLOAD_LEN: usize = 125; const MAX_HEADER_LEN_USIZE: usize = 14; +const OP_CODE_MASK: u8 = 0b0000_1111; +const PAYLOAD_MASK: u8 = 0b0111_1111; +const RSV1_MASK: u8 = 0b0100_0000; +const RSV2_MASK: u8 = 0b0010_0000; +const RSV3_MASK: u8 = 0b0001_0000; /// Always masks the payload before sending. pub type WebSocketClient = WebSocket; diff --git a/wtx/src/web_socket/frame.rs b/wtx/src/web_socket/frame.rs index 192f0413..70c8020a 100644 --- a/wtx/src/web_socket/frame.rs +++ b/wtx/src/web_socket/frame.rs @@ -1,7 +1,8 @@ use crate::{ misc::{Lease, Vector}, web_socket::{ - misc::fill_header_from_params, OpCode, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE, + misc::{fill_header_from_params, has_masked_frame}, + OpCode, MASK_MASK, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE, }, }; use core::str; @@ -76,16 +77,26 @@ impl Frame { (header, &mut self.payload) } - #[inline] - pub(crate) fn header_mut(&mut self) -> &mut [u8] { - self.header_and_payload_mut().0 - } - #[inline] pub(crate) fn header_first_two_mut(&mut self) -> [&mut u8; 2] { let [a, b, ..] = &mut self.header; [a, b] } + + #[inline] + pub(crate) fn set_mask(&mut self, mask: [u8; 4]) { + if has_masked_frame(self.header[1]) { + return; + } + self.header_len = self.header_len.wrapping_add(4); + if let Some([_, a, .., b, c, d, e]) = self.header.get_mut(..self.header_len.into()) { + *a |= MASK_MASK; + *b = mask[0]; + *c = mask[1]; + *d = mask[2]; + *e = mask[3]; + } + } } impl Frame diff --git a/wtx/src/web_socket/handshake/tests.rs b/wtx/src/web_socket/handshake/tests.rs index 7a4ca2f8..184d8d0c 100644 --- a/wtx/src/web_socket/handshake/tests.rs +++ b/wtx/src/web_socket/handshake/tests.rs @@ -26,26 +26,35 @@ static HAS_SERVER_FINISHED: AtomicBool = AtomicBool::new(false); #[cfg(feature = "flate2")] #[tokio::test] -async fn client_and_server_compressed() { +async fn compressed() { use crate::web_socket::compression::Flate2; #[cfg(feature = "_tracing-tree")] let _rslt = crate::misc::tracing_tree_init(None); - do_test_client_and_server_frames((), Flate2::default()).await; + do_test_client_and_server_frames(((), false), (Flate2::default(), false)).await; tokio::time::sleep(Duration::from_millis(200)).await; - do_test_client_and_server_frames(Flate2::default(), ()).await; + do_test_client_and_server_frames((Flate2::default(), false), ((), false)).await; tokio::time::sleep(Duration::from_millis(200)).await; - do_test_client_and_server_frames(Flate2::default(), Flate2::default()).await; + do_test_client_and_server_frames((Flate2::default(), false), (Flate2::default(), false)).await; } #[tokio::test] -async fn client_and_server_uncompressed() { +async fn uncompressed() { #[cfg(feature = "_tracing-tree")] let _rslt = crate::misc::tracing_tree_init(None); - do_test_client_and_server_frames((), ()).await; + do_test_client_and_server_frames(((), false), ((), false)).await; } -async fn do_test_client_and_server_frames(client_compression: CC, server_compression: SC) -where +#[tokio::test] +async fn uncompressed_no_masking() { + #[cfg(feature = "_tracing-tree")] + let _rslt = crate::misc::tracing_tree_init(None); + do_test_client_and_server_frames(((), true), ((), true)).await; +} + +async fn do_test_client_and_server_frames( + (client_compression, client_no_masking): (CC, bool), + (server_compression, server_no_masking): (SC, bool), +) where CC: Compression + Send, CC::NegotiatedCompression: Send, SC: Compression + Send + 'static, @@ -59,7 +68,7 @@ where let (stream, _) = listener.accept().await.unwrap(); let mut ws = WebSocketServer::accept( server_compression, - false, + server_no_masking, Xorshift64::from(simple_seed()), stream, WebSocketBuffer::new(), @@ -84,7 +93,7 @@ where let mut ws = WebSocketClient::connect( client_compression, [], - false, + client_no_masking, Xorshift64::from(simple_seed()), TcpStream::connect(uri.hostname_with_implied_port()).await.unwrap(), &uri.to_ref(), @@ -151,23 +160,13 @@ where let hello = ws.read_frame().await.unwrap(); assert_eq!(OpCode::Text, hello.op_code()); assert_eq!(b"Hello!", hello.payload()); - ws.write_frame(&mut Frame::new_fin( - OpCode::Text, - &mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!'], - )) - .await - .unwrap(); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Goodbye!")).await.unwrap(); assert_eq!(OpCode::Close, ws.read_frame().await.unwrap().op_code()); } async fn server(ws: &mut WebSocketServerOwned) { - ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'H', b'e', b'l', b'l', b'o', b'!'])) - .await - .unwrap(); - assert_eq!( - ws.read_frame().await.unwrap().payload(), - &mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!'] - ); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Hello!")).await.unwrap(); + assert_eq!(ws.read_frame().await.unwrap().payload(), b"Goodbye!"); ws.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await.unwrap(); } } @@ -205,7 +204,7 @@ where { async fn client(ws: &mut WebSocketClientOwned) { ws.write_frame(&mut Frame::new_fin(OpCode::Ping, &mut [1, 2, 3])).await.unwrap(); - ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'i', b'p', b'a', b't'])).await.unwrap(); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"ipat")).await.unwrap(); assert_eq!(OpCode::Pong, ws.read_frame().await.unwrap().op_code()); } diff --git a/wtx/src/web_socket/misc.rs b/wtx/src/web_socket/misc.rs index ddbc5009..9babfd30 100644 --- a/wtx/src/web_socket/misc.rs +++ b/wtx/src/web_socket/misc.rs @@ -1,4 +1,4 @@ -use crate::web_socket::{CloseCode, OpCode, MAX_HEADER_LEN_USIZE}; +use crate::web_socket::{CloseCode, OpCode, MASK_MASK, MAX_HEADER_LEN_USIZE, OP_CODE_MASK}; use core::ops::Range; /// The first two bytes of `payload` are filled with `code`. Does nothing if `payload` is @@ -26,43 +26,26 @@ pub(crate) fn fill_header_from_params( u8::from(fin) << 7 | rsv1 | u8::from(op_code) } - #[inline] - fn manage_mask( - second_byte: &mut u8, - [a, b, c, d]: [&mut u8; 4], - ) -> u8 { - if IS_CLIENT { - *second_byte &= 0b0111_1111; - *a = 0; - *b = 0; - *c = 0; - *d = 0; - N.wrapping_add(4) - } else { - N - } - } - match payload_len { 0..=125 => { - let [a, b, c, d, e, f, ..] = header; + let [a, b, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = u8::try_from(payload_len).unwrap_or_default(); - manage_mask::(b, [c, d, e, f]) + 2 } 126..=0xFFFF => { let [len_c, len_d] = u16::try_from(payload_len).map(u16::to_be_bytes).unwrap_or_default(); - let [a, b, c, d, e, f, g, h, ..] = header; + let [a, b, c, d, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = 126; *c = len_c; *d = len_d; - manage_mask::(b, [e, f, g, h]) + 4 } _ => { let len = u64::try_from(payload_len).map(u64::to_be_bytes).unwrap_or_default(); let [len_c, len_d, len_e, len_f, len_g, len_h, len_i, len_j] = len; - let [a, b, c, d, e, f, g, h, i, j, k, l, m, n] = header; + let [a, b, c, d, e, f, g, h, i, j, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = 127; *c = len_c; @@ -73,14 +56,19 @@ pub(crate) fn fill_header_from_params( *h = len_h; *i = len_i; *j = len_j; - manage_mask::(b, [k, l, m, n]) + 10 } } } +#[inline] +pub(crate) const fn has_masked_frame(second_header_byte: u8) -> bool { + second_header_byte & MASK_MASK != 0 +} + #[inline] pub(crate) fn op_code(first_header_byte: u8) -> crate::Result { - OpCode::try_from(first_header_byte & 0b0000_1111) + OpCode::try_from(first_header_byte & OP_CODE_MASK) } #[inline] diff --git a/wtx/src/web_socket/read_frame_info.rs b/wtx/src/web_socket/read_frame_info.rs index 333834ca..c5283adf 100644 --- a/wtx/src/web_socket/read_frame_info.rs +++ b/wtx/src/web_socket/read_frame_info.rs @@ -1,8 +1,10 @@ use crate::{ misc::{PartitionedFilledBuffer, Stream, _read_until}, web_socket::{ - compression::NegotiatedCompression, misc::op_code, OpCode, WebSocketError, - MAX_CONTROL_PAYLOAD_LEN, + compression::NegotiatedCompression, + misc::{has_masked_frame, op_code}, + OpCode, WebSocketError, FIN_MASK, MAX_CONTROL_PAYLOAD_LEN, PAYLOAD_MASK, RSV1_MASK, RSV2_MASK, + RSV3_MASK, }, }; @@ -135,9 +137,9 @@ impl ReadFrameInfo { where NC: NegotiatedCompression, { - let rsv1 = a & 0b0100_0000; - let rsv2 = a & 0b0010_0000; - let rsv3 = a & 0b0001_0000; + let rsv1 = a & RSV1_MASK; + let rsv2 = a & RSV2_MASK; + let rsv3 = a & RSV3_MASK; if rsv2 != 0 || rsv3 != 0 { return Err(WebSocketError::InvalidCompressionHeaderParameter.into()); } @@ -151,9 +153,9 @@ impl ReadFrameInfo { } else { rsv1 != 0 }; - let fin = a & 0b1000_0000 != 0; - let length_code = b & 0b0111_1111; - let masked = b & 0b1000_0000 != 0; + let fin = a & FIN_MASK != 0; + let length_code = b & PAYLOAD_MASK; + let masked = has_masked_frame(b); let op_code = op_code(a)?; Ok((fin, length_code, masked, op_code, should_decompress)) } diff --git a/wtx/src/web_socket/web_socket_writer.rs b/wtx/src/web_socket/web_socket_writer.rs index b12a1e10..65eb7c44 100644 --- a/wtx/src/web_socket/web_socket_writer.rs +++ b/wtx/src/web_socket/web_socket_writer.rs @@ -1,6 +1,9 @@ use crate::{ misc::{BufferMode, ConnectionState, Lease, LeaseMut, Rng, Stream, Vector, Xorshift64}, - web_socket::{compression::NegotiatedCompression, unmask::unmask, Frame, FrameMut, OpCode}, + web_socket::{ + compression::NegotiatedCompression, misc::has_masked_frame, unmask::unmask, Frame, FrameMut, + OpCode, + }, }; #[inline] @@ -122,11 +125,6 @@ where )) } -#[inline] -const fn has_masked_frame(second_header_byte: u8) -> bool { - second_header_byte & 0b1000_0000 != 0 -} - #[inline] fn mask_frame( frame: &mut Frame, @@ -136,17 +134,9 @@ fn mask_frame( P: LeaseMut<[u8]>, RNG: Rng, { - if IS_CLIENT && !no_masking { - if let [_, second_byte, .., a, b, c, d] = frame.header_mut() { - if !has_masked_frame(*second_byte) { - *second_byte |= 0b1000_0000; - let mask = rng.u8_4(); - *a = mask[0]; - *b = mask[1]; - *c = mask[2]; - *d = mask[3]; - unmask(frame.payload_mut().lease_mut(), mask); - } - } + if IS_CLIENT && !no_masking && !has_masked_frame(*frame.header_first_two_mut()[1]) { + let mask: [u8; 4] = rng.u8_4(); + frame.set_mask(mask); + unmask(frame.payload_mut().lease_mut(), mask); } }