Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
c410-f3r committed Oct 29, 2024
1 parent cae1705 commit f2d781f
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 84 deletions.
1 change: 1 addition & 0 deletions wtx-fuzz/web_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ libfuzzer_sys::fuzz_target!(|data: (OpCode, Vec<u8>)| {
Builder::new_current_thread().enable_all().build().unwrap().block_on(async move {
let Ok(mut ws) = WebSocketServerOwned::new(
(),
false,
Xorshift64::from(simple_seed()),
BytesStream::default(),
WebSocketBuffer::default(),
Expand Down
1 change: 1 addition & 0 deletions wtx-instances/http2-examples/http2-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ async fn manual(
}
wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?;
}
wos.close().await?;
Ok(())
}
1 change: 1 addition & 0 deletions wtx-instances/http2-examples/http2-web-socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async fn main() -> wtx::Result<()> {
}
wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?;
}
wos.close().await?;
stream.common().clear(false).await?;
Ok(())
}
19 changes: 17 additions & 2 deletions wtx/src/http2/web_socket_over_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
http::{Headers, KnownHeaderName, Method, Protocol, StatusCode},
http2::{Http2Buffer, Http2Data, Http2RecvStatus, SendDataMode, ServerStream},
http2::{Http2Buffer, Http2Data, Http2ErrorCode, Http2RecvStatus, SendDataMode, ServerStream},
misc::{
ConnectionState, LeaseMut, Lock, RefCounter, SingleTypeStorage, StreamWriter, Vector,
Xorshift64,
Expand All @@ -14,7 +14,8 @@ use crate::{
manage_text_of_first_continuation_frame, manage_text_of_recurrent_continuation_frames,
unmask_nb,
},
Frame, FrameMut, ReadFrameInfo,
web_socket_writer::manage_normal_frame,
Frame, FrameMut, OpCode, ReadFrameInfo,
},
};

Expand Down Expand Up @@ -63,6 +64,14 @@ where
Ok(Self { connection_state: ConnectionState::Open, no_masking, rng, stream })
}

/// Closes the stream as well as the WebSocket connection.
#[inline]
pub async fn close(&mut self) -> crate::Result<()> {
self.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await?;
self.stream.lease_mut().common().send_reset(Http2ErrorCode::NoError).await;
Ok(())
}

/// Reads a frame from the stream.
///
/// If a frame is made up of other sub-frames or continuations, then everything is collected
Expand Down Expand Up @@ -141,6 +150,12 @@ where
where
P: LeaseMut<[u8]>,
{
manage_normal_frame::<_, _, false>(
&mut self.connection_state,
frame,
self.no_masking,
&mut self.rng,
);
let (header, payload) = frame.header_and_payload();
let hss = self
.stream
Expand Down
9 changes: 8 additions & 1 deletion wtx/src/web_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod web_socket_buffer;
mod web_socket_error;
mod web_socket_parts;
pub(crate) mod web_socket_reader;
mod web_socket_writer;
pub(crate) mod web_socket_writer;

use crate::{
misc::{ConnectionState, LeaseMut, Stream, Xorshift64},
Expand All @@ -36,8 +36,15 @@ pub use web_socket_buffer::WebSocketBuffer;
pub use web_socket_error::WebSocketError;
pub use web_socket_parts::{WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart};

const FIN_MASK: u8 = 0b1000_0000;
const MASK_MASK: u8 = 0b1000_0000;
const MAX_CONTROL_PAYLOAD_LEN: usize = 125;
const MAX_HEADER_LEN_USIZE: usize = 14;
const OP_CODE_MASK: u8 = 0b0000_1111;
const PAYLOAD_MASK: u8 = 0b0111_1111;
const RSV1_MASK: u8 = 0b0100_0000;
const RSV2_MASK: u8 = 0b0010_0000;
const RSV3_MASK: u8 = 0b0001_0000;

/// Always masks the payload before sending.
pub type WebSocketClient<NC, S, WSB> = WebSocket<NC, S, WSB, true>;
Expand Down
23 changes: 17 additions & 6 deletions wtx/src/web_socket/frame.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
misc::{Lease, Vector},
web_socket::{
misc::fill_header_from_params, OpCode, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE,
misc::{fill_header_from_params, has_masked_frame},
OpCode, MASK_MASK, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE,
},
};
use core::str;
Expand Down Expand Up @@ -76,16 +77,26 @@ impl<P, const IS_CLIENT: bool> Frame<P, IS_CLIENT> {
(header, &mut self.payload)
}

#[inline]
pub(crate) fn header_mut(&mut self) -> &mut [u8] {
self.header_and_payload_mut().0
}

#[inline]
pub(crate) fn header_first_two_mut(&mut self) -> [&mut u8; 2] {
let [a, b, ..] = &mut self.header;
[a, b]
}

#[inline]
pub(crate) fn set_mask(&mut self, mask: [u8; 4]) {
if has_masked_frame(self.header[1]) {
return;
}
self.header_len = self.header_len.wrapping_add(4);
if let Some([_, a, .., b, c, d, e]) = self.header.get_mut(..self.header_len.into()) {
*a |= MASK_MASK;
*b = mask[0];
*c = mask[1];
*d = mask[2];
*e = mask[3];
}
}
}

