From c6247c5d47743c7089ff17deed1521138b00e04b Mon Sep 17 00:00:00 2001 From: Caio Date: Mon, 28 Oct 2024 22:55:27 -0300 Subject: [PATCH] Add the "no-masking" parameter --- CONTRIBUTING.md | 10 + Cargo.lock | 42 ++-- wtx-docs/src/web-socket/README.md | 4 + wtx-fuzz/web_socket.rs | 1 + .../generic-examples/client-api-framework.rs | 1 + .../generic-examples/web-socket-client.rs | 1 + .../http-server-framework-session.rs | 44 ++-- wtx-instances/http2-examples/http2-server.rs | 37 ++-- .../http2-examples/http2-web-socket.rs | 4 +- wtx-instances/src/bin/autobahn-client.rs | 3 + wtx-instances/src/bin/h2load.rs | 25 +-- wtx-instances/src/bin/h2spec-high-server.rs | 25 +-- wtx-ui/src/web_socket.rs | 2 + wtx/Cargo.toml | 2 +- wtx/src/database/client/postgres/tys.rs | 102 +++++++++ wtx/src/error.rs | 3 + wtx/src/grpc/grpc_manager.rs | 4 +- wtx/src/http.rs | 2 + wtx/src/http/header_name.rs | 6 +- wtx/src/http/headers.rs | 9 +- wtx/src/http/optioned_server/tokio_http2.rs | 204 ++++++++---------- .../http/optioned_server/tokio_web_socket.rs | 68 +++--- wtx/src/http/server_framework.rs | 34 ++- .../param_wrappers/serde_json.rs | 24 ++- wtx/src/http/server_framework/router.rs | 22 +- .../server_framework_builder.rs | 58 ++--- .../{req_aux.rs => stream_aux.rs} | 4 +- wtx/src/http/server_framework/tokio.rs | 56 ++--- wtx/src/http/stream_mode.rs | 45 ++++ wtx/src/http2.rs | 3 - wtx/src/http2/web_socket_over_stream.rs | 45 +++- wtx/src/misc.rs | 27 +++ wtx/src/misc/tuple_impls.rs | 6 +- wtx/src/pool/resource_manager.rs | 9 +- wtx/src/pool/simple_pool.rs | 2 +- wtx/src/web_socket.rs | 58 ++++- wtx/src/web_socket/compression/flate2.rs | 12 +- wtx/src/web_socket/frame.rs | 23 +- wtx/src/web_socket/handshake.rs | 176 +++++++++------ wtx/src/web_socket/handshake/tests.rs | 45 ++-- wtx/src/web_socket/misc.rs | 63 +----- wtx/src/web_socket/read_frame_info.rs | 66 ++++-- wtx/src/web_socket/web_socket_error.rs | 7 +- wtx/src/web_socket/web_socket_parts.rs | 17 +- wtx/src/web_socket/web_socket_reader.rs | 69 ++++-- wtx/src/web_socket/web_socket_writer.rs | 46 ++-- 46 files changed, 933 insertions(+), 583 deletions(-) create mode 100644 CONTRIBUTING.md rename wtx/src/http/server_framework/{req_aux.rs => stream_aux.rs} (75%) create mode 100644 wtx/src/http/stream_mode.rs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..6e6dbc0f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,10 @@ +# Contributing + +Before submitting a PR, you should probably run `./scripts/internal-tests-all.sh` and/or `./scripts/intergration-tests.sh` to make sure everything is fine. + +Integration tests interact with external programs like `podman` or require an internet connection, therefore, they usually aren't good candidates for offline development. On the other hand, internal tests are composed by unit tests, code formatting, `clippy` lints and fuzzing targets. + +## Building + +Taking aside common Rust tools that can be installed with `rustup` (https://rustup.rs/), at the current time it is only necessary to only have an C compiler to build the project. For example, you can use your favorite system package manager to install `gcc`. + diff --git a/Cargo.lock b/Cargo.lock index 1a9de3c1..75d9d5ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -363,15 +363,6 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" -[[package]] -name = "cmake" -version = "0.1.51" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" -dependencies = [ - "cc", -] - [[package]] name = "const-oid" version = "0.9.6" @@ -478,7 +469,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", - "libz-ng-sys", + "libz-rs-sys", "miniz_oxide", ] @@ -680,18 +671,17 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "a00419de735aac21d53b0de5ce2c03bd3627277cf471300f27ebc89f7d828047" [[package]] -name = "libz-ng-sys" -version = "1.1.20" +name = "libz-rs-sys" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f0f7295a34685977acb2e8cc8b08ee4a8dffd6cf278eeccddbe1ed55ba815d5" +checksum = "009b9249eef9fd7f6bbc96969f38de54a10f6be687f6d0a2ed98c4e4dcdc566f" dependencies = [ - "cmake", - "libc", + "zlib-rs", ] [[package]] @@ -1118,9 +1108,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustls" -version = "0.23.15" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "once_cell", "ring", @@ -1170,18 +1160,18 @@ checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" [[package]] name = "serde" -version = "1.0.213" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.213" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -1969,3 +1959,9 @@ dependencies = [ "quote", "syn 2.0.85", ] + +[[package]] +name = "zlib-rs" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b947c9af34afbf71a8ad64bedb8f3c26b562b1dad562218b265edd6f095731a" diff --git a/wtx-docs/src/web-socket/README.md b/wtx-docs/src/web-socket/README.md index 9787cbcd..99ccde66 100644 --- a/wtx-docs/src/web-socket/README.md +++ b/wtx-docs/src/web-socket/README.md @@ -17,6 +17,10 @@ The "permessage-deflate" extension is the only supported compression format and To get the most performance possible, try compiling your program with `RUSTFLAGS='-C target-cpu=native'` to allow `zlib-rs` to use more efficient SIMD instructions. +## No masking + +Although not officially endorsed, the `no-masking` parameter described at https://datatracker.ietf.org/doc/html/draft-damjanovic-websockets-nomasking-02 is supported to increase performance. If such a feature is not desirable, please make sure to check the handshake parameters to avoid accidental scenarios. + ## Client Example ```rust,edition2021,no_run diff --git a/wtx-fuzz/web_socket.rs b/wtx-fuzz/web_socket.rs index 0173ede0..8a5c8632 100644 --- a/wtx-fuzz/web_socket.rs +++ b/wtx-fuzz/web_socket.rs @@ -13,6 +13,7 @@ libfuzzer_sys::fuzz_target!(|data: (OpCode, Vec)| { Builder::new_current_thread().enable_all().build().unwrap().block_on(async move { let Ok(mut ws) = WebSocketServerOwned::new( (), + false, Xorshift64::from(simple_seed()), BytesStream::default(), WebSocketBuffer::default(), diff --git a/wtx-instances/generic-examples/client-api-framework.rs b/wtx-instances/generic-examples/client-api-framework.rs index d4f8c572..d92c10a0 100644 --- a/wtx-instances/generic-examples/client-api-framework.rs +++ b/wtx-instances/generic-examples/client-api-framework.rs @@ -106,6 +106,7 @@ async fn main() -> wtx::Result<()> { let web_socket = WebSocketClient::connect( (), [], + false, Xorshift64::from(simple_seed()), TcpStream::connect(uri.hostname_with_implied_port()).await?, &uri, diff --git a/wtx-instances/generic-examples/web-socket-client.rs b/wtx-instances/generic-examples/web-socket-client.rs index b8952118..31a55376 100644 --- a/wtx-instances/generic-examples/web-socket-client.rs +++ b/wtx-instances/generic-examples/web-socket-client.rs @@ -22,6 +22,7 @@ async fn main() -> wtx::Result<()> { let mut ws = WebSocketClient::connect( (), [], + false, Xorshift64::from(simple_seed()), TcpStream::connect(uri.hostname_with_implied_port()).await?, &uri.to_ref(), diff --git a/wtx-instances/http-server-framework-examples/http-server-framework-session.rs b/wtx-instances/http-server-framework-examples/http-server-framework-session.rs index 1e516099..a34f46dc 100644 --- a/wtx-instances/http-server-framework-examples/http-server-framework-session.rs +++ b/wtx-instances/http-server-framework-examples/http-server-framework-session.rs @@ -10,17 +10,16 @@ //! password BYTEA NOT NULL, //! salt BYTEA NOT NULL //! ); +//! ALTER TABLE "user" ADD CONSTRAINT user__email__uq UNIQUE (email); //! //! CREATE TABLE session ( //! id BYTEA NOT NULL PRIMARY KEY, //! user_id INT NOT NULL, //! expires_at TIMESTAMPTZ NOT NULL //! ); -//! //! ALTER TABLE session ADD CONSTRAINT session__user__fk FOREIGN KEY (user_id) REFERENCES "user" (id); //! ``` -use argon2::{Algorithm, Argon2, Block, Params, Version}; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use tokio::net::TcpStream; use wtx::{ @@ -29,21 +28,10 @@ use wtx::{ server_framework::{get, post, Router, ServerFrameworkBuilder, State, StateClean}, ReqResBuffer, ReqResData, SessionDecoder, SessionEnforcer, SessionTokio, StatusCode, }, + misc::argon2_pwd, pool::{PostgresRM, SimplePoolTokio}, }; -const ARGON2_OUTPUT_LEN: usize = 32; -const ARGON2_PARAMS: Params = { - let Ok(elem) = Params::new( - Params::DEFAULT_M_COST, - Params::DEFAULT_T_COST, - Params::DEFAULT_P_COST, - Some(ARGON2_OUTPUT_LEN), - ) else { - panic!(); - }; - elem -}; type ConnAux = (Session, ChaCha20Rng); type Pool = SimplePoolTokio>; type Session = SessionTokio; @@ -71,12 +59,6 @@ async fn main() -> wtx::Result<()> { Ok(()) } -#[derive(Debug, serde::Deserialize)] -struct User<'req> { - email: &'req str, - password: &'req str, -} - #[inline] async fn login(state: State<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result { let (session, rng) = state.ca; @@ -84,7 +66,7 @@ async fn login(state: State<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result = serde_json::from_slice(state.req.rrd.body())?; + let user: UserLoginReq<'_> = serde_json::from_slice(state.req.rrd.body())?; let mut executor_guard = session.store.get().await?; let record = executor_guard .fetch_with_stmt("SELECT id,password,salt FROM user WHERE email = $1", (user.email,)) @@ -92,19 +74,14 @@ async fn login(state: State<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result(0)?; let password_db = record.decode::<_, &[u8]>(1)?; let salt = record.decode::<_, &[u8]>(2)?; - let mut password_req = [0; ARGON2_OUTPUT_LEN]; - Argon2::new(Algorithm::Argon2id, Version::V0x13, ARGON2_PARAMS).hash_password_into_with_memory( - user.password.as_bytes(), - salt, - &mut password_req, - &mut [Block::new(); ARGON2_PARAMS.block_count()], - )?; + let password_req = argon2_pwd(user.password.as_bytes(), salt)?; state.req.rrd.clear(); if password_db != &password_req { return Ok(StatusCode::Unauthorized); } drop(executor_guard); session.set_session_cookie(id, rng, &mut state.req.rrd).await?; + serde_json::to_writer(&mut state.req.rrd.body, &UserLoginRes { id })?; Ok(StatusCode::Ok) } @@ -113,3 +90,14 @@ async fn logout(state: StateClean<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result state.ca.0.delete_session_cookie(&mut state.req.rrd).await?; Ok(StatusCode::Ok) } + +#[derive(Debug, serde::Deserialize)] +struct UserLoginReq<'req> { + email: &'req str, + password: &'req str, +} + +#[derive(Debug, serde::Serialize)] +struct UserLoginRes { + id: u32, +} diff --git a/wtx-instances/http2-examples/http2-server.rs b/wtx-instances/http2-examples/http2-server.rs index e2375e5e..f6d2b2ac 100644 --- a/wtx-instances/http2-examples/http2-server.rs +++ b/wtx-instances/http2-examples/http2-server.rs @@ -1,4 +1,4 @@ -//! Serves requests using low-level HTTP/2 resources along side self-made certificates. +//! HTTP/2 server that uses optioned parameters. extern crate tokio; extern crate tokio_rustls; @@ -8,8 +8,11 @@ extern crate wtx_instances; use tokio::{io::WriteHalf, net::TcpStream}; use tokio_rustls::server::TlsStream; use wtx::{ - http::{Headers, OptionedServer, ReqResBuffer, Request, Response, StatusCode}, - http2::{Http2Buffer, Http2Params, ServerStreamTokio, WebSocketOverStream}, + http::{ + AutoStream, ManualServerStreamTokio, OptionedServer, ReqResBuffer, Response, StatusCode, + StreamMode, + }, + http2::{is_web_socket_handshake, Http2Buffer, Http2Params, WebSocketOverStream}, misc::{simple_seed, TokioRustlsAcceptor, Vector, Xorshift64}, web_socket::{Frame, OpCode}, }; @@ -31,6 +34,13 @@ async fn main() -> wtx::Result<()> { |error| eprintln!("{error}"), manual, || Ok((Vector::new(), ReqResBuffer::empty())), + |headers, method, protocol| { + Ok(if is_web_socket_handshake(headers, method, protocol) { + StreamMode::Manual + } else { + StreamMode::Auto + }) + }, ( || { TokioRustlsAcceptor::without_client_auth() @@ -44,25 +54,19 @@ async fn main() -> wtx::Result<()> { .await } -async fn auto( - _: (), - _: Vector, - mut req: Request, -) -> Result, wtx::Error> { - req.rrd.clear(); - Ok(req.into_response(StatusCode::Ok)) +async fn auto(mut ha: AutoStream<(), Vector>) -> Result, wtx::Error> { + ha.req.rrd.clear(); + Ok(ha.req.into_response(StatusCode::Ok)) } async fn manual( - _: (), - mut buffer: Vector, - _: Headers, - stream: ServerStreamTokio>, false>, + mut hm: ManualServerStreamTokio<(), Vector, Http2Buffer, WriteHalf>>, ) -> Result<(), wtx::Error> { let rng = Xorshift64::from(simple_seed()); - let mut wos = WebSocketOverStream::new(&Headers::new(), rng, stream).await?; + hm.headers.clear(); + let mut wos = WebSocketOverStream::new(&hm.headers, false, rng, hm.stream).await?; loop { - let mut frame = wos.read_frame(&mut buffer).await?; + let mut frame = wos.read_frame(&mut hm.sa).await?; match (frame.op_code(), frame.text_payload()) { (_, Some(elem)) => println!("{elem}"), (OpCode::Close, _) => break, @@ -70,5 +74,6 @@ async fn manual( } wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?; } + wos.close().await?; Ok(()) } diff --git a/wtx-instances/http2-examples/http2-web-socket.rs b/wtx-instances/http2-examples/http2-web-socket.rs index a8342100..4af185e0 100644 --- a/wtx-instances/http2-examples/http2-web-socket.rs +++ b/wtx-instances/http2-examples/http2-web-socket.rs @@ -8,7 +8,6 @@ extern crate wtx; extern crate wtx_instances; use core::mem; - use tokio::net::TcpListener; use wtx::{ http::{Headers, ReqResBuffer}, @@ -47,7 +46,7 @@ async fn main() -> wtx::Result<()> { return Ok(()); }; let mut buffer = Vector::new(); - let mut wos = WebSocketOverStream::new(&Headers::new(), rng, &mut stream).await?; + let mut wos = WebSocketOverStream::new(&Headers::new(), false, rng, &mut stream).await?; loop { let mut frame = wos.read_frame(&mut buffer).await?; match (frame.op_code(), frame.text_payload()) { @@ -57,6 +56,7 @@ async fn main() -> wtx::Result<()> { } wos.write_frame(&mut Frame::new_fin(OpCode::Text, frame.payload_mut())).await?; } + wos.close().await?; stream.common().clear(false).await?; Ok(()) } diff --git a/wtx-instances/src/bin/autobahn-client.rs b/wtx-instances/src/bin/autobahn-client.rs index ef2f1ec3..97f5ea55 100644 --- a/wtx-instances/src/bin/autobahn-client.rs +++ b/wtx-instances/src/bin/autobahn-client.rs @@ -14,6 +14,7 @@ async fn main() -> wtx::Result<()> { let mut ws = WebSocketClient::connect( Flate2::default(), [], + false, Xorshift64::from(simple_seed()), TcpStream::connect(host).await?, &UriRef::new(&format!("http://{host}/runCase?case={case}&agent=wtx")), @@ -40,6 +41,7 @@ async fn main() -> wtx::Result<()> { WebSocketClient::connect( (), [], + false, Xorshift64::from(simple_seed()), TcpStream::connect(host).await?, &UriRef::new(&format!("http://{host}/updateReports?agent=wtx")), @@ -55,6 +57,7 @@ async fn get_case_count(host: &str, wsb: &mut WebSocketBuffer) -> wtx::Result wtx::Result<()> { |error| eprintln!("{error}"), manual, || Ok(((), ReqResBuffer::empty())), + |_, _, _| Ok(StreamMode::Auto), (|| Ok(()), |_| {}, |_, stream| async move { Ok(stream.into_split()) }), ) .await } -async fn auto( - _: (), - _: (), - mut req: Request, -) -> Result, wtx::Error> { - req.rrd.clear(); - Ok(req.into_response(StatusCode::Ok)) +async fn auto(mut ha: AutoStream<(), ()>) -> Result, wtx::Error> { + ha.req.rrd.clear(); + Ok(ha.req.into_response(StatusCode::Ok)) } async fn manual( - _: (), - _: (), - _: Headers, - _: ServerStreamTokio, + _: ManualServerStreamTokio<(), (), Http2Buffer, OwnedWriteHalf>, ) -> Result<(), wtx::Error> { Ok(()) } diff --git a/wtx-instances/src/bin/h2spec-high-server.rs b/wtx-instances/src/bin/h2spec-high-server.rs index ce8a4a86..e469dc04 100644 --- a/wtx-instances/src/bin/h2spec-high-server.rs +++ b/wtx-instances/src/bin/h2spec-high-server.rs @@ -4,8 +4,11 @@ use tokio::net::tcp::OwnedWriteHalf; use wtx::{ - http::{Headers, OptionedServer, ReqResBuffer, Request, Response, StatusCode}, - http2::{Http2Buffer, Http2Params, ServerStreamTokio}, + http::{ + AutoStream, ManualServerStreamTokio, OptionedServer, ReqResBuffer, Response, StatusCode, + StreamMode, + }, + http2::{Http2Buffer, Http2Params}, misc::{simple_seed, Xorshift64}, }; @@ -18,26 +21,20 @@ async fn main() -> wtx::Result<()> { |error| eprintln!("{error}"), manual, || Ok(((), ReqResBuffer::empty())), + |_, _, _| Ok(StreamMode::Auto), (|| Ok(()), |_| {}, |_, stream| async move { Ok(stream.into_split()) }), ) .await } -async fn auto( - _: (), - _: (), - mut req: Request, -) -> Result, wtx::Error> { - req.rrd.clear(); - req.rrd.body.extend_from_copyable_slice(b"Hello")?; - Ok(req.into_response(StatusCode::Ok)) +async fn auto(mut ha: AutoStream<(), ()>) -> Result, wtx::Error> { + ha.req.rrd.clear(); + ha.req.rrd.body.extend_from_copyable_slice(b"Hello")?; + Ok(ha.req.into_response(StatusCode::Ok)) } async fn manual( - _: (), - _: (), - _: Headers, - _: ServerStreamTokio, + _: ManualServerStreamTokio<(), (), Http2Buffer, OwnedWriteHalf>, ) -> Result<(), wtx::Error> { Ok(()) } diff --git a/wtx-ui/src/web_socket.rs b/wtx-ui/src/web_socket.rs index b31374ae..efb70791 100644 --- a/wtx-ui/src/web_socket.rs +++ b/wtx-ui/src/web_socket.rs @@ -13,6 +13,7 @@ pub(crate) async fn connect(uri: &str, cb: impl Fn(&str)) -> wtx::Result<()> { let mut ws = WebSocketClient::connect( (), [], + false, Xorshift64::from(simple_seed()), TcpStream::connect(uri.hostname_with_implied_port()).await?, &uri, @@ -55,6 +56,7 @@ pub(crate) async fn serve( let fun = async move { let mut ws = WebSocketServer::accept( (), + false, Xorshift64::from(simple_seed()), stream, WebSocketBuffer::default(), diff --git a/wtx/Cargo.toml b/wtx/Cargo.toml index a3b9615b..565c5588 100644 --- a/wtx/Cargo.toml +++ b/wtx/Cargo.toml @@ -9,7 +9,7 @@ cl-aux = { default-features = false, optional = true, features = ["alloc"], vers crypto-common = { default-features = false, optional = true, version = "0.1" } digest = { default-features = false, features = ["mac"], optional = true, version = "0.10" } fastrand = { default-features = false, optional = true, version = "2.0" } -flate2 = { default-features = false, features = ["zlib-ng"], optional = true, version = "1.0" } +flate2 = { default-features = false, features = ["zlib-rs"], optional = true, version = "1.0" } foldhash = { default-features = false, optional = true, version = "0.1" } hashbrown = { default-features = false, features = ["inline-more"], optional = true, version = "0.15" } hmac = { default-features = false, optional = true, version = "0.12" } diff --git a/wtx/src/database/client/postgres/tys.rs b/wtx/src/database/client/postgres/tys.rs index cd773348..84650eed 100644 --- a/wtx/src/database/client/postgres/tys.rs +++ b/wtx/src/database/client/postgres/tys.rs @@ -318,6 +318,108 @@ mod collections { proptest!(string, String); } +mod ip { + use crate::database::{ + client::postgres::{DecodeValue, EncodeValue, Postgres, Ty}, + Decode, Encode, Typed, + }; + use core::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + impl<'exec, E> Decode<'exec, Postgres> for IpAddr + where + E: From, + { + #[inline] + fn decode(dv: &DecodeValue<'exec>) -> Result { + Ok(match dv.bytes() { + [2, ..] => IpAddr::V4(Ipv4Addr::decode(dv)?), + [3, ..] => IpAddr::V6(Ipv6Addr::decode(dv)?), + _ => panic!(), + }) + } + } + impl Encode> for IpAddr + where + E: From, + { + #[inline] + fn encode(&self, ev: &mut EncodeValue<'_, '_>) -> Result<(), E> { + match self { + IpAddr::V4(ipv4_addr) => ipv4_addr.encode(ev), + IpAddr::V6(ipv6_addr) => ipv6_addr.encode(ev), + } + } + } + impl Typed> for IpAddr + where + E: From, + { + const TY: Ty = Ty::Inet; + } + test!(ipaddr_v4, IpAddr, IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))); + test!(ipaddr_v6, IpAddr, IpAddr::V6(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8))); + + impl<'exec, E> Decode<'exec, Postgres> for Ipv4Addr + where + E: From, + { + #[inline] + fn decode(dv: &DecodeValue<'exec>) -> Result { + let [2, 32, 0, 4, e, f, g, h] = dv.bytes() else { + panic!(); + }; + Ok(Ipv4Addr::from([*e, *f, *g, *h])) + } + } + impl Encode> for Ipv4Addr + where + E: From, + { + #[inline] + fn encode(&self, ev: &mut EncodeValue<'_, '_>) -> Result<(), E> { + ev.fbw()._extend_from_slices([&[2, 32, 0, 4][..], &self.octets()]).map_err(Into::into)?; + Ok(()) + } + } + impl Typed> for Ipv4Addr + where + E: From, + { + const TY: Ty = Ty::Inet; + } + test!(ipv4, Ipv4Addr, Ipv4Addr::new(1, 2, 3, 4)); + + impl<'exec, E> Decode<'exec, Postgres> for Ipv6Addr + where + E: From, + { + #[inline] + fn decode(dv: &DecodeValue<'exec>) -> Result { + let [3, 128, 0, 16, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t] = dv.bytes() else { + panic!(); + }; + Ok(Ipv6Addr::from([*e, *f, *g, *h, *i, *j, *k, *l, *m, *n, *o, *p, *q, *r, *s, *t])) + } + } + impl Encode> for Ipv6Addr + where + E: From, + { + #[inline] + fn encode(&self, ev: &mut EncodeValue<'_, '_>) -> Result<(), E> { + ev.fbw()._extend_from_slices([&[3, 128, 0, 16][..], &self.octets()]).map_err(Into::into)?; + Ok(()) + } + } + impl Typed> for Ipv6Addr + where + E: From, + { + const TY: Ty = Ty::Inet; + } + test!(ipv6, Ipv6Addr, Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8)); +} + mod pg_numeric { use crate::{ database::{ diff --git a/wtx/src/error.rs b/wtx/src/error.rs index 6935c653..d4d99e22 100644 --- a/wtx/src/error.rs +++ b/wtx/src/error.rs @@ -102,6 +102,9 @@ pub enum Error { NoInnerValue(&'static str), /// A set of arithmetic operations resulted in an overflow, underflow or division by zero OutOfBoundsArithmetic, + /// An error that shouldn't exist. If this variant is raised, then it is very likely that the + /// involved code was not built the way it should be. + ProgrammingError, /// Unexpected Unsigned integer UnboundedNumber { expected: RangeInclusive, diff --git a/wtx/src/grpc/grpc_manager.rs b/wtx/src/grpc/grpc_manager.rs index 2e52a7a5..488c833e 100644 --- a/wtx/src/grpc/grpc_manager.rs +++ b/wtx/src/grpc/grpc_manager.rs @@ -4,7 +4,7 @@ use crate::{ format::{VerbatimRequest, VerbatimResponse}, }, grpc::{serialize, GrpcStatusCode}, - http::{server_framework::ReqAux, ReqResBuffer}, + http::{server_framework::StreamAux, ReqResBuffer}, misc::Vector, }; @@ -50,7 +50,7 @@ impl GrpcManager { } } -impl ReqAux for GrpcManager +impl StreamAux for GrpcManager where DRSR: Default, { diff --git a/wtx/src/http.rs b/wtx/src/http.rs index d7493837..413b74fd 100644 --- a/wtx/src/http.rs +++ b/wtx/src/http.rs @@ -29,6 +29,7 @@ pub mod server_framework; #[cfg(feature = "http-session")] mod session; mod status_code; +mod stream_mode; mod version; #[cfg(feature = "http-session")] @@ -54,6 +55,7 @@ pub use response::Response; #[cfg(feature = "http-session")] pub use session::*; pub use status_code::StatusCode; +pub use stream_mode::*; pub use version::Version; pub(crate) const _MAX_AUTHORITY_LEN: usize = 64; diff --git a/wtx/src/http/header_name.rs b/wtx/src/http/header_name.rs index 868ad5ba..b36141b6 100644 --- a/wtx/src/http/header_name.rs +++ b/wtx/src/http/header_name.rs @@ -8,7 +8,7 @@ macro_rules! create_statics { )* ) => { /// A statically known set of header names - #[derive(Debug, Eq, PartialEq)] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum KnownHeaderName { $( $(#[$mac])* @@ -209,8 +209,10 @@ create_statics! { Referer = "referer"; Refresh = "refresh"; RetryAfter = "retry-after"; - SecWebsocketVersion = "sec-websocket-version"; + SecWebsocketAccept = "sec-websocket-accept"; + SecWebsocketExtensions = "sec-websocket-extensions"; SecWebsocketKey = "sec-websocket-key"; + SecWebsocketVersion = "sec-websocket-version"; Server = "server"; ServerTiming = "server-timing"; SetCookie = "set-cookie"; diff --git a/wtx/src/http/headers.rs b/wtx/src/http/headers.rs index 0b3a7784..71fb3c20 100644 --- a/wtx/src/http/headers.rs +++ b/wtx/src/http/headers.rs @@ -101,9 +101,12 @@ impl Headers { names: [&[u8]; N], ) -> [Option>; N] { let mut rslt = [None; N]; - for (header, value) in self.iter().zip(&mut rslt) { - if names.iter().any(|name| *name == header.name) { - *value = Some(header); + for header in self.iter() { + for (name, opt) in names.into_iter().zip(&mut rslt) { + if name == header.name { + *opt = Some(header); + break; + } } } rslt diff --git a/wtx/src/http/optioned_server/tokio_http2.rs b/wtx/src/http/optioned_server/tokio_http2.rs index 683fc2bd..c974f9b4 100644 --- a/wtx/src/http/optioned_server/tokio_http2.rs +++ b/wtx/src/http/optioned_server/tokio_http2.rs @@ -1,30 +1,38 @@ use crate::{ - http::{optioned_server::OptionedServer, Headers, ReqResBuffer, Request, Response}, - http2::{Http2Buffer, Http2ErrorCode, Http2Params, Http2Tokio, ServerStreamTokio}, + http::{ + optioned_server::OptionedServer, AutoStream, Headers, ManualServerStreamTokio, ManualStream, + Method, Protocol, ReqResBuffer, Response, StreamMode, + }, + http2::{Http2Buffer, Http2ErrorCode, Http2Params, Http2Tokio}, misc::{Either, FnFut, StreamReader, StreamWriter}, }; -use core::future::Future; +use core::{future::Future, mem}; use tokio::net::{TcpListener, TcpStream}; impl OptionedServer { /// Optioned HTTP/2 server using tokio. #[inline] - pub async fn tokio_high_http2( + pub async fn tokio_high_http2( addr: &str, auto_cb: A, conn_cb: impl Clone + Fn() -> crate::Result<(CA, Http2Buffer, Http2Params)> + Send + 'static, err_cb: impl Clone + Fn(E) + Send + 'static, manual_cb: M, stream_cb: impl Clone + Fn() -> crate::Result<(SA, ReqResBuffer)> + Send + 'static, - (acceptor_cb, local_acceptor_cb, net_cb): ( + stream_mode_cb: impl Clone + + Fn(&mut Headers, Method, Option) -> Result + + Send + + Sync + + 'static, + (acceptor_cb, conn_acceptor_cb, net_cb): ( impl FnOnce() -> crate::Result + Send + 'static, impl Clone + Fn(&ACPT) -> ACPT + Send + 'static, - impl Clone + Fn(ACPT, TcpStream) -> SF + Send + 'static, + impl Clone + Fn(ACPT, TcpStream) -> N + Send + 'static, ), ) -> crate::Result<()> where A: Clone - + FnFut<(CA, SA, Request), Result = Result, E>> + + FnFut<(AutoStream,), Result = Result, E>> + Send + 'static, A::Future: Send, @@ -32,12 +40,12 @@ impl OptionedServer { ACPT: Send + 'static, E: From + Send + 'static, M: Clone - + FnFut<(CA, SA, Headers, ServerStreamTokio), Result = Result<(), E>> + + FnFut<(ManualServerStreamTokio,), Result = Result<(), E>> + Send + 'static, M::Future: Send, + N: Future> + Send, SA: Send + 'static, - SF: Send + Future>, SR: Send + StreamReader + Unpin + 'static, SW: Send + StreamWriter + Unpin + 'static, for<'handle> &'handle A: Send, @@ -46,112 +54,84 @@ impl OptionedServer { let listener = TcpListener::bind(addr).await?; let acceptor = acceptor_cb()?; loop { - let (tcp_stream, _) = listener.accept().await?; - let local_acceptor = local_acceptor_cb(&acceptor); - let local_auto_cb = auto_cb.clone(); - let local_conn_cb = conn_cb.clone(); - let local_err_cb = err_cb.clone(); - let local_manual_cb = manual_cb.clone(); - let local_net_cb = net_cb.clone(); - let local_stream_cb = stream_cb.clone(); + let tcp_stream = listener.accept().await?.0; + let peer = tcp_stream.peer_addr()?.ip(); + let conn_acceptor = conn_acceptor_cb(&acceptor); + let conn_auto_cb = auto_cb.clone(); + let conn_conn_cb = conn_cb.clone(); + let conn_err_cb = err_cb.clone(); + let conn_manual_cb = manual_cb.clone(); + let conn_net_cb = net_cb.clone(); + let conn_stream_cb = stream_cb.clone(); + let conn_stream_mode_cb = stream_mode_cb.clone(); let _conn_jh = tokio::spawn(async move { - let local_local_err_cb = local_err_cb.clone(); - let fut = manage_conn( - local_acceptor, - local_auto_cb, - local_conn_cb, - local_err_cb, - local_manual_cb, - local_net_cb, - local_stream_cb, - tcp_stream, - ); - if let Err(err) = fut.await { - local_local_err_cb(E::from(err)); + let another_conn_err_cb = conn_err_cb.clone(); + let conn_fun = async move { + let (conn_ca, http2_buffer, http2_params) = conn_conn_cb()?; + let (frame_reader, mut http2) = Http2Tokio::accept( + http2_buffer, + http2_params, + conn_net_cb(conn_acceptor, tcp_stream).await?, + ) + .await?; + let _frame_reader_jh = tokio::spawn(frame_reader); + loop { + let (sa, rrb) = conn_stream_cb()?; + let (mut stream, headers_opt) = match http2 + .stream(rrb, |headers, method, protocol| { + Ok::<_, E>(match conn_stream_mode_cb(headers, method, protocol)? { + StreamMode::Auto => None, + StreamMode::Manual => Some(mem::take(headers)), + }) + }) + .await? + { + Either::Left(_) => return Ok(()), + Either::Right(elem) => elem, + }; + let stream_auto_cb = conn_auto_cb.clone(); + let stream_ca = conn_ca.clone(); + let stream_err_cb = conn_err_cb.clone(); + let stream_manual_cb = conn_manual_cb.clone(); + let _stream_jh = tokio::spawn(async move { + let stream_fun = async { + if let Some(headers) = headers_opt? { + stream_manual_cb + .call((ManualStream { + ca: stream_ca, + headers, + peer, + sa, + stream: stream.clone(), + },)) + .await?; + return Ok(()); + } + let (hrs, local_rrb) = stream.recv_req().await?; + if hrs.is_closed() { + return Ok(()); + } + let req = local_rrb.into_http2_request(stream.method()); + let _as = AutoStream { ca: stream_ca, peer, req, sa }; + let res = stream_auto_cb.call((_as,)).await?; + if stream.send_res(res).await?.is_closed() { + return Ok(()); + } + Ok::<_, E>(()) + }; + let stream_fun_rslt = stream_fun.await; + let _rslt = stream.common().clear(true).await; + if let Err(err) = stream_fun_rslt { + stream.common().send_go_away(Http2ErrorCode::InternalError).await; + stream_err_cb(err); + } + }); + } + }; + if let Err(err) = conn_fun.await { + another_conn_err_cb(E::from(err)); } }); } } } - -async fn manage_conn( - acceptor: ACPT, - auto_cb: A, - conn_cb: impl Clone + Fn() -> crate::Result<(CA, Http2Buffer, Http2Params)> + Send + 'static, - err_cb: impl Clone + Fn(E) + Send + 'static, - manual_cb: M, - net_cb: impl Clone + Fn(ACPT, TcpStream) -> SF + Send + 'static, - stream_cb: impl Clone + Fn() -> crate::Result<(SA, ReqResBuffer)> + Send + 'static, - tcp_stream: TcpStream, -) -> crate::Result<()> -where - A: Clone - + FnFut<(CA, SA, Request), Result = Result, E>> - + Send - + 'static, - A::Future: Send, - CA: Clone + Send + 'static, - E: From + Send + 'static, - M: Clone - + FnFut<(CA, SA, Headers, ServerStreamTokio), Result = Result<(), E>> - + Send - + 'static, - M::Future: Send, - SA: Send + 'static, - SF: Send + Future>, - SR: Send + StreamReader + Unpin + 'static, - SW: Send + StreamWriter + Unpin + 'static, - for<'handle> &'handle A: Send, - for<'handle> &'handle M: Send, -{ - let (ca, http2_buffer, http2_params) = conn_cb()?; - let net_tuple = net_cb(acceptor, tcp_stream).await?; - let accept_tuple = Http2Tokio::accept(http2_buffer, http2_params, net_tuple).await?; - let (frame_reader, mut http2) = accept_tuple; - let _jh = tokio::spawn(frame_reader); - loop { - let (ra, rrb) = stream_cb()?; - let (mut http2_stream, _headers_opt) = match http2 - .stream(rrb, |_headers, _method, _protocol| { - #[cfg(feature = "web-socket")] - { - let is_ws = crate::http2::is_web_socket_handshake(_headers, _method, _protocol); - is_ws.then(|| core::mem::take(_headers)) - } - }) - .await? - { - Either::Left(_) => return Ok(()), - Either::Right(elem) => elem, - }; - let local_auto_cb = auto_cb.clone(); - let local_ca = ca.clone(); - let local_err_cb = err_cb.clone(); - let _local_manual_cb = manual_cb.clone(); - let _stream_jh = tokio::spawn(async move { - let fun = async { - #[cfg(feature = "web-socket")] - if let Some(headers) = _headers_opt { - _local_manual_cb.call((local_ca, ra, headers, http2_stream.clone())).await?; - return Ok(()); - } - let (hrs, local_rrb) = http2_stream.recv_req().await?; - if hrs.is_closed() { - return Ok(()); - } - let req = local_rrb.into_http2_request(http2_stream.method()); - let res = local_auto_cb.call((local_ca, ra, req)).await?; - if http2_stream.send_res(res).await?.is_closed() { - return Ok(()); - } - Ok::<_, E>(()) - }; - let rslt = fun.await; - let _rslt = http2_stream.common().clear(true).await; - if let Err(err) = rslt { - http2_stream.common().send_go_away(Http2ErrorCode::InternalError).await; - local_err_cb(err); - } - }); - } -} diff --git a/wtx/src/http/optioned_server/tokio_web_socket.rs b/wtx/src/http/optioned_server/tokio_web_socket.rs index feea489d..f6cf83d2 100644 --- a/wtx/src/http/optioned_server/tokio_web_socket.rs +++ b/wtx/src/http/optioned_server/tokio_web_socket.rs @@ -1,29 +1,28 @@ use crate::{ http::OptionedServer, misc::{FnFut, Stream, Xorshift64, _number_or_available_parallelism, simple_seed}, - pool::{SimplePoolGetElem, SimplePoolResource, SimplePoolTokio, WebSocketRM}, + pool::{SimplePoolTokio, WebSocketRM}, web_socket::{Compression, WebSocketBuffer, WebSocketServer}, }; use core::{fmt::Debug, future::Future}; use std::sync::OnceLock; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::MutexGuard, -}; +use tokio::net::{TcpListener, TcpStream}; + +static POOL: OnceLock> = OnceLock::new(); impl OptionedServer { /// Optioned WebSocket server using tokio. #[inline] - pub async fn tokio_web_socket( + pub async fn tokio_web_socket( addr: &str, buffers_len_opt: Option, compression_cb: impl Clone + Fn() -> C + Send + 'static, err_cb: impl Clone + Fn(E) + Send + 'static, - handle_cb: F, - (acceptor_cb, local_acceptor_cb, stream_cb): ( + handle_cb: H, + (acceptor_cb, conn_acceptor_cb, net_cb): ( impl FnOnce() -> crate::Result + Send + 'static, impl Clone + Fn(&ACPT) -> ACPT + Send + 'static, - impl Clone + Fn(ACPT, TcpStream) -> SF + Send + 'static, + impl Clone + Fn(ACPT, TcpStream) -> N + Send + 'static, ), ) -> crate::Result<()> where @@ -31,40 +30,45 @@ impl OptionedServer { C: Compression + Send + 'static, C::NegotiatedCompression: Send, E: Debug + From + Send + 'static, - for<'wsb> F: Clone + for<'wsb> H: Clone + FnFut< (WebSocketServer,), Result = Result<(), E>, > + Send + 'static, + N: Send + Future>, S: Stream + Send, - SF: Send + Future>, - for<'wsb> , )>>::Future: Send, - for<'handle> &'handle F: Send, + for<'handle> &'handle H: Send, { let buffers_len = _number_or_available_parallelism(buffers_len_opt)?; let listener = TcpListener::bind(addr).await?; let acceptor = acceptor_cb()?; loop { - let (tcp_stream, _) = listener.accept().await?; - let local_acceptor = local_acceptor_cb(&acceptor); - - let mut conn_buffer_guard = conn_buffer(buffers_len).await?; - let local_compression_cb = compression_cb.clone(); - let local_conn_err = err_cb.clone(); - let local_handle_cb = handle_cb.clone(); - let local_stream_cb = stream_cb.clone(); + let conn_acceptor = conn_acceptor_cb(&acceptor); + let conn_compression_cb = compression_cb.clone(); + let conn_conn_err = err_cb.clone(); + let conn_handle_cb = handle_cb.clone(); + let conn_net_cb = net_cb.clone(); + let tcp_stream = listener.accept().await?.0; + let mut conn_buffer = POOL + .get_or_init(|| { + SimplePoolTokio::new(buffers_len, WebSocketRM::new(|| Ok(Default::default()))) + }) + .get() + .await?; let _jh = tokio::spawn(async move { - let wsb = &mut ***conn_buffer_guard; + let wsb = &mut ***conn_buffer; let fun = async move { - let stream = local_stream_cb(local_acceptor, tcp_stream).await?; - local_handle_cb + let net = conn_net_cb(conn_acceptor, tcp_stream).await?; + conn_handle_cb .call((WebSocketServer::accept( - local_compression_cb(), + conn_compression_cb(), + true, Xorshift64::from(simple_seed()), - stream, + net, wsb, |_| crate::Result::Ok(()), ) @@ -73,19 +77,9 @@ impl OptionedServer { Ok::<_, E>(()) }; if let Err(err) = fun.await { - local_conn_err(err); + conn_conn_err(err); } }); } } } - -async fn conn_buffer( - len: usize, -) -> crate::Result>>> { - static POOL: OnceLock> = OnceLock::new(); - POOL - .get_or_init(|| SimplePoolTokio::new(len, WebSocketRM::new(|| Ok(Default::default())))) - .get() - .await -} diff --git a/wtx/src/http/server_framework.rs b/wtx/src/http/server_framework.rs index b288d7c8..ea579aca 100644 --- a/wtx/src/http/server_framework.rs +++ b/wtx/src/http/server_framework.rs @@ -11,16 +11,16 @@ mod param_wrappers; mod path_management; mod path_params; mod redirect; -mod req_aux; mod res_finalizer; mod route_wrappers; mod router; mod server_framework_builder; mod state; +mod stream_aux; #[cfg(feature = "nightly")] mod tokio; -use crate::http::{conn_params::ConnParams, ReqResBuffer, Request, Response}; +use crate::http::{conn_params::ConnParams, AutoStream, ReqResBuffer, Response}; use alloc::sync::Arc; pub use conn_aux::ConnAux; pub use cors_middleware::CorsMiddleware; @@ -30,42 +30,40 @@ pub use param_wrappers::*; pub use path_management::PathManagement; pub use path_params::PathParams; pub use redirect::Redirect; -pub use req_aux::ReqAux; pub use res_finalizer::ResFinalizer; pub use route_wrappers::{get, json, post, Get, Json, Post}; pub use router::Router; pub use server_framework_builder::ServerFrameworkBuilder; pub use state::{State, StateClean, StateGeneric}; +pub use stream_aux::StreamAux; /// Server #[derive(Debug)] -pub struct ServerFramework { +pub struct ServerFramework { _ca_cb: CAC, _cp: ConnParams, - _ra_cb: RAC, - _router: Arc>, + _sa_cb: SAC, + _router: Arc>, } -impl ServerFramework +impl ServerFramework where E: From, - P: PathManagement, - RA: ReqAux, - REQM: ReqMiddleware, - RESM: ResMiddleware, + P: PathManagement, + REQM: ReqMiddleware, + RESM: ResMiddleware, + SA: StreamAux, { async fn _auto( - mut ca: CA, - (ra_cb, router): (impl Fn() -> RA::Init, Arc>), - mut req: Request, + mut _as: AutoStream SA::Init, Arc>)>, ) -> Result, E> { - let mut ra = RA::req_aux(ra_cb(), &mut req)?; + let mut sa = SA::req_aux(_as.sa.0(), &mut _as.req)?; #[cfg(feature = "matchit")] - let num = router.router.at(req.rrd.uri.path()).map_err(From::from)?.value; + let num = _as.sa.1.router.at(_as.req.rrd.uri.path()).map_err(From::from)?.value; #[cfg(not(feature = "matchit"))] let num = &[]; - let status_code = router.manage_path(&mut ca, (0, num), &mut ra, &mut req).await?; - Ok(Response { rrd: req.rrd, status_code, version: req.version }) + let status_code = _as.sa.1.manage_path(&mut _as.ca, (0, num), &mut sa, &mut _as.req).await?; + Ok(Response { rrd: _as.req.rrd, status_code, version: _as.req.version }) } } diff --git a/wtx/src/http/server_framework/param_wrappers/serde_json.rs b/wtx/src/http/server_framework/param_wrappers/serde_json.rs index 87ef3073..7738e089 100644 --- a/wtx/src/http/server_framework/param_wrappers/serde_json.rs +++ b/wtx/src/http/server_framework/param_wrappers/serde_json.rs @@ -1,6 +1,6 @@ use crate::{ http::{ - server_framework::{Endpoint, ResFinalizer}, + server_framework::{Endpoint, ResFinalizer, StateGeneric}, Header, KnownHeaderName, Mime, ReqResBuffer, Request, StatusCode, }, misc::{serde_collect_seq_rslt, FnFut, FnFutWrapper, IterWrapper, LeaseMut}, @@ -35,6 +35,28 @@ where } } +impl Endpoint + for FnFutWrapper<(StateGeneric<'_, CA, RA, ReqResBuffer, CLEAN>, SerdeJson), F> +where + E: From, + F: for<'any> FnFut<(StateGeneric<'any, CA, RA, ReqResBuffer, CLEAN>, SerdeJson), Result = RES>, + RES: ResFinalizer, + T: DeserializeOwned, +{ + #[inline] + async fn call( + &self, + ca: &mut CA, + _: (u8, &[(&'static str, u8)]), + ra: &mut RA, + req: &mut Request, + ) -> Result { + let elem = serde_json::from_slice(&req.rrd.lease_mut().body).map_err(crate::Error::from)?; + req.rrd.lease_mut().clear(); + self.0.call((StateGeneric::new(ca, ra, req), SerdeJson(elem))).await.finalize_response(req) + } +} + impl ResFinalizer for SerdeJson where E: From, diff --git a/wtx/src/http/server_framework/router.rs b/wtx/src/http/server_framework/router.rs index db54e0c6..adc350a5 100644 --- a/wtx/src/http/server_framework/router.rs +++ b/wtx/src/http/server_framework/router.rs @@ -9,19 +9,19 @@ use core::marker::PhantomData; /// Redirects requests to specific asynchronous functions based on the set of inner URIs. #[derive(Debug)] -pub struct Router { +pub struct Router { pub(crate) paths: P, - pub(crate) phantom: PhantomData<(CA, E, RA)>, + pub(crate) phantom: PhantomData<(CA, E, SA)>, pub(crate) req_middlewares: REQM, pub(crate) res_middlewares: RESM, #[cfg(feature = "matchit")] pub(crate) router: matchit::Router>, } -impl Router +impl Router where E: From, - P: PathManagement, + P: PathManagement, { /// Creates a new instance with paths and middlewares. #[inline] @@ -54,10 +54,10 @@ where } } -impl Router +impl Router where E: From, - P: PathManagement, + P: PathManagement, { /// Creates a new instance of empty middlewares. #[inline] @@ -75,12 +75,12 @@ where } } -impl PathManagement for Router +impl PathManagement for Router where E: From, - P: PathManagement, - REQM: ReqMiddleware, - RESM: ResMiddleware, + P: PathManagement, + REQM: ReqMiddleware, + RESM: ResMiddleware, { const IS_ROUTER: bool = true; @@ -89,7 +89,7 @@ where &self, ca: &mut CA, path_defs: (u8, &[(&'static str, u8)]), - ra: &mut RA, + ra: &mut SA, req: &mut Request, ) -> Result { self.req_middlewares.apply_req_middleware(ca, ra, req).await?; diff --git a/wtx/src/http/server_framework/server_framework_builder.rs b/wtx/src/http/server_framework/server_framework_builder.rs index 877a1b18..3a740ca1 100644 --- a/wtx/src/http/server_framework/server_framework_builder.rs +++ b/wtx/src/http/server_framework/server_framework_builder.rs @@ -1,49 +1,49 @@ use crate::http::{ conn_params::ConnParams, - server_framework::{ConnAux, ReqAux, Router, ServerFramework}, + server_framework::{ConnAux, Router, ServerFramework, StreamAux}, }; use alloc::sync::Arc; /// Server #[derive(Debug)] -pub struct ServerFrameworkBuilder { +pub struct ServerFrameworkBuilder { cp: ConnParams, - router: Arc>, + router: Arc>, } -impl ServerFrameworkBuilder +impl ServerFrameworkBuilder where CA: ConnAux, - RA: ReqAux, + SA: StreamAux, { /// New instance with default connection values. #[inline] - pub fn new(router: Router) -> Self { + pub fn new(router: Router) -> Self { Self { cp: ConnParams::default(), router: Arc::new(router) } } - /// Sets the initialization structures for both `CA` and `RA`. + /// Sets the initialization structures for both `CA` and `SA`. #[inline] - pub fn with_aux( + pub fn with_aux( self, ca_cb: CAC, - ra_cb: RAC, - ) -> ServerFramework + ra_cb: SAC, + ) -> ServerFramework where CAC: Fn() -> CA::Init, - RAC: Fn() -> RA::Init, + SAC: Fn() -> SA::Init, { - ServerFramework { _ca_cb: ca_cb, _cp: self.cp, _ra_cb: ra_cb, _router: self.router } + ServerFramework { _ca_cb: ca_cb, _cp: self.cp, _sa_cb: ra_cb, _router: self.router } } /// Fills the initialization structures for all auxiliaries with default values. #[inline] pub fn with_dflt_aux( self, - ) -> ServerFramework CA::Init, E, P, RA, fn() -> RA::Init, REQM, RESM> + ) -> ServerFramework CA::Init, E, P, REQM, RESM, SA, fn() -> SA::Init> where CA::Init: Default, - RA::Init: Default, + SA::Init: Default, { fn fun() -> T where @@ -51,19 +51,19 @@ where { T::default() } - ServerFramework { _ca_cb: fun, _cp: self.cp, _ra_cb: fun, _router: self.router } + ServerFramework { _ca_cb: fun, _cp: self.cp, _sa_cb: fun, _router: self.router } } } -impl ServerFrameworkBuilder<(), E, P, (), REQM, RESM> { +impl ServerFrameworkBuilder<(), E, P, REQM, RESM, ()> { /// Build without state #[inline] - pub fn without_aux(self) -> ServerFramework<(), fn() -> (), E, P, (), fn() -> (), REQM, RESM> { - ServerFramework { _ca_cb: nothing, _cp: self.cp, _ra_cb: nothing, _router: self.router } + pub fn without_aux(self) -> ServerFramework<(), fn() -> (), E, P, REQM, RESM, (), fn() -> ()> { + ServerFramework { _ca_cb: nothing, _cp: self.cp, _sa_cb: nothing, _router: self.router } } } -impl ServerFrameworkBuilder +impl ServerFrameworkBuilder where CA: ConnAux, { @@ -72,28 +72,28 @@ where pub fn with_conn_aux( self, ca_cb: CAC, - ) -> ServerFramework (), REQM, RESM> + ) -> ServerFramework ()> where CAC: Fn() -> CA::Init, { - ServerFramework { _ca_cb: ca_cb, _cp: self.cp, _ra_cb: nothing, _router: self.router } + ServerFramework { _ca_cb: ca_cb, _cp: self.cp, _sa_cb: nothing, _router: self.router } } } -impl ServerFrameworkBuilder<(), E, P, RA, REQM, RESM> +impl ServerFrameworkBuilder<(), E, P, REQM, RESM, SA> where - RA: ReqAux, + SA: StreamAux, { - /// Sets the initializing strut for `RA` and sets the connection auxiliary to `()`. + /// Sets the initializing strut for `SA` and sets the connection auxiliary to `()`. #[inline] - pub fn with_req_aux( + pub fn with_req_aux( self, - ra_cb: RAC, - ) -> ServerFramework<(), fn() -> (), E, P, RA, RAC, REQM, RESM> + ra_cb: SAC, + ) -> ServerFramework<(), fn() -> (), E, P, REQM, RESM, SA, SAC> where - RAC: Fn() -> RA::Init, + SAC: Fn() -> SA::Init, { - ServerFramework { _ca_cb: nothing, _cp: self.cp, _ra_cb: ra_cb, _router: self.router } + ServerFramework { _ca_cb: nothing, _cp: self.cp, _sa_cb: ra_cb, _router: self.router } } } diff --git a/wtx/src/http/server_framework/req_aux.rs b/wtx/src/http/server_framework/stream_aux.rs similarity index 75% rename from wtx/src/http/server_framework/req_aux.rs rename to wtx/src/http/server_framework/stream_aux.rs index fd1a6e7a..3dfa910d 100644 --- a/wtx/src/http/server_framework/req_aux.rs +++ b/wtx/src/http/server_framework/stream_aux.rs @@ -1,7 +1,7 @@ use crate::http::{ReqResBuffer, Request}; -/// Auxiliary structure for requests -pub trait ReqAux: Sized { +/// Auxiliary structures for streams or requests. +pub trait StreamAux: Sized { /// Initialization type Init; diff --git a/wtx/src/http/server_framework/tokio.rs b/wtx/src/http/server_framework/tokio.rs index 6df08bbe..1e8795aa 100644 --- a/wtx/src/http/server_framework/tokio.rs +++ b/wtx/src/http/server_framework/tokio.rs @@ -1,30 +1,30 @@ use crate::{ http::{ server_framework::{ - ConnAux, PathManagement, ReqAux, ReqMiddleware, ResMiddleware, Router, ServerFramework, + ConnAux, PathManagement, ReqMiddleware, ResMiddleware, Router, ServerFramework, StreamAux, }, - Headers, OptionedServer, ReqResBuffer, + ManualServerStreamTokio, OptionedServer, ReqResBuffer, StreamMode, }, - http2::{Http2Buffer, ServerStreamTokio}, + http2::Http2Buffer, misc::Rng, }; use std::sync::Arc; -use tokio::net::{tcp::OwnedWriteHalf, TcpStream}; +use tokio::net::tcp::OwnedWriteHalf; -impl ServerFramework +impl ServerFramework where CA: Clone + ConnAux + Send + 'static, CAC: Clone + Fn() -> CA::Init + Send + 'static, E: From + Send + 'static, - P: PathManagement + Send + 'static, - RA: ReqAux + Send + 'static, - RAC: Clone + Fn() -> RA::Init + Send + 'static, - REQM: ReqMiddleware + Send + 'static, - RESM: ResMiddleware + Send + 'static, - Arc>: Send, - Router: Send, - for<'any> &'any Arc>: Send, - for<'any> &'any Router: Send, + P: PathManagement + Send + 'static, + REQM: ReqMiddleware + Send + 'static, + RESM: ResMiddleware + Send + 'static, + SA: StreamAux + Send + 'static, + SAC: Clone + Fn() -> SA::Init + Send + 'static, + Arc>: Send, + Router: Send, + for<'any> &'any Arc>: Send, + for<'any> &'any Router: Send, { /// Starts listening to incoming requests based on the given `host`. #[inline] @@ -37,14 +37,15 @@ where where RNG: Clone + Rng + Send + 'static, { - let Self { _ca_cb: ca_cb, _cp: cp, _ra_cb: ra_cb, _router: router } = self; + let Self { _ca_cb: ca_cb, _cp: cp, _sa_cb: sa_cb, _router: router } = self; OptionedServer::tokio_high_http2( host, Self::_auto, move || Ok((CA::conn_aux(ca_cb())?, Http2Buffer::new(rng.clone()), cp._to_hp())), err_cb, Self::manual_tokio, - move || Ok(((ra_cb.clone(), Arc::clone(&router)), ReqResBuffer::empty())), + move || Ok(((sa_cb.clone(), Arc::clone(&router)), ReqResBuffer::empty())), + |_, _, _| Ok(StreamMode::Auto), (|| Ok(()), |_| {}, |_, stream| async move { Ok(stream.into_split()) }), ) .await @@ -63,7 +64,7 @@ where where RNG: Clone + Rng + Send + 'static, { - let Self { _ca_cb: ca_cb, _cp: cp, _ra_cb: ra_cb, _router: router } = self; + let Self { _ca_cb: ca_cb, _cp: cp, _sa_cb: ra_cb, _router: router } = self; OptionedServer::tokio_high_http2( host, Self::_auto, @@ -71,6 +72,7 @@ where err_cb, Self::manual_tokio_rustls, move || Ok(((ra_cb.clone(), Arc::clone(&router)), ReqResBuffer::empty())), + |_, _, _| Ok(StreamMode::Auto), ( || { crate::misc::TokioRustlsAcceptor::without_client_auth() @@ -86,10 +88,12 @@ where #[inline] async fn manual_tokio( - _: CA, - _: (impl Fn() -> RA::Init, Arc>), - _: Headers, - _: ServerStreamTokio, + _: ManualServerStreamTokio< + CA, + (impl Fn() -> SA::Init, Arc>), + Http2Buffer, + OwnedWriteHalf, + >, ) -> Result<(), E> { Err(E::from(crate::Error::ClosedConnection)) } @@ -97,13 +101,11 @@ where #[cfg(feature = "tokio-rustls")] #[inline] async fn manual_tokio_rustls( - _: CA, - _: (impl Fn() -> RA::Init, Arc>), - _: Headers, - _: ServerStreamTokio< + _: ManualServerStreamTokio< + CA, + (impl Fn() -> SA::Init, Arc>), Http2Buffer, - tokio::io::WriteHalf>, - false, + tokio::io::WriteHalf>, >, ) -> Result<(), E> { Err(E::from(crate::Error::ClosedConnection)) diff --git a/wtx/src/http/stream_mode.rs b/wtx/src/http/stream_mode.rs new file mode 100644 index 00000000..ed79a3d5 --- /dev/null +++ b/wtx/src/http/stream_mode.rs @@ -0,0 +1,45 @@ +use crate::http::{Headers, ReqResBuffer, Request}; +use core::net::IpAddr; + +#[cfg(all(feature = "http2", feature = "tokio"))] +/// Manual server stream backed by tokio structures. +pub type ManualServerStreamTokio = + ManualStream>>; + +/// Tells how an HTTP stream should be handled. +#[derive(Debug)] +pub enum StreamMode { + /// Automatic + Auto, + /// Manual + Manual, +} + +/// HTTP stream that is automatically managed by the system. In other words, all frames +/// are gathered until an end-of-stream flag is received and only then a response is sent. +#[derive(Debug)] +pub struct AutoStream { + /// Connection auxiliary + pub ca: CA, + /// Remote peer address + pub peer: IpAddr, + /// Request + pub req: Request, + /// Stream auxiliary + pub sa: SA, +} + +/// HTTP stream that is manually managed by the user. For example, WebSockets over streams. +#[derive(Debug)] +pub struct ManualStream { + /// Connection auxiliary + pub ca: CA, + /// Headers + pub headers: Headers, + /// Remote peer address + pub peer: IpAddr, + /// Stream auxiliary + pub sa: SA, + /// Stream + pub stream: S, +} diff --git a/wtx/src/http2.rs b/wtx/src/http2.rs index b8094a5a..5b290f09 100644 --- a/wtx/src/http2.rs +++ b/wtx/src/http2.rs @@ -106,9 +106,6 @@ pub type Http2Tokio = pub type Http2DataTokio = Arc>>; /// [`ServerStream`] instance using the mutex from tokio; -#[cfg(feature = "tokio")] -pub type ServerStreamTokio = - ServerStream>; pub(crate) type Scrp = HashMap; pub(crate) type Sorp = HashMap; diff --git a/wtx/src/http2/web_socket_over_stream.rs b/wtx/src/http2/web_socket_over_stream.rs index d0cd7a73..06773a99 100644 --- a/wtx/src/http2/web_socket_over_stream.rs +++ b/wtx/src/http2/web_socket_over_stream.rs @@ -1,8 +1,8 @@ //! Tools to manage WebSocket connections in HTTP/2 streams use crate::{ - http::{Headers, Method, Protocol, StatusCode}, - http2::{Http2Buffer, Http2Data, Http2RecvStatus, SendDataMode, ServerStream}, + http::{Headers, KnownHeaderName, Method, Protocol, StatusCode}, + http2::{Http2Buffer, Http2Data, Http2ErrorCode, Http2RecvStatus, SendDataMode, ServerStream}, misc::{ ConnectionState, LeaseMut, Lock, RefCounter, SingleTypeStorage, StreamWriter, Vector, Xorshift64, @@ -14,7 +14,8 @@ use crate::{ manage_text_of_first_continuation_frame, manage_text_of_recurrent_continuation_frames, unmask_nb, }, - Frame, FrameMut, ReadFrameInfo, + web_socket_writer::manage_normal_frame, + Frame, FrameMut, OpCode, ReadFrameInfo, }, }; @@ -25,15 +26,17 @@ pub fn is_web_socket_handshake( method: Method, protocol: Option, ) -> bool { + let bytes = KnownHeaderName::SecWebsocketVersion.into(); method == Method::Connect && protocol == Some(Protocol::WebSocket) - && headers.get_by_name(b"sec-websocket-version").map(|el| el.value) == Some(b"13") + && headers.get_by_name(bytes).map(|el| el.value) == Some(b"13") } /// WebSocket tunneling #[derive(Debug)] pub struct WebSocketOverStream { connection_state: ConnectionState, + no_masking: bool, rng: Xorshift64, stream: S, } @@ -48,12 +51,25 @@ where { /// Creates a new instance sending an `Ok` status codes that confirms the WebSocket handshake. #[inline] - pub async fn new(headers: &Headers, rng: Xorshift64, mut stream: S) -> crate::Result { + pub async fn new( + headers: &Headers, + no_masking: bool, + rng: Xorshift64, + mut stream: S, + ) -> crate::Result { let hss = stream.lease_mut().common().send_headers(headers, false, StatusCode::Ok).await?; if hss.is_closed() { return Err(crate::Error::ClosedConnection); } - Ok(Self { connection_state: ConnectionState::Open, rng, stream }) + Ok(Self { connection_state: ConnectionState::Open, no_masking, rng, stream }) + } + + /// Closes the stream as well as the WebSocket connection. + #[inline] + pub async fn close(&mut self) -> crate::Result<()> { + self.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await?; + self.stream.lease_mut().common().send_reset(Http2ErrorCode::NoError).await; + Ok(()) } /// Reads a frame from the stream. @@ -67,7 +83,7 @@ where ) -> crate::Result> { buffer.clear(); let first_rfi = loop { - let (rfi, is_eos) = recv_data(buffer, self.stream.lease_mut()).await?; + let (rfi, is_eos) = recv_data(buffer, self.no_masking, self.stream.lease_mut()).await?; if !rfi.fin { if is_eos { return Err(crate::Error::ClosedConnection); @@ -77,6 +93,7 @@ where if manage_auto_reply::<_, _, false>( self.stream.lease_mut(), &mut self.connection_state, + self.no_masking, rfi.op_code, buffer, &mut self.rng, @@ -89,7 +106,7 @@ where } }; loop { - let (rfi, is_eos) = recv_data(buffer, self.stream.lease_mut()).await?; + let (rfi, is_eos) = recv_data(buffer, self.no_masking, self.stream.lease_mut()).await?; if !rfi.fin && is_eos { return Err(crate::Error::ClosedConnection); } @@ -103,6 +120,7 @@ where if !manage_auto_reply::<_, _, false>( self.stream.lease_mut(), &mut self.connection_state, + self.no_masking, rfi.op_code, payload, &mut self.rng, @@ -132,6 +150,12 @@ where where P: LeaseMut<[u8]>, { + manage_normal_frame::<_, _, false>( + &mut self.connection_state, + frame, + self.no_masking, + &mut self.rng, + ); let (header, payload) = frame.header_and_payload(); let hss = self .stream @@ -149,6 +173,7 @@ where #[inline] async fn recv_data<'buffer, HB, HD, SW>( buffer: &'buffer mut Vector, + no_masking: bool, stream: &mut ServerStream, ) -> crate::Result<(ReadFrameInfo, bool)> where @@ -168,10 +193,10 @@ where Http2RecvStatus::Ongoing(data) => (data, false), }; let mut slice = data.as_slice(); - let rfi = ReadFrameInfo::from_bytes::<_, false>(&mut slice, usize::MAX, &())?; + let rfi = ReadFrameInfo::from_bytes::<_, false>(&mut slice, usize::MAX, &(), no_masking)?; let before = buffer.len(); buffer.extend_from_copyable_slice(slice)?; - unmask_nb::(buffer.get_mut(before..).unwrap_or_default(), &rfi)?; + unmask_nb::(buffer.get_mut(before..).unwrap_or_default(), no_masking, &rfi)?; Ok((rfi, is_eos)) } diff --git a/wtx/src/misc.rs b/wtx/src/misc.rs index bc6b5e9c..7db9a1c6 100644 --- a/wtx/src/misc.rs +++ b/wtx/src/misc.rs @@ -82,6 +82,33 @@ pub(crate) use { uri::_EMPTY_URI_STRING, }; +/// Hashes a password using the `argon2` algorithm. +#[cfg(feature = "argon2")] +#[inline] +pub fn argon2_pwd(pwd: &[u8], salt: &[u8]) -> crate::Result<[u8; 32]> { + use argon2::{Algorithm, Argon2, Params, Version}; + const OUT_LEN: usize = 32; + const PARAMS: Params = { + let Ok(elem) = Params::new( + Params::DEFAULT_M_COST, + Params::DEFAULT_T_COST, + Params::DEFAULT_P_COST, + Some(OUT_LEN), + ) else { + panic!(); + }; + elem + }; + let mut out = [0; OUT_LEN]; + Argon2::new(Algorithm::Argon2id, Version::V0x13, PARAMS).hash_password_into_with_memory( + pwd, + salt, + &mut out, + &mut [argon2::Block::new(); PARAMS.block_count()], + )?; + Ok(out) +} + /// Useful when a request returns an optional field but the actual usage is within a /// [`core::result::Result`] context. #[inline] diff --git a/wtx/src/misc/tuple_impls.rs b/wtx/src/misc/tuple_impls.rs index 11c4c10e..63f6808b 100644 --- a/wtx/src/misc/tuple_impls.rs +++ b/wtx/src/misc/tuple_impls.rs @@ -47,7 +47,7 @@ macro_rules! impl_0_16 { use crate::{ http::{ HttpError, Request, ReqResBuffer, Response, StatusCode, - server_framework::{ConnAux, ReqAux, ReqMiddleware, ResMiddleware, PathManagement, PathParams} + server_framework::{ConnAux, StreamAux, ReqMiddleware, ResMiddleware, PathManagement, PathParams} }, misc::{ArrayVector, Vector} }; @@ -65,9 +65,9 @@ macro_rules! impl_0_16 { } } - impl<$($T,)*> ReqAux for ($($T,)*) + impl<$($T,)*> StreamAux for ($($T,)*) where - $($T: ReqAux,)* + $($T: StreamAux,)* { type Init = ($($T::Init,)*); diff --git a/wtx/src/pool/resource_manager.rs b/wtx/src/pool/resource_manager.rs index db078ed4..615744c3 100644 --- a/wtx/src/pool/resource_manager.rs +++ b/wtx/src/pool/resource_manager.rs @@ -180,9 +180,12 @@ pub(crate) mod database { let mut buffer = ExecutorBuffer::_empty(); mem::swap(&mut buffer, &mut resource.eb); *resource = executor!(&self.uri, |config, uri| { - let stream = - TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(Into::into)?; - Executor::connect(&config, buffer, &mut &self.rng, stream) + Executor::connect( + &config, + buffer, + &mut &self.rng, + TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(Into::into)?, + ) })?; Ok(()) } diff --git a/wtx/src/pool/simple_pool.rs b/wtx/src/pool/simple_pool.rs index d74a1cf0..d8f1a568 100644 --- a/wtx/src/pool/simple_pool.rs +++ b/wtx/src/pool/simple_pool.rs @@ -90,7 +90,7 @@ where } #[cfg(feature = "http-server-framework")] -impl crate::http::server_framework::ReqAux for SimplePool { +impl crate::http::server_framework::StreamAux for SimplePool { type Init = Self; #[inline] diff --git a/wtx/src/web_socket.rs b/wtx/src/web_socket.rs index 7fc4cd4d..69c896e7 100644 --- a/wtx/src/web_socket.rs +++ b/wtx/src/web_socket.rs @@ -15,7 +15,7 @@ mod web_socket_buffer; mod web_socket_error; mod web_socket_parts; pub(crate) mod web_socket_reader; -mod web_socket_writer; +pub(crate) mod web_socket_writer; use crate::{ misc::{ConnectionState, LeaseMut, Stream, Xorshift64}, @@ -36,8 +36,15 @@ pub use web_socket_buffer::WebSocketBuffer; pub use web_socket_error::WebSocketError; pub use web_socket_parts::{WebSocketCommonPart, WebSocketReaderPart, WebSocketWriterPart}; +const FIN_MASK: u8 = 0b1000_0000; +const MASK_MASK: u8 = 0b1000_0000; const MAX_CONTROL_PAYLOAD_LEN: usize = 125; const MAX_HEADER_LEN_USIZE: usize = 14; +const OP_CODE_MASK: u8 = 0b0000_1111; +const PAYLOAD_MASK: u8 = 0b0111_1111; +const RSV1_MASK: u8 = 0b0100_0000; +const RSV2_MASK: u8 = 0b0010_0000; +const RSV3_MASK: u8 = 0b0001_0000; /// Always masks the payload before sending. pub type WebSocketClient = WebSocket; @@ -61,6 +68,7 @@ pub struct WebSocket { curr_payload: PayloadTy, max_payload_len: usize, nc: NC, + no_masking: bool, rng: Xorshift64, stream: S, wsb: WSB, @@ -83,7 +91,13 @@ where { /// Creates a new instance from a stream that supposedly has already completed the handshake. #[inline] - pub fn new(nc: NC, rng: Xorshift64, stream: S, mut wsb: WSB) -> crate::Result { + pub fn new( + nc: NC, + no_masking: bool, + rng: Xorshift64, + stream: S, + mut wsb: WSB, + ) -> crate::Result { wsb.lease_mut().network_buffer._clear_if_following_is_empty(); wsb.lease_mut().network_buffer._reserve(MAX_HEADER_LEN_USIZE)?; Ok(Self { @@ -91,6 +105,7 @@ where curr_payload: PayloadTy::None, max_payload_len: _MAX_PAYLOAD_LEN, nc, + no_masking, rng, stream, wsb, @@ -118,7 +133,16 @@ where WebSocketReaderPart<'_, NC, S, IS_CLIENT>, WebSocketWriterPart<'_, NC, S, IS_CLIENT>, ) { - let WebSocket { connection_state, curr_payload, nc, rng, stream, wsb, max_payload_len } = self; + let WebSocket { + connection_state, + curr_payload, + nc, + no_masking, + rng, + stream, + wsb, + max_payload_len, + } = self; let WebSocketBuffer { writer_buffer, network_buffer, @@ -130,11 +154,12 @@ where WebSocketReaderPart { max_payload_len: *max_payload_len, network_buffer, + no_masking: *no_masking, phantom: PhantomData, reader_buffer_first, reader_buffer_second, }, - WebSocketWriterPart { phantom: PhantomData, writer_buffer }, + WebSocketWriterPart { no_masking: *no_masking, phantom: PhantomData, writer_buffer }, ) } @@ -144,7 +169,16 @@ where /// until all fragments are received. #[inline] pub async fn read_frame(&mut self) -> crate::Result> { - let WebSocket { connection_state, curr_payload, max_payload_len, nc, rng, stream, wsb } = self; + let WebSocket { + connection_state, + curr_payload, + max_payload_len, + nc, + no_masking, + rng, + stream, + wsb, + } = self; let WebSocketBuffer { network_buffer, reader_buffer_first, @@ -156,6 +190,7 @@ where *max_payload_len, nc, network_buffer, + *no_masking, reader_buffer_first, reader_buffer_second, rng, @@ -172,9 +207,18 @@ where where P: LeaseMut<[u8]>, { - let WebSocket { connection_state, nc, rng, stream, wsb, .. } = self; + let WebSocket { connection_state, nc, no_masking, rng, stream, wsb, .. } = self; let WebSocketBuffer { writer_buffer, .. } = wsb.lease_mut(); - web_socket_writer::write_frame(connection_state, frame, nc, rng, stream, writer_buffer).await?; + web_socket_writer::write_frame( + connection_state, + frame, + *no_masking, + nc, + rng, + stream, + writer_buffer, + ) + .await?; Ok(()) } } diff --git a/wtx/src/web_socket/compression/flate2.rs b/wtx/src/web_socket/compression/flate2.rs index 896c31a4..7abd51b9 100644 --- a/wtx/src/web_socket/compression/flate2.rs +++ b/wtx/src/web_socket/compression/flate2.rs @@ -1,10 +1,7 @@ use crate::{ - http::GenericHeader, + http::{GenericHeader, KnownHeaderName}, misc::{bytes_split1, FilledBufferWriter, FromRadix10, VectorError}, - web_socket::{ - compression::NegotiatedCompression, misc::_trim_bytes, Compression, DeflateConfig, - WebSocketError, - }, + web_socket::{compression::NegotiatedCompression, Compression, DeflateConfig, WebSocketError}, }; use flate2::{Compress, Decompress, FlushCompress, FlushDecompress}; @@ -37,7 +34,8 @@ impl Compression for Flate2 { let mut has_extension = false; - for swe in headers.filter(|el| el.name().eq_ignore_ascii_case(b"sec-websocket-extensions")) { + let swe_bytes = KnownHeaderName::SecWebsocketExtensions.into(); + for swe in headers.filter(|el| el.name().eq_ignore_ascii_case(swe_bytes)) { for permessage_deflate_option in bytes_split1(swe.value(), b',') { dc = DeflateConfig { client_max_window_bits: self.dc.client_max_window_bits, @@ -47,7 +45,7 @@ impl Compression for Flate2 { let mut client_max_window_bits_flag = false; let mut permessage_deflate_flag = false; let mut server_max_window_bits_flag = false; - for param in bytes_split1(permessage_deflate_option, b';').map(_trim_bytes) { + for param in bytes_split1(permessage_deflate_option, b';').map(<[u8]>::trim_ascii) { if param == b"client_no_context_takeover" || param == b"server_no_context_takeover" { } else if param == b"permessage-deflate" { _manage_header_uniqueness(&mut permessage_deflate_flag, || Ok(()))? diff --git a/wtx/src/web_socket/frame.rs b/wtx/src/web_socket/frame.rs index 192f0413..70c8020a 100644 --- a/wtx/src/web_socket/frame.rs +++ b/wtx/src/web_socket/frame.rs @@ -1,7 +1,8 @@ use crate::{ misc::{Lease, Vector}, web_socket::{ - misc::fill_header_from_params, OpCode, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE, + misc::{fill_header_from_params, has_masked_frame}, + OpCode, MASK_MASK, MAX_CONTROL_PAYLOAD_LEN, MAX_HEADER_LEN_USIZE, }, }; use core::str; @@ -76,16 +77,26 @@ impl Frame { (header, &mut self.payload) } - #[inline] - pub(crate) fn header_mut(&mut self) -> &mut [u8] { - self.header_and_payload_mut().0 - } - #[inline] pub(crate) fn header_first_two_mut(&mut self) -> [&mut u8; 2] { let [a, b, ..] = &mut self.header; [a, b] } + + #[inline] + pub(crate) fn set_mask(&mut self, mask: [u8; 4]) { + if has_masked_frame(self.header[1]) { + return; + } + self.header_len = self.header_len.wrapping_add(4); + if let Some([_, a, .., b, c, d, e]) = self.header.get_mut(..self.header_len.into()) { + *a |= MASK_MASK; + *b = mask[0]; + *c = mask[1]; + *d = mask[2]; + *e = mask[3]; + } + } } impl Frame diff --git a/wtx/src/web_socket/handshake.rs b/wtx/src/web_socket/handshake.rs index 696344bd..51f7a0b6 100644 --- a/wtx/src/web_socket/handshake.rs +++ b/wtx/src/web_socket/handshake.rs @@ -1,24 +1,43 @@ #[cfg(all(feature = "_async-tests", test))] mod tests; +macro_rules! check_headers { + ($headers:expr, $($header:expr),*) => {{ + let rslt = check_headers( + [ + (KnownHeaderName::Connection, Some(b"upgrade")), + (KnownHeaderName::Upgrade, Some(b"websocket")), + $($header,)* + ], + $headers + )?; + drop(check_header_value(rslt[0])); + drop(check_header_value(rslt[1])); + rslt + }}; +} + use crate::{ http::{GenericHeader as _, GenericRequest as _, HttpError, KnownHeaderName, Method}, misc::{ bytes_split1, FilledBufferWriter, LeaseMut, Rng, Stream, UriRef, VectorError, Xorshift64, }, web_socket::{ - compression::NegotiatedCompression, misc::_trim_bytes, Compression, WebSocketBuffer, - WebSocketClient, WebSocketError, WebSocketServer, + compression::NegotiatedCompression, Compression, WebSocket, WebSocketBuffer, WebSocketError, }, }; use base64::{engine::general_purpose::STANDARD, Engine}; use httparse::{Header, Request, Response, Status, EMPTY_HEADER}; use sha1::{Digest, Sha1}; -const MAX_READ_LEN: usize = 2 * 1024; const MAX_READ_HEADER_LEN: usize = 64; +const MAX_READ_LEN: usize = 2 * 1024; +const NO_MASKING: &str = "no-masking"; +const UPGRADE: &str = "Upgrade"; +const VERSION: &str = "13"; +const WEBSOCKET: &str = "websocket"; -impl WebSocketServer +impl WebSocket where NC: NegotiatedCompression, S: Stream, @@ -28,6 +47,7 @@ where #[inline] pub async fn accept( compression: C, + mut no_masking: bool, rng: Xorshift64, mut stream: S, mut wsb: WSB, @@ -53,40 +73,37 @@ where match req.parse(nb._following()).map_err(From::from)? { Status::Complete(_) => { req_cb(&req)?; - if !_trim_bytes(req.method()).eq_ignore_ascii_case(b"get") { + if !req.method().trim_ascii().eq_ignore_ascii_case(b"get") { return Err( crate::Error::from(HttpError::UnexpectedHttpMethod { expected: Method::Get }).into(), ); } - verify_common_header(req.headers)?; - if !has_header_key_and_value(req.headers, b"sec-websocket-version", b"13") { - let expected = KnownHeaderName::SecWebsocketVersion; - return Err(crate::Error::from(HttpError::MissingHeader(expected)).into()); - }; - let Some(key) = req.headers.iter().find_map(|el| { - (el.name().eq_ignore_ascii_case(b"sec-websocket-key")).then_some(el.value()) - }) else { - return Err( - crate::Error::from(HttpError::MissingHeader(KnownHeaderName::SecWebsocketKey)).into(), - ); - }; - let compression = compression.negotiate(req.headers.iter())?; let mut key_buffer = [0; 30]; + let [_, _, c, d, e] = check_headers!( + req.headers, + (KnownHeaderName::SecWebsocketExtensions, None), + (KnownHeaderName::SecWebsocketKey, None), + (KnownHeaderName::SecWebsocketVersion, Some(VERSION.as_bytes())) + ); + no_masking &= check_header_value(c).map_or(false, has_no_masking); + let key = check_header_value(d)?; + let _ = check_header_value(e)?; + let nc = compression.negotiate(req.headers.iter())?; let swa = derived_key(&mut key_buffer, key); let mut headers_buffer = [EMPTY_HEADER; 3]; - headers_buffer[0] = Header { name: "Connection", value: b"Upgrade" }; + headers_buffer[0] = Header { name: "Connection", value: UPGRADE.as_bytes() }; headers_buffer[1] = Header { name: "Sec-WebSocket-Accept", value: swa }; - headers_buffer[2] = Header { name: "Upgrade", value: b"websocket" }; + headers_buffer[2] = Header { name: "Upgrade", value: WEBSOCKET.as_bytes() }; let mut res = Response::new(&mut headers_buffer); res.code = Some(101); res.version = Some(req.version().into()); { let mut fbw = nb.into(); - build_res(&compression, &mut fbw, res.headers).map_err(From::from)?; + build_res(&mut fbw, res.headers, &nc, no_masking).map_err(From::from)?; stream.write_all(fbw._curr_bytes()).await?; } nb._clear(); - return WebSocketServer::new(compression, rng, stream, wsb).map_err(From::from); + return WebSocket::new(nc, no_masking, rng, stream, wsb).map_err(From::from); } Status::Partial => {} } @@ -94,7 +111,7 @@ where } } -impl WebSocketClient +impl WebSocket where NC: NegotiatedCompression, S: Stream, @@ -105,12 +122,13 @@ where pub async fn connect<'headers, C, E>( compression: C, headers: impl IntoIterator, + mut no_masking: bool, mut rng: Xorshift64, mut stream: S, uri: &UriRef<'_>, mut wsb: WSB, res_cb: impl FnOnce(&Response<'_, '_>) -> Result<(), E>, - ) -> Result, E> + ) -> Result, E> where C: Compression, E: From, @@ -122,14 +140,14 @@ where nb._reserve(MAX_READ_LEN).map_err(From::from)?; { let fbw = &mut nb.into(); - let key = - build_req(&compression, fbw, headers, key_buffer, &mut rng, uri).map_err(From::from)?; + let key_rslt = build_req(&compression, fbw, headers, key_buffer, no_masking, &mut rng, uri); + let key = key_rslt.map_err(From::from)?; stream.write_all(fbw._curr_bytes()).await?; key } }; let mut read = 0; - let (compression, len) = loop { + let (nc, len) = loop { let nb = &mut wsb.lease_mut().network_buffer; let local_read = stream.read(nb._buffer_mut().get_mut(read..).unwrap_or_default()).await?; if local_read == 0 { @@ -146,23 +164,21 @@ where return Err(crate::Error::from(WebSocketError::MissingSwitchingProtocols).into()); } res_cb(&res)?; - verify_common_header(res.headers)?; - if !has_header_key_and_value( + let [_, _, c, d] = check_headers!( res.headers, - b"sec-websocket-accept", - derived_key(&mut [0; 30], key), - ) { - return Err( - crate::Error::from(HttpError::MissingHeader(KnownHeaderName::SecWebsocketKey)).into(), - ); - } + (KnownHeaderName::SecWebsocketAccept, Some(derived_key(&mut [0; 30], key))), + (KnownHeaderName::SecWebsocketExtensions, None) + ); + drop(check_header_value(c)); + no_masking &= check_header_value(d).map_or(false, has_no_masking); break (compression.negotiate(res.headers.iter())?, len); }; wsb.lease_mut().network_buffer._set_indices(0, len, read.wrapping_sub(len))?; - Ok(WebSocketClient::new(compression, rng, stream, wsb)?) + Ok(WebSocket::new(nc, no_masking, rng, stream, wsb)?) } } +#[inline] fn base64_from_array<'output, const I: usize, const O: usize>( input: &[u8; I], output: &'output mut [u8; O], @@ -176,11 +192,13 @@ fn base64_from_array<'output, const I: usize, const O: usize>( } /// Client request +#[inline] fn build_req<'headers, 'kb, C>( compression: &C, fbw: &mut FilledBufferWriter<'_>, headers: impl IntoIterator, key_buffer: &'kb mut [u8; 26], + no_masking: bool, rng: &mut impl Rng, uri: &UriRef<'_>, ) -> Result<&'kb [u8], VectorError> @@ -203,6 +221,9 @@ where } _ => fbw._extend_from_slices_group_rn(&[b"Host: ", uri.host().as_bytes()])?, } + if no_masking { + fbw._extend_from_slice_rn(b"Sec-WebSocket-Extensions: no-masking")?; + } fbw._extend_from_slices_group_rn(&[b"Sec-WebSocket-Key: ", key])?; fbw._extend_from_slice_rn(b"Sec-WebSocket-Version: 13")?; fbw._extend_from_slice_rn(b"Upgrade: websocket")?; @@ -212,23 +233,72 @@ where } /// Server response -fn build_res( - compression: &C, +#[inline] +fn build_res( fbw: &mut FilledBufferWriter<'_>, headers: &[Header<'_>], + nc: &NC, + no_masking: bool, ) -> Result<(), VectorError> where - C: NegotiatedCompression, + NC: NegotiatedCompression, { fbw._extend_from_slice_rn(b"HTTP/1.1 101 Switching Protocols")?; for header in headers { fbw._extend_from_slices_group_rn(&[header.name(), b": ", header.value()])?; } - compression.write_res_headers(fbw)?; + if no_masking { + fbw._extend_from_slices_group_rn(&[ + KnownHeaderName::SecWebsocketExtensions.into(), + b": ", + NO_MASKING.as_bytes(), + ])?; + } + nc.write_res_headers(fbw)?; fbw._extend_from_slice_rn(b"")?; Ok(()) } +#[inline] +fn check_header_value((name, value): (KnownHeaderName, Option<&[u8]>)) -> crate::Result<&[u8]> { + let Some(elem) = value else { + return Err(crate::Error::from(HttpError::MissingHeader(name)).into()); + }; + Ok(elem) +} + +#[inline] +fn check_headers<'header, 'headers, const N: usize>( + array: [(KnownHeaderName, Option<&[u8]>); N], + headers: &'headers [Header<'header>], +) -> crate::Result<[(KnownHeaderName, Option<&'headers [u8]>); N]> { + let mut rslt = [(KnownHeaderName::Accept, None); N]; + for header in headers { + let trimmed_name = header.name().trim_ascii(); + let trimmed_value = header.value().trim_ascii(); + for ((name, value_opt), rslt_elem) in array.into_iter().zip(&mut rslt) { + let has_name = rslt_elem.1.is_none() && trimmed_name.eq_ignore_ascii_case(name.into()); + if has_name { + if let Some(value) = value_opt { + for sub_value in bytes_split1(trimmed_value, b',') { + if sub_value.trim_ascii().eq_ignore_ascii_case(value) { + *rslt_elem = (name, Some(sub_value)); + break; + } + } + if rslt_elem.1.is_some() { + break; + } + } else { + *rslt_elem = (name, Some(trimmed_value)); + } + } + } + } + Ok(rslt) +} + +#[inline] fn derived_key<'buffer>(buffer: &'buffer mut [u8; 30], key: &[u8]) -> &'buffer [u8] { let mut sha1 = Sha1::new(); sha1.update(key); @@ -236,28 +306,12 @@ fn derived_key<'buffer>(buffer: &'buffer mut [u8; 30], key: &[u8]) -> &'buffer [ base64_from_array(&sha1.finalize().into(), buffer) } +#[inline] fn gen_key<'buffer>(buffer: &'buffer mut [u8; 26], rng: &mut impl Rng) -> &'buffer [u8] { base64_from_array(&rng.u8_16(), buffer) } -fn has_header_key_and_value(headers: &[Header<'_>], key: &[u8], value: &[u8]) -> bool { - headers - .iter() - .find_map(|h| { - let has_key = _trim_bytes(h.name()).eq_ignore_ascii_case(key); - let has_value = - bytes_split1(h.value(), b',').any(|el| _trim_bytes(el).eq_ignore_ascii_case(value)); - (has_key && has_value).then_some(true) - }) - .unwrap_or(false) -} - -fn verify_common_header(buffer: &[Header<'_>]) -> crate::Result<()> { - if !has_header_key_and_value(buffer, b"connection", b"upgrade") { - return Err(HttpError::MissingHeader(KnownHeaderName::Connection).into()); - } - if !has_header_key_and_value(buffer, b"upgrade", b"websocket") { - return Err(HttpError::MissingHeader(KnownHeaderName::Upgrade).into()); - } - Ok(()) +#[inline] +fn has_no_masking(el: &[u8]) -> bool { + el.eq_ignore_ascii_case(NO_MASKING.as_bytes()) } diff --git a/wtx/src/web_socket/handshake/tests.rs b/wtx/src/web_socket/handshake/tests.rs index 339bcc3a..184d8d0c 100644 --- a/wtx/src/web_socket/handshake/tests.rs +++ b/wtx/src/web_socket/handshake/tests.rs @@ -26,26 +26,35 @@ static HAS_SERVER_FINISHED: AtomicBool = AtomicBool::new(false); #[cfg(feature = "flate2")] #[tokio::test] -async fn client_and_server_compressed() { +async fn compressed() { use crate::web_socket::compression::Flate2; #[cfg(feature = "_tracing-tree")] let _rslt = crate::misc::tracing_tree_init(None); - do_test_client_and_server_frames((), Flate2::default()).await; + do_test_client_and_server_frames(((), false), (Flate2::default(), false)).await; tokio::time::sleep(Duration::from_millis(200)).await; - do_test_client_and_server_frames(Flate2::default(), ()).await; + do_test_client_and_server_frames((Flate2::default(), false), ((), false)).await; tokio::time::sleep(Duration::from_millis(200)).await; - do_test_client_and_server_frames(Flate2::default(), Flate2::default()).await; + do_test_client_and_server_frames((Flate2::default(), false), (Flate2::default(), false)).await; } #[tokio::test] -async fn client_and_server_uncompressed() { +async fn uncompressed() { #[cfg(feature = "_tracing-tree")] let _rslt = crate::misc::tracing_tree_init(None); - do_test_client_and_server_frames((), ()).await; + do_test_client_and_server_frames(((), false), ((), false)).await; } -async fn do_test_client_and_server_frames(client_compression: CC, server_compression: SC) -where +#[tokio::test] +async fn uncompressed_no_masking() { + #[cfg(feature = "_tracing-tree")] + let _rslt = crate::misc::tracing_tree_init(None); + do_test_client_and_server_frames(((), true), ((), true)).await; +} + +async fn do_test_client_and_server_frames( + (client_compression, client_no_masking): (CC, bool), + (server_compression, server_no_masking): (SC, bool), +) where CC: Compression + Send, CC::NegotiatedCompression: Send, SC: Compression + Send + 'static, @@ -59,6 +68,7 @@ where let (stream, _) = listener.accept().await.unwrap(); let mut ws = WebSocketServer::accept( server_compression, + server_no_masking, Xorshift64::from(simple_seed()), stream, WebSocketBuffer::new(), @@ -83,6 +93,7 @@ where let mut ws = WebSocketClient::connect( client_compression, [], + client_no_masking, Xorshift64::from(simple_seed()), TcpStream::connect(uri.hostname_with_implied_port()).await.unwrap(), &uri.to_ref(), @@ -149,23 +160,13 @@ where let hello = ws.read_frame().await.unwrap(); assert_eq!(OpCode::Text, hello.op_code()); assert_eq!(b"Hello!", hello.payload()); - ws.write_frame(&mut Frame::new_fin( - OpCode::Text, - &mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!'], - )) - .await - .unwrap(); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Goodbye!")).await.unwrap(); assert_eq!(OpCode::Close, ws.read_frame().await.unwrap().op_code()); } async fn server(ws: &mut WebSocketServerOwned) { - ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'H', b'e', b'l', b'l', b'o', b'!'])) - .await - .unwrap(); - assert_eq!( - ws.read_frame().await.unwrap().payload(), - &mut [b'G', b'o', b'o', b'd', b'b', b'y', b'e', b'!'] - ); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"Hello!")).await.unwrap(); + assert_eq!(ws.read_frame().await.unwrap().payload(), b"Goodbye!"); ws.write_frame(&mut Frame::new_fin(OpCode::Close, &mut [])).await.unwrap(); } } @@ -203,7 +204,7 @@ where { async fn client(ws: &mut WebSocketClientOwned) { ws.write_frame(&mut Frame::new_fin(OpCode::Ping, &mut [1, 2, 3])).await.unwrap(); - ws.write_frame(&mut Frame::new_fin(OpCode::Text, &mut [b'i', b'p', b'a', b't'])).await.unwrap(); + ws.write_frame(&mut Frame::new_fin(OpCode::Text, *b"ipat")).await.unwrap(); assert_eq!(OpCode::Pong, ws.read_frame().await.unwrap().op_code()); } diff --git a/wtx/src/web_socket/misc.rs b/wtx/src/web_socket/misc.rs index b3065eeb..9babfd30 100644 --- a/wtx/src/web_socket/misc.rs +++ b/wtx/src/web_socket/misc.rs @@ -1,4 +1,4 @@ -use crate::web_socket::{CloseCode, OpCode, MAX_HEADER_LEN_USIZE}; +use crate::web_socket::{CloseCode, OpCode, MASK_MASK, MAX_HEADER_LEN_USIZE, OP_CODE_MASK}; use core::ops::Range; /// The first two bytes of `payload` are filled with `code`. Does nothing if `payload` is @@ -26,43 +26,26 @@ pub(crate) fn fill_header_from_params( u8::from(fin) << 7 | rsv1 | u8::from(op_code) } - #[inline] - fn manage_mask( - second_byte: &mut u8, - [a, b, c, d]: [&mut u8; 4], - ) -> u8 { - if IS_CLIENT { - *second_byte &= 0b0111_1111; - *a = 0; - *b = 0; - *c = 0; - *d = 0; - N.wrapping_add(4) - } else { - N - } - } - match payload_len { 0..=125 => { - let [a, b, c, d, e, f, ..] = header; + let [a, b, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = u8::try_from(payload_len).unwrap_or_default(); - manage_mask::(b, [c, d, e, f]) + 2 } 126..=0xFFFF => { let [len_c, len_d] = u16::try_from(payload_len).map(u16::to_be_bytes).unwrap_or_default(); - let [a, b, c, d, e, f, g, h, ..] = header; + let [a, b, c, d, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = 126; *c = len_c; *d = len_d; - manage_mask::(b, [e, f, g, h]) + 4 } _ => { let len = u64::try_from(payload_len).map(u64::to_be_bytes).unwrap_or_default(); let [len_c, len_d, len_e, len_f, len_g, len_h, len_i, len_j] = len; - let [a, b, c, d, e, f, g, h, i, j, k, l, m, n] = header; + let [a, b, c, d, e, f, g, h, i, j, ..] = header; *a = first_header_byte(fin, op_code, rsv1); *b = 127; *c = len_c; @@ -73,19 +56,19 @@ pub(crate) fn fill_header_from_params( *h = len_h; *i = len_i; *j = len_j; - manage_mask::(b, [k, l, m, n]) + 10 } } } #[inline] -pub(crate) fn op_code(first_header_byte: u8) -> crate::Result { - OpCode::try_from(first_header_byte & 0b0000_1111) +pub(crate) const fn has_masked_frame(second_header_byte: u8) -> bool { + second_header_byte & MASK_MASK != 0 } #[inline] -pub(crate) fn _trim_bytes(bytes: &[u8]) -> &[u8] { - _trim_bytes_end(_trim_bytes_begin(bytes)) +pub(crate) fn op_code(first_header_byte: u8) -> crate::Result { + OpCode::try_from(first_header_byte & OP_CODE_MASK) } #[inline] @@ -94,27 +77,3 @@ pub(crate) fn _truncated_slice(slice: &[T], range: Range) -> &[T] { let end = range.end.min(slice.len()); slice.get(start..end).unwrap_or_default() } - -#[inline] -fn _trim_bytes_begin(mut bytes: &[u8]) -> &[u8] { - while let [first, rest @ ..] = bytes { - if first.is_ascii_whitespace() { - bytes = rest; - } else { - break; - } - } - bytes -} - -#[inline] -fn _trim_bytes_end(mut bytes: &[u8]) -> &[u8] { - while let [rest @ .., last] = bytes { - if last.is_ascii_whitespace() { - bytes = rest; - } else { - break; - } - } - bytes -} diff --git a/wtx/src/web_socket/read_frame_info.rs b/wtx/src/web_socket/read_frame_info.rs index e52d1f6b..c5283adf 100644 --- a/wtx/src/web_socket/read_frame_info.rs +++ b/wtx/src/web_socket/read_frame_info.rs @@ -1,8 +1,10 @@ use crate::{ misc::{PartitionedFilledBuffer, Stream, _read_until}, web_socket::{ - compression::NegotiatedCompression, misc::op_code, OpCode, WebSocketError, - MAX_CONTROL_PAYLOAD_LEN, + compression::NegotiatedCompression, + misc::{has_masked_frame, op_code}, + OpCode, WebSocketError, FIN_MASK, MAX_CONTROL_PAYLOAD_LEN, PAYLOAD_MASK, RSV1_MASK, RSV2_MASK, + RSV3_MASK, }, }; @@ -24,6 +26,7 @@ impl ReadFrameInfo { bytes: &mut &[u8], max_payload_len: usize, nc: &NC, + no_masking: bool, ) -> crate::Result where NC: NegotiatedCompression, @@ -36,7 +39,7 @@ impl ReadFrameInfo { [*a, *b] }; let tuple = Self::manage_first_two_bytes(first_two, nc)?; - let (fin, length_code, op_code, should_decompress) = tuple; + let (fin, length_code, masked, op_code, should_decompress) = tuple; let (mut header_len, payload_len) = match length_code { 126 => { let [a, b, rest @ ..] = bytes else { @@ -54,15 +57,16 @@ impl ReadFrameInfo { } _ => (2, length_code.into()), }; - let mut mask = None; - if !IS_CLIENT { + let mask = if Self::manage_mask::(masked, no_masking)? { let [a, b, c, d, rest @ ..] = bytes else { return Err(crate::Error::UnexpectedBufferState); }; *bytes = rest; - mask = Some([*a, *b, *c, *d]); header_len = header_len.wrapping_add(4); - } + Some([*a, *b, *c, *d]) + } else { + None + }; Self::manage_final_params(fin, op_code, max_payload_len, payload_len)?; Ok(ReadFrameInfo { fin, header_len, mask, op_code, payload_len, should_decompress }) } @@ -72,6 +76,7 @@ impl ReadFrameInfo { max_payload_len: usize, nc: &NC, network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, read: &mut usize, stream: &mut S, ) -> crate::Result @@ -82,7 +87,7 @@ impl ReadFrameInfo { let buffer = network_buffer._following_rest_mut(); let first_two = _read_until::<2, S>(buffer, read, 0, stream).await?; let tuple = Self::manage_first_two_bytes(first_two, nc)?; - let (fin, length_code, op_code, should_decompress) = tuple; + let (fin, length_code, masked, op_code, should_decompress) = tuple; let (mut header_len, payload_len) = match length_code { 126 => { let payload_len = _read_until::<2, S>(buffer, read, 2, stream).await?; @@ -94,11 +99,13 @@ impl ReadFrameInfo { } _ => (2, length_code.into()), }; - let mut mask = None; - if !IS_CLIENT { - mask = Some(_read_until::<4, S>(buffer, read, header_len.into(), stream).await?); + let mask = if Self::manage_mask::(masked, no_masking)? { + let rslt = _read_until::<4, S>(buffer, read, header_len.into(), stream).await?; header_len = header_len.wrapping_add(4); - } + Some(rslt) + } else { + None + }; Self::manage_final_params(fin, op_code, max_payload_len, payload_len)?; Ok(ReadFrameInfo { fin, header_len, mask, op_code, payload_len, should_decompress }) } @@ -123,13 +130,16 @@ impl ReadFrameInfo { } #[inline] - fn manage_first_two_bytes([a, b]: [u8; 2], nc: &NC) -> crate::Result<(bool, u8, OpCode, bool)> + fn manage_first_two_bytes( + [a, b]: [u8; 2], + nc: &NC, + ) -> crate::Result<(bool, u8, bool, OpCode, bool)> where NC: NegotiatedCompression, { - let rsv1 = a & 0b0100_0000; - let rsv2 = a & 0b0010_0000; - let rsv3 = a & 0b0001_0000; + let rsv1 = a & RSV1_MASK; + let rsv2 = a & RSV2_MASK; + let rsv3 = a & RSV3_MASK; if rsv2 != 0 || rsv3 != 0 { return Err(WebSocketError::InvalidCompressionHeaderParameter.into()); } @@ -143,9 +153,27 @@ impl ReadFrameInfo { } else { rsv1 != 0 }; - let fin = a & 0b1000_0000 != 0; - let length_code = b & 0b0111_1111; + let fin = a & FIN_MASK != 0; + let length_code = b & PAYLOAD_MASK; + let masked = has_masked_frame(b); let op_code = op_code(a)?; - Ok((fin, length_code, op_code, should_decompress)) + Ok((fin, length_code, masked, op_code, should_decompress)) + } + + #[inline] + fn manage_mask(masked: bool, no_masking: bool) -> crate::Result { + Ok(if IS_CLIENT { + false + } else if no_masking { + if masked { + return Err(WebSocketError::InvalidMaskBit.into()); + } + false + } else { + if !masked { + return Err(WebSocketError::InvalidMaskBit.into()); + } + true + }) } } diff --git a/wtx/src/web_socket/web_socket_error.rs b/wtx/src/web_socket/web_socket_error.rs index 1468b676..968154c6 100644 --- a/wtx/src/web_socket/web_socket_error.rs +++ b/wtx/src/web_socket/web_socket_error.rs @@ -11,6 +11,8 @@ pub enum WebSocketError { InvalidCompressionHeaderParameter, /// Header indices are out-of-bounds or the number of bytes are too small. InvalidFrameHeaderBounds, + /// The client sent an invalid mask bit. + InvalidMaskBit, /// Payload indices are out-of-bounds or the number of bytes are too small. InvalidPayloadBounds, /// Server received a frame without a mask. @@ -25,9 +27,8 @@ pub enum WebSocketError { ReservedBitsAreNotZero, /// Received control frame wasn't supposed to be fragmented. UnexpectedFragmentedControlFrame, - /// The first frame of a message is a continuation or the following frames are not a - /// continuation. - UnexpectedMessageFrame, + /// For example, the first frame of a message is a continuation. + UnexpectedFrame, /// Control frames have a maximum allowed size. VeryLargeControlFrame, /// Frame payload exceeds the defined threshold. diff --git a/wtx/src/web_socket/web_socket_parts.rs b/wtx/src/web_socket/web_socket_parts.rs index 29193045..11ff5cec 100644 --- a/wtx/src/web_socket/web_socket_parts.rs +++ b/wtx/src/web_socket/web_socket_parts.rs @@ -23,6 +23,7 @@ pub struct WebSocketCommonPart<'instance, NC, S, const IS_CLIENT: bool> { pub struct WebSocketReaderPart<'instance, NC, S, const IS_CLIENT: bool> { pub(crate) max_payload_len: usize, pub(crate) network_buffer: &'instance mut PartitionedFilledBuffer, + pub(crate) no_masking: bool, pub(crate) phantom: PhantomData<(NC, S)>, pub(crate) reader_buffer_first: &'instance mut Vector, pub(crate) reader_buffer_second: &'instance mut Vector, @@ -46,6 +47,7 @@ where let Self { max_payload_len, network_buffer, + no_masking, phantom: _, reader_buffer_first, reader_buffer_second, @@ -55,6 +57,7 @@ where *max_payload_len, nc, network_buffer, + *no_masking, reader_buffer_first, reader_buffer_second, rng, @@ -70,6 +73,7 @@ where /// to the same instance. #[derive(Debug)] pub struct WebSocketWriterPart<'instance, NC, S, const IS_CLIENT: bool> { + pub(crate) no_masking: bool, pub(crate) phantom: PhantomData<(NC, S)>, pub(crate) writer_buffer: &'instance mut Vector, } @@ -90,8 +94,17 @@ where P: LeaseMut<[u8]>, { let WebSocketCommonPart { connection_state, curr_payload: _, nc, rng, stream } = common; - let Self { phantom: _, writer_buffer } = self; - web_socket_writer::write_frame(connection_state, frame, nc, rng, stream, writer_buffer).await?; + let Self { no_masking, phantom: _, writer_buffer } = self; + web_socket_writer::write_frame( + connection_state, + frame, + *no_masking, + nc, + rng, + stream, + writer_buffer, + ) + .await?; Ok(()) } } diff --git a/wtx/src/web_socket/web_socket_reader.rs b/wtx/src/web_socket/web_socket_reader.rs index 32e07c13..05c6fc99 100644 --- a/wtx/src/web_socket/web_socket_reader.rs +++ b/wtx/src/web_socket/web_socket_reader.rs @@ -28,6 +28,7 @@ type ReadContinuationFramesCb = ( pub(crate) async fn manage_auto_reply( aux: &mut A, connection_state: &mut ConnectionState, + no_masking: bool, op_code: OpCode, payload: &mut [u8], rng: &mut RNG, @@ -58,6 +59,7 @@ where aux, connection_state, &mut Frame::new_fin(OpCode::Close, payload_ret), + no_masking, rng, write_control_frame_cb, ) @@ -70,6 +72,7 @@ where aux, connection_state, &mut Frame::new_fin(OpCode::Close, payload), + no_masking, rng, write_control_frame_cb, ) @@ -81,6 +84,7 @@ where aux, connection_state, &mut Frame::new_fin(OpCode::Pong, payload), + no_masking, rng, write_control_frame_cb, ) @@ -111,7 +115,7 @@ pub(crate) fn manage_op_code_of_continuation_frames( } } OpCode::Binary | OpCode::Close | OpCode::Ping | OpCode::Pong | OpCode::Text => { - return Err(WebSocketError::UnexpectedMessageFrame.into()); + return Err(WebSocketError::UnexpectedFrame.into()); } } Ok(false) @@ -127,7 +131,7 @@ pub(crate) fn manage_op_code_of_first_continuation_frame( OpCode::Binary => Ok(None), OpCode::Text => cb(payload), OpCode::Close | OpCode::Continuation | OpCode::Ping | OpCode::Pong => { - Err(WebSocketError::UnexpectedMessageFrame.into()) + Err(WebSocketError::UnexpectedFrame.into()) } } } @@ -142,7 +146,7 @@ pub(crate) fn manage_op_code_of_first_final_frame( return Ok(()); } OpCode::Continuation => { - return Err(WebSocketError::UnexpectedMessageFrame.into()); + return Err(WebSocketError::UnexpectedFrame.into()); } OpCode::Text => { let _str_validation = from_utf8_basic(payload)?; @@ -200,9 +204,10 @@ pub(crate) fn manage_text_of_recurrent_continuation_frames( #[inline] pub(crate) fn unmask_nb( network_buffer: &mut [u8], + no_masking: bool, rfi: &ReadFrameInfo, ) -> crate::Result<()> { - if !IS_CLIENT { + if !IS_CLIENT && !no_masking { unmask(network_buffer, rfi.mask.ok_or(WebSocketError::MissingFrameMask)?); } Ok(()) @@ -214,6 +219,7 @@ pub(crate) async fn read_frame_from_stream<'nb, 'rb, 'rslt, NC, RNG, S, const IS max_payload_len: usize, nc: &mut NC, network_buffer: &'nb mut PartitionedFilledBuffer, + no_masking: bool, reader_buffer_first: &'rb mut Vector, reader_buffer_second: &'rb mut Vector, rng: &mut RNG, @@ -230,9 +236,14 @@ where let first_rfi = loop { network_buffer._clear_if_following_is_empty(); reader_buffer_first.clear(); - let rfi = - fetch_frame_from_stream::<_, _, IS_CLIENT>(max_payload_len, nc, network_buffer, stream) - .await?; + let rfi = fetch_frame_from_stream::<_, _, IS_CLIENT>( + max_payload_len, + nc, + network_buffer, + no_masking, + stream, + ) + .await?; if !rfi.fin { break rfi; } @@ -240,18 +251,20 @@ where copy_from_compressed_nb_to_rb1::( nc, network_buffer, + no_masking, reader_buffer_first, &rfi, )?; (reader_buffer_first.as_slice_mut(), PayloadTy::FirstReader) } else { let current_mut = network_buffer._current_mut(); - unmask_nb::(current_mut, &rfi)?; + unmask_nb::(current_mut, no_masking, &rfi)?; (current_mut, PayloadTy::Network) }; if manage_auto_reply::<_, _, IS_CLIENT>( stream, connection_state, + no_masking, rfi.op_code, payload, rng, @@ -276,6 +289,7 @@ where max_payload_len, nc, network_buffer, + no_masking, reader_buffer_first, rng, stream, @@ -294,6 +308,7 @@ where max_payload_len, nc, network_buffer, + no_masking, reader_buffer_first, rng, stream, @@ -310,11 +325,12 @@ where #[inline] fn copy_from_arbitrary_nb_to_rb1( network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, reader_buffer_first: &mut Vector, rfi: &ReadFrameInfo, ) -> crate::Result<()> { let current_mut = network_buffer._current_mut(); - unmask_nb::(current_mut, rfi)?; + unmask_nb::(current_mut, no_masking, rfi)?; reader_buffer_first.extend_from_copyable_slice(current_mut)?; network_buffer._clear_if_following_is_empty(); Ok(()) @@ -324,13 +340,14 @@ fn copy_from_arbitrary_nb_to_rb1( fn copy_from_compressed_nb_to_rb1( nc: &mut NC, network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, reader_buffer_first: &mut Vector, rfi: &ReadFrameInfo, ) -> crate::Result<()> where NC: NegotiatedCompression, { - unmask_nb::(network_buffer._current_mut(), rfi)?; + unmask_nb::(network_buffer._current_mut(), no_masking, rfi)?; network_buffer._reserve(4)?; let curr_end_idx = network_buffer._current().len(); let curr_end_idx_p4 = curr_end_idx.wrapping_add(4); @@ -407,6 +424,7 @@ async fn fetch_frame_from_stream( max_payload_len: usize, nc: &NC, network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, stream: &mut S, ) -> crate::Result where @@ -418,6 +436,7 @@ where max_payload_len, nc, network_buffer, + no_masking, &mut read, stream, ) @@ -463,6 +482,7 @@ async fn read_continuation_frames( max_payload_len: usize, nc: &mut NC, network_buffer: &mut PartitionedFilledBuffer, + no_masking: bool, reader_buffer_first: &mut Vector, rng: &mut RNG, stream: &mut S, @@ -473,23 +493,39 @@ where RNG: Rng, S: Stream, { - copy_from_arbitrary_nb_to_rb1::(network_buffer, reader_buffer_first, first_rfi)?; + copy_from_arbitrary_nb_to_rb1::( + network_buffer, + no_masking, + reader_buffer_first, + first_rfi, + )?; let mut iuc = manage_op_code_of_first_continuation_frame( first_rfi.op_code, reader_buffer_first, first_text_cb, )?; loop { - let mut rfi = - fetch_frame_from_stream::<_, _, IS_CLIENT>(max_payload_len, nc, network_buffer, stream) - .await?; + let mut rfi = fetch_frame_from_stream::<_, _, IS_CLIENT>( + max_payload_len, + nc, + network_buffer, + no_masking, + stream, + ) + .await?; let begin = reader_buffer_first.len(); rfi.should_decompress = first_rfi.should_decompress; - copy_from_arbitrary_nb_to_rb1::(network_buffer, reader_buffer_first, &rfi)?; + copy_from_arbitrary_nb_to_rb1::( + network_buffer, + no_masking, + reader_buffer_first, + &rfi, + )?; let payload = reader_buffer_first.get_mut(begin..).unwrap_or_default(); if !manage_auto_reply::<_, _, IS_CLIENT>( stream, connection_state, + no_masking, rfi.op_code, payload, rng, @@ -518,6 +554,7 @@ async fn write_control_frame( aux: &mut A, connection_state: &mut ConnectionState, frame: &mut Frame, + no_masking: bool, rng: &mut RNG, wsc_cb: &mut impl for<'any> FnMutFut< (&'any mut A, &'any [u8], &'any [u8]), @@ -528,7 +565,7 @@ where P: LeaseMut<[u8]>, RNG: Rng, { - manage_normal_frame(connection_state, frame, rng); + manage_normal_frame(connection_state, frame, no_masking, rng); wsc_cb.call((aux, frame.header(), frame.payload().lease())).await?; Ok(()) } diff --git a/wtx/src/web_socket/web_socket_writer.rs b/wtx/src/web_socket/web_socket_writer.rs index 2aecb986..65eb7c44 100644 --- a/wtx/src/web_socket/web_socket_writer.rs +++ b/wtx/src/web_socket/web_socket_writer.rs @@ -1,6 +1,9 @@ use crate::{ misc::{BufferMode, ConnectionState, Lease, LeaseMut, Rng, Stream, Vector, Xorshift64}, - web_socket::{compression::NegotiatedCompression, unmask::unmask, Frame, FrameMut, OpCode}, + web_socket::{ + compression::NegotiatedCompression, misc::has_masked_frame, unmask::unmask, Frame, FrameMut, + OpCode, + }, }; #[inline] @@ -29,6 +32,7 @@ pub(crate) fn manage_frame_compression<'cb, P, NC, const IS_CLIENT: bool>( connection_state: &mut ConnectionState, nc: &mut NC, frame: &mut Frame, + no_masking: bool, rng: &mut Xorshift64, writer_buffer: &'cb mut Vector, ) -> crate::Result> @@ -40,7 +44,7 @@ where *connection_state = ConnectionState::Closed; } let mut compressed_frame = compress_frame(frame, nc, writer_buffer)?; - mask_frame(&mut compressed_frame, rng); + mask_frame(&mut compressed_frame, no_masking, rng); Ok(compressed_frame) } @@ -48,6 +52,7 @@ where pub(crate) fn manage_normal_frame( connection_state: &mut ConnectionState, frame: &mut Frame, + no_masking: bool, rng: &mut RNG, ) where P: LeaseMut<[u8]>, @@ -56,13 +61,14 @@ pub(crate) fn manage_normal_frame( if frame.op_code() == OpCode::Close { *connection_state = ConnectionState::Closed; } - mask_frame(frame, rng); + mask_frame(frame, no_masking, rng); } #[inline] pub(crate) async fn write_frame( connection_state: &mut ConnectionState, frame: &mut Frame, + no_masking: bool, nc: &mut NC, rng: &mut Xorshift64, stream: &mut S, @@ -74,10 +80,10 @@ where S: Stream, { if manage_compression(frame, nc) { - let nframe = manage_frame_compression(connection_state, nc, frame, rng, writer_buffer)?; - stream.write_all_vectored(&[nframe.header(), nframe.payload()]).await?; + let fr = manage_frame_compression(connection_state, nc, frame, no_masking, rng, writer_buffer)?; + stream.write_all_vectored(&[fr.header(), fr.payload()]).await?; } else { - manage_normal_frame::<_, _, IS_CLIENT>(connection_state, frame, rng); + manage_normal_frame::<_, _, IS_CLIENT>(connection_state, frame, no_masking, rng); let (header, payload) = frame.header_and_payload_mut(); stream.write_all_vectored(&[header, payload.lease()]).await?; } @@ -120,27 +126,17 @@ where } #[inline] -const fn has_masked_frame(second_header_byte: u8) -> bool { - second_header_byte & 0b1000_0000 != 0 -} - -#[inline] -fn mask_frame(frame: &mut Frame, rng: &mut RNG) -where +fn mask_frame( + frame: &mut Frame, + no_masking: bool, + rng: &mut RNG, +) where P: LeaseMut<[u8]>, RNG: Rng, { - if IS_CLIENT { - if let [_, second_byte, .., a, b, c, d] = frame.header_mut() { - if !has_masked_frame(*second_byte) { - *second_byte |= 0b1000_0000; - let mask = rng.u8_4(); - *a = mask[0]; - *b = mask[1]; - *c = mask[2]; - *d = mask[3]; - unmask(frame.payload_mut().lease_mut(), mask); - } - } + if IS_CLIENT && !no_masking && !has_masked_frame(*frame.header_first_two_mut()[1]) { + let mask: [u8; 4] = rng.u8_4(); + frame.set_mask(mask); + unmask(frame.payload_mut().lease_mut(), mask); } }