From 94a98c1eb4d90ca13533ea8a1c5698e8cbbbe929 Mon Sep 17 00:00:00 2001 From: Caio Date: Mon, 28 Aug 2023 07:49:24 -0300 Subject: [PATCH] First commit --- .editorconfig | 9 + .github/dependabot.yml | 12 + .github/workflows/ci.yaml | 50 ++ .gitignore | 8 + .scripts/autobahn.sh | 47 ++ .scripts/autobahn/fuzzingclient.json | 12 + .scripts/autobahn/fuzzingserver.json | 7 + .scripts/common.sh | 9 + .scripts/fuzz.sh | 6 + .scripts/wtx-bench.sh | 60 ++ .scripts/wtx.sh | 30 + Cargo.toml | 19 + rust-toolchain | 4 + wtx-bench/Cargo.toml | 13 + wtx-bench/README.md | 7 + wtx-bench/src/main.rs | 231 ++++++ wtx-fuzz/Cargo.toml | 21 + wtx-fuzz/parse_frame.rs | 25 + wtx-fuzz/unmask.rs | 8 + wtx/Cargo.toml | 80 ++ wtx/LICENSE | 1 + wtx/README.md | 51 ++ wtx/benches/simple.rs | 18 + wtx/examples/common/mod.rs | 94 +++ wtx/examples/localhost.crt | 28 + wtx/examples/localhost.key | 52 ++ .../web_socket_client_autobahn_raw_tokio.rs | 95 +++ .../web_socket_client_cli_raw_tokio_rustls.rs | 62 ++ wtx/examples/web_socket_server_echo_hyper.rs | 48 ++ .../web_socket_server_echo_raw_async_std.rs | 25 + .../web_socket_server_echo_raw_glommio.rs | 47 ++ .../web_socket_server_echo_raw_tokio.rs | 24 + ...web_socket_server_echo_raw_tokio_rustls.rs | 52 ++ wtx/profiling/web_socket.rs | 62 ++ wtx/src/cache.rs | 49 ++ wtx/src/error.rs | 174 +++++ wtx/src/expected_header.rs | 13 + wtx/src/lib.rs | 34 + wtx/src/misc.rs | 65 ++ wtx/src/misc/incomplete_utf8_char.rs | 76 ++ wtx/src/misc/rng.rs | 26 + wtx/src/misc/traits.rs | 84 ++ wtx/src/misc/uri_parts.rs | 36 + wtx/src/misc/utf8_errors.rs | 13 + wtx/src/read_buffer.rs | 103 +++ wtx/src/request.rs | 75 ++ wtx/src/response.rs | 82 ++ wtx/src/stream.rs | 260 +++++++ wtx/src/web_socket.rs | 723 ++++++++++++++++++ wtx/src/web_socket/close_code.rs | 101 +++ wtx/src/web_socket/frame.rs | 194 +++++ wtx/src/web_socket/frame_buffer.rs | 269 +++++++ wtx/src/web_socket/handshake.rs | 70 ++ wtx/src/web_socket/handshake/hyper.rs | 191 +++++ wtx/src/web_socket/handshake/misc.rs | 62 ++ wtx/src/web_socket/handshake/raw.rs | 259 +++++++ wtx/src/web_socket/handshake/tests.rs | 336 ++++++++ wtx/src/web_socket/mask.rs | 100 +++ wtx/src/web_socket/op_code.rs | 68 ++ wtx/src/web_socket/web_socket_error.rs | 35 + 60 files changed, 4815 insertions(+) create mode 100644 .editorconfig create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/ci.yaml create mode 100644 .gitignore create mode 100755 .scripts/autobahn.sh create mode 100644 .scripts/autobahn/fuzzingclient.json create mode 100644 .scripts/autobahn/fuzzingserver.json create mode 100644 .scripts/common.sh create mode 100755 .scripts/fuzz.sh create mode 100755 .scripts/wtx-bench.sh create mode 100755 .scripts/wtx.sh create mode 100644 Cargo.toml create mode 100644 rust-toolchain create mode 100644 wtx-bench/Cargo.toml create mode 100644 wtx-bench/README.md create mode 100644 wtx-bench/src/main.rs create mode 100644 wtx-fuzz/Cargo.toml create mode 100644 wtx-fuzz/parse_frame.rs create mode 100644 wtx-fuzz/unmask.rs create mode 100644 wtx/Cargo.toml create mode 120000 wtx/LICENSE create mode 100644 wtx/README.md create mode 100644 wtx/benches/simple.rs create mode 100644 wtx/examples/common/mod.rs create mode 100644 wtx/examples/localhost.crt create mode 100644 wtx/examples/localhost.key create mode 100644 wtx/examples/web_socket_client_autobahn_raw_tokio.rs create mode 100644 wtx/examples/web_socket_client_cli_raw_tokio_rustls.rs create mode 100644 wtx/examples/web_socket_server_echo_hyper.rs create mode 100644 wtx/examples/web_socket_server_echo_raw_async_std.rs create mode 100644 wtx/examples/web_socket_server_echo_raw_glommio.rs create mode 100644 wtx/examples/web_socket_server_echo_raw_tokio.rs create mode 100644 wtx/examples/web_socket_server_echo_raw_tokio_rustls.rs create mode 100644 wtx/profiling/web_socket.rs create mode 100644 wtx/src/cache.rs create mode 100644 wtx/src/error.rs create mode 100644 wtx/src/expected_header.rs create mode 100644 wtx/src/lib.rs create mode 100644 wtx/src/misc.rs create mode 100644 wtx/src/misc/incomplete_utf8_char.rs create mode 100644 wtx/src/misc/rng.rs create mode 100644 wtx/src/misc/traits.rs create mode 100644 wtx/src/misc/uri_parts.rs create mode 100644 wtx/src/misc/utf8_errors.rs create mode 100644 wtx/src/read_buffer.rs create mode 100644 wtx/src/request.rs create mode 100644 wtx/src/response.rs create mode 100644 wtx/src/stream.rs create mode 100644 wtx/src/web_socket.rs create mode 100644 wtx/src/web_socket/close_code.rs create mode 100644 wtx/src/web_socket/frame.rs create mode 100644 wtx/src/web_socket/frame_buffer.rs create mode 100644 wtx/src/web_socket/handshake.rs create mode 100644 wtx/src/web_socket/handshake/hyper.rs create mode 100644 wtx/src/web_socket/handshake/misc.rs create mode 100644 wtx/src/web_socket/handshake/raw.rs create mode 100644 wtx/src/web_socket/handshake/tests.rs create mode 100644 wtx/src/web_socket/mask.rs create mode 100644 wtx/src/web_socket/op_code.rs create mode 100644 wtx/src/web_socket/web_socket_error.rs diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..37d38bbb --- /dev/null +++ b/.editorconfig @@ -0,0 +1,9 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 2 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..0def0a03 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 + +updates: +- package-ecosystem: "cargo" + directory: "/" + schedule: + interval: daily + +- package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..6387d3a8 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,50 @@ +name: CI +on: + pull_request: + push: + branches: + - main + +jobs: + autobahn: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + override: true + profile: minimal + toolchain: nightly-2023-08-01 + - uses: Swatinem/rust-cache@v2 + + - run: .scripts/autobahn.sh ci + +# fuzz: +# runs-on: ubuntu-latest +# steps: +# - uses: actions/checkout@v3 +# - uses: actions-rs/toolchain@v1 +# with: +# override: true +# profile: minimal +# toolchain: nightly-2023-08-01 +# - uses: actions-rs/install@v0.1 +# with: +# crate: cargo-fuzz +# use-tool-cache: true +# +# - run: .scripts/fuzz.sh + + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + components: clippy, rustfmt + override: true + profile: minimal + toolchain: nightly-2023-08-01 + - uses: Swatinem/rust-cache@v2 + + - run: .scripts/wtx.sh \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f6e49f12 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.scripts/autobahn/reports +.vscode +**/*.rs.bk +**/artifacts +**/Cargo.lock +**/corpus +**/target +**/target \ No newline at end of file diff --git a/.scripts/autobahn.sh b/.scripts/autobahn.sh new file mode 100755 index 00000000..30b6b75a --- /dev/null +++ b/.scripts/autobahn.sh @@ -0,0 +1,47 @@ +set -euxo pipefail + +ARG=${1:-""} +if [ "$ARG" != "ci" ]; then + trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT +fi; + +# fuzzingclient + +cargo build --example web_socket_server_echo_raw_tokio --features tokio,web-socket-handshake --release +cargo run --example web_socket_server_echo_raw_tokio --features tokio,web-socket-handshake --release & cargo_pid=$! +mkdir -p .scripts/autobahn/reports/fuzzingclient +podman run \ + -v .scripts/autobahn/fuzzingclient.json:/fuzzingclient.json:ro \ + -v .scripts/autobahn:/autobahn \ + --name fuzzingclient \ + --net=host \ + --rm \ + docker.io/crossbario/autobahn-testsuite:0.8.2 wstest -m fuzzingclient -s fuzzingclient.json +podman rm --force --ignore fuzzingclient +kill -9 $cargo_pid + +if [ $(grep -ci "failed" .scripts/autobahn/reports/fuzzingclient/index.json) -gt 0 ] +then + exit 1 +fi + +## fuzzingserver + +cargo build --example web_socket_client_autobahn_raw_tokio --features tokio,web-socket-handshake --release +mkdir -p .scripts/autobahn/reports/fuzzingserver +podman run \ + -d \ + -p 9080:9080 \ + -v .scripts/autobahn/fuzzingserver.json:/fuzzingserver.json:ro \ + -v .scripts/autobahn:/autobahn \ + --name fuzzingserver \ + --net=host \ + docker.io/crossbario/autobahn-testsuite:0.8.2 wstest -m fuzzingserver -s fuzzingserver.json +sleep 5 +cargo run --example web_socket_client_autobahn_raw_tokio --features tokio,web-socket-handshake --release -- 127.0.0.1:9080 +podman rm --force --ignore fuzzingserver + +if [ $(grep -ci "failed" .scripts/autobahn/reports/fuzzingserver/index.json) -gt 0 ] +then + exit 1 +fi diff --git a/.scripts/autobahn/fuzzingclient.json b/.scripts/autobahn/fuzzingclient.json new file mode 100644 index 00000000..f2567006 --- /dev/null +++ b/.scripts/autobahn/fuzzingclient.json @@ -0,0 +1,12 @@ +{ + "cases": ["1.*", "2.*", "3.*", "4.*", "5.*", "6.*", "7.*", "9.*", "10.*"], + "exclude-agent-cases": {}, + "exclude-cases": [], + "outdir": "/autobahn/reports/fuzzingclient", + "servers": [ + { + "agent": "wtx", + "url": "ws://127.0.0.1:8080" + } + ] +} diff --git a/.scripts/autobahn/fuzzingserver.json b/.scripts/autobahn/fuzzingserver.json new file mode 100644 index 00000000..96fde560 --- /dev/null +++ b/.scripts/autobahn/fuzzingserver.json @@ -0,0 +1,7 @@ +{ + "cases": ["1.*", "2.*", "3.*", "4.*", "5.*", "6.*", "7.*", "9.*", "10.*"], + "exclude-agent-cases": {}, + "exclude-cases": [], + "outdir": "/autobahn/reports/fuzzingserver", + "url": "ws://127.0.0.1:9080" +} \ No newline at end of file diff --git a/.scripts/common.sh b/.scripts/common.sh new file mode 100644 index 00000000..6b263c69 --- /dev/null +++ b/.scripts/common.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +export rt='cargo run --bin rust-tools -- --template you-rust' + +export CARGO_TARGET_DIR="$($rt target-dir)" +export RUST_BACKTRACE=1 +export RUSTFLAGS="$($rt rust-flags)" diff --git a/.scripts/fuzz.sh b/.scripts/fuzz.sh new file mode 100755 index 00000000..781efd83 --- /dev/null +++ b/.scripts/fuzz.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +cargo fuzz run --fuzz-dir wtx-fuzz unmask -- -runs=100000 +cargo fuzz run --fuzz-dir wtx-fuzz parse-frame -- -runs=100000 diff --git a/.scripts/wtx-bench.sh b/.scripts/wtx-bench.sh new file mode 100755 index 00000000..04b74fbb --- /dev/null +++ b/.scripts/wtx-bench.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT + +pushd /tmp +git clone https://github.com/c410-f3r/fastwebsockets || true +cd fastwebsockets +git checkout -t origin/bench || true +RUSTFLAGS='-C target-cpu=native' cargo build --example echo_server --features simd,upgrade --release +RUSTFLAGS='-C target-cpu=native' cargo run --example echo_server --features simd,upgrade --release 127.0.0.1:8080 & + +cd /tmp +git clone https://github.com/c410-f3r/websocket || true +cd websocket/examples/echo +git checkout -t origin/bench || true +go run server.go 127.0.0.1:8081 & + +cd /tmp +git clone https://github.com/c410-f3r/tokio-tungstenite || true +cd tokio-tungstenite +git checkout -t origin/bench || true +RUSTFLAGS='-C target-cpu=native' cargo build --example echo-server --release +RUSTFLAGS='-C target-cpu=native' cargo run --example echo-server --release 127.0.0.1:8082 & + +cd /tmp +git clone --recursive https://github.com/c410-f3r/uWebSockets.git || true +cd uWebSockets +git checkout -t origin/bench || true +if [ ! -e ./EchoServer ] +then + make examples +fi +./EchoServer 8083 & +popd + +RUSTFLAGS='-C target-cpu=native' cargo build --example web_socket_server_echo_hyper --features simdutf8,web-socket-hyper --release +RUSTFLAGS='-C target-cpu=native' cargo run --example web_socket_server_echo_hyper --features simdutf8,web-socket-hyper --release 127.0.0.1:8084 & + +RUSTFLAGS='-C target-cpu=native' cargo build --example web_socket_server_echo_raw_async_std --features async-std,simdutf8,web-socket-handshake --release +RUSTFLAGS='-C target-cpu=native' cargo run --example web_socket_server_echo_raw_async_std --features async-std,simdutf8,web-socket-handshake --release 127.0.0.1:8085 & + +RUSTFLAGS='-C target-cpu=native' cargo build --example web_socket_server_echo_raw_glommio --features glommio,simdutf8,web-socket-handshake --release +RUSTFLAGS='-C target-cpu=native' cargo run --example web_socket_server_echo_raw_glommio --features glommio,simdutf8,web-socket-handshake --release 127.0.0.1:8086 & + +RUSTFLAGS='-C target-cpu=native' cargo build --example web_socket_server_echo_raw_tokio --features simdutf8,tokio,web-socket-handshake --release +RUSTFLAGS='-C target-cpu=native' cargo run --example web_socket_server_echo_raw_tokio --features simdutf8,tokio,web-socket-handshake --release 127.0.0.1:8087 & + +sleep 1 + +RUSTFLAGS='-C target-cpu=native' cargo run --bin wtx-bench --release -- \ + http://127.0.0.1:8080/fastwebsockets \ + http://127.0.0.1:8081/gorilla-websocket \ + http://127.0.0.1:8082/tokio-tungstenite \ + http://127.0.0.1:8083/uWebSockets \ + http://127.0.0.1:8084/wtx-hyper \ + http://127.0.0.1:8085/wtx-raw-async-std \ + http://127.0.0.1:8086/wtx-raw-glommio \ + http://127.0.0.1:8087/wtx-raw-tokio diff --git a/.scripts/wtx.sh b/.scripts/wtx.sh new file mode 100755 index 00000000..81229800 --- /dev/null +++ b/.scripts/wtx.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +cargo install rust-tools --git https://github.com/c410-f3r/regular-crates + +rt='rust-tools --template you-rust' + +CARGO_TARGET_DIR="$($rt target-dir)" +RUST_BACKTRACE=1 +RUSTFLAGS="$($rt rust-flags)" + +$rt rustfmt +$rt clippy + +$rt test-generic wtx +$rt test-with-features wtx async-std +$rt test-with-features wtx async-trait +$rt test-with-features wtx base64 +$rt test-with-features wtx futures-lite +$rt test-with-features wtx glommio +$rt test-with-features wtx http +$rt test-with-features wtx httparse +$rt test-with-features wtx hyper +$rt test-with-features wtx sha1 +$rt test-with-features wtx simdutf8 +$rt test-with-features wtx std +$rt test-with-features wtx tokio +$rt test-with-features wtx web-socket-handshake +$rt test-with-features wtx web-socket-hyper \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..e1a89c36 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[profile.profiling] +inherits = "release" +debug = true + +[profile.release] +codegen-units = 1 +debug = false +debug-assertions = false +incremental = false +lto = true +opt-level = 3 +overflow-checks = false +panic = 'abort' +rpath = false +strip = "debuginfo" + +[workspace] +members = ["wtx-bench", "wtx", "wtx-fuzz"] +resolver = "2" diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 00000000..9d0dd1d0 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1,4 @@ +[toolchain] +channel = "nightly-2023-08-01" +components = ["clippy", "rustfmt"] +profile = "minimal" diff --git a/wtx-bench/Cargo.toml b/wtx-bench/Cargo.toml new file mode 100644 index 00000000..98f9ddb4 --- /dev/null +++ b/wtx-bench/Cargo.toml @@ -0,0 +1,13 @@ +[dependencies] +hyper = { default-features = false, features = ["client", "http1", "server", "tcp"], version = "0.14" } +plotters = { default-features = false, features = ["histogram", "svg_backend"], version = "0.3" } +tokio = { default-features = false, features = ["macros", "rt-multi-thread"], version = "1.0" } +wtx = { features = ["tokio", "web-socket-handshake"], path = "../wtx" } + +[package] +description = "Benchmarks" +edition = "2021" +license = "Apache-2.0" +name = "wtx-bench" +publish = false +version = "0.0.1" diff --git a/wtx-bench/README.md b/wtx-bench/README.md new file mode 100644 index 00000000..f950b889 --- /dev/null +++ b/wtx-bench/README.md @@ -0,0 +1,7 @@ +# Benchmarks + +Call the `wtx-bench` binary passing the URLs of all different available echo servers. + +``` +cargo run --bin wtx-bench --release -- http://127.0.0.1:8080/some_server_name http://127.0.0.1:8081/another_server_name .. +``` \ No newline at end of file diff --git a/wtx-bench/src/main.rs b/wtx-bench/src/main.rs new file mode 100644 index 00000000..5970dc53 --- /dev/null +++ b/wtx-bench/src/main.rs @@ -0,0 +1,231 @@ +//! WebSocket benchmark + +#![allow( + // Does not matter + clippy::arithmetic_side_effects, + // Does not matter + clippy::unwrap_used +)] + +use plotters::{ + prelude::{ + ChartBuilder, IntoDrawingArea, IntoSegmentedCoord, LabelAreaPosition, PathElement, + SVGBackend, SeriesLabelPosition, + }, + series::Histogram, + style::{AsRelative, Color, Palette99, PaletteColor, BLACK, WHITE}, +}; +use std::time::Instant; +use tokio::{net::TcpStream, task::JoinSet}; +use wtx::{ + web_socket::{ + handshake::{WebSocketHandshake, WebSocketHandshakeRaw}, + FrameBufferVec, FrameVecMut, OpCode, WebSocketClientOwned, + }, + UriParts, +}; + +// Verifies the handling of concurrent calls. +const CONNECTIONS: usize = 1; +// Some applications use WebSocket to perform streaming so the length of a frame can be quite large +// but statistically it is generally low. +const FRAME_LEN: usize = 1; +// For each message, the client always verifies the content sent back from a server and this +// leads to a sequential-like behavior. +// +// If this is the only high metric, all different servers end-up performing similarly effectively +// making this criteria an "augmenting factor" when combined with other parameters. +const NUM_MESSAGES: usize = 1; + +// Automatically calculated. +const NUM_FRAMES: usize = { + let n = NUM_MESSAGES / 4; + if n == 0 { + 1 + } else { + n + } +}; + +static FRAME_DATA: &[u8; FRAME_LEN] = &[53; FRAME_LEN]; + +#[tokio::main] +async fn main() { + let uris: Vec<_> = std::env::args().skip(1).collect(); + let mut agents = Vec::new(); + for uri in uris { + let uri_parts = UriParts::from(uri.as_str()); + let mut agent = Agent { + result: 0, + name: uri_parts.href.to_owned(), + }; + bench(uri_parts.authority, &mut agent, &uri).await; + agents.push(agent); + } + flush(&agents); +} + +async fn bench(addr: &str, agent: &mut Agent, uri: &str) { + let instant = Instant::now(); + let mut set = JoinSet::new(); + for _ in 0..CONNECTIONS { + let _handle = set.spawn({ + let local_addr: String = addr.to_owned(); + let local_uri = uri.to_owned(); + async move { + let fb = &mut FrameBufferVec::default(); + let mut ws = ws(&local_addr, fb, &local_uri).await; + for _ in 0..NUM_MESSAGES { + match NUM_FRAMES { + 0 => break, + 1 => { + ws.write_frame( + &mut FrameVecMut::new_fin(fb.into(), OpCode::Text, FRAME_DATA) + .unwrap(), + ) + .await + .unwrap(); + } + 2 => { + ws.write_frame( + &mut FrameVecMut::new_unfin(fb.into(), OpCode::Text, FRAME_DATA) + .unwrap(), + ) + .await + .unwrap(); + ws.write_frame( + &mut FrameVecMut::new_fin( + fb.into(), + OpCode::Continuation, + FRAME_DATA, + ) + .unwrap(), + ) + .await + .unwrap(); + } + _ => { + ws.write_frame( + &mut FrameVecMut::new_unfin(fb.into(), OpCode::Text, FRAME_DATA) + .unwrap(), + ) + .await + .unwrap(); + for _ in (0..NUM_FRAMES).skip(2) { + ws.write_frame( + &mut FrameVecMut::new_unfin( + fb.into(), + OpCode::Continuation, + FRAME_DATA, + ) + .unwrap(), + ) + .await + .unwrap(); + } + ws.write_frame( + &mut FrameVecMut::new_fin( + fb.into(), + OpCode::Continuation, + FRAME_DATA, + ) + .unwrap(), + ) + .await + .unwrap(); + } + } + assert_eq!( + ws.read_frame(fb).await.unwrap().fb().payload().len(), + FRAME_LEN * NUM_FRAMES + ); + } + ws.write_frame(&mut FrameVecMut::new_fin(fb.into(), OpCode::Close, &[]).unwrap()) + .await + .unwrap(); + } + }); + } + while let Some(rslt) = set.join_next().await { + rslt.unwrap(); + } + agent.result = instant.elapsed().as_millis(); +} + +fn flush(agents: &[Agent]) { + if agents.is_empty() { + return; + } + let x_spec = agents + .iter() + .map(|el| &el.name) + .cloned() + .collect::>(); + let root = SVGBackend::new("/tmp/wtx-bench.png", (1000, 500)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let mut ctx = ChartBuilder::on(&root) + .caption( + format!("{CONNECTIONS} connection(s) sending {NUM_MESSAGES} message(s) composed by {NUM_FRAMES} frame(s) of {FRAME_LEN} byte(s)"), + ("sans-serif", (4).percent_height()), + ) + .margin((1).percent()) + .set_label_area_size(LabelAreaPosition::Left, (15).percent()) + .set_label_area_size(LabelAreaPosition::Bottom, (5).percent()) + .build_cartesian_2d(x_spec.into_segmented(), { + let start = 0u128; + let exact_end = agents.iter().map(|el| el.result).max().unwrap_or(5000); + let surplus_end = ((exact_end / 500) + 1) * 500; + start..surplus_end + }) + .unwrap(); + ctx.configure_mesh() + .axis_desc_style(("sans-serif", 15)) + .bold_line_style(WHITE.mix(0.3)) + .y_desc("Time (ms)") + .draw() + .unwrap(); + for (idx, agent) in agents.iter().enumerate() { + let _ = ctx + .draw_series( + Histogram::vertical(&ctx) + .style(PaletteColor::::pick(idx).mix(0.5).filled()) + .data([(&agent.name, agent.result)]), + ) + .unwrap() + .label(format!("{} ({}ms)", &agent.name, agent.result)) + .legend(move |(x, y)| { + PathElement::new([(x, y), (x + 20, y)], PaletteColor::::pick(idx)) + }); + } + ctx.configure_series_labels() + .border_style(BLACK) + .background_style(WHITE.mix(0.8)) + .position(SeriesLabelPosition::UpperRight) + .draw() + .unwrap(); + root.present().unwrap(); +} + +async fn ws( + authority: &str, + fb: &mut FrameBufferVec, + uri: &str, +) -> WebSocketClientOwned { + WebSocketHandshakeRaw { + fb, + headers_buffer: &mut <_>::default(), + rb: <_>::default(), + stream: TcpStream::connect(authority).await.unwrap(), + uri, + } + .handshake() + .await + .unwrap() + .1 +} + +#[derive(Debug)] +struct Agent { + result: u128, + name: String, +} diff --git a/wtx-fuzz/Cargo.toml b/wtx-fuzz/Cargo.toml new file mode 100644 index 00000000..5173305b --- /dev/null +++ b/wtx-fuzz/Cargo.toml @@ -0,0 +1,21 @@ +[[bin]] +name = "parse-frame" +path = "parse_frame.rs" + +[[bin]] +name = "unmask" +path = "unmask.rs" + +[dependencies] +libfuzzer-sys = { default-features = false, version = "0.4" } +tokio = { default-features = false, features = ["rt"], version = "1.0" } +wtx = { default-features = false, path = "../wtx" } + +[package] +name = "wtx-fuzz" +version = "0.0.0" +publish = false +edition = "2021" + +[package.metadata] +cargo-fuzz = true diff --git a/wtx-fuzz/parse_frame.rs b/wtx-fuzz/parse_frame.rs new file mode 100644 index 00000000..66edc2fa --- /dev/null +++ b/wtx-fuzz/parse_frame.rs @@ -0,0 +1,25 @@ +//! Parse + +#![allow( + // Does not matter + clippy::unwrap_used +)] +#![no_main] + +use tokio::runtime::Handle; +use wtx::{ + web_socket::{FrameBufferVec, FrameVecMut, OpCode, WebSocketServer}, + BytesStream, ReadBuffer, +}; + +libfuzzer_sys::fuzz_target!(|data: &[u8]| { + let mut ws = WebSocketServer::new(ReadBuffer::default(), BytesStream::default()); + ws.set_max_payload_len(u16::MAX.into()); + let fb = &mut FrameBufferVec::default(); + Handle::current().block_on(async move { + ws.write_frame(&mut FrameVecMut::new_fin(fb.into(), OpCode::Text, data).unwrap()) + .await + .unwrap(); + let _frame = ws.read_frame(fb).await.unwrap(); + }); +}); diff --git a/wtx-fuzz/unmask.rs b/wtx-fuzz/unmask.rs new file mode 100644 index 00000000..c4163d01 --- /dev/null +++ b/wtx-fuzz/unmask.rs @@ -0,0 +1,8 @@ +//! Unmask + +#![no_main] + +libfuzzer_sys::fuzz_target!(|data: &[u8]| { + let mut data = data.to_vec(); + wtx::web_socket::unmask(&mut data, [1, 2, 3, 4]); +}); diff --git a/wtx/Cargo.toml b/wtx/Cargo.toml new file mode 100644 index 00000000..685621a4 --- /dev/null +++ b/wtx/Cargo.toml @@ -0,0 +1,80 @@ +[[example]] +name = "web_socket_client_autobahn_raw_tokio" +required-features = ["tokio", "web-socket-handshake"] + +[[example]] +name = "web_socket_client_cli_raw_tokio_rustls" +required-features = ["tokio-rustls", "web-socket-handshake"] + +[[example]] +name = "web_socket_server_echo_hyper" +required-features = ["web-socket-hyper"] + +[[example]] +name = "web_socket_server_echo_raw_async_std" +required-features = ["async-std", "web-socket-handshake"] + +[[example]] +name = "web_socket_server_echo_raw_glommio" +required-features = ["glommio", "web-socket-handshake"] + +[[example]] +name = "web_socket_server_echo_raw_tokio" +required-features = ["tokio", "web-socket-handshake"] + +[[example]] +name = "web_socket_server_echo_raw_tokio_rustls" +required-features = ["tokio-rustls", "web-socket-handshake"] + +[[example]] +name = "web_socket" +path = "profiling/web_socket.rs" + +[dependencies] +async-std = { default-features = false, features = ["default"], optional = true, version = "1.0" } +async-trait = { default-features = false, optional = true, version = "0.1" } +base64 = { default-features = false, features = ["alloc"], optional = true, version = "0.21" } +futures-lite = { default-features = false, optional = true, version = "1.0" } +glommio = { default-features = false, optional = true, version = "0.8" } +http = { default-features = false, optional = true, version = "0.2" } +httparse = { default-features = false, optional = true, version = "1.0" } +hyper = { default-features = false, features = ["client", "http1", "server"], optional = true, version = "0.14" } +rand = { default-features = false, features = ["getrandom", "small_rng"], version = "0.8" } +sha1 = { default-features = false, optional = true, version = "0.10" } +simdutf8 = { default-features = false, features = ["aarch64_neon"], optional = true, version = "0.1" } +tokio = { default-features = false, features = ["io-util", "net"], optional = true, version = "1.0" } +tokio-rustls = { default-features = false, optional = true, version = "0.24" } + +[dev-dependencies] +async-std = { default-features = false, features = ["attributes"], version = "1.0" } +tokio = { default-features = false, features = ["macros", "rt-multi-thread", "time"], version = "1.0" } +tokio-rustls = { default-features = false, features = ["tls12"], version = "0.24" } +rustls-pemfile = { default-features = false, version = "1.0" } +webpki-roots = { default-features = false, version = "0.25" } +wtx = { default-features = false, features = ["std", "tokio"], path = "." } + +[features] +async-std = ["dep:async-std", "std"] +default = [] +glommio = ["futures-lite", "dep:glommio"] +hyper = ["http", "dep:hyper", "tokio"] +std = [] +tokio = ["std", "dep:tokio"] +tokio-rustls = ["tokio", "dep:tokio-rustls"] +web-socket-handshake = ["base64", "httparse", "sha1"] +web-socket-hyper = ["hyper", "web-socket-handshake"] + +[package] +authors = ["Caio Fernandes "] +categories = ["asynchronous", "data-structures", "network-programming", "no-std", "web-programming"] +description = "Asynchronous WebSocket implementations" +edition = "2021" +keywords = ["client", "io", "network", "server", "websocket"] +license = "Apache-2.0" +name = "wtx" +readme = "README.md" +repository = "https://github.com/c410-f3r/wtx" +version = "0.5.2" + +[package.metadata.docs.rs] +all-features = true diff --git a/wtx/LICENSE b/wtx/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/wtx/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/wtx/README.md b/wtx/README.md new file mode 100644 index 00000000..b34513e7 --- /dev/null +++ b/wtx/README.md @@ -0,0 +1,51 @@ +# WTX + +[![CI](https://github.com/c410-f3r/wtx/workflows/CI/badge.svg)](https://github.com/c410-f3r/wtx/actions?query=workflow%3ACI) +[![crates.io](https://img.shields.io/crates/v/wtx.svg)](https://crates.io/crates/wtx) +[![Documentation](https://docs.rs/wtx/badge.svg)](https://docs.rs/wtx) +[![License](https://img.shields.io/badge/license-APACHE2-blue.svg)](./LICENSE) +[![Rustc](https://img.shields.io/badge/rustc-1.71-lightgray")](https://blog.rust-lang.org/2020/03/12/Rust-1.71.html) + +Different web transport implementations. + +## WebSocket + +Provides low and high level abstractions to dispatch frames, as such, it is up to you to implement [Stream](https://docs.rs/wtx/latest/wtx/trait.Stream.html) with any desired logic or use any of the built-in strategies through the selection of features. + +[fastwebsockets](https://github.com/denoland/fastwebsockets) served as an initial inspiration for the skeleton of this implementation so thanks to the authors. + +```rust +use wtx::{ + Stream, web_socket::{FrameBufferVec, FrameMutVec, FrameVecMut, OpCode, WebSocketClientOwned} +}; + +pub async fn handle_client_frames( + fb: &mut FrameBufferVec, + ws: &mut WebSocketClientOwned + ) -> wtx::Result<()> { + loop { + let frame = match ws.read_msg(fb).await { + Err(err) => { + println!("Error: {err}"); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Close, &[])?).await?; + break; + } + Ok(elem) => elem, + }; + match (frame.op_code(), frame.text_payload()) { + (_, Some(elem)) => println!("{elem}"), + (OpCode::Close, _) => break, + _ => {} + } + } + Ok(()) +} +``` + +See the `examples` directory for more suggestions. + +### Performance + +There are mainly 2 things that impact performance, the chosen runtime and the number of pre-allocated bytes. Specially for servers that have to create a new `WebSocket` instance for each handshake, pre-allocating a high number of bytes for short-lived or low-transfer connections can have a negative impact. + +![Benchmark](https://i.imgur.com/ZZU3Hay.jpeg) \ No newline at end of file diff --git a/wtx/benches/simple.rs b/wtx/benches/simple.rs new file mode 100644 index 00000000..6c79655f --- /dev/null +++ b/wtx/benches/simple.rs @@ -0,0 +1,18 @@ +#![feature(test)] + +extern crate test; + +use test::Bencher; + +#[bench] +fn unmask(b: &mut Bencher) { + const DATA_LEN: usize = 64 << 20; + let mut data: Vec = (0..DATA_LEN) + .map(|el| { + let n = el % usize::try_from(u8::MAX).unwrap(); + n.try_into().unwrap() + }) + .collect(); + let mask = [3, 5, 7, 11]; + b.iter(|| wtx::web_socket::unmask(&mut data, mask)); +} diff --git a/wtx/examples/common/mod.rs b/wtx/examples/common/mod.rs new file mode 100644 index 00000000..d3e7f74e --- /dev/null +++ b/wtx/examples/common/mod.rs @@ -0,0 +1,94 @@ +use std::borrow::BorrowMut; +use wtx::{ + web_socket::{ + handshake::{ + WebSocketAccept, WebSocketAcceptRaw, WebSocketHandshake, WebSocketHandshakeRaw, + }, + FrameBufferVec, OpCode, WebSocketClient, WebSocketServer, + }, + ReadBuffer, Stream, +}; + +#[cfg(not(feature = "async-trait"))] +pub(crate) trait AsyncBounds {} + +#[cfg(not(feature = "async-trait"))] +impl AsyncBounds for T where T: ?Sized {} + +#[cfg(feature = "async-trait")] +pub(crate) trait AsyncBounds: Send + Sync {} + +#[cfg(feature = "async-trait")] +impl AsyncBounds for T where T: Send + Sync + ?Sized {} + +pub(crate) async fn _accept_conn_and_echo_frames( + fb: &mut FrameBufferVec, + rb: &mut ReadBuffer, + stream: impl AsyncBounds + Stream, +) -> wtx::Result<()> { + let (_, mut ws) = WebSocketAcceptRaw { + fb, + headers_buffer: &mut <_>::default(), + key_buffer: &mut <_>::default(), + rb, + stream, + } + .accept() + .await?; + _handle_frames(fb, &mut ws).await?; + Ok(()) +} + +pub(crate) async fn _connect( + fb: &mut FrameBufferVec, + uri: &str, + rb: RB, + stream: S, +) -> wtx::Result> +where + RB: AsyncBounds + BorrowMut, + S: AsyncBounds + Stream, +{ + Ok(WebSocketHandshakeRaw { + fb, + headers_buffer: &mut <_>::default(), + rb, + uri, + stream, + } + .handshake() + .await? + .1) +} + +pub(crate) async fn _handle_frames( + fb: &mut FrameBufferVec, + ws: &mut WebSocketServer, +) -> wtx::Result<()> +where + RB: BorrowMut, +{ + loop { + let mut frame = ws.read_msg(fb).await?; + match frame.op_code() { + OpCode::Binary | OpCode::Text => { + ws.write_frame(&mut frame).await?; + } + OpCode::Close => break, + _ => {} + } + } + Ok(()) +} + +pub(crate) fn _host_from_args() -> String { + std::env::args() + .nth(1) + .unwrap_or_else(|| "127.0.0.1:8080".to_owned()) +} + +pub(crate) fn _uri_from_args() -> String { + std::env::args() + .nth(1) + .unwrap_or_else(|| "http://127.0.0.1:8080".to_owned()) +} diff --git a/wtx/examples/localhost.crt b/wtx/examples/localhost.crt new file mode 100644 index 00000000..9d259427 --- /dev/null +++ b/wtx/examples/localhost.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEtDCCApwCCQDtm3HlNW4u5TANBgkqhkiG9w0BAQsFADAcMQswCQYDVQQGEwJJ +TjENMAsGA1UEAwwERGl2eTAeFw0yMzA0MjIxMjAzMjZaFw0yODA0MjAxMjAzMjZa +MBwxCzAJBgNVBAYTAklOMQ0wCwYDVQQDDAREaXZ5MIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAvqQG1wPW3F53JjydkfDJSnHMJYtvqjsVIHbJWVV/Aes8 +OKp/JvdpzlP8YRLu6KI/mutya6iuGt+xHLXJdRJYAThoke5QML27s9raxOfl3+wO +AwUtGYP9G0KcwVFVbxOD/edJ84NSwSL6o0MqfiHReydi7Gc6xyRa6R8PPpJ2ckWV +nx8r/m/LCG5TxAPCU1GbGx3sWhvDJyzL7Yj/X2y7wqIVsJy/lMz765ND01LtvmlJ +IG7N9hnmcoVgxCxrWmBQ+x4YIAJx7OWcs/vuvSjsxJuxlRl+YeiZqilvm5u3Fopn +x5xzE1oN+vBU5ncDVqfoidsh5w7BkPHgHbZWE7Ba1wlp9mJqBMBe1ko9/xVJjmlb +ot+EinTDYGxhUfngh7tGt45bJjNHFINPf3WSCRPUancF/lJjHoTvlVAGYMZUMwNz +ENo+chYCg2Bb5c8+/OuYgtfqtSCttdw+Eo0V49zue/i7leGD1IQ+pgskdGvCa1UG +bwHkSbC61U9tDwHyjju8oi0wMEYsVBMyjy25wuS/iYCte5J1pfrCIuziR/xAiYfF ++oC0Hd828Tujtbii5YtXXr3Bjb+A78lnkadecXUBfIe8yqtPkgiMOPSWUM0KEuJr +EYvnZX4wdhfz9AD0NmgZrIvTlXE0s0hjVHvBzgJLzmIBHLMHGcZVeDoCe9oPF0sC +AwEAATANBgkqhkiG9w0BAQsFAAOCAgEAAAQNpXGAqng6YNy1AOVsPSXecVrcqxfp +l28SSfDzrvhtLs5FaK9MxER5Lt3c4Fmlq+fC5mg7eAHQcVQvpvnCufMDH17yZF04 +X/EzIniQ4fVkhyWtukDm18orpCQfDIyJ0iRSsFDzDyi/UZFIsPkD0lumNjGZY4m/ +VoELlwIAeoDgDplSDBqJ2N0dgxZYUKDjqS0hLBnp8RfETLTbXtpQCVN3Q35gJApB +gGRtwOKYf5GZPIcp0iDNumRLPLtqYanT/cD8nd54Bil13925l5dqy0/ozfm2I+NT +TYAd3b2q8Mexs/rJD8naCE7BM+zfbUkoUPOy1Q1y9A/5CfhjCAdJlnDnK0B+isW8 +HNl9U4pySDQRg66oUUZDboRGh+G7qQuPA2ewAz42KvtfS3XX6zaYxpQlrrjut7db +Df0y7fYumdmQtqZQJ2MtJHI9pZQ3+zxGq5RN/xZNh53XPIurCiqBypfHj4nSFNIq +VADjJEITr+oiFabDjp5jiwoewtEGCdT0PuzaY/iADvlxTOMdy6AUdRTkhLob930F +1QtKU45rwHTbaPxLdjvnKMI2ElwqVyFS5H5YNgM2xSWkRMmqPlvihh3a7M+Ux/Ri +C878EKTdkCNTXUpCCJhMGhrXTzYkJ5G+Nh9ERcTGuLkw6uzkUdbAmYgOn2GN3xvl +Q26ks4m/6Fs= +-----END CERTIFICATE----- diff --git a/wtx/examples/localhost.key b/wtx/examples/localhost.key new file mode 100644 index 00000000..e13c8349 --- /dev/null +++ b/wtx/examples/localhost.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQC+pAbXA9bcXncm +PJ2R8MlKccwli2+qOxUgdslZVX8B6zw4qn8m92nOU/xhEu7ooj+a63JrqK4a37Ec +tcl1ElgBOGiR7lAwvbuz2trE5+Xf7A4DBS0Zg/0bQpzBUVVvE4P950nzg1LBIvqj +Qyp+IdF7J2LsZzrHJFrpHw8+knZyRZWfHyv+b8sIblPEA8JTUZsbHexaG8MnLMvt +iP9fbLvCohWwnL+UzPvrk0PTUu2+aUkgbs32GeZyhWDELGtaYFD7HhggAnHs5Zyz +++69KOzEm7GVGX5h6JmqKW+bm7cWimfHnHMTWg368FTmdwNWp+iJ2yHnDsGQ8eAd +tlYTsFrXCWn2YmoEwF7WSj3/FUmOaVui34SKdMNgbGFR+eCHu0a3jlsmM0cUg09/ +dZIJE9RqdwX+UmMehO+VUAZgxlQzA3MQ2j5yFgKDYFvlzz7865iC1+q1IK213D4S +jRXj3O57+LuV4YPUhD6mCyR0a8JrVQZvAeRJsLrVT20PAfKOO7yiLTAwRixUEzKP +LbnC5L+JgK17knWl+sIi7OJH/ECJh8X6gLQd3zbxO6O1uKLli1devcGNv4DvyWeR +p15xdQF8h7zKq0+SCIw49JZQzQoS4msRi+dlfjB2F/P0APQ2aBmsi9OVcTSzSGNU +e8HOAkvOYgEcswcZxlV4OgJ72g8XSwIDAQABAoICACH5GxrwHTcSQot22+GpFkYE +94ttSM3+T2qEoKch3EtcP1Qd1iD8kEdrohsug5LDbzBNawuSeMxjNq3WG3uYdERr +Z/8xh+rXtP59LuVOKiH4cBrLrljQs6dK/KJauy3bPXde40fZDENM13uGuajWn/0h +bLiSQOBCM0098rqE4UTF7772kCF8jKMI/jZ9MQEmFs0DTR5VujZd/k1rT48S0ncB +6XmaxW1gBjjZ+olLSwDWxGhaNqv3u6CG8lKjU9I8PdIyb7wsk17TIFTWvZnKFD+J +O2FFtMb/63pufewuGLeUnJ/u2ncFYl5ou8iCRv8HVyJSAb2qXIZXBEhnOPmzQMyn ++NEkX/3w8LpWhbd0RCsYeIQqYWTBX97kUXgbDtMCOCPH/DdDcJBQWRm5TUJt062K +dDOwqg1jWOQ8F8nr3OwJ9E7NPoBQO5zWVfG6i3opxcFhoWGdTF0oIEhT42aX20ip +AqQGwrW72j+qZdLt8nHK38kvnq8RinS7I2mw8QVhGFqN84x2WIUSiny9+lVD4lfT +ckcSblKL9BpjYVkDvQ3s6BS92RCYqFaBwES4s5oFd2V+Ffrz3cCqp59k/5bZ6x7h +ia+Hw2/Mmtp4TaAVcGrHOvAmnZcnfV9jvsdRzBeihom/gtnVrA8ctsipxhE+Ylkx +q4g/xHHuBpRLMxusD3cBAoIBAQD8fR27xKLYw6aCgLW5pYMDUp8YaFW3uhsA3VPu +UlErdF7QJwGHq7e4eiWp9W1lsc7Wq99PnetRxrs1HJNjExVcLSpCOdXOA0BGO8hI +VgnmFr2doAmjLAXZzSDWnt7jl2FP4Casyv01/EmtpfFdnsXtdJa4s8ITZ5k9cRaV +z28YkpMSXkoY83E4YOK2TrwzvgK7mwFjjV1x3xZZu6QXeMppCGxYQoCooNUvO6uO +r4npHXJVBUdz+mMxci08v0sqhmakk4YRPq0A9A5b55zlSZwwE51d8wSgwfqdFygx +EULkLCs881tM5jfIXcnHSxOQ23jd7bUHy8Hdll1U741AYXPLAoIBAQDBSrnORhzQ +urN33OUvdYBIQT8pFDPNC7hoEPhxrk1V3N47Kvmxgc88pzrscqNI0e93XixRqjc5 +bCeR07qWl/xWlD8/MVu2MGHxtdjoXEDUtynUoiVe6D1hea4hcF4BkBWAhLKIz81R +5fU9p/RzZmWy8Fbc+ZV8GvX64LS/orWGmQvJx9Q0byZui6JUUrChRrYZWeFX+ehB +5Y/5BHsOL15HntqBfF3v6zK7vzQ03Aqd9vWxR3xUlbpUvxiax5Kg+fuiksmwNKIh +P3/nhoP3LBtBZUx4h/Jdt0e/NFHtDXdIbxaaHO/jbTfy1tg4/wCp9OneMZF2kMVj +PpU7wetwr3qBAoIBAFOEV0d63Zrx7KwSQworc1CwDawXJvNk/fWlQFP+qpbDIXGc +1Wa5KEY/MSIs6ojO7eoYY/+D7wjXwajp0N7euxwIXIgXdV91t9cDg1ZaD2AqeYIg +I8/ziePndEtJtdR2iFvRezmA04z97Kkh0Nr03+eRvyFNZI7in8+xDpVzTf5EzZ0v +zza9n9/UPGmtVZeP7Ht95FG3uwclkdEQvlB9RgbEIIJ5TPF6ccnz5OWHrwiLEvyI +iIAWfKUobUpAxG5GksExgxFFOBiuoelIjZ9SX/WPJ2iiMA+02l8H/+VrHkM3UP4S +SUsAg8clLs9bSBeMYUiXjmALyA6x5CFqM8Dt+00CggEAPnVgDvh27Te3MF8vq6ND +XZW/zA1cI8DKyM3bChjxonIpWWMspiA1D/tVvfvZKXm08JR8q7Ld/28kZinNnEXm +Yy+qNEhFw1xk+c7yFTtiM5owKSZv/vf6hZnlG6cMqWKeoBXA/xZu2Sz+jvrLsdJ/ +wE+LMgJwPFcV7whXP6lbEPA5b+1jc8IK4CO8w5SowKRxyUVS3LPDSi/c0vGQtee2 +hlwdbUP7ssAEd8h0HTSRNbQMdkmMMmTjfej2EWW1ytCccE8QXyDS1v2G3hCIagFV +mU8bY8NCHOhRhcZpRrlYNw62dfwtxAaR0qV73wb/duvN+l94CqEDN2uMm2+xHYuG +gQKCAQEA76Lce+5FCR31hD3Gm0/RbA4jFDzkV5dY8JxOwzZHGzOUJpVd2sZ3a+rQ +BIGC5EiSh9cdIKPqfuxjf+ufzhrA1rzxJVLPWHiUT+37ionXsUpIQQLeeaEpeHFb +Dqg+vu2y+Fg9vYDXTKVZWXADp9kH+KtgpvrcaL2k4UkY2q+jKVLvTt+ezwWTWZFF +QSFSMpTiAAo/kSEryG9DGnyvC5UZsgKsV9eQe7rkMg8p6TjFANcx6oDR6M6fchtn +YmrKkFivZU2bhmGM1HJCIcmAIXtqsf6gb8CqqqX0NQb5m23OJU3NC7N9g34ofhCm +GPx3/+N92+2q031KtpGtHOvcSrHFMA== +-----END PRIVATE KEY----- diff --git a/wtx/examples/web_socket_client_autobahn_raw_tokio.rs b/wtx/examples/web_socket_client_autobahn_raw_tokio.rs new file mode 100644 index 00000000..e187b325 --- /dev/null +++ b/wtx/examples/web_socket_client_autobahn_raw_tokio.rs @@ -0,0 +1,95 @@ +//! WebSocket autobahn client. + +mod common; + +use tokio::net::TcpStream; +use wtx::{ + web_socket::{FrameBufferVec, FrameMutVec, OpCode}, + ReadBuffer, +}; + +#[tokio::main] +async fn main() -> Result<(), Error> { + let fb = &mut <_>::default(); + let host = &common::_host_from_args(); + let rb = &mut <_>::default(); + for case in 1..=get_case_count(fb, &host, rb).await? { + let mut ws = common::_connect( + fb, + &format!("http://{host}/runCase?case={case}&agent=wtx"), + &mut *rb, + TcpStream::connect(host).await.map_err(wtx::Error::from)?, + ) + .await?; + loop { + let mut frame = match ws.read_msg(fb).await { + Err(err) => { + println!("Error: {err}"); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Close, &[])?) + .await?; + break; + } + Ok(elem) => elem, + }; + match frame.op_code() { + OpCode::Binary | OpCode::Text => ws.write_frame(&mut frame).await?, + OpCode::Close => break, + _ => {} + } + } + } + common::_connect( + fb, + &format!("http://{host}/updateReports?agent=wtx"), + rb, + TcpStream::connect(host).await.map_err(wtx::Error::from)?, + ) + .await? + .write_frame(&mut FrameMutVec::close_from_params(1000, fb, &[])?) + .await?; + Ok(()) +} + +/// Error +#[derive(Debug)] +pub enum Error { + /// ParseIntError + ParseIntError(std::num::ParseIntError), + /// Wtx + Wtx(wtx::Error), +} + +impl From for Error { + fn from(from: std::num::ParseIntError) -> Self { + Self::ParseIntError(from) + } +} + +impl From for Error { + fn from(from: wtx::Error) -> Self { + Self::Wtx(from) + } +} + +async fn get_case_count( + fb: &mut FrameBufferVec, + host: &str, + rb: &mut ReadBuffer, +) -> Result { + let mut ws = common::_connect( + fb, + &format!("http://{host}/getCaseCount"), + rb, + TcpStream::connect(host).await.map_err(wtx::Error::from)?, + ) + .await?; + let rslt = ws + .read_msg(fb) + .await? + .text_payload() + .unwrap_or_default() + .parse()?; + ws.write_frame(&mut FrameMutVec::close_from_params(1000, fb, &[])?) + .await?; + Ok(rslt) +} diff --git a/wtx/examples/web_socket_client_cli_raw_tokio_rustls.rs b/wtx/examples/web_socket_client_cli_raw_tokio_rustls.rs new file mode 100644 index 00000000..3c029b84 --- /dev/null +++ b/wtx/examples/web_socket_client_cli_raw_tokio_rustls.rs @@ -0,0 +1,62 @@ +//! WebSocket CLI client. + +mod common; + +use std::{ + io::{self, ErrorKind}, + sync::Arc, +}; +use tokio::net::TcpStream; +use tokio_rustls::{ + rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, + TlsConnector, +}; +use webpki_roots::TLS_SERVER_ROOTS; +use wtx::{web_socket::OpCode, UriParts}; + +#[tokio::main] +async fn main() -> wtx::Result<()> { + let fb = &mut <_>::default(); + let map_err = |_err| io::Error::new(ErrorKind::InvalidInput, "invalid dnsname"); + let rb = &mut <_>::default(); + let uri = common::_uri_from_args(); + let uri_parts = UriParts::from(uri.as_str()); + let mut ws = common::_connect( + fb, + &uri, + rb, + tls_connector() + .connect( + ServerName::try_from(uri_parts.hostname).map_err(map_err)?, + TcpStream::connect(uri_parts.host).await?, + ) + .await?, + ) + .await?; + + loop { + let frame = ws.read_msg(fb).await?; + match (frame.op_code(), frame.text_payload()) { + (_, Some(elem)) => println!("{elem}"), + (OpCode::Close, _) => break, + _ => {} + } + } + Ok(()) +} + +fn tls_connector() -> TlsConnector { + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + TlsConnector::from(Arc::new(config)) +} diff --git a/wtx/examples/web_socket_server_echo_hyper.rs b/wtx/examples/web_socket_server_echo_hyper.rs new file mode 100644 index 00000000..9e5270c1 --- /dev/null +++ b/wtx/examples/web_socket_server_echo_hyper.rs @@ -0,0 +1,48 @@ +//! WebSocket echo server. + +mod common; + +use hyper::{server::conn::Http, service::service_fn, Body, Request, Response}; +use tokio::{net::TcpListener, task}; +use wtx::{ + web_socket::{ + handshake::{WebSocketUpgrade, WebSocketUpgradeHyper}, + WebSocketServer, + }, + ReadBuffer, +}; + +#[tokio::main] +async fn main() -> wtx::Result<()> { + let listener = TcpListener::bind(common::_host_from_args()).await?; + loop { + let (stream, _) = listener.accept().await?; + let _jh = tokio::spawn(async move { + let service = service_fn(server_upgrade); + if let Err(err) = Http::new() + .serve_connection(stream, service) + .with_upgrades() + .await + { + println!("An error occurred: {err}"); + } + }); + } +} + +async fn server_upgrade(req: Request) -> wtx::Result> { + let (res, fut) = WebSocketUpgradeHyper { req }.upgrade()?; + let _jh = task::spawn(async move { + let fut = async move { + common::_handle_frames( + &mut <_>::default(), + &mut WebSocketServer::new(ReadBuffer::default(), fut.await?), + ) + .await + }; + if let Err(err) = tokio::task::unconstrained(fut).await { + eprintln!("Error in WebSocket connection: {err}"); + } + }); + Ok(res) +} diff --git a/wtx/examples/web_socket_server_echo_raw_async_std.rs b/wtx/examples/web_socket_server_echo_raw_async_std.rs new file mode 100644 index 00000000..c5c361d3 --- /dev/null +++ b/wtx/examples/web_socket_server_echo_raw_async_std.rs @@ -0,0 +1,25 @@ +//! WebSocket echo server. + +mod common; + +use async_std::net::TcpListener; +use wtx::{web_socket::FrameBufferVec, ReadBuffer}; + +#[async_std::main] +async fn main() -> wtx::Result<()> { + let listener = TcpListener::bind(common::_host_from_args()).await?; + loop { + let (stream, _) = listener.accept().await?; + let _jh = async_std::task::spawn(async move { + if let Err(err) = common::_accept_conn_and_echo_frames( + &mut FrameBufferVec::default(), + &mut ReadBuffer::default(), + stream, + ) + .await + { + println!("{err}"); + } + }); + } +} diff --git a/wtx/examples/web_socket_server_echo_raw_glommio.rs b/wtx/examples/web_socket_server_echo_raw_glommio.rs new file mode 100644 index 00000000..3a072eab --- /dev/null +++ b/wtx/examples/web_socket_server_echo_raw_glommio.rs @@ -0,0 +1,47 @@ +//! WebSocket echo server. + +mod common; + +#[cfg(feature = "async-trait")] +mod cfg_hack { + pub(crate) fn hack() -> wtx::Result<()> { + Ok(()) + } +} + +#[cfg(not(feature = "async-trait"))] +mod cfg_hack { + use glommio::{net::TcpListener, CpuSet, LocalExecutorPoolBuilder, PoolPlacement}; + use std::thread::available_parallelism; + + pub(crate) fn hack() -> wtx::Result<()> { + let builder = LocalExecutorPoolBuilder::new(PoolPlacement::MaxSpread( + available_parallelism()?.into(), + CpuSet::online().ok(), + )); + for result in builder.on_all_shards(exec)?.join_all() { + result??; + } + Ok(()) + } + + async fn exec() -> wtx::Result<()> { + let listener = TcpListener::bind(crate::common::_host_from_args())?; + loop { + let stream = listener.accept().await?; + let _jh = glommio::spawn_local(async move { + let fb = &mut <_>::default(); + let rb = &mut <_>::default(); + if let Err(err) = crate::common::_accept_conn_and_echo_frames(fb, rb, stream).await + { + println!("{err}"); + } + }) + .detach(); + } + } +} + +fn main() -> wtx::Result<()> { + cfg_hack::hack() +} diff --git a/wtx/examples/web_socket_server_echo_raw_tokio.rs b/wtx/examples/web_socket_server_echo_raw_tokio.rs new file mode 100644 index 00000000..8c7ca32b --- /dev/null +++ b/wtx/examples/web_socket_server_echo_raw_tokio.rs @@ -0,0 +1,24 @@ +//! WebSocket echo server. + +mod common; + +use tokio::net::TcpListener; + +#[tokio::main] +async fn main() -> wtx::Result<()> { + let listener = TcpListener::bind(common::_host_from_args()).await?; + loop { + let (stream, _) = listener.accept().await?; + let _jh = tokio::spawn(async move { + if let Err(err) = tokio::task::unconstrained(common::_accept_conn_and_echo_frames( + &mut <_>::default(), + &mut <_>::default(), + stream, + )) + .await + { + println!("{err}"); + } + }); + } +} diff --git a/wtx/examples/web_socket_server_echo_raw_tokio_rustls.rs b/wtx/examples/web_socket_server_echo_raw_tokio_rustls.rs new file mode 100644 index 00000000..e1ede742 --- /dev/null +++ b/wtx/examples/web_socket_server_echo_raw_tokio_rustls.rs @@ -0,0 +1,52 @@ +//! WebSocket echo server. + +mod common; + +use rustls_pemfile::{certs, pkcs8_private_keys}; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio_rustls::{ + rustls::{Certificate, PrivateKey, ServerConfig}, + TlsAcceptor, +}; + +static CERT: &[u8] = include_bytes!("./localhost.crt"); +static KEY: &[u8] = include_bytes!("./localhost.key"); + +#[tokio::main] +async fn main() -> wtx::Result<()> { + let listener = TcpListener::bind(common::_host_from_args()).await?; + let tls_acceptor = tls_acceptor()?; + loop { + let (stream, _) = listener.accept().await?; + let local_tls_acceptor = tls_acceptor.clone(); + let _jh = tokio::spawn(async move { + let fun = || async move { + let stream = local_tls_acceptor.accept(stream).await?; + tokio::task::unconstrained(common::_accept_conn_and_echo_frames( + &mut <_>::default(), + &mut <_>::default(), + stream, + )) + .await + }; + if let Err(err) = fun().await { + println!("{err}"); + } + }); + } +} + +fn tls_acceptor() -> wtx::Result { + let mut keys: Vec = pkcs8_private_keys(&mut &*KEY) + .map(|certs| certs.into_iter().map(PrivateKey).collect()) + .map_err(wtx::Error::from)?; + let certs = certs(&mut &*CERT) + .map(|certs| certs.into_iter().map(Certificate).collect()) + .map_err(wtx::Error::from)?; + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0))?; + Ok(TlsAcceptor::from(Arc::new(config))) +} diff --git a/wtx/profiling/web_socket.rs b/wtx/profiling/web_socket.rs new file mode 100644 index 00000000..ab50566e --- /dev/null +++ b/wtx/profiling/web_socket.rs @@ -0,0 +1,62 @@ +//! Profiling + +use std::hint::black_box; +use wtx::{ + web_socket::{ + FrameBufferVec, FrameMutVec, OpCode, WebSocket, WebSocketClient, WebSocketServer, + }, + BytesStream, ReadBuffer, +}; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> wtx::Result<()> { + let data = vec![52; 16 * 1024 * 1024]; + let mut fb = FrameBufferVec::default(); + let mut stream = BytesStream::default(); + black_box(from_client_to_server(&data, &mut fb, &mut stream).await?); + stream.clear(); + black_box(from_server_to_client(&data, &mut fb, &mut stream).await?); + Ok(()) +} + +async fn from_client_to_server( + data: &[u8], + fb: &mut FrameBufferVec, + stream: &mut BytesStream, +) -> wtx::Result<()> { + write(data, fb, WebSocketClient::new(<_>::default(), stream)).await?; + read(fb, WebSocketServer::new(<_>::default(), stream)).await?; + Ok(()) +} + +async fn from_server_to_client( + data: &[u8], + fb: &mut FrameBufferVec, + stream: &mut BytesStream, +) -> wtx::Result<()> { + write(data, fb, WebSocketServer::new(<_>::default(), stream)).await?; + read(fb, WebSocketClient::new(<_>::default(), stream)).await?; + Ok(()) +} + +async fn read( + fb: &mut FrameBufferVec, + mut ws: WebSocket, +) -> wtx::Result<()> { + let _frame = ws.read_msg(fb).await?; + Ok(()) +} + +async fn write( + data: &[u8], + fb: &mut FrameBufferVec, + mut ws: WebSocket, +) -> wtx::Result<()> { + ws.write_frame(&mut FrameMutVec::new_unfin(fb, OpCode::Text, data)?) + .await?; + ws.write_frame(&mut FrameMutVec::new_unfin(fb, OpCode::Continuation, data)?) + .await?; + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Continuation, data)?) + .await?; + Ok(()) +} diff --git a/wtx/src/cache.rs b/wtx/src/cache.rs new file mode 100644 index 00000000..bbd33b4d --- /dev/null +++ b/wtx/src/cache.rs @@ -0,0 +1,49 @@ +#![allow( + // False positive + clippy::arithmetic_side_effects, + // Indices are within bounds + clippy::indexing_slicing +)] + +use core::{ + array, + sync::atomic::{AtomicUsize, Ordering}, +}; + +/// Helper intended to avoid excessive allocations between multiple tasks/threads through +/// the sharing of `N` elements behind some provided locking mechanism. +/// +/// Note that the current approach locks the maximum number of simultaneous accesses to `N`. If +/// it is not desirable, you can create your own strategy or always allocate a new instance. +#[derive(Debug)] +pub struct Cache { + array: [T; N], + idx: AtomicUsize, +} + +impl Cache { + /// It is up to the caller to provide all elements. + #[inline] + pub const fn new(array: [T; N]) -> Self { + Self { + array, + idx: AtomicUsize::new(0), + } + } + + /// Each array element is constructed using `cb`. + #[inline] + pub fn from_cb(cb: impl FnMut(usize) -> T) -> Self { + Self { + array: array::from_fn(cb), + idx: AtomicUsize::new(0), + } + } + + /// Provides the next available element returning back to the begging when the internal + /// counter overflows `N`. + #[inline] + pub fn next(&self) -> &T { + &self.array[self.idx.fetch_add(1, Ordering::Relaxed) & (N - 1)] + } +} diff --git a/wtx/src/error.rs b/wtx/src/error.rs new file mode 100644 index 00000000..0ed058f5 --- /dev/null +++ b/wtx/src/error.rs @@ -0,0 +1,174 @@ +use crate::ExpectedHeader; +use core::{ + fmt::{Debug, Display, Formatter}, + num::TryFromIntError, +}; + +/// Grouped individual errors +// +// * `Invalid` Something is present but has invalid state. +// * `Missing`: Not present when expected to be. +// * `Unexpected`: Received something that was not intended. +#[derive(Debug)] +pub enum Error { + /// Invalid UTF-8. + InvalidUTF8, + + /// Missing Header + MissingHeader { + /// See [ExpectedHeader]. + expected: ExpectedHeader, + }, + /// Url does not contain a host. + MissingHost, + + /// HTTP version does not match the expected method. + UnexpectedHttpMethod, + /// HTTP version does not match the expected value. + UnexpectedHttpVersion, + /// Unexpected end of file when reading. + UnexpectedEOF, + + /// The system does not process HTTP messages greater than 2048 bytes. + VeryLargeHttp, + + // External + // + /// See [glommio::GlommioError]. + #[cfg(all(feature = "glommio", feature = "hyper"))] + Glommio(std::sync::Mutex>), + /// See [glommio::GlommioError]. + #[cfg(all(feature = "glommio", not(feature = "hyper")))] + Glommio(Box>), + #[cfg(feature = "http")] + /// See [hyper::Error] + HttpError(http::Error), + /// See [http::header::InvalidHeaderName] + #[cfg(feature = "http")] + HttpInvalidHeaderName(http::header::InvalidHeaderName), + /// See [http::header::InvalidHeaderValue] + #[cfg(feature = "http")] + HttpInvalidHeaderValue(http::header::InvalidHeaderValue), + /// See [http::status::InvalidStatusCode] + #[cfg(feature = "http")] + HttpInvalidStatusCode(http::status::InvalidStatusCode), + #[cfg(feature = "web-socket-handshake")] + /// See [httparse::Error]. + HttpParse(httparse::Error), + #[cfg(feature = "hyper")] + /// See [hyper::Error] + HyperError(hyper::Error), + #[cfg(feature = "std")] + /// See [std::io::Error] + IoError(std::io::Error), + #[cfg(feature = "tokio-rustls")] + /// See [tokio_rustls::rustls::Error]. + TokioRustLsError(Box), + /// See [TryFromIntError] + TryFromIntError(TryFromIntError), + /// See [crate::web_socket::WebSocketError]. + WebSocketError(crate::web_socket::WebSocketError), +} + +impl Display for Error { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + ::fmt(self, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[cfg(feature = "glommio")] +impl From> for Error { + #[inline] + fn from(from: glommio::GlommioError<()>) -> Self { + Self::Glommio(from.into()) + } +} + +#[cfg(feature = "hyper")] +impl From for Error { + #[inline] + fn from(from: hyper::Error) -> Self { + Self::HyperError(from) + } +} + +#[cfg(feature = "http")] +impl From for Error { + #[inline] + fn from(from: http::Error) -> Self { + Self::HttpError(from) + } +} + +#[cfg(feature = "http")] +impl From for Error { + #[inline] + fn from(from: http::header::InvalidHeaderName) -> Self { + Self::HttpInvalidHeaderName(from) + } +} + +#[cfg(feature = "http")] +impl From for Error { + #[inline] + fn from(from: http::header::InvalidHeaderValue) -> Self { + Self::HttpInvalidHeaderValue(from) + } +} + +#[cfg(feature = "http")] +impl From for Error { + #[inline] + fn from(from: http::status::InvalidStatusCode) -> Self { + Self::HttpInvalidStatusCode(from) + } +} + +#[cfg(feature = "web-socket-handshake")] +impl From for Error { + #[inline] + fn from(from: httparse::Error) -> Self { + Self::HttpParse(from) + } +} + +impl From for Error { + #[inline] + fn from(_: core::str::Utf8Error) -> Self { + Self::InvalidUTF8 + } +} + +#[cfg(feature = "tokio-rustls")] +impl From for Error { + #[inline] + fn from(from: tokio_rustls::rustls::Error) -> Self { + Self::TokioRustLsError(from.into()) + } +} + +impl From for Error { + #[inline] + fn from(from: TryFromIntError) -> Self { + Self::TryFromIntError(from) + } +} + +impl From for Error { + #[inline] + fn from(from: crate::web_socket::WebSocketError) -> Self { + Self::WebSocketError(from) + } +} + +#[cfg(feature = "std")] +impl From for Error { + #[inline] + fn from(from: std::io::Error) -> Self { + Self::IoError(from) + } +} diff --git a/wtx/src/expected_header.rs b/wtx/src/expected_header.rs new file mode 100644 index 00000000..a89ad9c1 --- /dev/null +++ b/wtx/src/expected_header.rs @@ -0,0 +1,13 @@ +/// Expected HTTP headers +#[allow(non_camel_case_types)] +#[derive(Debug)] +pub enum ExpectedHeader { + /// `connection` key with `upgrade` value. + Connection_Upgrade, + /// `sec-websocket-key` key. + SecWebSocketKey, + /// `sec-websocket-version` key with `13` value. + SecWebSocketVersion_13, + /// `upgrade` key with `websocket` value. + Upgrade_WebSocket, +} diff --git a/wtx/src/lib.rs b/wtx/src/lib.rs new file mode 100644 index 00000000..32686321 --- /dev/null +++ b/wtx/src/lib.rs @@ -0,0 +1,34 @@ +#![cfg_attr( + not(feature = "async-trait"), + feature(array_chunks, async_fn_in_trait, impl_trait_projections, inline_const) +)] +#![cfg_attr(not(feature = "std"), no_std)] +#![doc = include_str!("../README.md")] + +extern crate alloc; + +mod cache; +mod error; +mod expected_header; +mod misc; +mod read_buffer; +#[cfg(feature = "web-socket-handshake")] +mod request; +#[cfg(feature = "web-socket-handshake")] +mod response; +mod stream; +pub mod web_socket; + +pub use crate::stream::{BytesStream, DummyStream, Stream}; +pub use cache::Cache; +pub use error::Error; +pub use expected_header::ExpectedHeader; +pub use misc::uri_parts::UriParts; +pub use read_buffer::ReadBuffer; +#[cfg(feature = "web-socket-handshake")] +pub use request::Request; +#[cfg(feature = "web-socket-handshake")] +pub use response::Response; + +/// Shortcut of [core::result::Result]. +pub type Result = core::result::Result; diff --git a/wtx/src/misc.rs b/wtx/src/misc.rs new file mode 100644 index 00000000..3d93a985 --- /dev/null +++ b/wtx/src/misc.rs @@ -0,0 +1,65 @@ +mod incomplete_utf8_char; +mod rng; +mod traits; +pub(crate) mod uri_parts; +mod utf8_errors; + +pub(crate) use incomplete_utf8_char::{CompleteErr, IncompleteUtf8Char}; +pub(crate) use rng::Rng; +pub(crate) use traits::{AsyncBounds, Expand, SingleTypeStorage}; +pub(crate) use utf8_errors::{ExtUtf8Error, StdUtf8Error}; + +pub(crate) fn from_utf8_opt(bytes: &[u8]) -> Option<&str> { + #[cfg(feature = "simdutf8")] + return simdutf8::basic::from_utf8(bytes).ok(); + #[cfg(not(feature = "simdutf8"))] + return core::str::from_utf8(bytes).ok(); +} + +pub(crate) fn from_utf8_ext_rslt(bytes: &[u8]) -> Result<&str, ExtUtf8Error> { + let err = match from_utf8_std_rslt(bytes) { + Ok(elem) => return Ok(elem), + Err(error) => error, + }; + let (_valid_bytes, after_valid) = bytes.split_at(err.valid_up_to); + match err.error_len { + None => Err(ExtUtf8Error::Incomplete { + incomplete_ending_char: { + let opt = IncompleteUtf8Char::new(after_valid); + opt.ok_or(ExtUtf8Error::Invalid)? + }, + }), + Some(_) => Err(ExtUtf8Error::Invalid), + } +} + +pub(crate) fn from_utf8_std_rslt(bytes: &[u8]) -> Result<&str, StdUtf8Error> { + #[cfg(feature = "simdutf8")] + return simdutf8::compat::from_utf8(bytes).map_err(|element| StdUtf8Error { + valid_up_to: element.valid_up_to(), + error_len: element.error_len(), + }); + #[cfg(not(feature = "simdutf8"))] + return core::str::from_utf8(bytes).map_err(|element| StdUtf8Error { + valid_up_to: element.valid_up_to(), + error_len: element.error_len(), + }); +} + +#[cfg(test)] +mod tests { + use crate::UriParts; + + #[test] + fn uri_parts_generates_correct_output() { + assert_eq!( + UriParts::from("foo://user:pass@sub.domain.com:80/pa/th?query=value#hash"), + UriParts { + authority: "user:pass@sub.domain.com:80", + host: "sub.domain.com:80", + hostname: "sub.domain.com", + href: "/pa/th?query=value#hash" + } + ); + } +} diff --git a/wtx/src/misc/incomplete_utf8_char.rs b/wtx/src/misc/incomplete_utf8_char.rs new file mode 100644 index 00000000..cacc795d --- /dev/null +++ b/wtx/src/misc/incomplete_utf8_char.rs @@ -0,0 +1,76 @@ +use crate::misc::from_utf8_std_rslt; + +pub(crate) struct IncompleteUtf8Char { + buffer: [u8; 4], + len: usize, +} + +impl IncompleteUtf8Char { + pub(crate) fn new(bytes: &[u8]) -> Option { + let len = bytes.len().min(4); + let bytes_slice = bytes.get(..len)?; + let mut buffer = [0, 0, 0, 0]; + buffer.get_mut(..len)?.copy_from_slice(bytes_slice); + Some(Self { + buffer, + len: bytes.len(), + }) + } + + pub(crate) fn complete<'bytes>( + &mut self, + bytes: &'bytes [u8], + ) -> (Result<(), CompleteErr>, &'bytes [u8]) { + let (consumed, tce) = self.push_to_build_valid_char(bytes); + let remaining = bytes.get(consumed..).unwrap_or_default(); + match tce { + None => (Ok(()), remaining), + Some(elem) => (Err(elem), remaining), + } + } + + fn push_to_build_valid_char(&mut self, bytes: &[u8]) -> (usize, Option) { + let initial_len = self.len; + let to_write_len; + { + let unwritten = self.buffer.get_mut(initial_len..).unwrap_or_default(); + to_write_len = unwritten.len().min(bytes.len()); + unwritten + .get_mut(..to_write_len) + .unwrap_or_default() + .copy_from_slice(bytes.get(..to_write_len).unwrap_or_default()); + }; + let new_bytes = { + let len = initial_len.wrapping_add(to_write_len); + self.buffer.get(..len).unwrap_or_default() + }; + if let Err(err) = from_utf8_std_rslt(new_bytes) { + if err.valid_up_to > 0 { + self.len = err.valid_up_to; + (err.valid_up_to.wrapping_sub(initial_len), None) + } else { + match err.error_len { + None => { + self.len = new_bytes.len(); + (to_write_len, Some(CompleteErr::InsufficientInput)) + } + Some(invalid_seq_len) => { + self.len = invalid_seq_len; + ( + invalid_seq_len.wrapping_sub(initial_len), + Some(CompleteErr::HasInvalidBytes), + ) + } + } + } + } else { + self.len = new_bytes.len(); + (to_write_len, None) + } + } +} + +pub(crate) enum CompleteErr { + HasInvalidBytes, + InsufficientInput, +} diff --git a/wtx/src/misc/rng.rs b/wtx/src/misc/rng.rs new file mode 100644 index 00000000..5771e807 --- /dev/null +++ b/wtx/src/misc/rng.rs @@ -0,0 +1,26 @@ +use rand::{rngs::SmallRng, Rng as _, SeedableRng}; + +// Used for compatibility reasons +#[derive(Debug)] +pub(crate) struct Rng { + rng: SmallRng, +} + +impl Rng { + pub(crate) fn random_u8_4(&mut self) -> [u8; 4] { + self.rng.gen() + } + + pub(crate) fn _random_u8_16(&mut self) -> [u8; 16] { + self.rng.gen() + } +} + +impl Default for Rng { + #[inline] + fn default() -> Self { + Self { + rng: SmallRng::from_entropy(), + } + } +} diff --git a/wtx/src/misc/traits.rs b/wtx/src/misc/traits.rs new file mode 100644 index 00000000..12de4d40 --- /dev/null +++ b/wtx/src/misc/traits.rs @@ -0,0 +1,84 @@ +use alloc::vec::Vec; + +/// Internal trait not intended for public usage +#[cfg(not(feature = "async-trait"))] +pub trait AsyncBounds {} + +#[cfg(not(feature = "async-trait"))] +impl AsyncBounds for T where T: ?Sized {} + +/// Internal trait not intended for public usage +#[cfg(feature = "async-trait")] +pub trait AsyncBounds: Send + Sync {} + +#[cfg(feature = "async-trait")] +impl AsyncBounds for T where T: Send + Sync + ?Sized {} + +/// Internal trait not intended for public usage +pub trait Expand { + /// Internal method not intended for public usage + fn expand(&mut self, len: usize); +} + +impl Expand for &mut T +where + T: Expand, +{ + fn expand(&mut self, len: usize) { + (*self).expand(len); + } +} + +impl Expand for Vec +where + T: Clone + Default, +{ + fn expand(&mut self, len: usize) { + if len > self.len() { + self.resize(len, <_>::default()); + } + } +} + +impl Expand for &mut [T] { + fn expand(&mut self, _: usize) {} +} + +impl Expand for [T; N] { + fn expand(&mut self, _: usize) {} +} + +/// Internal trait not intended for public usage +pub trait SingleTypeStorage { + /// Internal method not intended for public usage + type Item; +} + +impl SingleTypeStorage for &T +where + T: SingleTypeStorage, +{ + type Item = T::Item; +} + +impl SingleTypeStorage for &mut T +where + T: SingleTypeStorage, +{ + type Item = T::Item; +} + +impl SingleTypeStorage for [T; N] { + type Item = T; +} + +impl SingleTypeStorage for &'_ [T] { + type Item = T; +} + +impl SingleTypeStorage for &'_ mut [T] { + type Item = T; +} +impl SingleTypeStorage for Vec { + type Item = T; +} diff --git a/wtx/src/misc/uri_parts.rs b/wtx/src/misc/uri_parts.rs new file mode 100644 index 00000000..39ff8f8d --- /dev/null +++ b/wtx/src/misc/uri_parts.rs @@ -0,0 +1,36 @@ +/// Elements that compose an URI. +/// +/// ```txt +/// foo://user:pass@sub.domain.com:80/pa/th?query=value#hash +/// ``` +#[derive(Debug, Eq, PartialEq)] +pub struct UriParts<'uri> { + /// `user:pass@sub.domain.com:80` + pub authority: &'uri str, + /// `sub.domain.com:80` + pub host: &'uri str, + /// `sub.domain.com` + pub hostname: &'uri str, + /// `/pa/th?query=value#hash` + pub href: &'uri str, +} + +impl<'str> From<&'str str> for UriParts<'str> { + #[inline] + fn from(from: &'str str) -> Self { + let after_schema = from.split("://").nth(1).unwrap_or(from); + let (authority, href) = after_schema + .as_bytes() + .iter() + .position(|el| el == &b'/') + .map_or((after_schema, "/"), |el| after_schema.split_at(el)); + let host = authority.split('@').nth(1).unwrap_or(authority); + let hostname = host.rsplit(':').nth(1).unwrap_or(host); + Self { + authority, + host, + hostname, + href, + } + } +} diff --git a/wtx/src/misc/utf8_errors.rs b/wtx/src/misc/utf8_errors.rs new file mode 100644 index 00000000..d2219570 --- /dev/null +++ b/wtx/src/misc/utf8_errors.rs @@ -0,0 +1,13 @@ +use crate::misc::IncompleteUtf8Char; + +pub(crate) enum ExtUtf8Error { + Incomplete { + incomplete_ending_char: IncompleteUtf8Char, + }, + Invalid, +} + +pub(crate) struct StdUtf8Error { + pub(crate) error_len: Option, + pub(crate) valid_up_to: usize, +} diff --git a/wtx/src/read_buffer.rs b/wtx/src/read_buffer.rs new file mode 100644 index 00000000..0f55761e --- /dev/null +++ b/wtx/src/read_buffer.rs @@ -0,0 +1,103 @@ +use crate::web_socket::DFLT_READ_BUFFER_LEN; +use alloc::{vec, vec::Vec}; + +/// Internal buffer used to read external data. +// +// This structure isn't strictly necessary but it tries to optimize two things: +// +// 1. Avoid syscalls by reading the maximum possible number of bytes at once. +// 2. The transposition of **payloads** of frames that compose a message into the `FrameBuffer` +// of the same message. Frames are composed by headers and payloads as such it is necessary to +// have some transfer strategy. +#[derive(Debug)] +pub struct ReadBuffer { + antecedent_end_idx: usize, + buffer: Vec, + current_end_idx: usize, + following_end_idx: usize, +} + +impl ReadBuffer { + pub(crate) fn with_capacity(len: usize) -> Self { + Self { + antecedent_end_idx: 0, + buffer: vec![0; len], + current_end_idx: 0, + following_end_idx: 0, + } + } + + pub(crate) fn antecedent_end_idx(&self) -> usize { + self.antecedent_end_idx + } + + pub(crate) fn after_current_mut(&mut self) -> &mut [u8] { + self.buffer + .get_mut(self.current_end_idx..) + .unwrap_or_default() + } + + pub(crate) fn clear_if_following_is_empty(&mut self) { + if !self.has_following() { + self.antecedent_end_idx = 0; + self.current_end_idx = 0; + self.following_end_idx = 0; + } + } + + pub(crate) fn current(&self) -> &[u8] { + self.buffer + .get(self.antecedent_end_idx..self.current_end_idx) + .unwrap_or_default() + } + + pub(crate) fn current_mut(&mut self) -> &mut [u8] { + self.buffer + .get_mut(self.antecedent_end_idx..self.current_end_idx) + .unwrap_or_default() + } + + pub(crate) fn expand_after_current(&mut self, mut new_len: usize) { + new_len = self.current_end_idx.wrapping_add(new_len); + if new_len > self.buffer.len() { + self.buffer.resize(new_len, 0); + } + } + + pub(crate) fn expand_buffer(&mut self, new_len: usize) { + if new_len > self.buffer.len() { + self.buffer.resize(new_len, 0); + } + } + + pub(crate) fn following_len(&self) -> usize { + self.following_end_idx.wrapping_sub(self.current_end_idx) + } + + pub(crate) fn has_following(&self) -> bool { + self.following_end_idx > self.current_end_idx + } + + pub(crate) fn merge_current_with_antecedent(&mut self) { + self.antecedent_end_idx = self.current_end_idx; + } + + pub(crate) fn set_indices_through_expansion( + &mut self, + antecedent_end_idx: usize, + current_end_idx: usize, + following_end_idx: usize, + ) { + self.antecedent_end_idx = antecedent_end_idx; + self.current_end_idx = self.antecedent_end_idx.max(current_end_idx); + self.following_end_idx = self.current_end_idx.max(following_end_idx); + self.expand_buffer(self.following_end_idx); + } +} + +impl Default for ReadBuffer { + #[inline] + fn default() -> Self { + Self::with_capacity(DFLT_READ_BUFFER_LEN) + } +} diff --git a/wtx/src/request.rs b/wtx/src/request.rs new file mode 100644 index 00000000..a52bbc43 --- /dev/null +++ b/wtx/src/request.rs @@ -0,0 +1,75 @@ +/// Raw request that can be converted to other high-level requests. +#[derive(Debug)] +pub struct Request<'buffer, 'headers> { + body: &'buffer [u8], + req: httparse::Request<'headers, 'buffer>, +} + +impl<'buffer> Request<'buffer, '_> { + /// Body + #[inline] + pub fn body(&self) -> &'buffer [u8] { + self.body + } + + /// Method + #[inline] + pub fn method(&self) -> Option<&'buffer str> { + self.req.method + } +} + +#[cfg(feature = "http")] +mod http { + use crate::Request; + use http::{HeaderMap, HeaderName, HeaderValue, Method}; + + impl<'buffer, 'headers> TryFrom> for http::Request<&'buffer [u8]> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Request<'buffer, 'headers>) -> Result { + let method = + Method::try_from(from.req.method.ok_or(crate::Error::UnexpectedHttpVersion)?) + .unwrap(); + let version = if let Some(1) = from.req.version { + http::Version::HTTP_11 + } else { + return Err(crate::Error::UnexpectedHttpVersion); + }; + let mut headers = HeaderMap::with_capacity(from.req.headers.len()); + for h in from.req.headers { + let key = HeaderName::from_bytes(h.name.as_bytes())?; + let value = HeaderValue::from_bytes(h.value)?; + let _ = headers.append(key, value); + } + let mut req = http::Request::new(from.body); + *req.headers_mut() = headers; + *req.method_mut() = method; + *req.uri_mut() = from.req.path.unwrap().parse().unwrap(); + *req.version_mut() = version; + //*req.status_mut() = status; + Ok(req) + } + } + + impl<'buffer, 'headers> TryFrom> for http::Request<()> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Request<'buffer, 'headers>) -> Result { + let (parts, _) = http::Request::<&'buffer [u8]>::try_from(from)?.into_parts(); + Ok(http::Request::from_parts(parts, ())) + } + } + + impl<'buffer, 'headers> TryFrom> for http::Request> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Request<'buffer, 'headers>) -> Result { + let (parts, body) = http::Request::<&'buffer [u8]>::try_from(from)?.into_parts(); + Ok(http::Request::from_parts(parts, body.to_vec())) + } + } +} diff --git a/wtx/src/response.rs b/wtx/src/response.rs new file mode 100644 index 00000000..0c3f1a73 --- /dev/null +++ b/wtx/src/response.rs @@ -0,0 +1,82 @@ +use httparse::Header; + +/// Raw response that can be converted to other high-level responses. +#[derive(Debug)] +pub struct Response<'buffer, 'headers> { + body: &'buffer [u8], + res: httparse::Response<'headers, 'buffer>, +} + +impl<'buffer, 'headers> Response<'buffer, 'headers> { + pub(crate) fn new(body: &'buffer [u8], res: httparse::Response<'buffer, 'headers>) -> Self { + Self { body, res } + } + + /// Body + #[inline] + pub fn body(&self) -> &'buffer [u8] { + self.body + } + + /// Status code + #[inline] + pub fn code(&self) -> Option { + self.res.code + } + + pub(crate) fn headers(&self) -> &&'headers mut [Header<'buffer>] { + &self.res.headers + } +} + +#[cfg(feature = "http")] +mod http { + use crate::Response; + use http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; + + impl<'buffer, 'headers> TryFrom> for http::Response<&'buffer [u8]> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Response<'buffer, 'headers>) -> Result { + let status = + StatusCode::from_u16(from.res.code.ok_or(crate::Error::UnexpectedHttpVersion)?)?; + let version = if let Some(1) = from.res.version { + http::Version::HTTP_11 + } else { + return Err(crate::Error::UnexpectedHttpVersion); + }; + let mut headers = HeaderMap::with_capacity(from.res.headers.len()); + for h in from.res.headers { + let key = HeaderName::from_bytes(h.name.as_bytes())?; + let value = HeaderValue::from_bytes(h.value)?; + let _ = headers.append(key, value); + } + let mut res = http::Response::new(from.body); + *res.headers_mut() = headers; + *res.status_mut() = status; + *res.version_mut() = version; + Ok(res) + } + } + + impl<'buffer, 'headers> TryFrom> for http::Response<()> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Response<'buffer, 'headers>) -> Result { + let (parts, _) = http::Response::<&'buffer [u8]>::try_from(from)?.into_parts(); + Ok(http::Response::from_parts(parts, ())) + } + } + + impl<'buffer, 'headers> TryFrom> for http::Response> { + type Error = crate::Error; + + #[inline] + fn try_from(from: Response<'buffer, 'headers>) -> Result { + let (parts, body) = http::Response::<&'buffer [u8]>::try_from(from)?.into_parts(); + Ok(http::Response::from_parts(parts, body.to_vec())) + } + } +} diff --git a/wtx/src/stream.rs b/wtx/src/stream.rs new file mode 100644 index 00000000..06dd6bf6 --- /dev/null +++ b/wtx/src/stream.rs @@ -0,0 +1,260 @@ +use crate::misc::AsyncBounds; +#[cfg(feature = "async-trait")] +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::cmp::Ordering; + +/// A stream of values produced asynchronously. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +pub trait Stream { + /// Pulls some bytes from this source into the specified buffer, returning how many bytes + /// were read. + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result; + + /// Attempts to write all elements of `bytes`. + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()>; +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Stream for &mut T +where + T: AsyncBounds + Stream, +{ + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + (*self).read(bytes).await + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + (*self).write_all(bytes).await + } +} + +/// Stores written data to transfer when read. +#[derive(Debug, Default)] +pub struct BytesStream { + buffer: Vec, + idx: usize, +} + +impl BytesStream { + /// Empties the internal buffer. + #[inline] + pub fn clear(&mut self) { + self.buffer.clear(); + self.idx = 0; + } +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Stream for BytesStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + let working_buffer = self.buffer.get(self.idx..).unwrap_or_default(); + let working_buffer_len = working_buffer.len(); + Ok(match working_buffer_len.cmp(&bytes.len()) { + Ordering::Less => { + bytes + .get_mut(..working_buffer_len) + .unwrap_or_default() + .copy_from_slice(working_buffer); + self.clear(); + working_buffer_len + } + Ordering::Equal => { + bytes.copy_from_slice(working_buffer); + self.clear(); + working_buffer_len + } + Ordering::Greater => { + bytes.copy_from_slice(working_buffer.get(..bytes.len()).unwrap_or_default()); + self.idx = self.idx.wrapping_add(bytes.len()); + bytes.len() + } + }) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + self.buffer.extend_from_slice(bytes); + Ok(()) + } +} + +/// Does nothing. +#[derive(Debug)] +pub struct DummyStream; + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Stream for DummyStream { + #[inline] + async fn read(&mut self, _: &mut [u8]) -> crate::Result { + Ok(0) + } + + #[inline] + async fn write_all(&mut self, _: &[u8]) -> crate::Result<()> { + Ok(()) + } +} + +#[cfg(feature = "async-std")] +mod async_std { + use crate::Stream; + #[cfg(feature = "async-trait")] + use alloc::boxed::Box; + use async_std::{ + io::{ReadExt, WriteExt}, + net::TcpStream, + }; + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for TcpStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } +} + +#[cfg(all(feature = "glommio", not(feature = "async-trait")))] +mod glommio { + use crate::Stream; + use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; + use glommio::net::TcpStream; + + impl Stream for TcpStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } +} + +#[cfg(feature = "hyper")] +mod hyper { + use crate::Stream; + #[cfg(feature = "async-trait")] + use alloc::boxed::Box; + use hyper::upgrade::Upgraded; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for Upgraded { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } +} + +#[cfg(feature = "std")] +mod std { + use crate::Stream; + #[cfg(feature = "async-trait")] + use alloc::boxed::Box; + use std::{ + io::{Read, Write}, + net::TcpStream, + }; + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for TcpStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes)?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes)?; + Ok(()) + } + } +} + +#[cfg(feature = "tokio")] +mod tokio { + use crate::Stream; + #[cfg(feature = "async-trait")] + use alloc::boxed::Box; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + }; + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for TcpStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } +} + +#[cfg(feature = "tokio-rustls")] +mod tokio_rustls { + use crate::{misc::AsyncBounds, Stream}; + #[cfg(feature = "async-trait")] + use alloc::boxed::Box; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for tokio_rustls::client::TlsStream + where + T: AsyncBounds + AsyncRead + AsyncWrite + Unpin, + { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } + + #[cfg_attr(feature = "async-trait", async_trait::async_trait)] + impl Stream for tokio_rustls::server::TlsStream + where + T: AsyncBounds + AsyncRead + AsyncWrite + Unpin, + { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } + + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + } +} diff --git a/wtx/src/web_socket.rs b/wtx/src/web_socket.rs new file mode 100644 index 00000000..dbaafb12 --- /dev/null +++ b/wtx/src/web_socket.rs @@ -0,0 +1,723 @@ +//! A computer communications protocol, providing full-duplex communication channels over a single +//! TCP connection. + +mod close_code; +mod frame; +mod frame_buffer; +#[cfg(feature = "web-socket-handshake")] +pub mod handshake; +mod mask; +mod op_code; +mod web_socket_error; + +use crate::{ + misc::{from_utf8_ext_rslt, from_utf8_opt, CompleteErr, ExtUtf8Error, Rng}, + web_socket::close_code::CloseCode, + ReadBuffer, Stream, +}; +use alloc::vec::Vec; +use core::borrow::BorrowMut; +pub use frame::{ + Frame, FrameControlArray, FrameControlArrayMut, FrameMut, FrameMutControlArray, + FrameMutControlArrayMut, FrameMutMut, FrameMutVec, FrameMutVecMut, FrameVec, FrameVecMut, +}; +pub use frame_buffer::{ + FrameBuffer, FrameBufferControlArray, FrameBufferControlArrayMut, FrameBufferMut, + FrameBufferVec, FrameBufferVecMut, +}; +pub use mask::unmask; +pub use op_code::OpCode; +pub use web_socket_error::WebSocketError; + +pub(crate) const DFLT_FRAME_BUFFER_VEC_LEN: usize = 16 * 1024; +pub(crate) const DFLT_READ_BUFFER_LEN: usize = 2 * DFLT_FRAME_BUFFER_VEC_LEN; +pub(crate) const MAX_CONTROL_FRAME_LEN: usize = MAX_HDR_LEN_USIZE + MAX_CONTROL_FRAME_PAYLOAD_LEN; +pub(crate) const MAX_CONTROL_FRAME_PAYLOAD_LEN: usize = 125; +pub(crate) const MAX_HDR_LEN_U8: u8 = 14; +pub(crate) const MAX_HDR_LEN_USIZE: usize = 14; +pub(crate) const MAX_PAYLOAD_LEN: usize = 64 * 1024 * 1024; +pub(crate) const MIN_HEADER_LEN_USIZE: usize = 2; + +/// Always masks the payload before sending. +pub type WebSocketClient = WebSocket; +/// [WebSocketClient] with a mutable reference of [ReadBuffer]. +pub type WebSocketClientMut<'rb, S> = WebSocketClient<&'rb mut ReadBuffer, S>; +/// [WebSocketClient] with an owned [ReadBuffer]. +pub type WebSocketClientOwned = WebSocketClient; +/// Always decode the payload after receiving. +pub type WebSocketServer = WebSocket; +/// [WebSocketServer] with a mutable reference of [ReadBuffer]. +pub type WebSocketServerMut<'rb, S> = WebSocketServer<&'rb mut ReadBuffer, S>; +/// [WebSocketServer] with an owned [ReadBuffer]. +pub type WebSocketServerOwned = WebSocketServer; + +/// WebSocket protocol implementation over an asynchronous stream. +#[derive(Debug)] +pub struct WebSocket { + auto_close: bool, + auto_pong: bool, + is_stream_closed: bool, + max_payload_len: usize, + rb: RB, + rng: Rng, + stream: S, +} + +impl WebSocket { + /// Sets whether to automatically close the connection when a close frame is received. Defaults + /// to `true`. + #[inline] + pub fn set_auto_close(&mut self, auto_close: bool) { + self.auto_close = auto_close; + } + + /// Sets whether to automatically send a pong frame when a ping frame is received. Defaults + /// to `true`. + #[inline] + pub fn set_auto_pong(&mut self, auto_pong: bool) { + self.auto_pong = auto_pong; + } + + /// Sets whether to automatically close the connection when a received frame payload length + /// exceeds `max_payload_len`. Defaults to `64 * 1024 * 1024` bytes (64 MiB). + #[inline] + pub fn set_max_payload_len(&mut self, max_payload_len: usize) { + self.max_payload_len = max_payload_len; + } +} + +impl WebSocket +where + RB: BorrowMut, + S: Stream, +{ + /// Creates a new instance from a stream that supposedly has already completed the WebSocket + /// handshake. + #[inline] + pub fn new(mut rb: RB, stream: S) -> Self { + rb.borrow_mut().clear_if_following_is_empty(); + Self { + auto_close: true, + auto_pong: true, + is_stream_closed: false, + max_payload_len: MAX_PAYLOAD_LEN, + rb, + rng: Rng::default(), + stream, + } + } + + /// Reads a frame from the stream unmasking and validating its payload. + #[inline] + pub async fn read_frame<'fb, B>( + &mut self, + fb: &'fb mut FrameBuffer, + ) -> crate::Result, IS_CLIENT>> + where + B: AsMut> + AsRef<[u8]>, + { + let rbfi = self.do_read_frame::().await?; + Self::copy_from_rb_to_fb(CopyType::Normal, fb, self.rb.borrow(), &rbfi); + self.rb.borrow_mut().clear_if_following_is_empty(); + Frame::from_fb(fb) + } + + /// Collects frames and returns the completed message once all fragments have been received. + #[inline] + pub async fn read_msg<'fb, B>( + &mut self, + fb: &'fb mut FrameBuffer, + ) -> crate::Result, IS_CLIENT>> + where + B: AsMut<[u8]> + AsMut> + AsRef<[u8]>, + { + let mut iuc_opt = None; + let mut is_binary = true; + let rbfi = self.do_read_frame::().await?; + if rbfi.op_code.is_continuation() { + return Err(WebSocketError::UnexpectedMessageFrame.into()); + } + let should_stop_at_the_first_frame = match rbfi.op_code { + OpCode::Binary => rbfi.fin, + OpCode::Text => { + let range = rbfi.header_end_idx..; + let curr_payload = self.rb.borrow().current().get(range).unwrap_or_default(); + if rbfi.fin { + if from_utf8_opt(curr_payload).is_none() { + return Err(crate::Error::InvalidUTF8); + } + true + } else { + is_binary = false; + match from_utf8_ext_rslt(curr_payload) { + Err(ExtUtf8Error::Incomplete { + incomplete_ending_char, + .. + }) => { + iuc_opt = Some(incomplete_ending_char); + false + } + Err(ExtUtf8Error::Invalid { .. }) => { + return Err(crate::Error::InvalidUTF8); + } + Ok(_) => false, + } + } + } + OpCode::Continuation | OpCode::Close | OpCode::Ping | OpCode::Pong => true, + }; + if should_stop_at_the_first_frame { + Self::copy_from_rb_to_fb(CopyType::Normal, fb, self.rb.borrow(), &rbfi); + self.rb.borrow_mut().clear_if_following_is_empty(); + return Frame::from_fb(fb); + } + let mut total_frame_len = msg_header_placeholder::().into(); + Self::copy_from_rb_to_fb( + CopyType::Msg(&mut total_frame_len), + fb, + self.rb.borrow(), + &rbfi, + ); + if is_binary { + self.manage_read_msg_loop(fb, rbfi.op_code, &mut total_frame_len, |_| Ok(())) + .await?; + } else { + self.manage_read_msg_loop(fb, rbfi.op_code, &mut total_frame_len, |payload| { + let tail = if let Some(mut incomplete) = iuc_opt.take() { + let (rslt, remaining) = incomplete.complete(payload); + match rslt { + Err(CompleteErr::HasInvalidBytes) => { + return Err(crate::Error::InvalidUTF8); + } + Err(CompleteErr::InsufficientInput) => { + let _ = iuc_opt.replace(incomplete); + &[] + } + Ok(_) => remaining, + } + } else { + payload + }; + match from_utf8_ext_rslt(tail) { + Err(ExtUtf8Error::Incomplete { + incomplete_ending_char, + .. + }) => { + iuc_opt = Some(incomplete_ending_char); + } + Err(ExtUtf8Error::Invalid { .. }) => { + return Err(crate::Error::InvalidUTF8); + } + Ok(_) => {} + } + Ok(()) + }) + .await?; + }; + Frame::from_fb(fb) + } + + /// Writes a frame to the stream without masking its payload. + #[inline] + pub async fn write_frame( + &mut self, + frame: &mut Frame, + ) -> crate::Result<()> + where + B: AsMut<[u8]> + AsRef<[u8]>, + FB: BorrowMut>, + { + Self::do_write_frame( + frame, + &mut self.is_stream_closed, + &mut self.rng, + &mut self.stream, + ) + .await + } + + fn copy_from_rb_to_fb( + ct: CopyType<'_>, + fb: &mut FrameBuffer, + rb: &ReadBuffer, + rbfi: &ReadBufferFrameInfo, + ) where + B: AsMut>, + { + let current_frame = rb.current(); + let range = match ct { + CopyType::Msg(total_frame_len) => { + let prev = *total_frame_len; + *total_frame_len = total_frame_len.wrapping_add(rbfi.payload_len); + fb.set_params_through_expansion( + 0, + msg_header_placeholder::(), + *total_frame_len, + ); + prev..*total_frame_len + } + CopyType::Normal => { + let mask_placeholder = if IS_CLIENT { 4 } else { 0 }; + let header_len_total = rbfi.header_len.wrapping_add(mask_placeholder); + let header_len_total_usize = rbfi.header_len.wrapping_add(mask_placeholder).into(); + fb.set_params_through_expansion( + 0, + header_len_total, + rbfi.payload_len.wrapping_add(header_len_total_usize), + ); + fb.buffer_mut() + .as_mut() + .get_mut(..rbfi.header_len.into()) + .unwrap_or_default() + .copy_from_slice( + current_frame + .get(rbfi.header_begin_idx..rbfi.header_end_idx) + .unwrap_or_default(), + ); + let start = header_len_total_usize; + let end = current_frame + .len() + .wrapping_sub(rbfi.header_begin_idx) + .wrapping_add(mask_placeholder.into()); + start..end + } + }; + fb.buffer_mut() + .as_mut() + .get_mut(range) + .unwrap_or_default() + .copy_from_slice(current_frame.get(rbfi.header_end_idx..).unwrap_or_default()); + } + + #[inline] + async fn do_read_frame( + &mut self, + ) -> crate::Result { + loop { + let mut rbfi = self.fill_rb_from_stream().await?; + let curr_frame = self.rb.borrow_mut().current_mut(); + if !IS_CLIENT { + unmask( + curr_frame + .get_mut(rbfi.header_end_idx..) + .unwrap_or_default(), + rbfi.mask.ok_or(WebSocketError::MissingFrameMask)?, + ); + let n = remove_mask( + curr_frame + .get_mut(rbfi.header_begin_idx..rbfi.header_end_idx) + .unwrap_or_default(), + ); + let n_usize = n.into(); + rbfi.frame_len = rbfi.frame_len.wrapping_sub(n_usize); + rbfi.header_begin_idx = rbfi.header_begin_idx.wrapping_add(n_usize); + rbfi.header_len = rbfi.header_len.wrapping_sub(n); + } + let payload: &[u8] = curr_frame.get(rbfi.header_end_idx..).unwrap_or_default(); + match rbfi.op_code { + OpCode::Close if self.auto_close && !self.is_stream_closed => { + match payload { + [] => {} + [_] => return Err(WebSocketError::InvalidCloseFrame.into()), + [a, b, rest @ ..] => { + if from_utf8_opt(rest).is_none() { + return Err(crate::Error::InvalidUTF8); + }; + let is_not_allowed = + !CloseCode::from(u16::from_be_bytes([*a, *b])).is_allowed(); + if is_not_allowed || rest.len() > MAX_CONTROL_FRAME_PAYLOAD_LEN - 2 { + Self::write_control_frame( + &mut FrameControlArray::close_from_params( + 1002, + <_>::default(), + rest, + )?, + &mut self.is_stream_closed, + &mut self.rng, + &mut self.stream, + ) + .await?; + return Err(WebSocketError::InvalidCloseFrame.into()); + } + } + } + Self::write_control_frame( + &mut FrameControlArray::new_fin(<_>::default(), OpCode::Close, payload)?, + &mut self.is_stream_closed, + &mut self.rng, + &mut self.stream, + ) + .await?; + break Ok(rbfi); + } + OpCode::Ping if self.auto_pong => { + Self::write_control_frame( + &mut FrameControlArray::new_fin(<_>::default(), OpCode::Pong, payload)?, + &mut self.is_stream_closed, + &mut self.rng, + &mut self.stream, + ) + .await?; + } + OpCode::Text => { + if CHECK_TEXT_UTF8 && from_utf8_opt(payload).is_none() { + return Err(crate::Error::InvalidUTF8); + } + break Ok(rbfi); + } + OpCode::Continuation + | OpCode::Binary + | OpCode::Close + | OpCode::Ping + | OpCode::Pong => { + break Ok(rbfi); + } + } + } + } + + async fn do_write_frame( + frame: &mut Frame, + is_stream_closed: &mut bool, + rng: &mut Rng, + stream: &mut S, + ) -> crate::Result<()> + where + B: AsMut<[u8]> + AsRef<[u8]>, + FB: BorrowMut>, + { + if IS_CLIENT { + let mut mask_opt = None; + if let [_, second_byte, .., a, b, c, d] = frame.fb_mut().borrow_mut().header_mut() { + if !has_masked_frame(*second_byte) { + *second_byte |= 0b1000_0000; + let mask = rng.random_u8_4(); + *a = mask[0]; + *b = mask[1]; + *c = mask[2]; + *d = mask[3]; + mask_opt = Some(mask); + } + } + if let Some(mask) = mask_opt { + unmask(frame.fb_mut().borrow_mut().payload_mut(), mask); + } + } + if frame.op_code() == OpCode::Close { + *is_stream_closed = true; + } + stream.write_all(frame.fb().borrow().frame()).await?; + Ok(()) + } + + async fn fill_initial_rb_from_stream( + buffer: &mut [u8], + max_payload_len: usize, + read: &mut usize, + stream: &mut S, + ) -> crate::Result + where + S: Stream, + { + async fn read_until( + buffer: &mut [u8], + read: &mut usize, + start: usize, + stream: &mut S, + ) -> crate::Result<[u8; LEN]> + where + [u8; LEN]: Default, + S: Stream, + { + let until = start.wrapping_add(LEN); + while *read < until { + let actual_buffer = buffer.get_mut(*read..).unwrap_or_default(); + let local_read = stream.read(actual_buffer).await?; + if local_read == 0 { + return Err(crate::Error::UnexpectedEOF); + } + *read = read.wrapping_add(local_read); + } + Ok(buffer + .get(start..until) + .and_then(|el| el.try_into().ok()) + .unwrap_or_default()) + } + + let first_two = read_until::<_, 2>(buffer, read, 0, stream).await?; + + let fin = first_two[0] & 0b1000_0000 != 0; + let rsv1 = first_two[0] & 0b0100_0000 != 0; + let rsv2 = first_two[0] & 0b0010_0000 != 0; + let rsv3 = first_two[0] & 0b0001_0000 != 0; + + if rsv1 || rsv2 || rsv3 { + return Err(WebSocketError::ReservedBitsAreNotZero.into()); + } + + let is_masked = has_masked_frame(first_two[1]); + let length_code = first_two[1] & 0b0111_1111; + let op_code = op_code(first_two[0])?; + + let (mut header_len, payload_len) = match length_code { + 126 => ( + 4, + u16::from_be_bytes(read_until::<_, 2>(buffer, read, 2, stream).await?).into(), + ), + 127 => { + let payload_len = read_until::<_, 8>(buffer, read, 2, stream).await?; + (10, u64::from_be_bytes(payload_len).try_into()?) + } + _ => (2, length_code.into()), + }; + + let mut mask = None; + if is_masked { + mask = Some(read_until::<_, 4>(buffer, read, header_len, stream).await?); + header_len = header_len.wrapping_add(4); + } + + if op_code.is_control() && !fin { + return Err(WebSocketError::UnexpectedFragmentedControlFrame.into()); + } + if op_code == OpCode::Ping && payload_len > MAX_CONTROL_FRAME_PAYLOAD_LEN { + return Err(WebSocketError::VeryLargeControlFrame.into()); + } + if payload_len >= max_payload_len { + return Err(WebSocketError::VeryLargePayload.into()); + } + + Ok(ReadBufferFrameInfo { + fin, + frame_len: header_len.wrapping_add(payload_len), + header_begin_idx: 0, + header_end_idx: header_len, + header_len: header_len.try_into().unwrap_or_default(), + mask, + op_code, + payload_len, + }) + } + + async fn fill_rb_from_stream(&mut self) -> crate::Result { + let mut read = self.rb.borrow().following_len(); + self.rb.borrow_mut().merge_current_with_antecedent(); + self.rb.borrow_mut().expand_after_current(MAX_HDR_LEN_USIZE); + let rbfi = Self::fill_initial_rb_from_stream( + self.rb.borrow_mut().after_current_mut(), + self.max_payload_len, + &mut read, + &mut self.stream, + ) + .await?; + if self.is_stream_closed && rbfi.op_code != OpCode::Close { + return Err(WebSocketError::ConnectionClosed.into()); + } + loop { + if read >= rbfi.frame_len { + break; + } + self.rb.borrow_mut().expand_after_current(rbfi.frame_len); + let local_read = self + .stream + .read( + self.rb + .borrow_mut() + .after_current_mut() + .get_mut(read..) + .unwrap_or_default(), + ) + .await?; + read = read.wrapping_add(local_read); + } + let rb = self.rb.borrow_mut(); + rb.set_indices_through_expansion( + rb.antecedent_end_idx(), + rb.antecedent_end_idx().wrapping_add(rbfi.frame_len), + rb.antecedent_end_idx().wrapping_add(read), + ); + Ok(rbfi) + } + + async fn manage_read_msg_loop( + &mut self, + fb: &mut FrameBuffer, + first_frame_op_code: OpCode, + total_frame_len: &mut usize, + mut cb: impl FnMut(&[u8]) -> crate::Result<()>, + ) -> crate::Result<()> + where + B: AsMut<[u8]> + AsMut> + AsRef<[u8]>, + S: Stream, + { + loop { + let rbfi = self.do_read_frame::().await?; + Self::copy_from_rb_to_fb(CopyType::Msg(total_frame_len), fb, self.rb.borrow(), &rbfi); + match rbfi.op_code { + OpCode::Continuation => { + cb(self + .rb + .borrow() + .current() + .get(rbfi.header_end_idx..) + .unwrap_or_default())?; + if rbfi.fin { + let mut buffer = [0; MAX_HDR_LEN_USIZE]; + let header_len = copy_header_params_to_buffer::( + &mut buffer, + true, + first_frame_op_code, + fb.payload().len(), + )?; + let start_idx = + msg_header_placeholder::().wrapping_sub(header_len); + fb.header_mut() + .get_mut(start_idx.into()..) + .unwrap_or_default() + .copy_from_slice(buffer.get(..header_len.into()).unwrap_or_default()); + fb.set_params_through_expansion(start_idx, header_len, *total_frame_len); + self.rb.borrow_mut().clear_if_following_is_empty(); + break; + } + } + OpCode::Binary | OpCode::Close | OpCode::Ping | OpCode::Pong | OpCode::Text => { + return Err(WebSocketError::UnexpectedMessageFrame.into()); + } + } + } + Ok(()) + } + + async fn write_control_frame( + frame: &mut FrameControlArray, + is_stream_closed: &mut bool, + rng: &mut Rng, + stream: &mut S, + ) -> crate::Result<()> { + Self::do_write_frame(frame, is_stream_closed, rng, stream).await?; + Ok(()) + } +} + +#[derive(Debug)] +enum CopyType<'read> { + Msg(&'read mut usize), + Normal, +} + +#[derive(Debug)] +struct ReadBufferFrameInfo { + fin: bool, + frame_len: usize, + header_begin_idx: usize, + header_end_idx: usize, + header_len: u8, + mask: Option<[u8; 4]>, + op_code: OpCode, + payload_len: usize, +} + +pub(crate) fn copy_header_params_to_buffer( + buffer: &mut [u8], + fin: bool, + op_code: OpCode, + payload_len: usize, +) -> crate::Result { + fn first_header_byte(fin: bool, op_code: OpCode) -> u8 { + u8::from(fin) << 7 | u8::from(op_code) + } + + fn manage_mask( + rest: &mut [u8], + second_byte: &mut u8, + ) -> crate::Result { + Ok(if IS_CLIENT { + *second_byte &= 0b0111_1111; + let [a, b, c, d, ..] = rest else { + return Err(WebSocketError::InvalidFrameHeaderBounds.into()); + }; + *a = 0; + *b = 0; + *c = 0; + *d = 0; + N.wrapping_add(4) + } else { + N + }) + } + match payload_len { + 0..=125 => { + if let ([a, b, rest @ ..], Ok(u8_len)) = (buffer, u8::try_from(payload_len)) { + *a = first_header_byte(fin, op_code); + *b = u8_len; + return manage_mask::(rest, b); + } + } + 126..=0xFFFF => { + let rslt = u16::try_from(payload_len).map(u16::to_be_bytes); + if let ([a, b, c, d, rest @ ..], Ok([len_c, len_d])) = (buffer, rslt) { + *a = first_header_byte(fin, op_code); + *b = 126; + *c = len_c; + *d = len_d; + return manage_mask::(rest, b); + } + } + _ => { + if let ( + [a, b, c, d, e, f, g, h, i, j, rest @ ..], + Ok([len_c, len_d, len_e, len_f, len_g, len_h, len_i, len_j]), + ) = (buffer, u64::try_from(payload_len).map(u64::to_be_bytes)) + { + *a = first_header_byte(fin, op_code); + *b = 127; + *c = len_c; + *d = len_d; + *e = len_e; + *f = len_f; + *g = len_g; + *h = len_h; + *i = len_i; + *j = len_j; + return manage_mask::(rest, b); + } + } + } + + Err(WebSocketError::InvalidFrameHeaderBounds.into()) +} + +pub(crate) fn has_masked_frame(second_header_byte: u8) -> bool { + second_header_byte & 0b1000_0000 != 0 +} + +pub(crate) fn op_code(first_header_byte: u8) -> crate::Result { + OpCode::try_from(first_header_byte & 0b0000_1111) +} +const fn msg_header_placeholder() -> u8 { + if IS_CLIENT { + MAX_HDR_LEN_U8 + } else { + MAX_HDR_LEN_U8 - 4 + } +} + +fn remove_mask(header: &mut [u8]) -> u8 { + let Some(second_header_byte) = header.get_mut(1) else { + return 0; + }; + if !has_masked_frame(*second_header_byte) { + return 0; + } + *second_header_byte &= 0b0111_1111; + let prev_header_len = header.len(); + let until_mask = header + .get_mut(..prev_header_len.wrapping_sub(4)) + .unwrap_or_default(); + let mut buffer = [0u8; MAX_HDR_LEN_USIZE - 4]; + let swap_bytes = buffer.get_mut(..until_mask.len()).unwrap_or_default(); + swap_bytes.copy_from_slice(until_mask); + let new_header = header.get_mut(4..prev_header_len).unwrap_or_default(); + new_header.copy_from_slice(swap_bytes); + 4 +} diff --git a/wtx/src/web_socket/close_code.rs b/wtx/src/web_socket/close_code.rs new file mode 100644 index 00000000..67eb8972 --- /dev/null +++ b/wtx/src/web_socket/close_code.rs @@ -0,0 +1,101 @@ +/// Status code used to indicate why an endpoint is closing the WebSocket connection. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum CloseCode { + /// Normal closure. + Normal, + /// An endpoint is not longer active. + Away, + /// Closing connection due to a protocol error. + Protocol, + /// An endpoint does not support a certain type of data. + Unsupported, + /// Closing frame without a status code. + Status, + /// Connection dropped without an error. + Abnormal, + /// Received data that differs from the frame type. + Invalid, + /// Generic error. + Policy, + /// Received a very large payload. + Size, + /// Client didn't receive extension from the server. + Extension, + /// An unexpected condition occurred. + Error, + /// Server is restarting. + Restart, + /// Server is busy and the client should reconnect. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Reserved(u16), + #[doc(hidden)] + Iana(u16), + #[doc(hidden)] + Library(u16), + #[doc(hidden)] + Bad(u16), +} + +impl CloseCode { + /// Checks if this instances is allowed. + pub fn is_allowed(self) -> bool { + !matches!( + self, + Self::Bad(_) | Self::Reserved(_) | Self::Status | Self::Abnormal | Self::Tls + ) + } +} + +impl From for CloseCode { + fn from(code: u16) -> CloseCode { + match code { + 1000 => Self::Normal, + 1001 => Self::Away, + 1002 => Self::Protocol, + 1003 => Self::Unsupported, + 1005 => Self::Status, + 1006 => Self::Abnormal, + 1007 => Self::Invalid, + 1008 => Self::Policy, + 1009 => Self::Size, + 1010 => Self::Extension, + 1011 => Self::Error, + 1012 => Self::Restart, + 1013 => Self::Again, + 1015 => Self::Tls, + 1016..=2999 => Self::Reserved(code), + 3000..=3999 => Self::Iana(code), + 4000..=4999 => Self::Library(code), + _ => Self::Bad(code), + } + } +} + +impl From for u16 { + #[inline] + fn from(from: CloseCode) -> u16 { + match from { + CloseCode::Normal => 1000, + CloseCode::Away => 1001, + CloseCode::Protocol => 1002, + CloseCode::Unsupported => 1003, + CloseCode::Status => 1005, + CloseCode::Abnormal => 1006, + CloseCode::Invalid => 1007, + CloseCode::Policy => 1008, + CloseCode::Size => 1009, + CloseCode::Extension => 1010, + CloseCode::Error => 1011, + CloseCode::Restart => 1012, + CloseCode::Again => 1013, + CloseCode::Tls => 1015, + CloseCode::Bad(code) + | CloseCode::Iana(code) + | CloseCode::Library(code) + | CloseCode::Reserved(code) => code, + } + } +} diff --git a/wtx/src/web_socket/frame.rs b/wtx/src/web_socket/frame.rs new file mode 100644 index 00000000..c7e9f00b --- /dev/null +++ b/wtx/src/web_socket/frame.rs @@ -0,0 +1,194 @@ +use crate::{ + misc::{from_utf8_opt, Expand, SingleTypeStorage}, + web_socket::{ + copy_header_params_to_buffer, + frame_buffer::{ + FrameBufferControlArray, FrameBufferControlArrayMut, FrameBufferMut, FrameBufferVecMut, + }, + op_code, FrameBuffer, FrameBufferVec, OpCode, WebSocketError, + MAX_CONTROL_FRAME_PAYLOAD_LEN, MAX_HDR_LEN_USIZE, MIN_HEADER_LEN_USIZE, + }, +}; +use core::{ + borrow::{Borrow, BorrowMut}, + str, +}; + +/// Composed by a [FrameBufferControlArray]. +pub type FrameControlArray = Frame; +/// Composed by a [FrameBufferControlArrayMut]. +pub type FrameControlArrayMut<'bytes, const IS_CLIENT: bool> = + Frame, IS_CLIENT>; +/// Composed by a [FrameBufferMut]. +pub type FrameMut<'bytes, const IS_CLIENT: bool> = Frame, IS_CLIENT>; +/// Composed by a [FrameBufferVec]. +pub type FrameVec = Frame; +/// Composed by an mutable [FrameBufferVecMut] reference. +pub type FrameVecMut<'bytes, const IS_CLIENT: bool> = Frame, IS_CLIENT>; + +/// Composed by an mutable [FrameBufferControlArray] reference. +pub type FrameMutControlArray<'fb, const IS_CLIENT: bool> = + Frame<&'fb mut FrameBufferControlArray, IS_CLIENT>; +/// Composed by an mutable [FrameBufferControlArrayMut] reference. +pub type FrameMutControlArrayMut<'fb, const IS_CLIENT: bool> = + Frame<&'fb mut FrameBufferControlArray, IS_CLIENT>; +/// Composed by an mutable [FrameBufferMut] reference. +pub type FrameMutMut<'bytes, 'fb, const IS_CLIENT: bool> = + Frame<&'fb mut FrameBufferMut<'bytes>, IS_CLIENT>; +/// Composed by an mutable [FrameBufferVec] reference. +pub type FrameMutVec<'fb, const IS_CLIENT: bool> = Frame<&'fb mut FrameBufferVec, IS_CLIENT>; +/// Composed by an mutable [FrameBufferVecMut] reference. +pub type FrameMutVecMut<'bytes, 'fb, const IS_CLIENT: bool> = + Frame<&'fb mut FrameBufferVecMut<'bytes>, IS_CLIENT>; + +/// Represents a WebSocket frame +#[derive(Debug)] +pub struct Frame { + fin: bool, + op_code: OpCode, + fb: FB, +} + +impl Frame { + /// Contains the raw bytes that compose this frame. + #[inline] + pub fn fb(&self) -> &FB { + &self.fb + } + + pub(crate) fn fb_mut(&mut self) -> &mut FB { + &mut self.fb + } + + /// Indicates if this is the final frame in a message. + #[inline] + pub fn fin(&self) -> bool { + self.fin + } + + /// See [OpCode]. + #[inline] + pub fn op_code(&self) -> OpCode { + self.op_code + } +} + +impl Frame +where + B: AsRef<[u8]>, + FB: Borrow> + SingleTypeStorage, +{ + /// Creates a new instance based on the contained bytes of `fb`. + #[inline] + pub fn from_fb(fb: FB) -> crate::Result { + let header = fb.borrow().header(); + let len = header.len(); + let has_valid_header = (MIN_HEADER_LEN_USIZE..=MAX_HDR_LEN_USIZE).contains(&len); + let (true, Some(first_header_byte)) = (has_valid_header, header.first().copied()) else { + return Err(WebSocketError::InvalidFrameHeaderBounds.into()); + }; + Ok(Self { + fb, + fin: first_header_byte & 0b1000_0000 != 0, + op_code: op_code(first_header_byte)?, + }) + } + + /// Checks if the frame payload is valid UTF-8, regardless of its type. + #[inline] + pub fn is_utf8(&self) -> bool { + self.op_code.is_text() || from_utf8_opt(self.fb.borrow().payload()).is_some() + } + + /// If the frame is of type [OpCode::Text], returns its payload interpreted as a string. + #[inline] + pub fn text_payload<'this>(&'this self) -> Option<&'this str> + where + B: 'this, + { + self.op_code.is_text().then(|| { + #[allow(unsafe_code)] + // SAFETY: All text frames have valid UTF-8 contents when read. + unsafe { + str::from_utf8_unchecked(self.fb.borrow().payload()) + } + }) + } +} + +impl Frame +where + B: AsMut<[u8]> + AsRef<[u8]> + Expand, + FB: BorrowMut> + SingleTypeStorage, +{ + /// Creates based on the individual parameters that compose a close frame. + /// + /// `reason` is capped based on the maximum allowed size of a control frame minus 2. + #[inline] + pub fn close_from_params(code: u16, fb: FB, reason: &[u8]) -> crate::Result { + let reason_len = reason.len().min(MAX_CONTROL_FRAME_PAYLOAD_LEN - 2); + let payload_len = reason_len.wrapping_add(2); + Self::build_frame(fb, true, OpCode::Close, payload_len, |local_fb| { + let payload = local_fb.borrow_mut().payload_mut(); + payload + .get_mut(..2) + .unwrap_or_default() + .copy_from_slice(&code.to_be_bytes()); + payload + .get_mut(2..) + .unwrap_or_default() + .copy_from_slice(reason.get(..reason_len).unwrap_or_default()); + Ok(()) + }) + } + + /// Creates a new instance that is considered final. + #[inline] + pub fn new_fin(fb: FB, op_code: OpCode, payload: &[u8]) -> crate::Result { + Self::new(fb, true, op_code, payload) + } + + /// Creates a new instance that is meant to be a continuation of previous frames. + #[inline] + pub fn new_unfin(fb: FB, op_code: OpCode, payload: &[u8]) -> crate::Result { + Self::new(fb, false, op_code, payload) + } + + fn build_frame( + mut fb: FB, + fin: bool, + op_code: OpCode, + payload_len: usize, + cb: impl FnOnce(&mut FB) -> crate::Result<()>, + ) -> crate::Result { + fb.borrow_mut().clear(); + fb.borrow_mut() + .buffer_mut() + .expand(MAX_HDR_LEN_USIZE.saturating_add(payload_len)); + let n = copy_header_params_to_buffer::( + fb.borrow_mut().buffer_mut().as_mut(), + fin, + op_code, + payload_len, + )?; + fb.borrow_mut().set_header_indcs(0, n)?; + fb.borrow_mut().set_payload_len(payload_len)?; + cb(&mut fb)?; + Ok(Self { fin, op_code, fb }) + } + + fn new(fb: FB, fin: bool, op_code: OpCode, payload: &[u8]) -> crate::Result { + let payload_len = if op_code.is_control() { + payload.len().min(MAX_CONTROL_FRAME_PAYLOAD_LEN) + } else { + payload.len() + }; + Self::build_frame(fb, fin, op_code, payload_len, |local_fb| { + local_fb + .borrow_mut() + .payload_mut() + .copy_from_slice(payload.get(..payload_len).unwrap_or_default()); + Ok(()) + }) + } +} diff --git a/wtx/src/web_socket/frame_buffer.rs b/wtx/src/web_socket/frame_buffer.rs new file mode 100644 index 00000000..d38f64fe --- /dev/null +++ b/wtx/src/web_socket/frame_buffer.rs @@ -0,0 +1,269 @@ +use crate::{ + misc::SingleTypeStorage, + web_socket::{ + WebSocketError, DFLT_FRAME_BUFFER_VEC_LEN, MAX_CONTROL_FRAME_LEN, MAX_HDR_LEN_U8, + }, +}; +use alloc::{vec, vec::Vec}; +use core::array; + +/// Composed by an array with the maximum allowed size of a frame control. +pub type FrameBufferControlArray = FrameBuffer<[u8; MAX_CONTROL_FRAME_LEN]>; +/// Composed by an mutable array reference with the maximum allowed size of a frame control. +pub type FrameBufferControlArrayMut<'bytes> = FrameBuffer<&'bytes mut [u8; MAX_CONTROL_FRAME_LEN]>; +/// Composed by a sequence of mutable bytes. +pub type FrameBufferMut<'bytes> = FrameBuffer<&'bytes mut [u8]>; +/// Composed by an owned vector. +pub type FrameBufferVec = FrameBuffer>; +/// Composed by a mutable vector reference. +pub type FrameBufferVecMut<'bytes> = FrameBuffer<&'bytes mut Vec>; + +/// Concentrates all data necessary to read or write to a stream. +// +// ``` +// [ prefix | header | payload | suffix ] +// ``` +#[derive(Debug)] +#[repr(C)] +pub struct FrameBuffer { + header_begin_idx: u8, + header_end_idx: u8, + payload_end_idx: usize, + // Tail field to hopefully help transforms + buffer: B, +} + +impl FrameBuffer { + /// The underlying byte collection. + #[inline] + pub fn buffer(&self) -> &B { + &self.buffer + } + + /// The indices that represent all frame parts contained in the underlying byte collection. + /// + /// ```rust + /// let fb = wtx::web_socket::FrameBufferVec::default(); + /// let (header_begin_idx, header_end_idx, payload_end_idx) = fb.indcs(); + /// assert_eq!( + /// fb.buffer().get(header_begin_idx.into()..header_end_idx.into()).unwrap_or_default(), + /// fb.header() + /// ); + /// assert_eq!( + /// fb.buffer().get(header_end_idx.into()..payload_end_idx).unwrap_or_default(), + /// fb.payload() + /// ); + /// ``` + #[inline] + pub fn indcs(&self) -> (u8, u8, usize) { + ( + self.header_begin_idx, + self.header_end_idx, + self.payload_end_idx, + ) + } + + pub(crate) fn buffer_mut(&mut self) -> &mut B { + &mut self.buffer + } + + pub(crate) fn clear(&mut self) { + self.header_begin_idx = 0; + self.header_end_idx = 0; + self.payload_end_idx = 0; + } + + fn header_end_idx_from_parts(header_begin_idx: u8, header_len: u8) -> u8 { + header_begin_idx.saturating_add(header_len) + } + + fn payload_end_idx_from_parts(header_end: u8, payload_len: usize) -> usize { + usize::from(header_end).wrapping_add(payload_len) + } +} + +impl FrameBuffer +where + B: AsRef<[u8]>, +{ + /// Creates a new instance from the given `buffer`. + #[inline] + pub fn new(buffer: B) -> Self { + Self { + header_begin_idx: 0, + header_end_idx: 0, + payload_end_idx: 0, + buffer, + } + } + + /// Sequence of bytes that composes the frame header. + #[inline] + pub fn header(&self) -> &[u8] { + self.buffer + .as_ref() + .get(self.header_begin_idx.into()..self.header_end_idx.into()) + .unwrap_or_default() + } + + /// Sequence of bytes that composes the frame payload. + #[inline] + pub fn payload(&self) -> &[u8] { + self.buffer + .as_ref() + .get(self.header_end_idx.into()..self.payload_end_idx) + .unwrap_or_default() + } + + pub(crate) fn frame(&self) -> &[u8] { + self.buffer + .as_ref() + .get(self.header_begin_idx.into()..self.payload_end_idx) + .unwrap_or_default() + } + + pub(crate) fn set_header_indcs(&mut self, begin_idx: u8, len: u8) -> crate::Result<()> { + let header_end_idx = Self::header_end_idx_from_parts(begin_idx, len); + if len > MAX_HDR_LEN_U8 || usize::from(header_end_idx) > self.buffer.as_ref().len() { + return Err(WebSocketError::InvalidFrameHeaderBounds.into()); + } + self.header_begin_idx = begin_idx; + self.header_end_idx = header_end_idx; + self.payload_end_idx = usize::from(header_end_idx).max(self.payload_end_idx); + Ok(()) + } + + pub(crate) fn set_payload_len(&mut self, payload_len: usize) -> crate::Result<()> { + let payload_end_idx = Self::payload_end_idx_from_parts(self.header_end_idx, payload_len); + if payload_end_idx > self.buffer.as_ref().len() { + return Err(WebSocketError::InvalidPayloadBounds.into()); + } + self.payload_end_idx = payload_end_idx; + Ok(()) + } +} + +impl FrameBuffer +where + B: AsMut<[u8]>, +{ + pub(crate) fn header_mut(&mut self) -> &mut [u8] { + self.buffer + .as_mut() + .get_mut(self.header_begin_idx.into()..self.header_end_idx.into()) + .unwrap_or_default() + } + + pub(crate) fn payload_mut(&mut self) -> &mut [u8] { + self.buffer + .as_mut() + .get_mut(self.header_end_idx.into()..self.payload_end_idx) + .unwrap_or_default() + } +} + +impl FrameBuffer +where + B: AsMut>, +{ + pub(crate) fn set_params_through_expansion( + &mut self, + header_begin_idx: u8, + header_len: u8, + mut payload_end_idx: usize, + ) { + let header_end_idx = Self::header_end_idx_from_parts(header_begin_idx, header_len); + payload_end_idx = payload_end_idx.max(header_len.into()); + if payload_end_idx > self.buffer.as_mut().len() { + self.buffer.as_mut().resize(payload_end_idx, 0); + } + self.header_begin_idx = header_begin_idx; + self.header_end_idx = header_end_idx; + self.payload_end_idx = payload_end_idx; + } +} + +impl FrameBufferVec { + /// Creates a new instance with pre-allocated bytes. + #[inline] + pub fn with_capacity(n: usize) -> Self { + Self { + header_begin_idx: 0, + header_end_idx: 0, + payload_end_idx: 0, + buffer: vec![0; n], + } + } +} + +impl SingleTypeStorage for FrameBuffer { + type Item = B; +} + +impl Default for FrameBufferControlArray { + #[inline] + fn default() -> Self { + Self { + header_begin_idx: 0, + header_end_idx: 0, + payload_end_idx: 0, + buffer: array::from_fn(|_| 0), + } + } +} + +impl Default for FrameBufferVec { + #[inline] + fn default() -> Self { + Self { + header_begin_idx: 0, + header_end_idx: 0, + payload_end_idx: 0, + buffer: vec![0; DFLT_FRAME_BUFFER_VEC_LEN], + } + } +} + +impl<'fb, B> From<&'fb mut FrameBuffer> for FrameBufferMut<'fb> +where + B: AsMut<[u8]>, +{ + #[inline] + fn from(from: &'fb mut FrameBuffer) -> Self { + Self { + header_begin_idx: from.header_begin_idx, + header_end_idx: from.header_end_idx, + payload_end_idx: from.payload_end_idx, + buffer: from.buffer.as_mut(), + } + } +} + +impl<'bytes, 'fb> From<&'fb mut FrameBufferVec> for FrameBufferVecMut<'bytes> +where + 'fb: 'bytes, +{ + #[inline] + fn from(from: &'fb mut FrameBufferVec) -> Self { + Self { + header_begin_idx: from.header_begin_idx, + header_end_idx: from.header_end_idx, + payload_end_idx: from.payload_end_idx, + buffer: &mut from.buffer, + } + } +} + +impl From> for FrameBufferVec { + #[inline] + fn from(from: Vec) -> Self { + Self::new(from) + } +} + +impl<'bytes> From<&'bytes mut Vec> for FrameBufferVecMut<'bytes> { + #[inline] + fn from(from: &'bytes mut Vec) -> Self { + Self::new(from) + } +} diff --git a/wtx/src/web_socket/handshake.rs b/wtx/src/web_socket/handshake.rs new file mode 100644 index 00000000..beb4168b --- /dev/null +++ b/wtx/src/web_socket/handshake.rs @@ -0,0 +1,70 @@ +//! Handshake + +#[cfg(feature = "web-socket-hyper")] +pub(super) mod hyper; +mod misc; +pub(super) mod raw; +#[cfg(test)] +mod tests; + +#[cfg(feature = "web-socket-hyper")] +pub use self::hyper::{UpgradeFutHyper, WebSocketHandshakeHyper, WebSocketUpgradeHyper}; +use crate::web_socket::{Stream, WebSocketClient, WebSocketServer}; +#[cfg(feature = "async-trait")] +use alloc::boxed::Box; +use core::future::Future; +#[cfg(feature = "web-socket-handshake")] +pub use raw::{WebSocketAcceptRaw, WebSocketHandshakeRaw}; + +/// Manages incoming data to establish WebSocket connections. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +pub trait WebSocketAccept { + /// Specific implementation response. + type Response; + /// Specific implementation stream. + type Stream: Stream; + + /// Try to upgrade a received request to a WebSocket connection. + async fn accept(self) -> crate::Result<(Self::Response, WebSocketServer)>; +} + +/// Initial negotiation sent by a client to start exchanging frames. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +pub trait WebSocketHandshake { + /// Specific implementation response. + type Response; + /// Specific implementation stream. + type Stream: Stream; + + /// Performs the client handshake. + async fn handshake(self) -> crate::Result<(Self::Response, WebSocketClient)>; +} + +/// Manages the upgrade of already established requests into WebSocket connections. +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +pub trait WebSocketUpgrade { + /// Specific implementation response. + type Response; + /// Specific implementation stream. + type Stream: Stream; + /// Specific implementation future that resolves to [WebSocketServer]. + type Upgrade: Future>; + + /// Try to upgrade a received request to a WebSocket connection. + fn upgrade(self) -> crate::Result<(Self::Response, Self::Upgrade)>; +} + +/// Necessary to decode incoming bytes of responses or requests. +#[derive(Debug)] +pub struct HeadersBuffer<'buffer, const N: usize> { + pub(crate) headers: [httparse::Header<'buffer>; N], +} + +impl Default for HeadersBuffer<'_, N> { + #[inline] + fn default() -> Self { + Self { + headers: core::array::from_fn(|_| httparse::EMPTY_HEADER), + } + } +} diff --git a/wtx/src/web_socket/handshake/hyper.rs b/wtx/src/web_socket/handshake/hyper.rs new file mode 100644 index 00000000..766c1960 --- /dev/null +++ b/wtx/src/web_socket/handshake/hyper.rs @@ -0,0 +1,191 @@ +use crate::{ + misc::AsyncBounds, + web_socket::{ + handshake::{ + misc::{derived_key, gen_key, trim}, + WebSocketHandshake, WebSocketUpgrade, + }, + WebSocketClient, WebSocketError, + }, + Error::MissingHeader, + ExpectedHeader, ReadBuffer, +}; +#[cfg(feature = "async-trait")] +use alloc::boxed::Box; +use core::{ + borrow::BorrowMut, + future::Future, + pin::{pin, Pin}, + task::{ready, Context, Poll}, +}; +use hyper::{ + client::conn::{self, Connection}, + header::{CONNECTION, HOST, UPGRADE}, + http::{HeaderMap, HeaderValue}, + rt::Executor, + upgrade::{self, OnUpgrade, Upgraded}, + Body, Request, Response, StatusCode, +}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// A future that resolves to a WebSocket stream when the associated HTTP upgrade completes. +#[derive(Debug)] +pub struct UpgradeFutHyper { + inner: OnUpgrade, +} + +impl Future for UpgradeFutHyper { + type Output = crate::Result; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let stream = ready!(pin!(&mut self.inner).poll(cx))?; + Poll::Ready(Ok(stream)) + } +} + +/// Marker used to implement [WebSocketHandshake]. +#[derive(Debug)] +pub struct WebSocketHandshakeHyper<'executor, E, RB, S> { + /// Executor + pub executor: &'executor E, + /// Read buffer + pub rb: RB, + /// Request + pub req: Request, + /// Stream + pub stream: S, +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl<'executor, E, RB, S> WebSocketHandshake for WebSocketHandshakeHyper<'executor, E, RB, S> +where + E: AsyncBounds + Executor> + 'executor, + RB: AsyncBounds + BorrowMut, + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + type Response = Response; + type Stream = Upgraded; + + #[inline] + async fn handshake( + mut self, + ) -> crate::Result<(Self::Response, WebSocketClient)> { + let fun = || { + let authority = self.req.uri().authority().map(|el| el.as_str())?; + let mut iter = authority.split('@'); + let before_at = iter.next()?; + Some(iter.next().unwrap_or(before_at)) + }; + let host = fun().ok_or(crate::Error::MissingHost)?.parse()?; + drop( + self.req + .headers_mut() + .insert(CONNECTION, HeaderValue::from_static("upgrade")), + ); + drop(self.req.headers_mut().insert(HOST, host)); + drop( + self.req + .headers_mut() + .insert("Sec-WebSocket-Key", gen_key(&mut <_>::default()).parse()?), + ); + drop( + self.req + .headers_mut() + .insert("Sec-WebSocket-Version", HeaderValue::from_static("13")), + ); + drop( + self.req + .headers_mut() + .insert(UPGRADE, HeaderValue::from_static("websocket")), + ); + let (mut sender, conn) = conn::handshake(self.stream).await?; + self.executor.execute(conn); + let mut res = sender.send_request(self.req).await?; + verify_res(&res)?; + match upgrade::on(&mut res).await { + Err(err) => Err(err.into()), + Ok(elem) => Ok((res, WebSocketClient::new(self.rb, elem))), + } + } +} + +/// Structured used to implement [WebSocketUpgrade]. +#[derive(Debug)] +pub struct WebSocketUpgradeHyper { + /// Request + pub req: Request, +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl WebSocketUpgrade for WebSocketUpgradeHyper +where + T: AsyncBounds, +{ + type Response = Response; + type Stream = Upgraded; + type Upgrade = UpgradeFutHyper; + + #[inline] + fn upgrade(self) -> crate::Result<(Self::Response, Self::Upgrade)> { + verify_headers(self.req.headers())?; + let sws_opt = self.req.headers().get("Sec-WebSocket-Key"); + let swk = sws_opt.ok_or(MissingHeader { + expected: ExpectedHeader::SecWebSocketKey, + })?; + if self + .req + .headers() + .get("Sec-WebSocket-Version") + .map(HeaderValue::as_bytes) + != Some(b"13") + { + return Err(MissingHeader { + expected: ExpectedHeader::SecWebSocketVersion_13, + }); + } + let res = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(CONNECTION, "upgrade") + .header(UPGRADE, "websocket") + .header( + "Sec-WebSocket-Accept", + derived_key(&mut <_>::default(), swk.as_bytes()), + ) + .body(Body::from("switching to websocket protocol"))?; + let stream = UpgradeFutHyper { + inner: upgrade::on(self.req), + }; + Ok((res, stream)) + } +} + +fn verify_headers(hm: &HeaderMap) -> crate::Result<()> { + if !hm + .get("Upgrade") + .map(|h| h.as_bytes()) + .map_or(false, |h| trim(h).eq_ignore_ascii_case(b"websocket")) + { + return Err(MissingHeader { + expected: ExpectedHeader::Upgrade_WebSocket, + }); + } + if !hm + .get("Connection") + .map(|h| h.as_bytes()) + .map_or(false, |h| trim(h).eq_ignore_ascii_case(b"upgrade")) + { + return Err(MissingHeader { + expected: ExpectedHeader::Connection_Upgrade, + }); + } + Ok(()) +} + +fn verify_res(res: &Response) -> crate::Result<()> { + if res.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(WebSocketError::MissingSwitchingProtocols.into()); + } + verify_headers(res.headers())?; + Ok(()) +} diff --git a/wtx/src/web_socket/handshake/misc.rs b/wtx/src/web_socket/handshake/misc.rs new file mode 100644 index 00000000..31e53033 --- /dev/null +++ b/wtx/src/web_socket/handshake/misc.rs @@ -0,0 +1,62 @@ +use crate::misc::{from_utf8_opt, Rng}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use sha1::{Digest, Sha1}; + +pub(crate) fn derived_key<'buffer>(buffer: &'buffer mut [u8; 30], key: &[u8]) -> &'buffer str { + let mut sha1 = Sha1::new(); + sha1.update(key); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + base64_from_array(&sha1.finalize().into(), buffer) +} + +pub(crate) fn gen_key(buffer: &mut [u8; 26]) -> &str { + base64_from_array(&Rng::default()._random_u8_16(), buffer) +} + +pub(crate) fn trim(bytes: &[u8]) -> &[u8] { + trim_end(trim_begin(bytes)) +} + +#[allow( + // False positive + clippy::arithmetic_side_effects, + // Buffer has enough capacity and `base64` already returns a valid string + clippy::unwrap_used +)] +fn base64_from_array<'output, const I: usize, const O: usize>( + input: &[u8; I], + output: &'output mut [u8; O], +) -> &'output str { + fn div_ceil(x: usize, y: usize) -> usize { + let fun = || { + let num = x.checked_add(y)?.checked_sub(1)?; + num.checked_div(y) + }; + fun().unwrap_or_default() + } + assert!(O >= div_ceil(I, 3).wrapping_mul(4)); + let len = STANDARD.encode_slice(input, output).unwrap(); + from_utf8_opt(output.get(..len).unwrap_or_default()).unwrap() +} + +fn trim_begin(mut bytes: &[u8]) -> &[u8] { + while let [first, rest @ ..] = bytes { + if first.is_ascii_whitespace() { + bytes = rest; + } else { + break; + } + } + bytes +} + +fn trim_end(mut bytes: &[u8]) -> &[u8] { + while let [rest @ .., last] = bytes { + if last.is_ascii_whitespace() { + bytes = rest; + } else { + break; + } + } + bytes +} diff --git a/wtx/src/web_socket/handshake/raw.rs b/wtx/src/web_socket/handshake/raw.rs new file mode 100644 index 00000000..0ca3da13 --- /dev/null +++ b/wtx/src/web_socket/handshake/raw.rs @@ -0,0 +1,259 @@ +use crate::{ + misc::AsyncBounds, + web_socket::{ + handshake::{ + misc::{derived_key, gen_key, trim}, + HeadersBuffer, WebSocketAccept, WebSocketHandshake, + }, + FrameBufferVec, WebSocketClient, WebSocketError, WebSocketServer, + }, + ExpectedHeader, ReadBuffer, Stream, UriParts, +}; +#[cfg(feature = "async-trait")] +use alloc::boxed::Box; +use core::borrow::BorrowMut; +use httparse::{Header, Status}; + +const MAX_READ_HEADER_LEN: usize = 64; +const MAX_READ_LEN: usize = 2 * 1024; + +/// Marker used to implement [WebSocketAccept]. +#[derive(Debug)] +pub struct WebSocketAcceptRaw<'any, RB, S> { + /// Frame buffer + pub fb: &'any mut FrameBufferVec, + /// Headers buffer + pub headers_buffer: &'any mut HeadersBuffer<'any, 3>, + /// Key buffer + pub key_buffer: &'any mut [u8; 30], + /// Read buffer + pub rb: RB, + /// Stream + pub stream: S, +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl<'any, RB, S> WebSocketAccept for WebSocketAcceptRaw<'any, RB, S> +where + RB: AsyncBounds + BorrowMut, + S: AsyncBounds + Stream, +{ + type Response = crate::Response<'any, 'any>; + type Stream = S; + + #[inline] + async fn accept( + mut self, + ) -> crate::Result<(Self::Response, WebSocketServer)> { + self.fb.set_params_through_expansion(0, 0, MAX_READ_LEN); + let mut read = 0; + let (key, version) = loop { + let read_buffer = self.fb.payload_mut().get_mut(read..).unwrap_or_default(); + let local_read = self.stream.read(read_buffer).await?; + read = read.wrapping_add(local_read); + if read > MAX_READ_LEN { + return Err(crate::Error::VeryLargeHttp); + } + if local_read == 0 { + return Err(crate::Error::UnexpectedEOF); + } + let working_buffer = self.fb.payload().get(..read).unwrap_or_default(); + let mut req_buffer = [httparse::EMPTY_HEADER; MAX_READ_HEADER_LEN]; + let mut req = httparse::Request::new(&mut req_buffer); + match req.parse(working_buffer)? { + Status::Complete(_) => { + if !req + .method + .map_or(false, |el| trim(el.as_bytes()).eq_ignore_ascii_case(b"get")) + { + return Err(crate::Error::UnexpectedHttpMethod); + } + verify_common_header(req.headers)?; + if !has_header_key_and_value(req.headers, "sec-websocket-version", b"13") { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::SecWebSocketVersion_13, + }); + }; + let Some(key) = req.headers.iter().find_map(|el| { + (el.name.eq_ignore_ascii_case("sec-websocket-key")).then_some(el.value) + }) else { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::SecWebSocketKey, + }); + }; + break (key, req.version); + } + Status::Partial => {} + } + }; + self.headers_buffer.headers[0] = Header { + name: "Connection", + value: b"Upgrade", + }; + self.headers_buffer.headers[1] = Header { + name: "Sec-WebSocket-Accept", + value: derived_key(self.key_buffer, key).as_bytes(), + }; + self.headers_buffer.headers[2] = Header { + name: "Upgrade", + value: b"websocket", + }; + let mut httparse_res = httparse::Response::new(&mut self.headers_buffer.headers); + httparse_res.code = Some(101); + httparse_res.version = version; + let res = crate::Response::new(&[], httparse_res); + let res_bytes = build_101_res(self.fb, res.headers()); + self.stream.write_all(res_bytes).await?; + Ok((res, WebSocketServer::new(self.rb, self.stream))) + } +} + +/// Marker used to implement [WebSocketHandshake]. +#[derive(Debug)] +pub struct WebSocketHandshakeRaw<'any, RB, S> { + /// Frame buffer + pub fb: &'any mut FrameBufferVec, + /// Headers buffer + pub headers_buffer: &'any mut HeadersBuffer<'any, MAX_READ_HEADER_LEN>, + /// Read buffer + pub rb: RB, + /// Stream + pub stream: S, + /// Uri + pub uri: &'any str, +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl<'any, RB, S> WebSocketHandshake for WebSocketHandshakeRaw<'any, RB, S> +where + RB: AsyncBounds + BorrowMut, + S: AsyncBounds + Stream, +{ + type Response = crate::Response<'any, 'any>; + type Stream = S; + + #[inline] + async fn handshake( + mut self, + ) -> crate::Result<(Self::Response, WebSocketClient)> { + self.fb.set_params_through_expansion(0, 0, MAX_READ_LEN); + let key_buffer = &mut <_>::default(); + let (key, req) = build_upgrade_req(self.fb, key_buffer, self.uri); + self.stream.write_all(req).await?; + let mut read = 0; + let res_len = loop { + let read_buffer = self.fb.payload_mut().get_mut(read..).unwrap_or_default(); + let local_read = self.stream.read(read_buffer).await?; + read = read.wrapping_add(local_read); + if read > MAX_READ_LEN { + return Err(crate::Error::VeryLargeHttp); + } + if local_read == 0 { + return Err(crate::Error::UnexpectedEOF); + } + let mut headers = [httparse::EMPTY_HEADER; MAX_READ_HEADER_LEN]; + let mut httparse_res = httparse::Response::new(&mut headers); + match httparse_res.parse(self.fb.payload().get(..read).unwrap_or_default())? { + Status::Complete(el) => break el, + Status::Partial => {} + } + }; + let mut httparse_res = httparse::Response::new(&mut self.headers_buffer.headers); + let _rslt = httparse_res.parse(self.fb.payload().get(..res_len).unwrap_or_default())?; + let res = crate::Response::new(&[], httparse_res); + if res.code() != Some(101) { + return Err(WebSocketError::MissingSwitchingProtocols.into()); + } + verify_common_header(res.headers())?; + if !has_header_key_and_value( + res.headers(), + "sec-websocket-accept", + derived_key(&mut <_>::default(), key.as_bytes()).as_bytes(), + ) { + return Err(crate::Error::MissingHeader { + expected: crate::ExpectedHeader::SecWebSocketKey, + }); + } + let idx = read.wrapping_sub(res_len); + self.rb + .borrow_mut() + .set_indices_through_expansion(0, 0, idx); + self.rb + .borrow_mut() + .after_current_mut() + .get_mut(..idx) + .unwrap_or_default() + .copy_from_slice(self.fb.payload().get(res_len..read).unwrap_or_default()); + Ok((res, WebSocketClient::new(self.rb, self.stream))) + } +} + +fn build_upgrade_req<'fb, 'kb>( + fb: &'fb mut FrameBufferVec, + key_buffer: &'kb mut [u8; 26], + uri: &str, +) -> (&'kb str, &'fb [u8]) { + let uri_parts = UriParts::from(uri); + let key = gen_key(key_buffer); + + let idx = fb.buffer().len(); + fb.buffer_mut().extend(b"GET "); + fb.buffer_mut().extend(uri_parts.href.as_bytes()); + fb.buffer_mut().extend(b" HTTP/1.1\r\n"); + + fb.buffer_mut().extend(b"Connection: Upgrade\r\n"); + fb.buffer_mut().extend(b"Host: "); + fb.buffer_mut().extend(uri_parts.host.as_bytes()); + fb.buffer_mut().extend(b"\r\n"); + fb.buffer_mut().extend(b"Sec-WebSocket-Key: "); + fb.buffer_mut().extend(key.as_bytes()); + fb.buffer_mut().extend(b"\r\n"); + fb.buffer_mut().extend(b"Sec-WebSocket-Version: 13\r\n"); + fb.buffer_mut().extend(b"Upgrade: websocket\r\n"); + + fb.buffer_mut().extend(b"\r\n"); + + (key, fb.buffer().get(idx..).unwrap_or_default()) +} + +fn build_101_res<'fb>(fb: &'fb mut FrameBufferVec, headers: &[Header<'_>]) -> &'fb [u8] { + let idx = fb.buffer().len(); + fb.buffer_mut() + .extend(b"HTTP/1.1 101 Switching Protocols\r\n"); + for header in headers { + fb.buffer_mut().extend(header.name.as_bytes()); + fb.buffer_mut().extend(b": "); + fb.buffer_mut().extend(header.value); + fb.buffer_mut().extend(b"\r\n"); + } + fb.buffer_mut().extend(b"\r\n"); + fb.buffer().get(idx..).unwrap_or_default() +} + +fn has_header_key_and_value(buffer: &[Header<'_>], key: &str, value: &[u8]) -> bool { + buffer + .iter() + .find_map(|h| { + let has_key = trim(h.name.as_bytes()).eq_ignore_ascii_case(key.as_bytes()); + let has_value = h + .value + .split(|el| el == &b',') + .any(|el| trim(el).eq_ignore_ascii_case(value)); + (has_key && has_value).then_some(true) + }) + .unwrap_or(false) +} + +fn verify_common_header(buffer: &[Header<'_>]) -> crate::Result<()> { + if !has_header_key_and_value(buffer, "connection", b"upgrade") { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::Connection_Upgrade, + }); + } + if !has_header_key_and_value(buffer, "upgrade", b"websocket") { + return Err(crate::Error::MissingHeader { + expected: ExpectedHeader::Upgrade_WebSocket, + }); + } + Ok(()) +} diff --git a/wtx/src/web_socket/handshake/tests.rs b/wtx/src/web_socket/handshake/tests.rs new file mode 100644 index 00000000..32eb38ee --- /dev/null +++ b/wtx/src/web_socket/handshake/tests.rs @@ -0,0 +1,336 @@ +macro_rules! call_tests { + (($ty:ident, $fb:expr, $ws:expr), $($struct:ident),+ $(,)?) => { + $( + println!("***** {} - {}", stringify!($ty), stringify!($struct)); + $struct::$ty($fb, $ws).await; + tokio::time::sleep(Duration::from_millis(200)).await; + )+ + }; +} + +use crate::web_socket::{ + frame::FrameMutVec, + handshake::{WebSocketAccept, WebSocketAcceptRaw, WebSocketHandshake, WebSocketHandshakeRaw}, + FrameBufferVec, OpCode, WebSocketClientOwned, WebSocketServerOwned, +}; +use core::{ + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::net::{TcpListener, TcpStream}; + +static HAS_SERVER_FINISHED: AtomicBool = AtomicBool::new(false); + +#[tokio::test] +async fn client_and_server_frames() { + let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + + let _server_jh = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut fb = <_>::default(); + let (_, mut ws) = WebSocketAcceptRaw { + fb: &mut fb, + headers_buffer: &mut <_>::default(), + rb: <_>::default(), + key_buffer: &mut <_>::default(), + stream, + } + .accept() + .await + .unwrap(); + call_tests!( + (server, &mut fb, &mut ws), + FragmentedMessage, + LargeFragmentedMessage, + PingAndText, + SeveralBytes, + TwoPings, + // Last + HelloAndGoodbye, + ); + HAS_SERVER_FINISHED.store(true, Ordering::Relaxed); + }); + + let mut fb = <_>::default(); + let (_res, mut ws) = WebSocketHandshakeRaw { + fb: &mut fb, + headers_buffer: &mut <_>::default(), + rb: <_>::default(), + stream: TcpStream::connect("127.0.0.1:8080").await.unwrap(), + uri: "http://127.0.0.1:8080", + } + .handshake() + .await + .unwrap(); + call_tests!( + (client, &mut fb, &mut ws), + FragmentedMessage, + LargeFragmentedMessage, + PingAndText, + SeveralBytes, + TwoPings, + // Last + HelloAndGoodbye, + ); + + let mut has_server_finished = false; + for _ in 0..15 { + let local_has_server_finished = HAS_SERVER_FINISHED.load(Ordering::Relaxed); + if local_has_server_finished { + has_server_finished = local_has_server_finished; + break; + } + tokio::time::sleep(Duration::from_millis(200)).await; + } + if !has_server_finished { + panic!("Server didn't finish"); + } +} + +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +trait Test { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned); + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned); +} + +struct FragmentedMessage; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for FragmentedMessage { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + ws.write_frame(&mut FrameMutVec::new_unfin(fb, OpCode::Text, b"1").unwrap()) + .await + .unwrap(); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Continuation, b"23").unwrap()) + .await + .unwrap(); + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + let text = ws.read_msg(fb).await.unwrap(); + assert_eq!(OpCode::Text, text.op_code()); + assert_eq!(b"123", text.fb().payload()); + } +} + +struct HelloAndGoodbye; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for HelloAndGoodbye { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + let hello = ws.read_frame(fb).await.unwrap(); + assert_eq!(OpCode::Text, hello.op_code()); + assert_eq!(b"Hello!", hello.fb().payload()); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Text, b"Goodbye!").unwrap()) + .await + .unwrap(); + assert_eq!(OpCode::Close, ws.read_frame(fb).await.unwrap().op_code()); + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Text, b"Hello!").unwrap()) + .await + .unwrap(); + assert_eq!( + ws.read_frame(&mut *fb).await.unwrap().fb().payload(), + b"Goodbye!" + ); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Close, &[]).unwrap()) + .await + .unwrap(); + } +} + +struct LargeFragmentedMessage; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for LargeFragmentedMessage { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + async fn write( + frame: &mut FrameMutVec<'_, true>, + ws: &mut WebSocketClientOwned, + ) { + ws.write_frame(frame).await.unwrap(); + } + let bytes = vec![51; 256 * 1024]; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Text, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_fin(fb, OpCode::Continuation, &bytes).unwrap(), + ws, + ) + .await; + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + let text = ws.read_msg(fb).await.unwrap(); + assert_eq!(OpCode::Text, text.op_code()); + assert_eq!(&vec![51; 10 * 256 * 1024], text.fb().payload()); + } +} + +struct PingAndText; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for PingAndText { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Ping, b"").unwrap()) + .await + .unwrap(); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Text, b"ipat").unwrap()) + .await + .unwrap(); + assert_eq!(OpCode::Pong, ws.read_frame(fb).await.unwrap().op_code()); + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + assert_eq!(b"ipat", ws.read_frame(fb).await.unwrap().fb().payload()); + } +} + +struct SeveralBytes; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for SeveralBytes { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + async fn write( + frame: &mut FrameMutVec<'_, true>, + ws: &mut WebSocketClientOwned, + ) { + ws.write_frame(frame).await.unwrap(); + } + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Text, &[206]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[186]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[225]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[189]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[185]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[207]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[131]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[206]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[188]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[206]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_unfin(fb, OpCode::Continuation, &[181]).unwrap(), + ws, + ) + .await; + write( + &mut FrameMutVec::new_fin(fb, OpCode::Continuation, &[]).unwrap(), + ws, + ) + .await; + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + let text = ws.read_msg(fb).await.unwrap(); + assert_eq!(OpCode::Text, text.op_code()); + assert_eq!("κόσμε".as_bytes(), text.fb().payload()); + } +} + +struct TwoPings; +#[cfg_attr(feature = "async-trait", async_trait::async_trait)] +impl Test for TwoPings { + async fn client(fb: &mut FrameBufferVec, ws: &mut WebSocketClientOwned) { + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Ping, b"0").unwrap()) + .await + .unwrap(); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Ping, b"1").unwrap()) + .await + .unwrap(); + let _0 = ws.read_frame(fb).await.unwrap(); + assert_eq!(OpCode::Pong, _0.op_code()); + assert_eq!(b"0", _0.fb().payload()); + let _1 = ws.read_frame(fb).await.unwrap(); + assert_eq!(OpCode::Pong, _1.op_code()); + assert_eq!(b"1", _1.fb().payload()); + ws.write_frame(&mut FrameMutVec::new_fin(fb, OpCode::Text, b"").unwrap()) + .await + .unwrap(); + } + + async fn server(fb: &mut FrameBufferVec, ws: &mut WebSocketServerOwned) { + let _0 = ws.read_frame(fb).await.unwrap(); + assert_eq!(OpCode::Text, _0.op_code()); + assert_eq!(b"", _0.fb().payload()); + } +} diff --git a/wtx/src/web_socket/mask.rs b/wtx/src/web_socket/mask.rs new file mode 100644 index 00000000..f6d18743 --- /dev/null +++ b/wtx/src/web_socket/mask.rs @@ -0,0 +1,100 @@ +/// Unmasks a sequence of bytes using the given 4-byte `mask`. +#[inline] +pub fn unmask(bytes: &mut [u8], mask: [u8; 4]) { + let mut mask_u32 = u32::from_ne_bytes(mask); + #[allow(unsafe_code)] + // SAFETY: Changing a sequence of `u8` to `u32` should be fine + let (prefix, words, suffix) = unsafe { bytes.align_to_mut::() }; + unmask_u8_slice(prefix, mask); + let mut shift = u32::try_from(prefix.len() & 3).unwrap_or_default(); + if shift > 0 { + shift = shift.wrapping_mul(8); + if cfg!(target_endian = "big") { + mask_u32 = mask_u32.rotate_left(shift); + } else { + mask_u32 = mask_u32.rotate_right(shift); + } + } + unmask_u32_slice(words, mask_u32); + unmask_u8_slice(suffix, mask_u32.to_ne_bytes()); +} + +#[allow( + // Index will always by in-bounds. + clippy::indexing_slicing +)] +fn unmask_u8_slice(bytes: &mut [u8], mask: [u8; 4]) { + for (idx, elem) in bytes.iter_mut().enumerate() { + *elem ^= mask[idx & 3]; + } +} + +fn unmask_u32_slice(bytes: &mut [u32], mask: u32) { + macro_rules! loop_chunks { + ($bytes:expr, $mask:expr, $($elem:ident),* $(,)?) => {{ + let mut iter; + #[cfg(feature = "async-trait")] + { + iter = $bytes.chunks_exact_mut(0 $( + { let $elem = 1; $elem })*); + while let Some([$($elem,)*]) = iter.next() { + $( *$elem ^= $mask; )* + } + } + #[cfg(not(feature = "async-trait"))] + { + iter = $bytes.array_chunks_mut::<{ 0 $( + { let $elem = 1; $elem })* }>(); + for [$($elem,)*] in iter.by_ref() { + $( *$elem ^= $mask; )* + } + } + iter + }}; + } + loop_chunks!(bytes, mask, _1, _2, _3, _4) + .into_remainder() + .iter_mut() + .for_each(|elem| *elem ^= mask); +} + +#[cfg(test)] +mod tests { + use crate::{misc::Rng, web_socket::mask::unmask}; + use alloc::{vec, vec::Vec}; + + #[test] + fn test_unmask() { + let mut payload = [0u8; 33]; + let mask = [1, 2, 3, 4]; + unmask(&mut payload, mask); + assert_eq!( + &payload, + &[ + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 1, 2, 3, 4, 1 + ] + ); + } + + #[test] + fn length_variation_unmask() { + for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] { + let mut payload = vec![0u8; *len]; + let mask = [1, 2, 3, 4]; + unmask(&mut payload, mask); + + let expected = (0..*len).map(|i| (i & 3) as u8 + 1).collect::>(); + assert_eq!(payload, expected); + } + } + + #[test] + fn length_variation_unmask_2() { + for len in &[0, 2, 3, 8, 16, 18, 31, 32, 40] { + let mut payload = vec![0u8; *len]; + let mask = Rng::default().random_u8_4(); + unmask(&mut payload, mask); + let expected = (0..*len).map(|i| mask[i & 3]).collect::>(); + assert_eq!(payload, expected); + } + } +} diff --git a/wtx/src/web_socket/op_code.rs b/wtx/src/web_socket/op_code.rs new file mode 100644 index 00000000..0bca667e --- /dev/null +++ b/wtx/src/web_socket/op_code.rs @@ -0,0 +1,68 @@ +macro_rules! create_enum { + ($(#[$meta:meta])* $vis:vis enum $name:ident { + $($(#[$variant_meta:meta])* $variant_ident:ident = $variant_value:expr,)* + }) => { + $(#[$meta])* + $vis enum $name { + $($(#[$variant_meta])* $variant_ident = $variant_value,)* + } + + impl From<$name> for u8 { + #[inline] + fn from(from: $name) -> Self { + match from { + $($name::$variant_ident => $variant_value,)* + } + } + } + + impl TryFrom for $name { + type Error = crate::Error; + + #[inline] + fn try_from(from: u8) -> Result { + match from { + $(x if x == u8::from($name::$variant_ident) => Ok($name::$variant_ident),)* + _ => Err(crate::web_socket::WebSocketError::InvalidOpCodeByte { provided: from }.into()), + } + } + } + } +} + +create_enum! { + /// Defines how to interpret the payload data. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + #[repr(u8)] + pub enum OpCode { + /// Continuation of a previous frame. + Continuation = 0b0000_0000, + /// UTF-8 text. + Text = 0b0000_0001, + /// Opaque bytes. + Binary = 0b0000_0010, + /// Connection is closed. + Close = 0b0000_1000, + /// Test reachability. + Ping = 0b0000_1001, + /// Response of a ping frame. + Pong = 0b0000_1010, + } +} + +impl OpCode { + #[inline] + pub(crate) fn is_continuation(self) -> bool { + matches!(self, OpCode::Continuation) + } + + #[inline] + pub(crate) fn is_control(self) -> bool { + matches!(self, OpCode::Close | OpCode::Ping | OpCode::Pong) + } + + #[inline] + pub(crate) fn is_text(self) -> bool { + matches!(self, OpCode::Text) + } +} diff --git a/wtx/src/web_socket/web_socket_error.rs b/wtx/src/web_socket/web_socket_error.rs new file mode 100644 index 00000000..8feba292 --- /dev/null +++ b/wtx/src/web_socket/web_socket_error.rs @@ -0,0 +1,35 @@ +/// Errors related to the WebSocket module +#[derive(Debug)] +pub enum WebSocketError { + /// Received close frame has invalid parameters. + InvalidCloseFrame, + /// Header indices are out-of-bounds or the number of bytes are too small. + InvalidFrameHeaderBounds, + /// No op code can be represented with the provided byte. + InvalidOpCodeByte { + /// Provided byte + provided: u8, + }, + /// Payload indices are out-of-bounds or the number of bytes are too small. + InvalidPayloadBounds, + + /// Server received a frame without a mask. + MissingFrameMask, + /// Status code is expected to be + MissingSwitchingProtocols, + + /// Received control frame wasn't supposed to be fragmented. + UnexpectedFragmentedControlFrame, + /// The first frame of a message is a continuation or the following frames are not a + /// continuation. + UnexpectedMessageFrame, + + /// It it not possible to read a frame of a connection that was previously closed. + ConnectionClosed, + /// Reserved bits are not zero. + ReservedBitsAreNotZero, + /// Control frames have a maximum allowed size. + VeryLargeControlFrame, + /// Frame payload exceeds the defined threshold. + VeryLargePayload, +}