diff --git a/wtx-docs/src/web-socket/README.md b/wtx-docs/src/web-socket/README.md index 524a569a..0301139b 100644 --- a/wtx-docs/src/web-socket/README.md +++ b/wtx-docs/src/web-socket/README.md @@ -2,7 +2,7 @@ Implementation of [RFC6455](https://datatracker.ietf.org/doc/html/rfc6455) and [RFC7692](https://datatracker.ietf.org/doc/html/rfc7692). WebSocket is a communication protocol that enables full-duplex communication between a client (typically a web browser) and a server over a single TCP connection. Unlike traditional HTTP, which is request-response based, WebSocket allows real-time data exchange without the need for polling. -In-house benchmarks are available at . If you are aware of other benchmark tools, please open an discussion in the GitHub project. +In-house benchmarks are available at . If you are aware of other benchmark tools, please open a discussion in the GitHub project. To use this functionality, it is necessary to activate the `web-socket` feature. @@ -28,11 +28,11 @@ To make everything work as intended both parties, client and server, need to imp ## Client Example ```rust,edition2021,no_run -{{#rustdoc_include ../../../wtx-instances/generic-examples/web-socket-client.rs}} +{{#rustdoc_include ../../../wtx-instances/web-socket-examples/web-socket-client.rs}} ``` ## Server Example ```rust,edition2021,no_run -{{#rustdoc_include ../../../wtx-instances/generic-examples/web-socket-server.rs}} +{{#rustdoc_include ../../../wtx-instances/web-socket-examples/web-socket-server.rs}} ``` \ No newline at end of file diff --git a/wtx-instances/Cargo.toml b/wtx-instances/Cargo.toml index 868c4d90..9c8e37ea 100644 --- a/wtx-instances/Cargo.toml +++ b/wtx-instances/Cargo.toml @@ -57,7 +57,7 @@ required-features = ["grpc"] [[example]] name = "grpc-server" path = "generic-examples/grpc-server.rs" -required-features = ["grpc", "wtx/tokio-rustls", "wtx/webpki-roots"] +required-features = ["grpc", "wtx/tokio-rustls"] [[example]] name = "http-client-framework" @@ -69,16 +69,6 @@ name = "pool" path = "generic-examples/pool.rs" required-features = ["wtx/pool"] -[[example]] -name = "web-socket-client" -path = "generic-examples/web-socket-client.rs" -required-features = ["wtx/web-socket-handshake", "wtx/webpki-roots"] - -[[example]] -name = "web-socket-server" -path = "generic-examples/web-socket-server.rs" -required-features = ["tokio-rustls", "wtx/pool", "wtx/tokio-rustls", "wtx/web-socket-handshake"] - # HTTP Server Framework Examples [[example]] @@ -118,6 +108,23 @@ name = "http2-web-socket" path = "http2-examples/http2-web-socket.rs" required-features = ["wtx/http2", "wtx/tokio-rustls", "wtx/web-socket"] +# WebSocket Examples + +[[example]] +name = "web-socket-client" +path = "web-socket-examples/web-socket-client.rs" +required-features = ["wtx/web-socket-handshake"] + +[[example]] +name = "web-socket-concurrent-client" +path = "web-socket-examples/web-socket-concurrent-client.rs" +required-features = ["wtx/tokio-rustls", "wtx/web-socket-handshake", "wtx/webpki-roots"] + +[[example]] +name = "web-socket-server" +path = "web-socket-examples/web-socket-server.rs" +required-features = ["tokio-rustls", "wtx/pool", "wtx/tokio-rustls", "wtx/web-socket-handshake"] + [build-dependencies] pb-rs = { default-features = false, optional = true, version = "0.10" } diff --git a/wtx-instances/generic-examples/client-api-framework.rs b/wtx-instances/generic-examples/client-api-framework.rs index d92c10a0..c23fb9c1 100644 --- a/wtx-instances/generic-examples/client-api-framework.rs +++ b/wtx-instances/generic-examples/client-api-framework.rs @@ -80,52 +80,52 @@ mod generic_web_socket_subscription { pub type GenericWebSocketSubscriptionRes = u64; } -#[tokio::main] -async fn main() -> wtx::Result<()> { - async fn http_pair( - ) -> Pair, ClientFrameworkTokio> { - Pair::new( - PkgsAux::from_minimum( - GenericThrottlingApi { - rt: RequestThrottling::from_rl(RequestLimit::new(5, Duration::from_secs(1))), - }, - SerdeJson, - HttpParams::from_uri("ws://generic_web_socket_uri.com".into()), - ), - ClientFrameworkTokio::tokio(1).build(), - ) - } +async fn http_pair( +) -> Pair, ClientFrameworkTokio> { + Pair::new( + PkgsAux::from_minimum( + GenericThrottlingApi { + rt: RequestThrottling::from_rl(RequestLimit::new(5, Duration::from_secs(1))), + }, + SerdeJson, + HttpParams::from_uri("ws://generic_web_socket_uri.com".into()), + ), + ClientFrameworkTokio::tokio(1).build(), + ) +} - async fn web_socket_pair() -> wtx::Result< - Pair< - PkgsAux, - WebSocketClient<(), TcpStream, WebSocketBuffer>, - >, - > { - let uri = Uri::new("ws://generic_web_socket_uri.com"); - let web_socket = WebSocketClient::connect( - (), - [], - false, - Xorshift64::from(simple_seed()), - TcpStream::connect(uri.hostname_with_implied_port()).await?, - &uri, - WebSocketBuffer::default(), - |_| wtx::Result::Ok(()), - ) - .await?; - Ok(Pair::new( - PkgsAux::from_minimum( - GenericThrottlingApi { - rt: RequestThrottling::from_rl(RequestLimit::new(40, Duration::from_secs(2))), - }, - SerdeJson, - WsParams::default(), - ), - web_socket, - )) - } +async fn web_socket_pair() -> wtx::Result< + Pair< + PkgsAux, + WebSocketClient<(), TcpStream, WebSocketBuffer>, + >, +> { + let uri = Uri::new("ws://generic_web_socket_uri.com"); + let web_socket = WebSocketClient::connect( + (), + [], + false, + Xorshift64::from(simple_seed()), + TcpStream::connect(uri.hostname_with_implied_port()).await?, + &uri, + WebSocketBuffer::default(), + |_| wtx::Result::Ok(()), + ) + .await?; + Ok(Pair::new( + PkgsAux::from_minimum( + GenericThrottlingApi { + rt: RequestThrottling::from_rl(RequestLimit::new(40, Duration::from_secs(2))), + }, + SerdeJson, + WsParams::default(), + ), + web_socket, + )) +} +#[tokio::main] +async fn main() -> wtx::Result<()> { let mut hp = http_pair().await; let _http_response_tuple = hp .trans diff --git a/wtx-instances/src/bin/autobahn-client.rs b/wtx-instances/src/bin/autobahn-client.rs index 97f5ea55..c5587ec9 100644 --- a/wtx-instances/src/bin/autobahn-client.rs +++ b/wtx-instances/src/bin/autobahn-client.rs @@ -22,7 +22,7 @@ async fn main() -> wtx::Result<()> { |_| wtx::Result::Ok(()), ) .await?; - let (mut common, mut reader, mut writer) = ws.parts(); + let (mut common, mut reader, mut writer) = ws.parts_mut(); loop { let mut frame = match reader.read_frame(&mut common).await { Err(_err) => { diff --git a/wtx-instances/src/bin/autobahn-server.rs b/wtx-instances/src/bin/autobahn-server.rs index b18677ba..e628bd2d 100644 --- a/wtx-instances/src/bin/autobahn-server.rs +++ b/wtx-instances/src/bin/autobahn-server.rs @@ -27,7 +27,7 @@ async fn main() -> wtx::Result<()> { async fn handle( mut ws: WebSocketServer, TcpStream, &mut WebSocketBuffer>, ) -> wtx::Result<()> { - let (mut common, mut reader, mut writer) = ws.parts(); + let (mut common, mut reader, mut writer) = ws.parts_mut(); loop { let mut frame = reader.read_frame(&mut common).await?; match frame.op_code() { diff --git a/wtx-instances/generic-examples/web-socket-client.rs b/wtx-instances/web-socket-examples/web-socket-client.rs similarity index 100% rename from wtx-instances/generic-examples/web-socket-client.rs rename to wtx-instances/web-socket-examples/web-socket-client.rs diff --git a/wtx-instances/web-socket-examples/web-socket-concurrent-client.rs b/wtx-instances/web-socket-examples/web-socket-concurrent-client.rs new file mode 100644 index 00000000..d23940ad --- /dev/null +++ b/wtx-instances/web-socket-examples/web-socket-concurrent-client.rs @@ -0,0 +1,50 @@ +//! WebSocket client that reads and writes frames in different tasks. + +extern crate tokio; +extern crate wtx; +extern crate wtx_instances; + +use tokio::{net::TcpStream, sync::Mutex}; +use wtx::{ + misc::{simple_seed, Arc, TokioRustlsConnector, Uri, Xorshift64}, + web_socket::{Frame, OpCode, WebSocketBuffer, WebSocketClient}, +}; + +#[tokio::main] +async fn main() -> wtx::Result<()> { + let uri = Uri::new("ws://www.example.com"); + let connector = TokioRustlsConnector::from_auto()?.push_certs(wtx_instances::ROOT_CA)?; + let stream = TcpStream::connect(uri.hostname_with_implied_port()).await?; + let ws = WebSocketClient::connect( + (), + [], + false, + Xorshift64::from(simple_seed()), + connector.connect_without_client_auth(uri.hostname(), stream).await?, + &uri.to_ref(), + WebSocketBuffer::default(), + |_| wtx::Result::Ok(()), + ) + .await?; + let (mut reader, mut writer) = ws.into_parts::>, _, _>(|el| tokio::io::split(el)); + let reader_jh = tokio::spawn(async move { + loop { + let frame = reader.read_frame().await?; + match (frame.op_code(), frame.text_payload()) { + (_, Some(elem)) => println!("{elem}"), + (OpCode::Close, _) => break, + _ => {} + } + } + wtx::Result::Ok(()) + }); + let writer_jh = tokio::spawn(async move { + writer.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Hi and Bye")).await?; + writer.write_frame(&mut Frame::new_fin(OpCode::Close, [])).await?; + wtx::Result::Ok(()) + }); + let (reader_rslt, writer_rslt) = tokio::join!(reader_jh, writer_jh); + reader_rslt??; + writer_rslt??; + Ok(()) +} diff --git a/wtx-instances/generic-examples/web-socket-server.rs b/wtx-instances/web-socket-examples/web-socket-server.rs similarity index 95% rename from wtx-instances/generic-examples/web-socket-server.rs rename to wtx-instances/web-socket-examples/web-socket-server.rs index 378a7f8c..096c0ec3 100644 --- a/wtx-instances/generic-examples/web-socket-server.rs +++ b/wtx-instances/web-socket-examples/web-socket-server.rs @@ -36,7 +36,7 @@ async fn main() -> wtx::Result<()> { async fn handle( mut ws: WebSocketServer<(), TlsStream, &mut WebSocketBuffer>, ) -> wtx::Result<()> { - let (mut common, mut reader, mut writer) = ws.parts(); + let (mut common, mut reader, mut writer) = ws.parts_mut(); loop { let mut frame = reader.read_frame(&mut common).await?; match frame.op_code() { diff --git a/wtx/src/http2/web_socket_over_stream.rs b/wtx/src/http2/web_socket_over_stream.rs index dc3e4f18..9b832d37 100644 --- a/wtx/src/http2/web_socket_over_stream.rs +++ b/wtx/src/http2/web_socket_over_stream.rs @@ -180,7 +180,7 @@ where Http2RecvStatus::Ongoing(data) => (data, false), }; let mut slice = data.as_slice(); - let rfi = ReadFrameInfo::from_bytes::<_, false>(&mut slice, usize::MAX, &(), no_masking)?; + let rfi = ReadFrameInfo::from_bytes::(&mut slice, usize::MAX, (true, 0), no_masking)?; let before = buffer.len(); buffer.extend_from_copyable_slice(slice)?; unmask_nb::(buffer.get_mut(before..).unwrap_or_default(), no_masking, &rfi)?; diff --git a/wtx/src/misc.rs b/wtx/src/misc.rs index aa5f24ae..a6341fbb 100644 --- a/wtx/src/misc.rs +++ b/wtx/src/misc.rs @@ -293,14 +293,14 @@ where } #[inline] -pub(crate) async fn _read_payload( +pub(crate) async fn _read_payload( (header_len, payload_len): (usize, usize), network_buffer: &mut PartitionedFilledBuffer, read: &mut usize, - stream: &mut S, + stream: &mut SR, ) -> crate::Result<()> where - S: StreamReader, + SR: StreamReader, { let frame_len = header_len.wrapping_add(payload_len); network_buffer._reserve(frame_len)?; diff --git a/wtx/src/misc/connection_state.rs b/wtx/src/misc/connection_state.rs index 5ee0e570..986b10d9 100644 --- a/wtx/src/misc/connection_state.rs +++ b/wtx/src/misc/connection_state.rs @@ -1,3 +1,5 @@ +use crate::misc::{Lease, LeaseMut}; + /// The state of a connection between two parties. #[derive(Clone, Copy, Debug)] pub enum ConnectionState { @@ -21,6 +23,20 @@ impl ConnectionState { } } +impl Lease for ConnectionState { + #[inline] + fn lease(&self) -> &ConnectionState { + self + } +} + +impl LeaseMut for ConnectionState { + #[inline] + fn lease_mut(&mut self) -> &mut ConnectionState { + self + } +} + impl From for ConnectionState { #[inline] fn from(from: bool) -> Self { diff --git a/wtx/src/misc/partitioned_filled_buffer.rs b/wtx/src/misc/partitioned_filled_buffer.rs index 72a44a8c..a466b1f0 100644 --- a/wtx/src/misc/partitioned_filled_buffer.rs +++ b/wtx/src/misc/partitioned_filled_buffer.rs @@ -1,4 +1,4 @@ -use crate::misc::{FilledBuffer, FilledBufferWriter, VectorError}; +use crate::misc::{FilledBuffer, FilledBufferWriter, Lease, LeaseMut, VectorError}; use core::ops::Range; // ``` @@ -174,6 +174,20 @@ impl PartitionedFilledBuffer { } } +impl Lease for PartitionedFilledBuffer { + #[inline] + fn lease(&self) -> &PartitionedFilledBuffer { + self + } +} + +impl LeaseMut for PartitionedFilledBuffer { + #[inline] + fn lease_mut(&mut self) -> &mut PartitionedFilledBuffer { + self + } +} + impl Default for PartitionedFilledBuffer { #[inline] fn default() -> Self { diff --git a/wtx/src/misc/rng/xorshift.rs b/wtx/src/misc/rng/xorshift.rs index 137ffc91..8cd5c285 100644 --- a/wtx/src/misc/rng/xorshift.rs +++ b/wtx/src/misc/rng/xorshift.rs @@ -1,4 +1,4 @@ -use crate::misc::{AtomicU64, Rng}; +use crate::misc::{AtomicU64, Lease, LeaseMut, Rng}; use core::sync::atomic::Ordering; /// Xorshift that deals with 64 bits numbers. @@ -29,6 +29,20 @@ impl Rng for Xorshift64 { } } +impl Lease for Xorshift64 { + #[inline] + fn lease(&self) -> &Xorshift64 { + self + } +} + +impl LeaseMut for Xorshift64 { + #[inline] + fn lease_mut(&mut self) -> &mut Xorshift64 { + self + } +} + impl From for Xorshift64 { #[inline] fn from(value: u64) -> Self { diff --git a/wtx/src/misc/stream/tokio.rs b/wtx/src/misc/stream/tokio.rs index 5aa476c3..cbd5138b 100644 --- a/wtx/src/misc/stream/tokio.rs +++ b/wtx/src/misc/stream/tokio.rs @@ -31,14 +31,6 @@ impl StreamReader for TcpStream { } } -#[cfg(unix)] -impl StreamReader for tokio::net::UnixStream { - #[inline] - async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { - Ok(::read(self, bytes).await?) - } -} - impl StreamWriter for OwnedWriteHalf { #[inline] async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { @@ -83,18 +75,3 @@ impl StreamWriter for TcpStream { Ok(()) } } - -#[cfg(unix)] -impl StreamWriter for tokio::net::UnixStream { - #[inline] - async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { - ::write_all(self, bytes).await?; - Ok(()) - } - - #[inline] - async fn write_all_vectored(&mut self, bytes: &[&[u8]]) -> crate::Result<()> { - _local_write_all_vectored!(bytes, self, |io_slices| self.write_vectored(io_slices).await); - Ok(()) - } -} diff --git a/wtx/src/web_socket.rs b/wtx/src/web_socket.rs index 004e2efd..24b98d70 100644 --- a/wtx/src/web_socket.rs +++ b/wtx/src/web_socket.rs @@ -1,6 +1,9 @@ //! A computer communications protocol, providing full-duplex communication channels over a single //! TCP connection. +#[macro_use] +mod macros; + mod close_code; pub mod compression; mod frame; @@ -18,8 +21,14 @@ pub(crate) mod web_socket_reader; pub(crate) mod web_socket_writer; use crate::{ - misc::{ConnectionState, LeaseMut, Stream, Xorshift64}, - web_socket::payload_ty::PayloadTy, + misc::{ConnectionState, LeaseMut, Lock, Stream, Xorshift64}, + web_socket::{ + compression::NegotiatedCompression, + payload_ty::PayloadTy, + web_socket_parts::web_socket_part::{ + WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart, + }, + }, _MAX_PAYLOAD_LEN, }; pub use close_code::CloseCode; @@ -34,7 +43,12 @@ pub use op_code::OpCode; pub use read_frame_info::ReadFrameInfo; pub use web_socket_buffer::WebSocketBuffer; pub use web_socket_error::WebSocketError; -pub use web_socket_parts::{WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart}; +pub use web_socket_parts::{ + web_socket_part_mut::{WebSocketCommonPartMut, WebSocketReaderPartMut, WebSocketWriterPartMut}, + web_socket_part_owned::{ + WebSocketCommonPartOwned, WebSocketReaderPartOwned, WebSocketWriterPartOwned, + }, +}; const FIN_MASK: u8 = 0b1000_0000; const MASK_MASK: u8 = 0b1000_0000; @@ -46,18 +60,18 @@ const RSV1_MASK: u8 = 0b0100_0000; const RSV2_MASK: u8 = 0b0010_0000; const RSV3_MASK: u8 = 0b0001_0000; -/// Always masks the payload before sending. +/// [`WebSocket`] instance for clients. pub type WebSocketClient = WebSocket; /// [`WebSocketClient`] with a mutable reference of [`WebSocketBuffer`]. -pub type WebSocketClientMut<'wsb, NC, S> = WebSocketClient; +pub type WebSocketClientMut<'wsb, NC, S> = WebSocket; /// [`WebSocketClient`] with an owned [`WebSocketBuffer`]. -pub type WebSocketClientOwned = WebSocketClient; -/// Always unmasks the payload after receiving. +pub type WebSocketClientOwned = WebSocket; +/// [`WebSocket`] instance for servers pub type WebSocketServer = WebSocket; /// [`WebSocketServer`] with a mutable reference of [`WebSocketBuffer`]. -pub type WebSocketServerMut<'wsb, NC, S> = WebSocketServer; +pub type WebSocketServerMut<'wsb, NC, S> = WebSocket; /// [`WebSocketServer`] with an owned [`WebSocketBuffer`]. -pub type WebSocketServerOwned = WebSocketServer; +pub type WebSocketServerOwned = WebSocket; /// Full-duplex communication over an asynchronous stream. /// @@ -85,7 +99,7 @@ impl WebSocket { impl WebSocket where - NC: compression::NegotiatedCompression, + NC: NegotiatedCompression, S: Stream, WSB: LeaseMut, { @@ -122,14 +136,14 @@ where } } - /// Different mutable parts that allow sending received frames using the same original instance. + /// Different mutable parts that allow sending received frames using common elements. #[inline] - pub fn parts( + pub fn parts_mut( &mut self, ) -> ( - WebSocketCommonPart<'_, NC, S, IS_CLIENT>, - WebSocketReaderPart<'_, NC, S, IS_CLIENT>, - WebSocketWriterPart<'_, NC, S, IS_CLIENT>, + WebSocketCommonPartMut<'_, NC, S, IS_CLIENT>, + WebSocketReaderPartMut<'_, NC, S, IS_CLIENT>, + WebSocketWriterPartMut<'_, NC, S, IS_CLIENT>, ) { let WebSocket { connection_state, @@ -147,17 +161,25 @@ where reader_buffer_first, reader_buffer_second, } = wsb.lease_mut(); + let nc_rsv1 = nc.rsv1(); ( - WebSocketCommonPart { connection_state, curr_payload, nc, rng, stream }, - WebSocketReaderPart { - max_payload_len: *max_payload_len, - network_buffer, - no_masking: *no_masking, + WebSocketCommonPartMut { wsc: WebSocketCommonPart { connection_state, nc, rng, stream } }, + WebSocketReaderPartMut { phantom: PhantomData, - reader_buffer_first, - reader_buffer_second, + wsrp: WebSocketReaderPart { + curr_payload, + max_payload_len: *max_payload_len, + nc_rsv1, + network_buffer, + no_masking: *no_masking, + reader_buffer_first, + reader_buffer_second, + }, + }, + WebSocketWriterPartMut { + phantom: PhantomData, + wswp: WebSocketWriterPart { no_masking: *no_masking, writer_buffer }, }, - WebSocketWriterPart { no_masking: *no_masking, phantom: PhantomData, writer_buffer }, ) } @@ -183,18 +205,25 @@ where reader_buffer_second, writer_buffer: _, } = wsb.lease_mut(); - let (frame, payload_ty) = web_socket_reader::read_frame_from_stream( - connection_state, + let nc_rsv1 = nc.rsv1(); + let (frame, payload_ty) = read_frame_from_stream!( *max_payload_len, - nc, + (NC::IS_NOOP, nc_rsv1), network_buffer, *no_masking, - reader_buffer_first, + &mut *reader_buffer_first, reader_buffer_second, - rng, stream, - ) - .await?; + ( + stream, + WebSocketCommonPart::<_, _, _, _, IS_CLIENT> { + connection_state: &mut *connection_state, + nc: &mut *nc, + rng: &mut *rng, + stream: &mut *stream + } + ) + ); *curr_payload = payload_ty; Ok(frame) } @@ -220,3 +249,64 @@ where Ok(()) } } + +impl WebSocket +where + NC: NegotiatedCompression, +{ + /// Splits the instance into owned parts that can be used in concurrent scenarios. + #[inline] + pub fn into_parts( + self, + split: impl FnOnce(S) -> (SR, SW), + ) -> ( + WebSocketReaderPartOwned, + WebSocketWriterPartOwned, + ) + where + C: Clone + Lock>, + { + let WebSocket { + connection_state, + curr_payload, + nc, + no_masking, + rng, + stream, + wsb, + max_payload_len, + } = self; + let WebSocketBuffer { + writer_buffer, + network_buffer, + reader_buffer_first, + reader_buffer_second, + } = wsb; + let (stream_reader, stream_writer) = split(stream); + let nc_rsv1 = nc.rsv1(); + let common = C::new(WebSocketCommonPartOwned { + wsc: WebSocketCommonPart { connection_state, nc, rng, stream: stream_writer }, + }); + ( + WebSocketReaderPartOwned { + common: common.clone(), + phantom: PhantomData, + stream_reader, + wsrp: WebSocketReaderPart { + curr_payload, + max_payload_len, + nc_rsv1, + network_buffer, + no_masking, + reader_buffer_first, + reader_buffer_second, + }, + }, + WebSocketWriterPartOwned { + common, + phantom: PhantomData, + wswp: WebSocketWriterPart { no_masking, writer_buffer }, + }, + ) + } +} diff --git a/wtx/src/web_socket/macros.rs b/wtx/src/web_socket/macros.rs new file mode 100644 index 00000000..e54b5004 --- /dev/null +++ b/wtx/src/web_socket/macros.rs @@ -0,0 +1,190 @@ +macro_rules! read_continuation_frames { + ( + $first_rfi:expr, + $max_payload_len:expr, + ($nc_is_noop:expr, $nc_rsv1:expr), + $network_buffer:expr, + $no_masking:expr, + $reader_buffer_first:expr, + $reader_buffer_second:expr, + $stream:expr, + ($stream_reader_expr:expr, $stream_writer_expr:expr), + ($first_text_cb:expr, $recurrent_text_cb:expr), + $reader_buffer_first_cb:expr + ) => { + 'rcf_block: { + use crate::web_socket::web_socket_reader; + web_socket_reader::copy_from_arbitrary_nb_to_rb1::( + $network_buffer, + $no_masking, + $reader_buffer_first, + $first_rfi, + )?; + let mut iuc = web_socket_reader::manage_op_code_of_first_continuation_frame( + $first_rfi.op_code, + $reader_buffer_first, + $first_text_cb, + )?; + loop { + let mut rfi = web_socket_reader::fetch_frame_from_stream::<_, IS_CLIENT>( + $max_payload_len, + ($nc_is_noop, $nc_rsv1), + $network_buffer, + $no_masking, + $stream_reader_expr, + ) + .await?; + let begin = $reader_buffer_first.len(); + rfi.should_decompress = $first_rfi.should_decompress; + web_socket_reader::copy_from_arbitrary_nb_to_rb1::( + $network_buffer, + $no_masking, + $reader_buffer_first, + &rfi, + )?; + let payload = $reader_buffer_first.get_mut(begin..).unwrap_or_default(); + let WebSocketCommonPart { connection_state, nc, rng, stream } = $stream_writer_expr; + if !web_socket_reader::manage_auto_reply::<_, _, IS_CLIENT>( + stream, + connection_state.lease_mut(), + $no_masking, + rfi.op_code, + payload, + rng.lease_mut(), + &mut web_socket_reader::write_control_frame_cb, + ) + .await? + { + $reader_buffer_first.truncate(begin); + continue; + } + if web_socket_reader::manage_op_code_of_continuation_frames( + rfi.fin, + $first_rfi.op_code, + &mut iuc, + rfi.op_code, + payload, + $recurrent_text_cb, + )? { + let cb: fn(_, _, _, _) -> crate::Result<()> = $reader_buffer_first_cb; + cb($first_rfi, nc, $reader_buffer_first, $reader_buffer_second)?; + break 'rcf_block; + } + } + } + }; +} + +macro_rules! read_frame_from_stream { + ( + $max_payload_len:expr, + ($nc_is_noop:expr, $nc_rsv1:expr), + $network_buffer:expr, + $no_masking:expr, + $reader_buffer_first:expr, + $reader_buffer_second:expr, + $stream:ident, + ($stream_reader_expr:expr, $stream_writer_expr:expr) + ) => { + 'rffs_block: { + use crate::web_socket::web_socket_reader; + let first_rfi = loop { + $reader_buffer_first.clear(); + let rfi = web_socket_reader::fetch_frame_from_stream::<_, IS_CLIENT>( + $max_payload_len, + ($nc_is_noop, $nc_rsv1), + $network_buffer, + $no_masking, + $stream_reader_expr, + ) + .await?; + if !rfi.fin { + break rfi; + } + let WebSocketCommonPart { connection_state, nc, rng, stream } = $stream_writer_expr; + let (payload, payload_ty) = if rfi.should_decompress { + web_socket_reader::copy_from_compressed_nb_to_rb1::( + nc, + $network_buffer, + $no_masking, + $reader_buffer_first, + &rfi, + )?; + ($reader_buffer_first.as_slice_mut(), PayloadTy::FirstReader) + } else { + let current_mut = $network_buffer._current_mut(); + web_socket_reader::unmask_nb::(current_mut, $no_masking, &rfi)?; + (current_mut, PayloadTy::Network) + }; + if web_socket_reader::manage_auto_reply::<_, _, IS_CLIENT>( + stream, + connection_state.lease_mut(), + $no_masking, + rfi.op_code, + payload, + rng.lease_mut(), + &mut web_socket_reader::write_control_frame_cb, + ) + .await? + { + web_socket_reader::manage_op_code_of_first_final_frame(rfi.op_code, payload)?; + // FIXME(STABLE): Use `payload` with polonius + let borrow_checker = if rfi.should_decompress { + $reader_buffer_first.as_slice_mut() + } else { + $network_buffer._current_mut() + }; + break 'rffs_block (Frame::new(true, rfi.op_code, borrow_checker, $nc_rsv1), payload_ty); + } + }; + $reader_buffer_second.clear(); + if first_rfi.should_decompress { + read_continuation_frames!( + &first_rfi, + $max_payload_len, + ($nc_is_noop, $nc_rsv1), + $network_buffer, + $no_masking, + $reader_buffer_first, + $reader_buffer_second, + $stream, + ($stream_reader_expr, $stream_writer_expr), + (|_| Ok(None), |_, _| Ok(())), + |local_first_rfi, local_nc, local_rbf, local_rbs| { + web_socket_reader::copy_from_compressed_rb1_to_rb2( + local_first_rfi, + local_nc, + local_rbf, + local_rbs, + ) + } + ); + ( + Frame::new(true, first_rfi.op_code, $reader_buffer_second, $nc_rsv1), + PayloadTy::SecondReader, + ) + } else { + read_continuation_frames!( + &first_rfi, + $max_payload_len, + ($nc_is_noop, $nc_rsv1), + $network_buffer, + $no_masking, + $reader_buffer_first, + $reader_buffer_second, + $stream, + ($stream_reader_expr, $stream_writer_expr), + ( + web_socket_reader::manage_text_of_first_continuation_frame, + web_socket_reader::manage_text_of_recurrent_continuation_frames + ), + |_, _, _, _| Ok(()) + ); + ( + Frame::new(true, first_rfi.op_code, $reader_buffer_first, $nc_rsv1), + PayloadTy::FirstReader, + ) + } + } + }; +} diff --git a/wtx/src/web_socket/payload_ty.rs b/wtx/src/web_socket/payload_ty.rs index 732476a9..53014bc4 100644 --- a/wtx/src/web_socket/payload_ty.rs +++ b/wtx/src/web_socket/payload_ty.rs @@ -1,3 +1,5 @@ +use crate::misc::{Lease, LeaseMut}; + #[derive(Debug)] pub(crate) enum PayloadTy { FirstReader, @@ -5,3 +7,17 @@ pub(crate) enum PayloadTy { None, SecondReader, } + +impl Lease for PayloadTy { + #[inline] + fn lease(&self) -> &PayloadTy { + self + } +} + +impl LeaseMut for PayloadTy { + #[inline] + fn lease_mut(&mut self) -> &mut PayloadTy { + self + } +} diff --git a/wtx/src/web_socket/read_frame_info.rs b/wtx/src/web_socket/read_frame_info.rs index ea72c98b..f5b6927e 100644 --- a/wtx/src/web_socket/read_frame_info.rs +++ b/wtx/src/web_socket/read_frame_info.rs @@ -1,7 +1,6 @@ use crate::{ - misc::{PartitionedFilledBuffer, Stream, _read_header}, + misc::{PartitionedFilledBuffer, StreamReader, _read_header}, web_socket::{ - compression::NegotiatedCompression, misc::{has_masked_frame, op_code}, OpCode, WebSocketError, FIN_MASK, MAX_CONTROL_PAYLOAD_LEN, PAYLOAD_MASK, RSV1_MASK, RSV2_MASK, RSV3_MASK, @@ -22,15 +21,12 @@ pub struct ReadFrameInfo { impl ReadFrameInfo { /// Creates a new instance based on a sequence of bytes. #[inline] - pub fn from_bytes( + pub fn from_bytes( bytes: &mut &[u8], max_payload_len: usize, - nc: &NC, + (nc_is_noop, nc_rsv1): (bool, u8), no_masking: bool, - ) -> crate::Result - where - NC: NegotiatedCompression, - { + ) -> crate::Result { let first_two = { let [a, b, rest @ ..] = bytes else { return Err(crate::Error::UnexpectedBufferState); @@ -38,7 +34,7 @@ impl ReadFrameInfo { *bytes = rest; [*a, *b] }; - let tuple = Self::manage_first_two_bytes(first_two, nc)?; + let tuple = Self::manage_first_two_bytes(first_two, (nc_is_noop, nc_rsv1))?; let (fin, length_code, masked, op_code, should_decompress) = tuple; let (mut header_len, payload_len) = match length_code { 126 => { @@ -72,37 +68,36 @@ impl ReadFrameInfo { } #[inline] - pub(crate) async fn from_stream( + pub(crate) async fn from_stream( max_payload_len: usize, - nc: &NC, + (nc_is_noop, nc_rsv1): (bool, u8), network_buffer: &mut PartitionedFilledBuffer, no_masking: bool, read: &mut usize, - stream: &mut S, + stream: &mut SR, ) -> crate::Result where - NC: NegotiatedCompression, - S: Stream, + SR: StreamReader, { let buffer = network_buffer._following_rest_mut(); - let first_two = _read_header::<0, 2, S>(buffer, read, stream).await?; - let tuple = Self::manage_first_two_bytes(first_two, nc)?; + let first_two = _read_header::<0, 2, SR>(buffer, read, stream).await?; + let tuple = Self::manage_first_two_bytes(first_two, (nc_is_noop, nc_rsv1))?; let (fin, length_code, masked, op_code, should_decompress) = tuple; let mut mask = None; let (header_len, payload_len) = match length_code { 126 => { - let payload_len = _read_header::<2, 2, S>(buffer, read, stream).await?; + let payload_len = _read_header::<2, 2, SR>(buffer, read, stream).await?; if Self::manage_mask::(masked, no_masking)? { - mask = Some(_read_header::<4, 4, S>(buffer, read, stream).await?); + mask = Some(_read_header::<4, 4, SR>(buffer, read, stream).await?); (8, u16::from_be_bytes(payload_len).into()) } else { (4, u16::from_be_bytes(payload_len).into()) } } 127 => { - let payload_len = _read_header::<2, 8, S>(buffer, read, stream).await?; + let payload_len = _read_header::<2, 8, SR>(buffer, read, stream).await?; if Self::manage_mask::(masked, no_masking)? { - mask = Some(_read_header::<10, 4, S>(buffer, read, stream).await?); + mask = Some(_read_header::<10, 4, SR>(buffer, read, stream).await?); (14, u64::from_be_bytes(payload_len).try_into()?) } else { (10, u64::from_be_bytes(payload_len).try_into()?) @@ -110,7 +105,7 @@ impl ReadFrameInfo { } _ => { if Self::manage_mask::(masked, no_masking)? { - mask = Some(_read_header::<2, 4, S>(buffer, read, stream).await?); + mask = Some(_read_header::<2, 4, SR>(buffer, read, stream).await?); (6, length_code.into()) } else { (2, length_code.into()) @@ -141,22 +136,19 @@ impl ReadFrameInfo { } #[inline] - fn manage_first_two_bytes( + fn manage_first_two_bytes( [a, b]: [u8; 2], - nc: &NC, - ) -> crate::Result<(bool, u8, bool, OpCode, bool)> - where - NC: NegotiatedCompression, - { + (nc_is_noop, nc_rsv1): (bool, u8), + ) -> crate::Result<(bool, u8, bool, OpCode, bool)> { let rsv1 = a & RSV1_MASK; let rsv2 = a & RSV2_MASK; let rsv3 = a & RSV3_MASK; if rsv2 != 0 || rsv3 != 0 { return Err(WebSocketError::InvalidCompressionHeaderParameter.into()); } - let should_decompress = if NC::IS_NOOP { + let should_decompress = if nc_is_noop { false - } else if nc.rsv1() == 0 { + } else if nc_rsv1 == 0 { if rsv1 != 0 { return Err(WebSocketError::InvalidCompressionHeaderParameter.into()); } diff --git a/wtx/src/web_socket/web_socket_parts.rs b/wtx/src/web_socket/web_socket_parts.rs index 11ff5cec..d37b1f6c 100644 --- a/wtx/src/web_socket/web_socket_parts.rs +++ b/wtx/src/web_socket/web_socket_parts.rs @@ -1,110 +1,6 @@ -use crate::{ - misc::{ConnectionState, LeaseMut, PartitionedFilledBuffer, Stream, Vector, Xorshift64}, - web_socket::{ - compression::NegotiatedCompression, payload_ty::PayloadTy, web_socket_reader, - web_socket_writer, Frame, FrameMut, - }, -}; -use core::marker::PhantomData; +// Common elements shared between pure WebSocket structures. Tunneling protocols should use +// the functions provided in `web_socket_reader` and `web_socket_writer`. -/// Auxiliary structure used by [`WebSocketReaderStub`] and [`WebSocketWriterStub`] -#[derive(Debug)] -pub struct WebSocketCommonPart<'instance, NC, S, const IS_CLIENT: bool> { - pub(crate) connection_state: &'instance mut ConnectionState, - pub(crate) curr_payload: &'instance mut PayloadTy, - pub(crate) nc: &'instance mut NC, - pub(crate) rng: &'instance mut Xorshift64, - pub(crate) stream: &'instance mut S, -} - -/// Auxiliary structure that can be used when it is necessary to write a received frame that belongs -/// to the same instance. -#[derive(Debug)] -pub struct WebSocketReaderPart<'instance, NC, S, const IS_CLIENT: bool> { - pub(crate) max_payload_len: usize, - pub(crate) network_buffer: &'instance mut PartitionedFilledBuffer, - pub(crate) no_masking: bool, - pub(crate) phantom: PhantomData<(NC, S)>, - pub(crate) reader_buffer_first: &'instance mut Vector, - pub(crate) reader_buffer_second: &'instance mut Vector, -} - -impl<'instance, NC, S, const IS_CLIENT: bool> WebSocketReaderPart<'instance, NC, S, IS_CLIENT> -where - NC: NegotiatedCompression, - S: Stream, -{ - /// Reads a frame from the stream. - /// - /// If a frame is made up of other sub-frames or continuations, then everything is collected - /// until all fragments are received. - #[inline] - pub async fn read_frame( - &mut self, - common: &mut WebSocketCommonPart<'instance, NC, S, IS_CLIENT>, - ) -> crate::Result> { - let WebSocketCommonPart { connection_state, curr_payload, nc, rng, stream } = common; - let Self { - max_payload_len, - network_buffer, - no_masking, - phantom: _, - reader_buffer_first, - reader_buffer_second, - } = self; - let (frame, payload_ty) = web_socket_reader::read_frame_from_stream( - connection_state, - *max_payload_len, - nc, - network_buffer, - *no_masking, - reader_buffer_first, - reader_buffer_second, - rng, - stream, - ) - .await?; - **curr_payload = payload_ty; - Ok(frame) - } -} - -/// Auxiliary structure that can be used when it is necessary to write a received frame that belongs -/// to the same instance. -#[derive(Debug)] -pub struct WebSocketWriterPart<'instance, NC, S, const IS_CLIENT: bool> { - pub(crate) no_masking: bool, - pub(crate) phantom: PhantomData<(NC, S)>, - pub(crate) writer_buffer: &'instance mut Vector, -} - -impl<'instance, NC, S, const IS_CLIENT: bool> WebSocketWriterPart<'instance, NC, S, IS_CLIENT> -where - NC: NegotiatedCompression, - S: Stream, -{ - /// Writes a frame to the stream. - #[inline] - pub async fn write_frame

( - &mut self, - common: &mut WebSocketCommonPart<'instance, NC, S, IS_CLIENT>, - frame: &mut Frame, - ) -> crate::Result<()> - where - P: LeaseMut<[u8]>, - { - let WebSocketCommonPart { connection_state, curr_payload: _, nc, rng, stream } = common; - let Self { no_masking, phantom: _, writer_buffer } = self; - web_socket_writer::write_frame( - connection_state, - frame, - *no_masking, - nc, - rng, - stream, - writer_buffer, - ) - .await?; - Ok(()) - } -} +pub(crate) mod web_socket_part; +pub(crate) mod web_socket_part_mut; +pub(crate) mod web_socket_part_owned; diff --git a/wtx/src/web_socket/web_socket_parts/web_socket_part.rs b/wtx/src/web_socket/web_socket_parts/web_socket_part.rs new file mode 100644 index 00000000..ca82779a --- /dev/null +++ b/wtx/src/web_socket/web_socket_parts/web_socket_part.rs @@ -0,0 +1,157 @@ +use crate::{ + misc::{ + ConnectionState, LeaseMut, Lock, PartitionedFilledBuffer, Stream, StreamReader, StreamWriter, + Vector, Xorshift64, + }, + web_socket::{ + compression::NegotiatedCompression, payload_ty::PayloadTy, + web_socket_parts::web_socket_part_owned::WebSocketCommonPartOwned, web_socket_writer, Frame, + FrameMut, + }, +}; + +#[derive(Debug)] +pub(crate) struct WebSocketCommonPart { + pub(crate) connection_state: CS, + pub(crate) nc: NC, + pub(crate) rng: RNG, + pub(crate) stream: S, +} + +#[derive(Debug)] +pub(crate) struct WebSocketReaderPart { + pub(crate) curr_payload: PT, + pub(crate) max_payload_len: usize, + pub(crate) nc_rsv1: u8, + pub(crate) network_buffer: PFB, + pub(crate) no_masking: bool, + pub(crate) reader_buffer_first: V, + pub(crate) reader_buffer_second: V, +} + +impl WebSocketReaderPart +where + PFB: LeaseMut, + PT: LeaseMut, + V: LeaseMut>, +{ + #[inline] + pub(crate) async fn read_frame_from_stream( + &mut self, + common: &mut WebSocketCommonPart, + ) -> crate::Result> + where + CS: LeaseMut, + NC: NegotiatedCompression, + RNG: LeaseMut, + S: Stream, + { + let WebSocketCommonPart { connection_state, nc, rng, stream } = common; + let Self { + curr_payload, + max_payload_len, + nc_rsv1, + network_buffer, + no_masking, + reader_buffer_first, + reader_buffer_second, + } = self; + let (frame, payload_ty) = read_frame_from_stream!( + *max_payload_len, + (NC::IS_NOOP, *nc_rsv1), + network_buffer.lease_mut(), + *no_masking, + reader_buffer_first.lease_mut(), + reader_buffer_second.lease_mut(), + stream, + ( + stream, + WebSocketCommonPart::<_, _, _, _, IS_CLIENT> { + connection_state: &mut *connection_state, + nc: &mut *nc, + rng: &mut *rng, + stream: &mut *stream + } + ) + ); + *curr_payload.lease_mut() = payload_ty; + Ok(frame) + } + + #[inline] + pub(crate) async fn read_frame_from_parts( + &mut self, + common: &mut C, + stream_reader: &mut SR, + ) -> crate::Result> + where + C: Lock>, + NC: NegotiatedCompression, + SR: StreamReader, + SW: StreamWriter, + { + let Self { + curr_payload, + max_payload_len, + network_buffer, + nc_rsv1, + no_masking, + reader_buffer_first, + reader_buffer_second, + } = self; + let parts = &mut (stream_reader, common); + let (frame, payload_ty) = read_frame_from_stream!( + *max_payload_len, + (NC::IS_NOOP, *nc_rsv1), + network_buffer.lease_mut(), + *no_masking, + reader_buffer_first.lease_mut(), + reader_buffer_second.lease_mut(), + parts, + (&mut parts.0, &mut parts.1.lock().await.wsc) + ); + *curr_payload.lease_mut() = payload_ty; + Ok(frame) + } +} + +/// Auxiliary structure that can be used when it is necessary to write a received frame that belongs +/// to the same instance. +#[derive(Debug)] +pub(crate) struct WebSocketWriterPart { + pub(crate) no_masking: bool, + pub(crate) writer_buffer: V, +} + +impl WebSocketWriterPart +where + V: LeaseMut>, +{ + #[inline] + pub(crate) async fn write_frame( + &mut self, + common: &mut WebSocketCommonPart, + frame: &mut Frame, + ) -> crate::Result<()> + where + CS: LeaseMut, + NC: NegotiatedCompression, + P: LeaseMut<[u8]>, + RNG: LeaseMut, + SW: StreamWriter, + { + let WebSocketCommonPart { connection_state, nc, rng, stream } = common; + let Self { no_masking, writer_buffer } = self; + web_socket_writer::write_frame( + connection_state.lease_mut(), + frame, + *no_masking, + nc, + rng.lease_mut(), + stream, + writer_buffer.lease_mut(), + ) + .await?; + Ok(()) + } +} diff --git a/wtx/src/web_socket/web_socket_parts/web_socket_part_mut.rs b/wtx/src/web_socket/web_socket_parts/web_socket_part_mut.rs new file mode 100644 index 00000000..0166cf7a --- /dev/null +++ b/wtx/src/web_socket/web_socket_parts/web_socket_part_mut.rs @@ -0,0 +1,82 @@ +use crate::{ + misc::{ConnectionState, LeaseMut, PartitionedFilledBuffer, Stream, Vector, Xorshift64}, + web_socket::{ + compression::NegotiatedCompression, + payload_ty::PayloadTy, + web_socket_parts::web_socket_part::{ + WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart, + }, + Frame, FrameMut, + }, +}; +use core::marker::PhantomData; + +/// Auxiliary structure used by [`WebSocketReaderPartMut`] and [`WebSocketWriterPartMut`] +#[derive(Debug)] +pub struct WebSocketCommonPartMut<'instance, NC, S, const IS_CLIENT: bool> { + pub(crate) wsc: WebSocketCommonPart< + &'instance mut ConnectionState, + &'instance mut NC, + &'instance mut Xorshift64, + &'instance mut S, + IS_CLIENT, + >, +} + +/// Auxiliary structure that can be used when it is necessary to write a received frame that belongs +/// to the same instance. +#[derive(Debug)] +pub struct WebSocketReaderPartMut<'instance, NC, S, const IS_CLIENT: bool> { + pub(crate) phantom: PhantomData<(NC, S)>, + pub(crate) wsrp: WebSocketReaderPart< + &'instance mut PartitionedFilledBuffer, + &'instance mut PayloadTy, + &'instance mut Vector, + IS_CLIENT, + >, +} + +impl<'instance, NC, S, const IS_CLIENT: bool> WebSocketReaderPartMut<'instance, NC, S, IS_CLIENT> +where + NC: NegotiatedCompression, + S: Stream, +{ + /// Reads a frame from the stream. + /// + /// If a frame is made up of other sub-frames or continuations, then everything is collected + /// until all fragments are received. + #[inline] + pub async fn read_frame( + &mut self, + common: &mut WebSocketCommonPartMut<'instance, NC, S, IS_CLIENT>, + ) -> crate::Result> { + self.wsrp.read_frame_from_stream(&mut common.wsc).await + } +} + +/// Auxiliary structure that can be used when it is necessary to write a received frame that belongs +/// to the same instance. +#[derive(Debug)] +pub struct WebSocketWriterPartMut<'instance, NC, S, const IS_CLIENT: bool> { + pub(crate) phantom: PhantomData<(NC, S)>, + pub(crate) wswp: WebSocketWriterPart<&'instance mut Vector, IS_CLIENT>, +} + +impl<'instance, NC, S, const IS_CLIENT: bool> WebSocketWriterPartMut<'instance, NC, S, IS_CLIENT> +where + NC: NegotiatedCompression, + S: Stream, +{ + /// Writes a frame to the stream. + #[inline] + pub async fn write_frame

( + &mut self, + common: &mut WebSocketCommonPartMut<'instance, NC, S, IS_CLIENT>, + frame: &mut Frame, + ) -> crate::Result<()> + where + P: LeaseMut<[u8]>, + { + self.wswp.write_frame(&mut common.wsc, frame).await + } +} diff --git a/wtx/src/web_socket/web_socket_parts/web_socket_part_owned.rs b/wtx/src/web_socket/web_socket_parts/web_socket_part_owned.rs new file mode 100644 index 00000000..11095d72 --- /dev/null +++ b/wtx/src/web_socket/web_socket_parts/web_socket_part_owned.rs @@ -0,0 +1,71 @@ +use crate::{ + misc::{ + ConnectionState, LeaseMut, Lock, PartitionedFilledBuffer, StreamReader, StreamWriter, Vector, + Xorshift64, + }, + web_socket::{ + compression::NegotiatedCompression, + payload_ty::PayloadTy, + web_socket_parts::web_socket_part::{ + WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart, + }, + Frame, FrameMut, + }, +}; +use core::marker::PhantomData; + +/// Auxiliary structure used by [`WebSocketReaderPartOwned`] and [`WebSocketWriterPartOwned`] +#[derive(Debug)] +pub struct WebSocketCommonPartOwned { + pub(crate) wsc: WebSocketCommonPart, +} + +/// Reader that can be used in concurrent scenarios. +#[derive(Debug)] +pub struct WebSocketReaderPartOwned { + pub(crate) common: C, + pub(crate) phantom: PhantomData<(NC, SR)>, + pub(crate) stream_reader: SR, + pub(crate) wsrp: WebSocketReaderPart, IS_CLIENT>, +} + +impl WebSocketReaderPartOwned +where + C: Lock>, + NC: NegotiatedCompression, + SR: StreamReader, + SW: StreamWriter, +{ + /// Reads a frame from the stream. + /// + /// If a frame is made up of other sub-frames or continuations, then everything is collected + /// until all fragments are received. + #[inline] + pub async fn read_frame(&mut self) -> crate::Result> { + self.wsrp.read_frame_from_parts(&mut self.common, &mut self.stream_reader).await + } +} + +/// Writer that can be used in concurrent scenarios. +#[derive(Debug)] +pub struct WebSocketWriterPartOwned { + pub(crate) common: C, + pub(crate) phantom: PhantomData<(NC, SW)>, + pub(crate) wswp: WebSocketWriterPart, IS_CLIENT>, +} + +impl WebSocketWriterPartOwned +where + C: Lock>, + NC: NegotiatedCompression, + SW: StreamWriter, +{ + /// Writes a frame to the stream. + #[inline] + pub async fn write_frame

(&mut self, frame: &mut Frame) -> crate::Result<()> + where + P: LeaseMut<[u8]>, + { + self.wswp.write_frame(&mut self.common.lock().await.wsc, frame).await + } +} diff --git a/wtx/src/web_socket/web_socket_reader.rs b/wtx/src/web_socket/web_socket_reader.rs index d09c2f27..b8bc707f 100644 --- a/wtx/src/web_socket/web_socket_reader.rs +++ b/wtx/src/web_socket/web_socket_reader.rs @@ -1,3 +1,5 @@ +// Common functions that used be used by pure WebSocket structures or tunneling protocols. +// // | Frame | With Decompression | Without Decompression | // |------------|--------------------------|-----------------------| // |Single |(NB -> RB1)¹ |(NB)¹ | @@ -6,23 +8,131 @@ use crate::{ misc::{ from_utf8_basic, from_utf8_ext, BufferMode, CompletionErr, ConnectionState, ExtUtf8Error, - FnMutFut, IncompleteUtf8Char, LeaseMut, PartitionedFilledBuffer, Rng, Stream, Vector, - _read_payload, + FnMutFut, IncompleteUtf8Char, LeaseMut, PartitionedFilledBuffer, Rng, StreamReader, + StreamWriter, Vector, _read_payload, }, web_socket::{ - compression::NegotiatedCompression, fill_with_close_code, payload_ty::PayloadTy, - read_frame_info::ReadFrameInfo, unmask::unmask, web_socket_writer::manage_normal_frame, - CloseCode, Frame, FrameMut, OpCode, WebSocketError, MAX_CONTROL_PAYLOAD_LEN, - MAX_HEADER_LEN_USIZE, + compression::NegotiatedCompression, fill_with_close_code, read_frame_info::ReadFrameInfo, + unmask::unmask, web_socket_writer::manage_normal_frame, CloseCode, Frame, OpCode, + WebSocketError, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE, }, }; const DECOMPRESSION_SUFFIX: [u8; 4] = [0, 0, 255, 255]; -type ReadContinuationFramesCb = ( - fn(&[u8]) -> crate::Result>, - fn(&[u8], &mut Option) -> crate::Result<()>, -); +#[inline] +pub(crate) fn copy_from_arbitrary_nb_to_rb1( + network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, + reader_buffer_first: &mut Vector, + rfi: &ReadFrameInfo, +) -> crate::Result<()> { + let current_mut = network_buffer._current_mut(); + unmask_nb::(current_mut, no_masking, rfi)?; + reader_buffer_first.extend_from_copyable_slice(current_mut)?; + Ok(()) +} + +#[inline] +pub(crate) fn copy_from_compressed_nb_to_rb1( + nc: &mut NC, + network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, + reader_buffer_first: &mut Vector, + rfi: &ReadFrameInfo, +) -> crate::Result<()> +where + NC: NegotiatedCompression, +{ + unmask_nb::(network_buffer._current_mut(), no_masking, rfi)?; + network_buffer._reserve(4)?; + let curr_end_idx = network_buffer._current().len(); + let curr_end_idx_p4 = curr_end_idx.wrapping_add(4); + let has_following = network_buffer._has_following(); + let input = network_buffer._current_rest_mut().get_mut(..curr_end_idx_p4).unwrap_or_default(); + let original = if let [.., a, b, c, d] = input { + let original = [*a, *b, *c, *d]; + *a = DECOMPRESSION_SUFFIX[0]; + *b = DECOMPRESSION_SUFFIX[1]; + *c = DECOMPRESSION_SUFFIX[2]; + *d = DECOMPRESSION_SUFFIX[3]; + original + } else { + [0, 0, 0, 0] + }; + let before = reader_buffer_first.len(); + let additional = input.len().saturating_mul(2); + let payload_len_rslt = nc.decompress( + input, + reader_buffer_first, + |local_rb| expand_rb(additional, local_rb, before), + |local_rb, written| expand_rb(additional, local_rb, before.wrapping_add(written)), + ); + if has_following { + if let [.., a, b, c, d] = input { + *a = original[0]; + *b = original[1]; + *c = original[2]; + *d = original[3]; + } + } + let payload_len = payload_len_rslt?; + reader_buffer_first.truncate(before.wrapping_add(payload_len)); + Ok(()) +} + +#[inline] +pub(crate) fn copy_from_compressed_rb1_to_rb2( + first_rfi: &ReadFrameInfo, + nc: &mut NC, + reader_buffer_first: &mut Vector, + reader_buffer_second: &mut Vector, +) -> crate::Result<()> +where + NC: NegotiatedCompression, +{ + reader_buffer_first.extend_from_copyable_slice(&DECOMPRESSION_SUFFIX)?; + let additional = reader_buffer_first.len().saturating_mul(2); + let payload_len = nc.decompress( + reader_buffer_first, + reader_buffer_second, + |local_rb| expand_rb(additional, local_rb, 0), + |local_rb, written| expand_rb(additional, local_rb, written), + )?; + reader_buffer_second.truncate(payload_len); + if matches!(first_rfi.op_code, OpCode::Text) && from_utf8_basic(reader_buffer_second).is_err() { + return Err(crate::Error::InvalidUTF8); + } + Ok(()) +} + +#[inline] +pub(crate) async fn fetch_frame_from_stream( + max_payload_len: usize, + (nc_is_noop, nc_rsv1): (bool, u8), + network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, + stream: &mut SR, +) -> crate::Result +where + SR: StreamReader, +{ + network_buffer._clear_if_following_is_empty(); + network_buffer._reserve(MAX_HEADER_LEN_USIZE)?; + let mut read = network_buffer._following_len(); + let rfi = ReadFrameInfo::from_stream::<_, IS_CLIENT>( + max_payload_len, + (nc_is_noop, nc_rsv1), + network_buffer, + no_masking, + &mut read, + stream, + ) + .await?; + let header_len = rfi.header_len.into(); + _read_payload((header_len, rfi.payload_len), network_buffer, &mut read, stream).await?; + Ok(rfi) +} /// If this method returns `false`, then a `ping` frame was received and the caller should fetch /// more external data in order to get the desired frame. @@ -216,196 +326,15 @@ pub(crate) fn unmask_nb( } #[inline] -pub(crate) async fn read_frame_from_stream<'nb, 'rb, 'rslt, NC, RNG, S, const IS_CLIENT: bool>( - connection_state: &mut ConnectionState, - max_payload_len: usize, - nc: &mut NC, - network_buffer: &'nb mut PartitionedFilledBuffer, - no_masking: bool, - reader_buffer_first: &'rb mut Vector, - reader_buffer_second: &'rb mut Vector, - rng: &mut RNG, - stream: &mut S, -) -> crate::Result<(FrameMut<'rslt, IS_CLIENT>, PayloadTy)> -where - 'nb: 'rslt, - 'rb: 'rslt, - NC: NegotiatedCompression, - RNG: Rng, - S: Stream, -{ - let first_rfi = loop { - reader_buffer_first.clear(); - let rfi = fetch_frame_from_stream::<_, _, IS_CLIENT>( - max_payload_len, - nc, - network_buffer, - no_masking, - stream, - ) - .await?; - if !rfi.fin { - break rfi; - } - let (payload, payload_ty) = if rfi.should_decompress { - copy_from_compressed_nb_to_rb1::( - nc, - network_buffer, - no_masking, - reader_buffer_first, - &rfi, - )?; - (reader_buffer_first.as_slice_mut(), PayloadTy::FirstReader) - } else { - let current_mut = network_buffer._current_mut(); - unmask_nb::(current_mut, no_masking, &rfi)?; - (current_mut, PayloadTy::Network) - }; - if manage_auto_reply::<_, _, IS_CLIENT>( - stream, - connection_state, - no_masking, - rfi.op_code, - payload, - rng, - &mut write_control_frame_cb, - ) - .await? - { - manage_op_code_of_first_final_frame(rfi.op_code, payload)?; - // FIXME(STABLE): Use `payload` with polonius - let borrow_checker = if rfi.should_decompress { - reader_buffer_first.as_slice_mut() - } else { - network_buffer._current_mut() - }; - return Ok((Frame::new(true, rfi.op_code, borrow_checker, nc.rsv1()), payload_ty)); - } - }; - reader_buffer_second.clear(); - if first_rfi.should_decompress { - read_continuation_frames::<_, _, _, IS_CLIENT>( - connection_state, - &first_rfi, - max_payload_len, - nc, - network_buffer, - no_masking, - reader_buffer_first, - rng, - stream, - (|_| Ok(None), |_, _| Ok(())), - ) - .await?; - copy_from_compressed_rb1_to_rb2(&first_rfi, nc, reader_buffer_first, reader_buffer_second)?; - Ok(( - Frame::new(true, first_rfi.op_code, reader_buffer_second, nc.rsv1()), - PayloadTy::SecondReader, - )) - } else { - read_continuation_frames::<_, _, _, IS_CLIENT>( - connection_state, - &first_rfi, - max_payload_len, - nc, - network_buffer, - no_masking, - reader_buffer_first, - rng, - stream, - (manage_text_of_first_continuation_frame, manage_text_of_recurrent_continuation_frames), - ) - .await?; - Ok(( - Frame::new(true, first_rfi.op_code, reader_buffer_first, nc.rsv1()), - PayloadTy::FirstReader, - )) - } -} - -#[inline] -fn copy_from_arbitrary_nb_to_rb1( - network_buffer: &mut PartitionedFilledBuffer, - no_masking: bool, - reader_buffer_first: &mut Vector, - rfi: &ReadFrameInfo, -) -> crate::Result<()> { - let current_mut = network_buffer._current_mut(); - unmask_nb::(current_mut, no_masking, rfi)?; - reader_buffer_first.extend_from_copyable_slice(current_mut)?; - Ok(()) -} - -#[inline] -fn copy_from_compressed_nb_to_rb1( - nc: &mut NC, - network_buffer: &mut PartitionedFilledBuffer, - no_masking: bool, - reader_buffer_first: &mut Vector, - rfi: &ReadFrameInfo, -) -> crate::Result<()> -where - NC: NegotiatedCompression, -{ - unmask_nb::(network_buffer._current_mut(), no_masking, rfi)?; - network_buffer._reserve(4)?; - let curr_end_idx = network_buffer._current().len(); - let curr_end_idx_p4 = curr_end_idx.wrapping_add(4); - let has_following = network_buffer._has_following(); - let input = network_buffer._current_rest_mut().get_mut(..curr_end_idx_p4).unwrap_or_default(); - let original = if let [.., a, b, c, d] = input { - let original = [*a, *b, *c, *d]; - *a = DECOMPRESSION_SUFFIX[0]; - *b = DECOMPRESSION_SUFFIX[1]; - *c = DECOMPRESSION_SUFFIX[2]; - *d = DECOMPRESSION_SUFFIX[3]; - original - } else { - [0, 0, 0, 0] - }; - let before = reader_buffer_first.len(); - let additional = input.len().saturating_mul(2); - let payload_len_rslt = nc.decompress( - input, - reader_buffer_first, - |local_rb| expand_rb(additional, local_rb, before), - |local_rb, written| expand_rb(additional, local_rb, before.wrapping_add(written)), - ); - if has_following { - if let [.., a, b, c, d] = input { - *a = original[0]; - *b = original[1]; - *c = original[2]; - *d = original[3]; - } - } - let payload_len = payload_len_rslt?; - reader_buffer_first.truncate(before.wrapping_add(payload_len)); - Ok(()) -} - -#[inline] -fn copy_from_compressed_rb1_to_rb2( - first_rfi: &ReadFrameInfo, - nc: &mut NC, - reader_buffer_first: &mut Vector, - reader_buffer_second: &mut Vector, +pub(crate) async fn write_control_frame_cb( + stream: &mut SW, + header: &[u8], + payload: &[u8], ) -> crate::Result<()> where - NC: NegotiatedCompression, + SW: StreamWriter, { - reader_buffer_first.extend_from_copyable_slice(&DECOMPRESSION_SUFFIX)?; - let additional = reader_buffer_first.len().saturating_mul(2); - let payload_len = nc.decompress( - reader_buffer_first, - reader_buffer_second, - |local_rb| expand_rb(additional, local_rb, 0), - |local_rb, written| expand_rb(additional, local_rb, written), - )?; - reader_buffer_second.truncate(payload_len); - if matches!(first_rfi.op_code, OpCode::Text) && from_utf8_basic(reader_buffer_second).is_err() { - return Err(crate::Error::InvalidUTF8); - } + stream.write_all_vectored(&[header, payload]).await?; Ok(()) } @@ -419,109 +348,6 @@ fn expand_rb( Ok(reader_buffer_first.get_mut(written..).unwrap_or_default()) } -#[inline] -async fn fetch_frame_from_stream( - max_payload_len: usize, - nc: &NC, - network_buffer: &mut PartitionedFilledBuffer, - no_masking: bool, - stream: &mut S, -) -> crate::Result -where - NC: NegotiatedCompression, - S: Stream, -{ - network_buffer._clear_if_following_is_empty(); - network_buffer._reserve(MAX_HEADER_LEN_USIZE)?; - let mut read = network_buffer._following_len(); - let rfi = ReadFrameInfo::from_stream::<_, _, IS_CLIENT>( - max_payload_len, - nc, - network_buffer, - no_masking, - &mut read, - stream, - ) - .await?; - let header_len = rfi.header_len.into(); - _read_payload((header_len, rfi.payload_len), network_buffer, &mut read, stream).await?; - Ok(rfi) -} - -#[inline] -async fn read_continuation_frames( - connection_state: &mut ConnectionState, - first_rfi: &ReadFrameInfo, - max_payload_len: usize, - nc: &mut NC, - network_buffer: &mut PartitionedFilledBuffer, - no_masking: bool, - reader_buffer_first: &mut Vector, - rng: &mut RNG, - stream: &mut S, - (first_text_cb, recurrent_text_cb): ReadContinuationFramesCb, -) -> crate::Result<()> -where - NC: NegotiatedCompression, - RNG: Rng, - S: Stream, -{ - copy_from_arbitrary_nb_to_rb1::( - network_buffer, - no_masking, - reader_buffer_first, - first_rfi, - )?; - let mut iuc = manage_op_code_of_first_continuation_frame( - first_rfi.op_code, - reader_buffer_first, - first_text_cb, - )?; - loop { - let mut rfi = fetch_frame_from_stream::<_, _, IS_CLIENT>( - max_payload_len, - nc, - network_buffer, - no_masking, - stream, - ) - .await?; - let begin = reader_buffer_first.len(); - rfi.should_decompress = first_rfi.should_decompress; - copy_from_arbitrary_nb_to_rb1::( - network_buffer, - no_masking, - reader_buffer_first, - &rfi, - )?; - let payload = reader_buffer_first.get_mut(begin..).unwrap_or_default(); - if !manage_auto_reply::<_, _, IS_CLIENT>( - stream, - connection_state, - no_masking, - rfi.op_code, - payload, - rng, - &mut write_control_frame_cb, - ) - .await? - { - reader_buffer_first.truncate(begin); - continue; - } - if manage_op_code_of_continuation_frames( - rfi.fin, - first_rfi.op_code, - &mut iuc, - rfi.op_code, - payload, - recurrent_text_cb, - )? { - return Ok(()); - } - } -} - #[inline] async fn write_control_frame( aux: &mut A, @@ -542,16 +368,3 @@ where wsc_cb.call((aux, frame.header(), frame.payload().lease())).await?; Ok(()) } - -#[inline] -async fn write_control_frame_cb( - stream: &mut S, - header: &[u8], - payload: &[u8], -) -> crate::Result<()> -where - S: Stream, -{ - stream.write_all_vectored(&[header, payload]).await?; - Ok(()) -} diff --git a/wtx/src/web_socket/web_socket_writer.rs b/wtx/src/web_socket/web_socket_writer.rs index 65eb7c44..4f68eff2 100644 --- a/wtx/src/web_socket/web_socket_writer.rs +++ b/wtx/src/web_socket/web_socket_writer.rs @@ -1,5 +1,7 @@ +// Common functions that used be used by pure WebSocket structures or tunneling protocols. + use crate::{ - misc::{BufferMode, ConnectionState, Lease, LeaseMut, Rng, Stream, Vector, Xorshift64}, + misc::{BufferMode, ConnectionState, Lease, LeaseMut, Rng, StreamWriter, Vector, Xorshift64}, web_socket::{ compression::NegotiatedCompression, misc::has_masked_frame, unmask::unmask, Frame, FrameMut, OpCode, @@ -65,19 +67,19 @@ pub(crate) fn manage_normal_frame( } #[inline] -pub(crate) async fn write_frame( +pub(crate) async fn write_frame( connection_state: &mut ConnectionState, frame: &mut Frame, no_masking: bool, nc: &mut NC, rng: &mut Xorshift64, - stream: &mut S, + stream: &mut SW, writer_buffer: &mut Vector, ) -> crate::Result<()> where NC: NegotiatedCompression, P: LeaseMut<[u8]>, - S: Stream, + SW: StreamWriter, { if manage_compression(frame, nc) { let fr = manage_frame_compression(connection_state, nc, frame, no_masking, rng, writer_buffer)?;