impl<P, const IS_CLIENT: bool> Frame<P, IS_CLIENT>
Expand Down
47 changes: 23 additions & 24 deletions wtx/src/web_socket/handshake/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,35 @@ static HAS_SERVER_FINISHED: AtomicBool = AtomicBool::new(false);

#[cfg(feature = "flate2")]
#[tokio::test]
async fn client_and_server_compressed() {
async fn compressed() {
use crate::web_socket::compression::Flate2;
#[cfg(feature = "_tracing-tree")]
let _rslt = crate::misc::tracing_tree_init(None);
do_test_client_and_server_frames((), Flate2::default()).await;
do_test_client_and_server_frames(((), false), (Flate2::default(), false)).await;
tokio::time::sleep(Duration::from_millis(200)).await;
do_test_client_and_server_frames(Flate2::default(), ()).await;
do_test_client_and_server_frames((Flate2::default(), false), ((), false)).await;
tokio::time::sleep(Duration::from_millis(200)).await;
do_test_client_and_server_frames(Flate2::default(), Flate2::default()).await;
do_test_client_and_server_frames((Flate2::default(), false), (Flate2::default(), false)).await;
}

#[tokio::test]
async fn client_and_server_uncompressed() {
async fn uncompressed() {
#[cfg(feature = "_tracing-tree")]
let _rslt = crate::misc::tracing_tree_init(None);
do_test_client_and_server_frames((), ()).await;
do_test_client_and_server_frames(((), false), ((), false)).await;
}

async fn do_test_client_and_server_frames<CC, SC>(client_compression: CC, server_compression: SC)
where
#[tokio::test]
async fn uncompressed_no_masking() {
#[cfg(feature = "_tracing-tree")]
let _rslt = crate::misc::tracing_tree_init(None);
do_test_client_and_server_frames(((), true), ((), true)).await;
}

async fn do_test_client_and_server_frames<CC, SC>(
(client_compression, client_no_masking): (CC, bool),
(server_compression, server_no_masking): (SC, bool),
) where
CC: Compression<true> + Send,
CC::NegotiatedCompression: Send,
SC: Compression<false> + Send + 'static,
Expand All @@ -59,7 +68,7 @@ where
let (stream, _) = listener.accept().await.unwrap();
let mut ws = WebSocketServer::accept(
server_compression,
false,
server_no_masking,
Xorshift64::from(simple_seed()),
stream,
WebSocketBuffer::new(),
Expand All @@ -84,7 +93,7 @@ where
let mut ws = WebSocketClient::connect(
client_compression,
[],
false,
client_no_masking,
Xorshift64::from(simple_seed()),
TcpStream::connect(uri.hostname_with_implied_port()).await.unwrap(),
&uri.to_ref(),
Expand Down Expand Up @@ -151,23 +160,13 @@ where
let hello = ws.read_frame().await.unwrap();
assert_eq!(OpCode::Text, hello.op_code());
assert_eq!(b"Hello!", hello.payload());
ws.write_frame(&mut Frame::new_fin(
OpCode::Text,
&mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!'],
))
.await
.unwrap();
ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Goodbye!")).await.unwrap();
assert_eq!(OpCode::Close, ws.read_frame().await.unwrap().op_code());
}

async fn server(ws: &mut WebSocketServerOwned<NC, TcpStream>) {
ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'H', b'e', b'l', b'l', b'o', b'!']))
.await
.unwrap();
assert_eq!(
ws.read_frame().await.unwrap().payload(),
&mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!']
);
ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Hello!")).await.unwrap();
assert_eq!(ws.read_frame().await.unwrap().payload(), b"Goodbye!");
ws.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await.unwrap();
}
}
Expand Down Expand Up @@ -205,7 +204,7 @@ where
{
async fn client(ws: &mut WebSocketClientOwned<NC, TcpStream>) {
ws.write_frame(&mut Frame::new_fin(OpCode::Ping, &mut [1, 2, 3])).await.unwrap();
ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'i', b'p', b'a', b't'])).await.unwrap();
ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"ipat")).await.unwrap();
assert_eq!(OpCode::Pong, ws.read_frame().await.unwrap().op_code());
}

Expand Down
38 changes: 13 additions & 25 deletions wtx/src/web_socket/misc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::web_socket::{CloseCode, OpCode, MAX_HEADER_LEN_USIZE};
use crate::web_socket::{CloseCode, OpCode, MASK_MASK, MAX_HEADER_LEN_USIZE, OP_CODE_MASK};
use core::ops::Range;

/// The first two bytes of `payload` are filled with `code`. Does nothing if `payload` is
Expand Down Expand Up @@ -26,43 +26,26 @@ pub(crate) fn fill_header_from_params<const IS_CLIENT: bool>(
u8::from(fin) << 7 | rsv1 | u8::from(op_code)
}

#[inline]
fn manage_mask<const IS_CLIENT: bool, const N: u8>(
second_byte: &mut u8,
[a, b, c, d]: [&mut u8; 4],
) -> u8 {
if IS_CLIENT {
*second_byte &= 0b0111_1111;
*a = 0;
*b = 0;
*c = 0;
*d = 0;
N.wrapping_add(4)
} else {
N
}
}

match payload_len {
0..=125 => {
let [a, b, c, d, e, f, ..] = header;
let [a, b, ..] = header;
*a = first_header_byte(fin, op_code, rsv1);
*b = u8::try_from(payload_len).unwrap_or_default();
manage_mask::<IS_CLIENT, 2>(b, [c, d, e, f])
2
}
126..=0xFFFF => {
let [len_c, len_d] = u16::try_from(payload_len).map(u16::to_be_bytes).unwrap_or_default();
let [a, b, c, d, e, f, g, h, ..] = header;
let [a, b, c, d, ..] = header;
*a = first_header_byte(fin, op_code, rsv1);
*b = 126;
*c = len_c;
*d = len_d;
manage_mask::<IS_CLIENT, 4>(b, [e, f, g, h])
4
}
_ => {
let len = u64::try_from(payload_len).map(u64::to_be_bytes).unwrap_or_default();
let [len_c, len_d, len_e, len_f, len_g, len_h, len_i, len_j] = len;
let [a, b, c, d, e, f, g, h, i, j, k, l, m, n] = header;
let [a, b, c, d, e, f, g, h, i, j, ..] = header;
*a = first_header_byte(fin, op_code, rsv1);
*b = 127;
*c = len_c;
Expand All @@ -73,14 +56,19 @@ pub(crate) fn fill_header_from_params<const IS_CLIENT: bool>(
*h = len_h;
*i = len_i;
*j = len_j;
manage_mask::<IS_CLIENT, 10>(b, [k, l, m, n])
10
}
}
}

#[inline]
pub(crate) const fn has_masked_frame(second_header_byte: u8) -> bool {
second_header_byte & MASK_MASK != 0
}

#[inline]
pub(crate) fn op_code(first_header_byte: u8) -> crate::Result<OpCode> {
OpCode::try_from(first_header_byte & 0b0000_1111)
OpCode::try_from(first_header_byte & OP_CODE_MASK)
}

#[inline]
Expand Down
18 changes: 10 additions & 8 deletions wtx/src/web_socket/read_frame_info.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::{
misc::{PartitionedFilledBuffer, Stream, _read_until},
web_socket::{
compression::NegotiatedCompression, misc::op_code, OpCode, WebSocketError,
MAX_CONTROL_PAYLOAD_LEN,
compression::NegotiatedCompression,
misc::{has_masked_frame, op_code},
OpCode, WebSocketError, FIN_MASK, MAX_CONTROL_PAYLOAD_LEN, PAYLOAD_MASK, RSV1_MASK, RSV2_MASK,
RSV3_MASK,
},
};

Expand Down Expand Up @@ -135,9 +137,9 @@ impl ReadFrameInfo {
where
NC: NegotiatedCompression,
{
let rsv1 = a & 0b0100_0000;
let rsv2 = a & 0b0010_0000;
let rsv3 = a & 0b0001_0000;
let rsv1 = a & RSV1_MASK;
let rsv2 = a & RSV2_MASK;
let rsv3 = a & RSV3_MASK;
if rsv2 != 0 || rsv3 != 0 {
return Err(WebSocketError::InvalidCompressionHeaderParameter.into());
}
Expand All @@ -151,9 +153,9 @@ impl ReadFrameInfo {
} else {
rsv1 != 0
};
let fin = a & 0b1000_0000 != 0;
let length_code = b & 0b0111_1111;
let masked = b & 0b1000_0000 != 0;
let fin = a & FIN_MASK != 0;
let length_code = b & PAYLOAD_MASK;
let masked = has_masked_frame(b);
let op_code = op_code(a)?;
Ok((fin, length_code, masked, op_code, should_decompress))
}
Expand Down
26 changes: 8 additions & 18 deletions wtx/src/web_socket/web_socket_writer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
misc::{BufferMode, ConnectionState, Lease, LeaseMut, Rng, Stream, Vector, Xorshift64},
web_socket::{compression::NegotiatedCompression, unmask::unmask, Frame, FrameMut, OpCode},
web_socket::{
compression::NegotiatedCompression, misc::has_masked_frame, unmask::unmask, Frame, FrameMut,
OpCode,
},
};

#[inline]
Expand Down Expand Up @@ -122,11 +125,6 @@ where
))
}

#[inline]
const fn has_masked_frame(second_header_byte: u8) -> bool {
second_header_byte & 0b1000_0000 != 0
}

#[inline]
fn mask_frame<P, RNG, const IS_CLIENT: bool>(
frame: &mut Frame<P, IS_CLIENT>,
Expand All @@ -136,17 +134,9 @@ fn mask_frame<P, RNG, const IS_CLIENT: bool>(
P: LeaseMut<[u8]>,
RNG: Rng,
{
if IS_CLIENT && !no_masking {
if let [_, second_byte, .., a, b, c, d] = frame.header_mut() {
if !has_masked_frame(*second_byte) {
*second_byte |= 0b1000_0000;
let mask = rng.u8_4();
*a = mask[0];
*b = mask[1];
*c = mask[2];
*d = mask[3];
unmask(frame.payload_mut().lease_mut(), mask);
}
}
if IS_CLIENT && !no_masking && !has_masked_frame(*frame.header_first_two_mut()[1]) {
let mask: [u8; 4] = rng.u8_4();
frame.set_mask(mask);
unmask(frame.payload_mut().lease_mut(), mask);
}
}

0 comments on commit f2d781f

Please sign in to comment.