diff --git a/.gitignore b/.gitignore index 29dde039..b347bd3e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ **/curl **/hpack-test-case **/profile.json -**/proptest-regressions **/rustc-ice-*.txt **/target mdbook-target diff --git a/.scripts/internal-tests-0.sh b/.scripts/internal-tests-0.sh index 9cf6e244..cd04eee4 100755 --- a/.scripts/internal-tests-0.sh +++ b/.scripts/internal-tests-0.sh @@ -5,7 +5,7 @@ $rt rustfmt $rt clippy -Aclippy::little-endian-bytes,-Aclippy::panic-in-result-fn -cargo miri test --features http2,postgres,web-socket -p wtx +cargo miri test --features http2,postgres,web-socket -p wtx two_sta # WTX @@ -36,7 +36,6 @@ $rt test-with-features wtx matchit $rt test-with-features wtx memchr $rt test-with-features wtx pool $rt test-with-features wtx postgres -$rt test-with-features wtx proptest $rt test-with-features wtx quick-protobuf $rt test-with-features wtx rand_chacha $rt test-with-features wtx ring @@ -49,7 +48,6 @@ $rt test-with-features wtx sha1 $rt test-with-features wtx sha2 $rt test-with-features wtx simdutf8 $rt test-with-features wtx std -$rt test-with-features wtx test-strategy $rt test-with-features wtx tokio $rt test-with-features wtx tokio-rustls $rt test-with-features wtx tracing @@ -61,7 +59,6 @@ $rt test-with-features wtx x509-certificate $rt test-with-features wtx _async-tests $rt test-with-features wtx _bench $rt test-with-features wtx _integration-tests -$rt test-with-features wtx _proptest $rt test-with-features wtx _tracing-tree # WTX Macros diff --git a/.scripts/internal-tests-all.sh b/.scripts/internal-tests.sh similarity index 100% rename from .scripts/internal-tests-all.sh rename to .scripts/internal-tests.sh diff --git a/.scripts/podman-start.sh b/.scripts/podman-start.sh index f2b3f944..38081e43 100755 --- a/.scripts/podman-start.sh +++ b/.scripts/podman-start.sh @@ -5,7 +5,7 @@ podman run \ -e POSTGRES_PASSWORD=wtx \ -p 5432:5432 \ -v ./.test-utils/postgres.sh:/docker-entrypoint-initdb.d/setup.sh \ - docker.io/library/postgres:16 + docker.io/library/postgres:17 # Utils diff --git a/.test-utils/docker-compose.yml b/.test-utils/docker-compose.yml index c8edacf0..b92e3c80 100644 --- a/.test-utils/docker-compose.yml +++ b/.test-utils/docker-compose.yml @@ -5,7 +5,7 @@ services: environment: POSTGRES_DB: wtx POSTGRES_PASSWORD: wtx - image: postgres:16 + image: postgres:17 ports: - 5432:5432 volumes: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 397a80fc..01eb5e6c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing -Before submitting a PR, you should probably run `./scripts/internal-tests-all.sh` and/or `./scripts/intergration-tests.sh` to make sure everything is fine. +Before submitting a PR, you should probably run `./scripts/internal-tests.sh` and/or `./scripts/intergration-tests.sh` to make sure everything is fine. Integration tests interact with external programs like `podman` or require an internet connection, therefore, they usually aren't good candidates for offline development. On the other hand, internal tests are composed by unit tests, code formatting, `clippy` lints and fuzzing targets. diff --git a/Cargo.lock b/Cargo.lock index 75d9d5ac..78ae4660 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,19 +19,18 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aead" -version = "0.5.2" +version = "0.6.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +checksum = "b5f451b77e2f92932dc411da6ef9f3d33efad68a6f14a7a83e559453458e85ac" dependencies = [ - "crypto-common", - "generic-array", + "crypto-common 0.2.0-rc.1", ] [[package]] name = "aes" -version = "0.8.4" +version = "0.9.0-pre.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +checksum = "e7856582c758ade85d71daf27ec6bcea6c1c73913692b07b8dffea2dc03531c9" dependencies = [ "cfg-if", "cipher", @@ -40,9 +39,9 @@ dependencies = [ [[package]] name = "aes-gcm" -version = "0.10.3" +version = "0.11.0-pre.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +checksum = "0cce27af05d45b901bb28da33ff8b2b2b2044f595b24fc0f36d4882dae91d484" dependencies = [ "aead", "aes", @@ -95,9 +94,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "arbitrary" @@ -169,12 +168,6 @@ dependencies = [ "smallvec", ] -[[package]] -name = "bitflags" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" - [[package]] name = "bitvec" version = "1.0.1" @@ -205,6 +198,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.4.0-rc.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d7992d59cd95a984bde8833d4d025886eec3718777971ad15c58df0b070254a" +dependencies = [ + "hybrid-array", +] + [[package]] name = "borsh" version = "1.5.1" @@ -225,7 +227,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", "syn_derive", ] @@ -307,11 +309,11 @@ dependencies = [ [[package]] name = "cipher" -version = "0.4.4" +version = "0.5.0-pre.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +checksum = "5b1425e6ce000f05a73096556cabcfb6a10a3ffe3bb4d75416ca8f00819c0b6a" dependencies = [ - "crypto-common", + "crypto-common 0.2.0-rc.1", "inout", ] @@ -354,7 +356,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -403,11 +405,22 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.0-rc.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0b8ce8218c97789f16356e7896b3714f26c2ee1079b79c0b7ae7064bb9089fa" +dependencies = [ + "getrandom", + "hybrid-array", + "rand_core", +] + [[package]] name = "ctr" -version = "0.9.2" +version = "0.10.0-pre.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +checksum = "77e1482d284b80d7fddb211666d513dc5e23b0cc3a03ad398ff70543827c789f" dependencies = [ "cipher", ] @@ -430,7 +443,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -440,7 +453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", - "crypto-common", + "crypto-common 0.1.6", "subtle", ] @@ -508,9 +521,9 @@ dependencies = [ [[package]] name = "ghash" -version = "0.5.1" +version = "0.6.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +checksum = "3b92860fda25ab571512af210134cde2c42732cd53253bcee3f21b288b7afbc4" dependencies = [ "opaque-debug", "polyval", @@ -580,6 +593,15 @@ version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +[[package]] +name = "hybrid-array" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45a9a965bb102c1c891fb017c09a05c965186b1265a207640f323ddd009f9deb" +dependencies = [ + "typenum", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -615,11 +637,12 @@ dependencies = [ [[package]] name = "inout" -version = "0.1.3" +version = "0.2.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +checksum = "bbc33218cf9ce7b927426ee4ad3501bcc5d8c26bf5fb4a82849a083715aca427" dependencies = [ - "generic-array", + "block-padding", + "hybrid-array", ] [[package]] @@ -669,12 +692,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "libm" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a00419de735aac21d53b0de5ce2c03bd3627277cf471300f27ebc89f7d828047" - [[package]] name = "libz-rs-sys" version = "0.3.1" @@ -764,7 +781,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", - "libm", ] [[package]] @@ -827,9 +843,9 @@ checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "polyval" -version = "0.6.2" +version = "0.7.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +checksum = "b01cbf5c028f9f862c6f7f5a5544307d7858634df190488d432ec470c8fbc063" dependencies = [ "cfg-if", "cpufeatures", @@ -887,22 +903,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "proptest" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c2511913b88df1637da85cc8d96ec8e43a3f8bb8ccb71ee1ac240d6f3df58d" -dependencies = [ - "bitflags", - "lazy_static", - "num-traits", - "rand", - "rand_chacha", - "rand_xorshift", - "regex-syntax 0.8.5", - "unarray", -] - [[package]] name = "ptr_meta" version = "0.1.4" @@ -977,15 +977,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "rand_xorshift" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" -dependencies = [ - "rand_core", -] - [[package]] name = "regex" version = "1.11.1" @@ -1093,7 +1084,6 @@ dependencies = [ "borsh", "bytes", "num-traits", - "proptest", "rand", "rkyv", "serde", @@ -1175,7 +1165,7 @@ checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1283,29 +1273,6 @@ dependencies = [ "der", ] -[[package]] -name = "structmeta" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e1575d8d40908d70f6fd05537266b90ae71b15dbbe7a8b7dffa2b759306d329" -dependencies = [ - "proc-macro2", - "quote", - "structmeta-derive", - "syn 2.0.85", -] - -[[package]] -name = "structmeta-derive" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.85", -] - [[package]] name = "subtle" version = "2.6.1" @@ -1325,9 +1292,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" dependencies = [ "proc-macro2", "quote", @@ -1343,7 +1310,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1367,36 +1334,24 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "test-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf41af45e3f54cc184831d629d41d5b2bda8297e29c81add7ae4f362ed5e01b" -dependencies = [ - "proc-macro2", - "quote", - "structmeta", - "syn 2.0.85", -] - [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1448,7 +1403,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1515,7 +1470,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1575,12 +1530,6 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unarray" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" - [[package]] name = "unicode-ident" version = "1.0.13" @@ -1589,11 +1538,11 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "universal-hash" -version = "0.5.1" +version = "0.6.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +checksum = "3517d72c5ca6d60f9f2e85d2c772e2652830062a685105a528d19dd823cf87d5" dependencies = [ - "crypto-common", + "crypto-common 0.2.0-rc.1", "subtle", ] @@ -1649,7 +1598,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", "wasm-bindgen-shared", ] @@ -1671,7 +1620,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1811,7 +1760,7 @@ dependencies = [ "borsh", "chrono", "cl-aux", - "crypto-common", + "crypto-common 0.1.6", "digest", "fastrand", "flate2", @@ -1821,7 +1770,6 @@ dependencies = [ "httparse", "matchit", "memchr", - "proptest", "quick-protobuf", "rand_chacha", "rand_core", @@ -1834,7 +1782,6 @@ dependencies = [ "sha1", "sha2", "simdutf8", - "test-strategy", "tokio", "tokio-rustls", "tracing", @@ -1937,7 +1884,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] @@ -1957,7 +1904,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.86", ] [[package]] diff --git a/README.md b/README.md index 9db13509..83bbcbdb 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ A collection of different transport implementations and related tools focused pr 4. [gRPC Client/Server](https://c410-f3r.github.io/wtx/grpc/index.html) 5. [HTTP Client Framework](https://c410-f3r.github.io/wtx/http-client-framework/index.html) 6. [HTTP Server Framework](https://c410-f3r.github.io/wtx/http-server-framework/index.html) -7. [HTTP2 Client/Server](https://c410-f3r.github.io/wtx/http2/index.html) +7. [HTTP/2 Client/Server](https://c410-f3r.github.io/wtx/http2/index.html) 8. [Pool Manager](https://c410-f3r.github.io/wtx/pool/index.html) 9. [UI tools](https://c410-f3r.github.io/wtx/ui-tools/index.html) 10. [WebSocket Client/Server](https://c410-f3r.github.io/wtx/web-socket/index.html) diff --git a/wtx-docs/src/SUMMARY.md b/wtx-docs/src/SUMMARY.md index 25ab9070..544c9782 100644 --- a/wtx-docs/src/SUMMARY.md +++ b/wtx-docs/src/SUMMARY.md @@ -4,7 +4,7 @@ - [Client API Framework](client-api-framework/README.md) - [Database Client](database-client/README.md) - [Database Schema Manager](database-schema-manager/README.md) -- [Grpc](grpc/README.md) +- [gRPC](grpc/README.md) - [HTTP/2](http2/README.md) - [HTTP Client Framework](http-client-framework/README.md) - [HTTP Server Framework](http-server-framework/README.md) diff --git a/wtx-docs/src/client-api-framework/README.md b/wtx-docs/src/client-api-framework/README.md index 0da719cc..228a0d34 100644 --- a/wtx-docs/src/client-api-framework/README.md +++ b/wtx-docs/src/client-api-framework/README.md @@ -4,7 +4,7 @@ A flexible client API framework for writing asynchronous, fast, organizable, sca Checkout the `wtx-apis` project to see a collection of APIs based on `wtx`. -To use this functionality, it necessary to activate the `client-api-framework` feature. +To use this functionality, it is necessary to activate the `client-api-framework` feature. ## Objective diff --git a/wtx-docs/src/database-client/README.md b/wtx-docs/src/database-client/README.md index 4a1ccb86..31ab4031 100644 --- a/wtx-docs/src/database-client/README.md +++ b/wtx-docs/src/database-client/README.md @@ -7,7 +7,7 @@ At the current time PostgreSQL is the only supported database. Implements . -To use this functionality, it necessary to activate the `postgres` feature. +To use this functionality, it is necessary to activate the `postgres` feature. ![PostgreSQL Benchmark](https://i.imgur.com/vf2tYxY.jpeg) diff --git a/wtx-docs/src/database-schema-manager/README.md b/wtx-docs/src/database-schema-manager/README.md index e44b3c89..eca457f4 100644 --- a/wtx-docs/src/database-schema-manager/README.md +++ b/wtx-docs/src/database-schema-manager/README.md @@ -2,7 +2,7 @@ Embedded and CLI workflows using raw SQL commands. A schema manager is a tool thats allows developers to define, track and apply changes to database structures over time, ensuring consistency across different environments. -Activation feature is called `schema-manager`. +To use this functionality, it is necessary to activate the `schema-manager` feature. ## CLI diff --git a/wtx-docs/src/grpc/README.md b/wtx-docs/src/grpc/README.md index e2026a14..5c03e438 100644 --- a/wtx-docs/src/grpc/README.md +++ b/wtx-docs/src/grpc/README.md @@ -7,7 +7,7 @@ Basic implementation that currently only supports unary calls. gRPC is an high-p Due to the lack of an official parser, the definitions of a `Service` must be manually typed. -To use this functionality, it necessary to activate the `grpc` feature. +To use this functionality, it is necessary to activate the `grpc` feature. ## Client Example diff --git a/wtx-docs/src/http-client-framework/README.md b/wtx-docs/src/http-client-framework/README.md index 2a403984..ecaea0d0 100644 --- a/wtx-docs/src/http-client-framework/README.md +++ b/wtx-docs/src/http-client-framework/README.md @@ -2,7 +2,7 @@ High-level pool of HTTP clients that currently only supports HTTP/2. Allows multiple connections that can be referenced in concurrent scenarios. -To use this functionality, it necessary to activate the `http-client-framework` feature. +To use this functionality, it is necessary to activate the `http-client-framework` feature. ## Example diff --git a/wtx-docs/src/http-server-framework/README.md b/wtx-docs/src/http-server-framework/README.md index 534549c2..f648c885 100644 --- a/wtx-docs/src/http-server-framework/README.md +++ b/wtx-docs/src/http-server-framework/README.md @@ -9,9 +9,9 @@ A small and fast to compile framework that can interact with many built-in featu * URI router * WebSocket -If dynamic or nested routes are needed, then it is necessary to activate the `matchit` feature. Without it, only simple and flat routes will work. +If dynamic or nested routes are needed, then please activate the `matchit` feature. Without it, only simple and flat routes will work. -To use this functionality, it necessary to activate the `http-server-framework` feature. +To use this functionality, it is necessary to activate the `http-server-framework` feature. ![HTTP/2 Benchmarks](https://i.imgur.com/lUOX3iM.png) diff --git a/wtx-docs/src/http2/README.md b/wtx-docs/src/http2/README.md index 80401b95..d74def63 100644 --- a/wtx-docs/src/http2/README.md +++ b/wtx-docs/src/http2/README.md @@ -4,7 +4,7 @@ Implementation of [RFC7541](https://datatracker.ietf.org/doc/html/rfc7541) and [ Passes the `hpack-test-case` and the `h2spec` test suites. Due to official deprecation, prioritization is not supported and due to the lack of third-party support, server-push is also not supported. -To use this functionality, it necessary to activate the `http2` feature. +To use this functionality, it is necessary to activate the `http2` feature. ## Client Example diff --git a/wtx-docs/src/pool/README.md b/wtx-docs/src/pool/README.md index 7a2fa7f0..c2443f67 100644 --- a/wtx-docs/src/pool/README.md +++ b/wtx-docs/src/pool/README.md @@ -4,7 +4,7 @@ An asynchronous pool of arbitrary objects where each element is dynamically crea Can also be used for database connections, which is quite handy because it enhances the performance of executing commands and alleviates the use of hardware resources. -To use this functionality, it necessary to activate the `pool` feature. +To use this functionality, it is necessary to activate the `pool` feature. ## Example diff --git a/wtx-docs/src/web-socket-over-http2/README.md b/wtx-docs/src/web-socket-over-http2/README.md index acc3a2dc..3f20bbe4 100644 --- a/wtx-docs/src/web-socket-over-http2/README.md +++ b/wtx-docs/src/web-socket-over-http2/README.md @@ -7,9 +7,9 @@ While HTTP/2 inherently supports full-duplex communication, web browsers typical 1. Servers can efficiently handle multiple concurrent streams within a single TCP connection 2. Client applications can continue using existing WebSocket APIs without modification -For this particular scenario, the `no-masking` parameter defined in https://datatracker.ietf.org/doc/html/draft-damjanovic-websockets-nomasking-02 is also supported. +For this particular scenario, the `no-masking` parameter defined in is also supported. -To use this functionality, it necessary to activate the `http2` and `web-socket` features. +To use this functionality, it is necessary to activate the `http2` and `web-socket` features. ## Example diff --git a/wtx-docs/src/web-socket/README.md b/wtx-docs/src/web-socket/README.md index f8cd7fda..3386cf75 100644 --- a/wtx-docs/src/web-socket/README.md +++ b/wtx-docs/src/web-socket/README.md @@ -2,7 +2,7 @@ Implementation of [RFC6455](https://datatracker.ietf.org/doc/html/rfc6455) and [RFC7692](https://datatracker.ietf.org/doc/html/rfc7692). WebSocket is a communication protocol that enables full-duplex communication between a client (typically a web browser) and a server over a single TCP connection. Unlike traditional HTTP, which is request-response based, WebSocket allows real-time data exchange without the need for polling. -To use this functionality, it necessary to activate the `web-socket` feature. +To use this functionality, it is necessary to activate the `web-socket` feature. ![WebSocket Benchmark](https://i.imgur.com/Iv2WzJV.jpg) @@ -19,9 +19,9 @@ To get the most performance possible, try compiling your program with `RUSTFLAGS ## No masking -Although not officially endorsed, the `no-masking` parameter described at https://datatracker.ietf.org/doc/html/draft-damjanovic-websockets-nomasking-02 is supported to increase performance. If such a thing is not desirable, please make sure to check the handshake parameters to avoid accidental scenarios. +Although not officially endorsed, the `no-masking` parameter described at is supported to increase performance. If such a thing is not desirable, please make sure to check the handshake parameters to avoid accidental scenarios. -To make everything work as intended it is necessary that both parts, client and server, implement this feature. For example, web browser won't stop masking frames. +To make everything work as intended both partys, client and server, need to implement this feature. For example, web browser won't stop masking frames. ## Client Example diff --git a/wtx-instances/generic-examples/grpc-server.rs b/wtx-instances/generic-examples/grpc-server.rs index aff657f7..df24a9ed 100644 --- a/wtx-instances/generic-examples/grpc-server.rs +++ b/wtx-instances/generic-examples/grpc-server.rs @@ -8,7 +8,7 @@ extern crate wtx_instances; use std::borrow::Cow; use wtx::{ data_transformation::dnsn::QuickProtobuf, - grpc::{GrpcManager, GrpcResMiddleware}, + grpc::{GrpcManager, GrpcMiddleware}, http::{ server_framework::{post, Router, ServerFrameworkBuilder, State}, ReqResBuffer, StatusCode, @@ -21,8 +21,7 @@ use wtx_instances::grpc_bindings::wtx::{GenericRequest, GenericResponse}; async fn main() -> wtx::Result<()> { let router = Router::new( wtx::paths!(("wtx.GenericService/generic_method", post(wtx_generic_service_generic_method))), - (), - GrpcResMiddleware, + GrpcMiddleware, )?; ServerFrameworkBuilder::new(router) .with_req_aux(|| QuickProtobuf::default()) diff --git a/wtx-instances/http-server-framework-examples/http-server-framework-cors.rs b/wtx-instances/http-server-framework-examples/http-server-framework-cors.rs index 713f7eec..c2c506be 100644 --- a/wtx-instances/http-server-framework-examples/http-server-framework-cors.rs +++ b/wtx-instances/http-server-framework-examples/http-server-framework-cors.rs @@ -7,7 +7,7 @@ use wtx::{ #[tokio::main] async fn main() -> wtx::Result<()> { - let router = Router::new(wtx::paths!(("/hello", get(hello))), (), CorsMiddleware::permissive())?; + let router = Router::new(wtx::paths!(("/hello", get(hello))), CorsMiddleware::permissive())?; ServerFrameworkBuilder::new(router) .without_aux() .listen_tokio("0.0.0.0:9000", Xorshift64::from(simple_seed()), |error: wtx::Error| { diff --git a/wtx-instances/http-server-framework-examples/http-server-framework-session.rs b/wtx-instances/http-server-framework-examples/http-server-framework-session.rs index 5d8c7e77..c9c5e74a 100644 --- a/wtx-instances/http-server-framework-examples/http-server-framework-session.rs +++ b/wtx-instances/http-server-framework-examples/http-server-framework-session.rs @@ -7,17 +7,18 @@ //! CREATE TABLE "user" ( //! id INT NOT NULL PRIMARY KEY, //! email VARCHAR(128) NOT NULL, +//! first_name VARCHAR(32) NOT NULL, //! password BYTEA NOT NULL, -//! salt BYTEA NOT NULL +//! salt CHAR(32) NOT NULL //! ); //! ALTER TABLE "user" ADD CONSTRAINT user__email__uq UNIQUE (email); //! -//! CREATE TABLE session ( +//! CREATE TABLE "session" ( //! id BYTEA NOT NULL PRIMARY KEY, //! user_id INT NOT NULL, //! expires_at TIMESTAMPTZ NOT NULL //! ); -//! ALTER TABLE session ADD CONSTRAINT session__user__fk FOREIGN KEY (user_id) REFERENCES "user" (id); +//! ALTER TABLE "session" ADD CONSTRAINT session__user__fk FOREIGN KEY (user_id) REFERENCES "user" (id); //! ``` use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -28,7 +29,7 @@ use wtx::{ server_framework::{get, post, Router, ServerFrameworkBuilder, State, StateClean}, ReqResBuffer, ReqResData, SessionDecoder, SessionEnforcer, SessionTokio, StatusCode, }, - misc::argon2_pwd, + misc::{argon2_pwd, Vector}, pool::{PostgresRM, SimplePoolTokio}, }; @@ -38,14 +39,13 @@ type Session = SessionTokio; #[tokio::main] async fn main() -> wtx::Result<()> { - let router = Router::new( - wtx::paths!(("/login", post(login)), ("/logout", get(logout)),), - (SessionDecoder::new(), SessionEnforcer::new(["/admin"])), - (), - )?; let pool = Pool::new(4, PostgresRM::tokio("postgres://USER:PASSWORD@localhost/DB_NAME".into())); let mut rng = ChaCha20Rng::from_entropy(); let (expired_sessions, session) = Session::builder(pool).build_generating_key(&mut rng); + let router = Router::new( + wtx::paths!(("/login", post(login)), ("/logout", get(logout)),), + (SessionDecoder::new(session.clone()), SessionEnforcer::new(["/admin"], session.clone())), + )?; tokio::spawn(async move { if let Err(err) = expired_sessions.await { eprintln!("{err}"); @@ -61,33 +61,35 @@ async fn main() -> wtx::Result<()> { #[inline] async fn login(state: State<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result { - let (session, rng) = state.conn_aux; - if session.content.lock().await.state().is_some() { - session.delete_session_cookie(&mut state.req.rrd).await?; + let (Session { manager, store }, rng) = state.conn_aux; + if manager.inner.lock().await.state().is_some() { + manager.delete_session_cookie(&mut state.req.rrd, store).await?; return Ok(StatusCode::Forbidden); } let user: UserLoginReq<'_> = serde_json::from_slice(state.req.rrd.body())?; - let mut executor_guard = session.store.get().await?; - let record = executor_guard - .fetch_with_stmt("SELECT id,password,salt FROM user WHERE email = $1", (user.email,)) + let mut guard = store.get().await?; + let record = guard + .fetch_with_stmt("SELECT id,first_name,password,salt FROM user WHERE email = $1", (user.email,)) .await?; let id = record.decode::<_, u32>(0)?; - let password_db = record.decode::<_, &[u8]>(1)?; - let salt = record.decode::<_, &[u8]>(2)?; - let password_req = argon2_pwd(user.password.as_bytes(), salt)?; + let first_name = record.decode::<_, &str>(1)?; + let pw_db = record.decode::<_, &[u8]>(2)?; + let salt = record.decode::<_, &str>(3)?; + let pw_req = argon2_pwd::<32>(&mut Vector::new(), user.password.as_bytes(), salt.as_bytes())?; state.req.rrd.clear(); - if password_db != &password_req { + if pw_db != &pw_req { return Ok(StatusCode::Unauthorized); } - drop(executor_guard); - session.set_session_cookie(id, rng, &mut state.req.rrd).await?; - serde_json::to_writer(&mut state.req.rrd.body, &UserLoginRes { id })?; + serde_json::to_writer(&mut state.req.rrd.body, &UserLoginRes { id, name: first_name })?; + drop(guard); + manager.set_session_cookie(id, rng, &mut state.req.rrd, store).await?; Ok(StatusCode::Ok) } #[inline] async fn logout(state: StateClean<'_, ConnAux, (), ReqResBuffer>) -> wtx::Result { - state.conn_aux.0.delete_session_cookie(&mut state.req.rrd).await?; + let (Session { manager, store }, _) = state.conn_aux; + manager.delete_session_cookie(&mut state.req.rrd, store).await?; Ok(StatusCode::Ok) } @@ -98,6 +100,7 @@ struct UserLoginReq<'req> { } #[derive(Debug, serde::Serialize)] -struct UserLoginRes { +struct UserLoginRes<'se> { id: u32, + name: &'se str, } diff --git a/wtx-instances/http-server-framework-examples/http-server-framework.rs b/wtx-instances/http-server-framework-examples/http-server-framework.rs index f65c05f9..e2aa1ecc 100644 --- a/wtx-instances/http-server-framework-examples/http-server-framework.rs +++ b/wtx-instances/http-server-framework-examples/http-server-framework.rs @@ -10,17 +10,17 @@ extern crate tokio; extern crate wtx; extern crate wtx_instances; -use core::fmt::Write; +use core::{fmt::Write, ops::ControlFlow}; use tokio::net::TcpStream; use wtx::{ database::{Executor, Record}, http::{ server_framework::{ - get, post, PathOwned, Router, SerdeJson, ServerFrameworkBuilder, StateClean, + get, post, Middleware, PathOwned, Router, SerdeJson, ServerFrameworkBuilder, StateClean, }, ReqResBuffer, Request, Response, StatusCode, }, - misc::{simple_seed, FnFutWrapper, Xorshift64}, + misc::{simple_seed, Xorshift64}, pool::{PostgresRM, SimplePoolTokio}, }; @@ -33,11 +33,7 @@ async fn main() -> wtx::Result<()> { ("/json", post(json)), ( "/say", - Router::new( - wtx::paths!(("/hello", get(hello)), ("/world", get(world))), - FnFutWrapper::from(request_middleware), - FnFutWrapper::from(response_middleware), - )?, + Router::new(wtx::paths!(("/hello", get(hello)), ("/world", get(world))), CustomMiddleware,)?, ), ))?; let rm = PostgresRM::tokio("postgres://USER:PASSWORD@localhost/DB_NAME".into()); @@ -80,24 +76,39 @@ async fn json(_: SerdeJson) -> wtx::Result, -) -> wtx::Result<()> { - println!("Before response"); - Ok(()) +async fn world() -> &'static str { + "world" } -async fn response_middleware( - _: &mut (), - _: &mut Pool, - _: Response<&mut ReqResBuffer>, -) -> wtx::Result<()> { - println!("After response"); - Ok(()) -} +struct CustomMiddleware; -async fn world() -> &'static str { - "world" +impl Middleware<(), wtx::Error, Pool> for CustomMiddleware { + type Aux = (); + + #[inline] + fn aux(&self) -> Self::Aux { + () + } + + async fn req( + &self, + _: &mut (), + _: &mut Self::Aux, + _: &mut Request, + _: &mut Pool, + ) -> wtx::Result> { + println!("Inspecting request"); + Ok(ControlFlow::Continue(())) + } + + async fn res( + &self, + _: &mut (), + _: &mut Self::Aux, + _: Response<&mut ReqResBuffer>, + _: &mut Pool, + ) -> wtx::Result> { + println!("Inspecting response"); + Ok(ControlFlow::Continue(())) + } } diff --git a/wtx-instances/src/bin/h2spec-low-server.rs b/wtx-instances/src/bin/h2spec-low-server.rs index cbc9c158..4649fbef 100644 --- a/wtx-instances/src/bin/h2spec-low-server.rs +++ b/wtx-instances/src/bin/h2spec-low-server.rs @@ -2,8 +2,7 @@ #![expect(clippy::print_stderr, reason = "internal")] -use std::mem; - +use core::mem; use tokio::net::TcpListener; use wtx::{ http::{ReqResBuffer, StatusCode}, diff --git a/wtx-instances/src/lib.rs b/wtx-instances/src/lib.rs index 2f7b5986..8151c769 100644 --- a/wtx-instances/src/lib.rs +++ b/wtx-instances/src/lib.rs @@ -39,11 +39,13 @@ pub static ROOT_CA: &[u8] = include_bytes!("../../.certs/root-ca.crt"); pub async fn executor_postgres( uri_str: &str, ) -> wtx::Result> { + use std::usize; + let uri = Uri::new(uri_str); let mut rng = Xorshift64::from(simple_seed()); Executor::connect( &Config::from_uri(&uri)?, - ExecutorBuffer::with_default_params(&mut rng)?, + ExecutorBuffer::new(usize::MAX, &mut rng), &mut rng, TcpStream::connect(uri.hostname_with_implied_port()).await?, ) diff --git a/wtx-ui/src/schema_manager.rs b/wtx-ui/src/schema_manager.rs index 7ecbb2db..63212ec9 100644 --- a/wtx-ui/src/schema_manager.rs +++ b/wtx-ui/src/schema_manager.rs @@ -24,7 +24,7 @@ pub(crate) async fn schema_manager(sm: SchemaManager) -> wtx::Result<()> { let mut rng = Xorshift64::from(simple_seed()); let executor = Executor::connect( &Config::from_uri(&uri)?, - ExecutorBuffer::with_default_params(&mut rng)?, + ExecutorBuffer::new(usize::MAX, &mut rng), &mut rng, TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(wtx::Error::from)?, ) diff --git a/wtx/Cargo.toml b/wtx/Cargo.toml index 565c5588..8fa5b6e4 100644 --- a/wtx/Cargo.toml +++ b/wtx/Cargo.toml @@ -1,5 +1,5 @@ [dependencies] -aes-gcm = { default-features = false, optional = true, version = "0.10" } +aes-gcm = { default-features = false, optional = true, version = "0.11.0-pre.2" } arbitrary = { default-features = false, features = ["derive_arbitrary"], optional = true, version = "1.0" } argon2 = { default-features = false, optional = true, version = "0.5" } base64 = { default-features = false, features = ["alloc"], optional = true, version = "0.22" } @@ -16,7 +16,6 @@ hmac = { default-features = false, optional = true, version = "0.12" } httparse = { default-features = false, optional = true, version = "1.0" } matchit = { default-features = false, optional = true, version = "0.8" } memchr = { default-features = false, optional = true, version = "2.0" } -proptest = { default-features = false, features = ["alloc"], optional = true, version = "1.0" } quick-protobuf = { default-features = false, optional = true, version = "0.8" } rand_chacha = { default-features = false, optional = true, version = "0.3" } rand_core = { default-features = false, optional = true, version = "0.6" } @@ -29,7 +28,6 @@ serde_json = { default-features = false, features = ["alloc"], optional = true, sha1 = { default-features = false, optional = true, version = "0.10" } sha2 = { default-features = false, optional = true, version = "0.10" } simdutf8 = { default-features = false, features = ["aarch64_neon"], optional = true, version = "0.1" } -test-strategy = { default-features = false, optional = true, version = "0.4" } tokio = { default-features = false, features = ["io-util", "net", "rt", "sync", "time"], optional = true, version = "1.0" } tokio-rustls = { default-features = false, features = ["ring"], optional = true, version = "0.26" } tracing = { default-features = false, features = ["attributes"], optional = true, version = "0.1" } @@ -70,7 +68,6 @@ nightly = ["hashbrown?/nightly"] optimization = ["memchr", "simdutf8"] pool = [] postgres = ["base64", "crypto-common", "database", "digest", "foldhash", "hashbrown", "hmac", "sha2"] -proptest = ["dep:proptest"] quick-protobuf = ["dep:quick-protobuf", "std"] rand_chacha = ["dep:rand_chacha", "dep:rand_core"] ring = ["dep:ring"] @@ -82,8 +79,7 @@ serde_json = ["serde", "dep:serde_json", "std"] sha1 = ["dep:sha1"] sha2 = ["dep:sha2"] simdutf8 = ["dep:simdutf8"] -std = ["aes-gcm?/std", "argon2?/std", "base64?/std", "borsh?/std", "chrono?/std", "cl-aux?/std", "crypto-common?/std", "digest?/std", "fastrand?/std", "foldhash?/std", "hmac?/std", "httparse?/std", "memchr?/std", "proptest?/std", "quick-protobuf?/std", "rand_chacha?/std", "rand_core?/std", "ring?/std", "rust_decimal?/std", "rustls-pemfile?/std", "rustls-pki-types?/std", "serde?/std", "serde_json?/std", "sha1?/std", "sha2?/std", "simdutf8?/std", "tracing?/std", "tracing-subscriber?/std"] -test-strategy = ["dep:test-strategy", "proptest", "std"] +std = ["aes-gcm?/std", "argon2?/std", "base64?/std", "borsh?/std", "chrono?/std", "cl-aux?/std", "crypto-common?/std", "digest?/std", "fastrand?/std", "foldhash?/std", "hmac?/std", "httparse?/std", "memchr?/std", "quick-protobuf?/std", "rand_chacha?/std", "rand_core?/std", "ring?/std", "rust_decimal?/std", "rustls-pemfile?/std", "rustls-pki-types?/std", "serde?/std", "serde_json?/std", "sha1?/std", "sha2?/std", "simdutf8?/std", "tracing?/std", "tracing-subscriber?/std"] tokio = ["std", "dep:tokio"] tokio-rustls = ["ring", "dep:rustls-pemfile", "dep:rustls-pki-types", "tokio", "dep:tokio-rustls"] tracing = ["dep:tracing"] @@ -96,9 +92,11 @@ x509-certificate = ["dep:x509-certificate"] _async-tests = ["tokio/macros", "tokio/net", "tokio/rt-multi-thread", "tokio/time"] _bench = [] _integration-tests = ["serde_json?/raw_value"] -_proptest = ["proptest/std", "rust_decimal?/proptest", "std", "test-strategy"] _tracing-tree = ["tracing", "tracing-subscriber", "dep:tracing-tree"] +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(kani)'] } + [package] authors = ["Caio Fernandes "] categories = ["asynchronous", "database", "network-programming", "no-std", "web-programming"] diff --git a/wtx/src/client_api_framework/network/transport/mock.rs b/wtx/src/client_api_framework/network/transport/mock.rs index da3add93..17d3c28d 100644 --- a/wtx/src/client_api_framework/network/transport/mock.rs +++ b/wtx/src/client_api_framework/network/transport/mock.rs @@ -10,12 +10,9 @@ use crate::{ pkg::{Package, PkgsAux}, Api, ClientApiFrameworkError, }, - misc::{Lease, Vector}, -}; -use alloc::{ - borrow::{Cow, ToOwned}, - collections::VecDeque, + misc::{Deque, Lease, Vector}, }; +use alloc::borrow::{Cow, ToOwned}; use core::{fmt::Debug, marker::PhantomData, ops::Range}; /// For API's that send and received raw bytes. @@ -47,7 +44,7 @@ where asserted: usize, phantom: PhantomData, requests: Vector>, - responses: VecDeque>, + responses: Deque>, } impl Mock @@ -81,7 +78,7 @@ where /// Stores `res` into the inner response storage #[inline] pub fn push_response(&mut self, res: Cow<'static, T>) { - self.responses.push_back(res); + self.responses.push_back(res).unwrap(); } fn pop_response(&mut self) -> crate::Result> { @@ -143,6 +140,6 @@ where { #[inline] fn default() -> Self { - Self { asserted: 0, phantom: PhantomData, requests: Vector::new(), responses: VecDeque::new() } + Self { asserted: 0, phantom: PhantomData, requests: Vector::new(), responses: Deque::new() } } } diff --git a/wtx/src/database.rs b/wtx/src/database.rs index e9384055..65b0279c 100644 --- a/wtx/src/database.rs +++ b/wtx/src/database.rs @@ -38,6 +38,8 @@ pub use transaction_manager::TransactionManager; pub use typed::Typed; pub use value_ident::ValueIdent; +/// The default value for the maximum number of cached statements +pub const DEFAULT_MAX_STMTS: usize = 128; /// Default environment variable name for the database URL pub const DEFAULT_URI_VAR: &str = "DATABASE_URI"; diff --git a/wtx/src/database/client/postgres.rs b/wtx/src/database/client/postgres.rs index cf7cd758..f761d180 100644 --- a/wtx/src/database/client/postgres.rs +++ b/wtx/src/database/client/postgres.rs @@ -78,18 +78,18 @@ mod tests { use crate::database::client::postgres::{statements::Column, Ty}; pub(crate) fn column0() -> Column { - Column { name: "a".try_into().unwrap(), ty: Ty::VarcharArray } + Column::new("a".try_into().unwrap(), Ty::VarcharArray) } pub(crate) fn column1() -> Column { - Column { name: "b".try_into().unwrap(), ty: Ty::Int8 } + Column::new("b".try_into().unwrap(), Ty::Int8) } pub(crate) fn column2() -> Column { - Column { name: "c".try_into().unwrap(), ty: Ty::Char } + Column::new("c".try_into().unwrap(), Ty::Char) } pub(crate) fn column3() -> Column { - Column { name: "d".try_into().unwrap(), ty: Ty::Date } + Column::new("d".try_into().unwrap(), Ty::Date) } } diff --git a/wtx/src/database/client/postgres/authentication.rs b/wtx/src/database/client/postgres/authentication.rs index 3c85b9ed..d6bfe067 100644 --- a/wtx/src/database/client/postgres/authentication.rs +++ b/wtx/src/database/client/postgres/authentication.rs @@ -14,6 +14,8 @@ pub(crate) enum Authentication<'bytes> { impl<'bytes> TryFrom<&'bytes [u8]> for Authentication<'bytes> { type Error = crate::Error; + + #[inline] fn try_from(bytes: &'bytes [u8]) -> Result { let (n, rest) = if let [a, b, c, d, rest @ ..] = bytes { (u32::from_be_bytes([*a, *b, *c, *d]), rest) diff --git a/wtx/src/database/client/postgres/config.rs b/wtx/src/database/client/postgres/config.rs index 2130dec4..a2f393bd 100644 --- a/wtx/src/database/client/postgres/config.rs +++ b/wtx/src/database/client/postgres/config.rs @@ -48,6 +48,7 @@ impl<'data> Config<'data> { Ok(this) } + #[inline] fn set_param(&mut self, key: &str, value: &'data str) -> crate::Result<()> { match key { "application_name" => { diff --git a/wtx/src/database/client/postgres/executor.rs b/wtx/src/database/client/postgres/executor.rs index 1dab7b06..a2835805 100644 --- a/wtx/src/database/client/postgres/executor.rs +++ b/wtx/src/database/client/postgres/executor.rs @@ -143,16 +143,15 @@ where RV: RecordValues, SC: StmtCmd, { - let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = self.eb.lease_mut().parts_mut(); + let Self { cs, eb, phantom: _, stream } = self; + let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = eb.lease_mut().parts_mut(); ExecutorBuffer::clear_cmd_buffers(nb, rb, vb); let mut rows = 0; - let mut fwsc = - FetchWithStmtCommons { cs: &mut self.cs, rb, stream: &mut self.stream, tys: &[] }; - let (_, stmt_id_str, stmt) = - Self::write_send_await_stmt_prot(&mut fwsc, nb, sc, stmts, vb).await?; - Self::write_send_await_stmt_initial(&mut fwsc, nb, rv, &stmt, &stmt_id_str).await?; + let mut fwsc = FetchWithStmtCommons { cs, stream, tys: &[] }; + let (_, stmt_id, stmt) = Self::write_send_await_stmt_prot(&mut fwsc, nb, sc, stmts, vb).await?; + Self::write_send_await_stmt_initial(&mut fwsc, nb, rv, &stmt, &stmt_id).await?; loop { - let msg = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; + let msg = Self::fetch_msg_from_stream(cs, nb, stream).await?; match msg.ty { MessageTy::CommandComplete(local_rows) => { rows = local_rows; @@ -179,11 +178,11 @@ where RV: RecordValues, SC: StmtCmd, { - let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = self.eb.lease_mut().parts_mut(); - let fwsc = - &mut FetchWithStmtCommons { cs: &mut self.cs, rb, stream: &mut self.stream, tys: &[] }; - let (_, stmt_id_str, stmt) = Self::write_send_await_stmt_prot(fwsc, nb, sc, stmts, vb).await?; - Self::write_send_await_fetch_with_stmt_wo_prot(fwsc, nb, rv, stmt, &stmt_id_str, vb).await + let Self { cs, eb, phantom: _, stream } = self; + let ExecutorBufferPartsMut { nb, stmts, vb, .. } = eb.lease_mut().parts_mut(); + let mut fwsc = FetchWithStmtCommons { cs, stream, tys: &[] }; + let (_, stmt_id, stmt) = Self::write_send_await_stmt_prot(&mut fwsc, nb, sc, stmts, vb).await?; + Self::write_send_await_fetch_with_stmt_wo_prot(&mut fwsc, nb, rv, stmt, &stmt_id, vb).await } #[inline] @@ -197,29 +196,28 @@ where RV: RecordValues, SC: StmtCmd, { - let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = self.eb.lease_mut().parts_mut(); + let Self { cs, eb, phantom: _, stream } = self; + let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = eb.lease_mut().parts_mut(); ExecutorBuffer::clear_cmd_buffers(nb, rb, vb); - let mut fwsc = - FetchWithStmtCommons { cs: &mut self.cs, rb, stream: &mut self.stream, tys: &[] }; - let (_, stmt_id_str, stmt) = - Self::write_send_await_stmt_prot(&mut fwsc, nb, sc, stmts, vb).await?; - Self::write_send_await_stmt_initial(&mut fwsc, nb, rv, &stmt, &stmt_id_str).await?; + let mut fwsc = FetchWithStmtCommons { cs, stream, tys: &[] }; + let (_, stmt_id, stmt) = Self::write_send_await_stmt_prot(&mut fwsc, nb, sc, stmts, vb).await?; + Self::write_send_await_stmt_initial(&mut fwsc, nb, rv, &stmt, &stmt_id).await?; let begin = nb._current_end_idx(); let begin_data = nb._current_end_idx().wrapping_add(7); loop { - let msg = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; + let msg = Self::fetch_msg_from_stream(cs, nb, stream).await?; match msg.ty { + MessageTy::CommandComplete(_) | MessageTy::EmptyQueryResponse => {} MessageTy::DataRow(len) => { let bytes = nb._buffer().get(begin_data..nb._current_end_idx()).unwrap_or_default(); let range_begin = nb._antecedent_end_idx().wrapping_sub(begin); let range_end = nb._current_end_idx().wrapping_sub(begin_data); cb(&Record::parse(bytes, range_begin..range_end, stmt.clone(), vb, len)?)?; - fwsc.rb.push(vb.len()).map_err(Into::into)?; + rb.push(vb.len()).map_err(Into::into)?; } MessageTy::ReadyForQuery => { break; } - MessageTy::CommandComplete(_) | MessageTy::EmptyQueryResponse => {} _ => { return Err(<_>::from( PostgresError::UnexpectedDatabaseMessage { received: msg.tag }.into(), @@ -241,19 +239,11 @@ where #[inline] async fn prepare(&mut self, cmd: &str) -> Result { - let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = self.eb.lease_mut().parts_mut(); + let Self { cs, eb, phantom: _, stream } = self; + let ExecutorBufferPartsMut { nb, rb, stmts, vb, .. } = eb.lease_mut().parts_mut(); ExecutorBuffer::clear_cmd_buffers(nb, rb, vb); - Ok( - Self::write_send_await_stmt_prot( - &mut FetchWithStmtCommons { cs: &mut self.cs, rb, stream: &mut self.stream, tys: &[] }, - nb, - cmd, - stmts, - vb, - ) - .await? - .0, - ) + let mut fwsc = FetchWithStmtCommons { cs, stream, tys: &[] }; + Ok(Self::write_send_await_stmt_prot(&mut fwsc, nb, cmd, stmts, vb).await?.0) } #[inline] diff --git a/wtx/src/database/client/postgres/executor/authentication.rs b/wtx/src/database/client/postgres/executor/authentication.rs index 85123c7e..8ce9c7b4 100644 --- a/wtx/src/database/client/postgres/executor/authentication.rs +++ b/wtx/src/database/client/postgres/executor/authentication.rs @@ -24,12 +24,15 @@ where EB: LeaseMut, S: Stream, { - /// Ascending sequence of extra parameters received from the database. + /// Connection parameters + /// + /// Extra parameters received from the database. #[inline] - pub fn conn_params(&self) -> &[(Identifier, Identifier)] { - &self.eb.lease().conn_params + pub fn conn_params(&self) -> impl Iterator { + self.eb.lease().cp.iter() } + #[inline] pub(crate) async fn manage_authentication( &mut self, config: &Config<'_>, @@ -109,18 +112,18 @@ where } } + #[inline] pub(crate) async fn read_after_authentication_data(&mut self) -> crate::Result<()> { self.eb.lease_mut().nb._reserve(2048)?; loop { - let ExecutorBufferPartsMut { conn_params, nb, .. } = self.eb.lease_mut().parts_mut(); + let ExecutorBufferPartsMut { cp, nb, .. } = self.eb.lease_mut().parts_mut(); let msg = Self::fetch_msg_from_stream(&mut self.cs, nb, &mut self.stream).await?; match msg.ty { MessageTy::BackendKeyData => {} MessageTy::ParameterStatus(name, value) => { - conn_params.insert( - conn_params.partition_point(|(local_name, _)| local_name.as_bytes() < name), - (from_utf8_basic(name)?.try_into()?, from_utf8_basic(value)?.try_into()?), - )?; + let name = from_utf8_basic(name)?.try_into()?; + let value = from_utf8_basic(value)?.try_into()?; + let _ = cp.insert(name, value); } MessageTy::ReadyForQuery => return Ok(()), _ => { @@ -132,6 +135,7 @@ where // The 'null' case of `tls_server_end_point` is already handled by `method_header`, as such, // it is fine to use an empty slice. + #[inline] async fn sasl_authenticate( config: &Config<'_>, cs: &mut ConnectionState, @@ -212,6 +216,7 @@ where } } +#[inline] fn nonce(rng: &mut RNG) -> [u8; 24] where RNG: Rng, @@ -233,6 +238,7 @@ where rslt } +#[inline] fn salted_password(len: u32, salt: &[u8], str: &str) -> crate::Result<[u8; 32]> { let mut array: [u8; 32] = { let mut hmac = Hmac::::new_from_slice(str.as_bytes())?; diff --git a/wtx/src/database/client/postgres/executor/commons.rs b/wtx/src/database/client/postgres/executor/commons.rs index 9b915cf8..06f413c3 100644 --- a/wtx/src/database/client/postgres/executor/commons.rs +++ b/wtx/src/database/client/postgres/executor/commons.rs @@ -1,11 +1,7 @@ -use crate::{ - database::client::postgres::Ty, - misc::{ConnectionState, Vector}, -}; +use crate::{database::client::postgres::Ty, misc::ConnectionState}; pub(crate) struct FetchWithStmtCommons<'others, S> { pub(crate) cs: &'others mut ConnectionState, - pub(crate) rb: &'others mut Vector, pub(crate) stream: &'others mut S, /// Pre-specified types pub(crate) tys: &'others [Ty], diff --git a/wtx/src/database/client/postgres/executor/fetch.rs b/wtx/src/database/client/postgres/executor/fetch.rs index 3b7d1684..0cf9bb2d 100644 --- a/wtx/src/database/client/postgres/executor/fetch.rs +++ b/wtx/src/database/client/postgres/executor/fetch.rs @@ -8,7 +8,10 @@ use crate::{ }, RecordValues, }, - misc::{ConnectionState, LeaseMut, PartitionedFilledBuffer, Stream, Vector, _read_until}, + misc::{ + ConnectionState, LeaseMut, PartitionedFilledBuffer, Stream, Usize, Vector, _read_header, + _read_payload, + }, }; use core::ops::Range; @@ -17,6 +20,7 @@ where EB: LeaseMut, S: Stream, { + #[inline] pub(crate) async fn write_send_await_fetch_with_stmt_wo_prot<'any, RV>( fwsc: &mut FetchWithStmtCommons<'_, S>, nb: &'any mut PartitionedFilledBuffer, @@ -34,11 +38,11 @@ where loop { let msg = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; match msg.ty { + MessageTy::CommandComplete(_) | MessageTy::EmptyQueryResponse => {} MessageTy::DataRow(len) => { data_row_msg_range = Some((len, nb._current_range())); } MessageTy::ReadyForQuery => break, - MessageTy::CommandComplete(_) | MessageTy::EmptyQueryResponse => {} _ => { return Err(E::from( PostgresError::UnexpectedDatabaseMessage { received: msg.tag }.into(), @@ -56,6 +60,7 @@ where } } + #[inline] pub(crate) async fn fetch_msg_from_stream<'nb>( cs: &mut ConnectionState, nb: &'nb mut PartitionedFilledBuffer, @@ -65,52 +70,25 @@ where Ok(Message { tag, ty: MessageTy::try_from((cs, nb._current()))? }) } - async fn fetch_one_header_from_stream( - nb: &mut PartitionedFilledBuffer, - read: &mut usize, - stream: &mut S, - ) -> crate::Result<(u8, usize)> { - let buffer = nb._following_rest_mut(); - let [mt_n, b, c, d, e] = _read_until::<5, S>(buffer, read, 0, stream).await?; - let len: usize = u32::from_be_bytes([b, c, d, e]).try_into()?; - Ok((mt_n, len.wrapping_add(1))) - } - + // | Ty | Len | Payload | + // | 1 | 4 | x | + // + // The value of `Len` is payload length plus 4, therefore, the frame length is `Len` plus 1. + #[inline] async fn fetch_one_msg_from_stream<'nb>( nb: &'nb mut PartitionedFilledBuffer, stream: &mut S, ) -> crate::Result { + nb._reserve(5)?; let mut read = nb._following_len(); - let (ty, len) = Self::fetch_one_header_from_stream(nb, &mut read, stream).await?; - Self::fetch_one_payload_from_stream(len, nb, &mut read, stream).await?; - let current_end_idx = nb._current_end_idx(); - nb._set_indices(current_end_idx, len, read.wrapping_sub(len))?; - Ok(ty) - } - - async fn fetch_one_payload_from_stream( - len: usize, - nb: &mut PartitionedFilledBuffer, - read: &mut usize, - stream: &mut S, - ) -> crate::Result<()> { - let mut is_payload_filled = false; - nb._reserve(len)?; - for _ in 0..=len { - if *read >= len { - is_payload_filled = true; - break; - } - *read = read.wrapping_add( - stream.read(nb._following_rest_mut().get_mut(*read..).unwrap_or_default()).await?, - ); - } - if !is_payload_filled { - return Err(crate::Error::UnexpectedBufferState); - } - Ok(()) + let buffer = nb._following_rest_mut(); + let [a, b, c, d, e] = _read_header::<0, 5, S>(buffer, &mut read, stream).await?; + let len = Usize::from(u32::from_be_bytes([b, c, d, e])).into_usize().wrapping_add(1); + _read_payload((0, len), nb, &mut read, stream).await?; + Ok(a) } + #[inline] async fn fetch_representative_msg_from_stream<'nb>( nb: &'nb mut PartitionedFilledBuffer, stream: &mut S, diff --git a/wtx/src/database/client/postgres/executor/prepare.rs b/wtx/src/database/client/postgres/executor/prepare.rs index 87822750..560ec40b 100644 --- a/wtx/src/database/client/postgres/executor/prepare.rs +++ b/wtx/src/database/client/postgres/executor/prepare.rs @@ -6,15 +6,13 @@ use crate::{ message::MessageTy, msg_field::MsgField, protocol::{bind, describe, execute, parse, sync}, - statements::{Column, PushRslt, Statement}, + statements::{Column, Statement, StatementsMisc}, ty::Ty, Executor, Postgres, PostgresError, Statements, }, RecordValues, StmtCmd, }, - misc::{ - ArrayString, FilledBufferWriter, LeaseMut, PartitionedFilledBuffer, Stream, _unreachable, - }, + misc::{ArrayString, FilledBufferWriter, LeaseMut, PartitionedFilledBuffer, Stream}, }; use core::ops::Range; @@ -24,6 +22,7 @@ where EB: LeaseMut, S: Stream, { + #[inline] pub(crate) async fn write_send_await_stmt_initial( fwsc: &mut FetchWithStmtCommons<'_, S>, nb: &mut PartitionedFilledBuffer, @@ -48,6 +47,7 @@ where Ok(()) } + #[inline] pub(crate) async fn write_send_await_stmt_prot<'stmts, SC>( fwsc: &mut FetchWithStmtCommons<'_, S>, nb: &mut PartitionedFilledBuffer, @@ -59,10 +59,11 @@ where SC: StmtCmd, { let stmt_hash = sc.hash(stmts.hasher_mut()); - let (stmt_id_str, mut builder) = match stmts.push(stmt_hash) { - PushRslt::Builder(builder) => (Self::stmt_id_str(stmt_hash)?, builder), - PushRslt::Stmt(stmt) => return Ok((stmt_hash, Self::stmt_id_str(stmt_hash)?, stmt)), - }; + let stmt_id_str = Self::stmt_id_str(stmt_hash)?; + if stmts.get_by_stmt_hash(stmt_hash).is_some() { + // FIXME(stable): Use `if let Some ...` with polonius + return Ok((stmt_hash, stmt_id_str, stmts.get_by_stmt_hash(stmt_hash).unwrap())); + } let stmt_cmd = sc.cmd().ok_or_else(|| E::from(PostgresError::UnknownStatementId.into()))?; @@ -80,45 +81,65 @@ where }; let msg1 = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; - let MessageTy::ParameterDescription(mut pd) = msg1.ty else { + let MessageTy::ParameterDescription(types_len, mut pd) = msg1.ty else { return Err(E::from(PostgresError::UnexpectedDatabaseMessage { received: msg1.tag }.into())); }; - while let [a, b, c, d, sub_data @ ..] = pd { - let id = u32::from_be_bytes([*a, *b, *c, *d]); - builder.push_param(Ty::Custom(id)); - pd = sub_data; + + let mut builder = stmts.builder(); + let _ = builder.expand(types_len.into())?; + + { + let elements = builder.inserted_elements(); + for idx in 0..types_len { + let element_opt = elements.get_mut(usize::from(idx)); + let ([a, b, c, d, sub_data @ ..], Some(element)) = (pd, element_opt) else { break }; + element.1 = Ty::Custom(u32::from_be_bytes([*a, *b, *c, *d])); + pd = sub_data; + } } let msg2 = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; - match msg2.ty { - MessageTy::NoData => {} - MessageTy::RowDescription(mut rd) => loop { - let (read, msg_field) = MsgField::parse(rd)?; - let ty = Ty::Custom(msg_field.type_oid); - builder.push_column(Column { name: msg_field.name.try_into().map_err(Into::into)?, ty }); - if let Some(elem @ [_not_empty, ..]) = rd.get(read..) { - rd = elem; - } else { - break; + let columns_len = match msg2.ty { + MessageTy::NoData => 0, + MessageTy::RowDescription(columns_len, mut rd) => { + if let Some(diff @ 1..=u16::MAX) = columns_len.checked_sub(types_len) { + let _ = builder.expand(diff.into())?; } - }, + let elements = builder.inserted_elements(); + for idx in 0..columns_len { + let (read, msg_field) = MsgField::parse(rd)?; + let ty = Ty::Custom(msg_field.type_oid); + let Some(element) = elements.get_mut(usize::from(idx)) else { + break; + }; + element.0 = Column::new(msg_field.name.try_into().map_err(Into::into)?, ty); + if let Some(elem @ [_not_empty, ..]) = rd.get(read..) { + rd = elem; + } else { + break; + } + } + columns_len + } _ => { return Err(E::from(PostgresError::UnexpectedDatabaseMessage { received: msg2.tag }.into())) } - } + }; let msg3 = Self::fetch_msg_from_stream(fwsc.cs, nb, fwsc.stream).await?; let MessageTy::ReadyForQuery = msg3.ty else { return Err(E::from(PostgresError::UnexpectedDatabaseMessage { received: msg3.tag }.into())); }; - if let Some(stmt) = builder.finish().get_by_stmt_hash(stmt_hash) { - Ok((stmt_hash, stmt_id_str, stmt)) - } else { - _unreachable() - } + let sm = StatementsMisc::new(columns_len.into(), types_len.into()); + let idx = builder.build(stmt_hash, sm)?; + let Some(stmt) = stmts.get_by_idx(idx) else { + return Err(crate::Error::ProgrammingError.into()); + }; + Ok((stmt_hash, stmt_id_str, stmt)) } + #[inline] fn stmt_id_str(stmt_hash: u64) -> crate::Result> { Ok(ArrayString::try_from(format_args!("s{stmt_hash}"))?) } diff --git a/wtx/src/database/client/postgres/executor/simple_query.rs b/wtx/src/database/client/postgres/executor/simple_query.rs index edbb25d6..e3da6691 100644 --- a/wtx/src/database/client/postgres/executor/simple_query.rs +++ b/wtx/src/database/client/postgres/executor/simple_query.rs @@ -11,6 +11,7 @@ where EB: LeaseMut, S: Stream, { + #[inline] pub(crate) async fn simple_query_execute( &mut self, cmd: &str, diff --git a/wtx/src/database/client/postgres/executor_buffer.rs b/wtx/src/database/client/postgres/executor_buffer.rs index 91c791ac..2a4c654e 100644 --- a/wtx/src/database/client/postgres/executor_buffer.rs +++ b/wtx/src/database/client/postgres/executor_buffer.rs @@ -1,32 +1,21 @@ use crate::{ - database::{ - client::postgres::{ty::Ty, Statements}, - Identifier, - }, + database::{client::postgres::Statements, Identifier}, misc::{Lease, LeaseMut, PartitionedFilledBuffer, Rng, Vector}, }; use core::ops::Range; use hashbrown::HashMap; -pub(crate) const DFLT_PARAMS_LEN: usize = 16; -pub(crate) const DFLT_RECORDS_LEN: usize = 32; -pub(crate) const DFLT_VALUES_LEN: usize = 16; - #[derive(Debug)] #[doc = _internal_buffer_doc!()] pub struct ExecutorBuffer { - /// Asynchronous parameters received from the database. - pub(crate) conn_params: Vector<(Identifier, Identifier)>, - /// Fetch type buffer - pub(crate) ftb: Vector<(usize, u32)>, + /// Connection parameters. + pub(crate) cp: HashMap, /// Network Buffer. pub(crate) nb: PartitionedFilledBuffer, /// Records Buffer. pub(crate) rb: Vector, /// Statements pub(crate) stmts: Statements, - /// Types buffer - pub(crate) tb: HashMap, /// Values Buffer. pub(crate) vb: Vector<(bool, Range)>, } @@ -34,67 +23,51 @@ pub struct ExecutorBuffer { impl ExecutorBuffer { /// With provided capacity. #[inline] - pub fn new( - (network_buffer_cap, records_buffer_cap, values_buffer_cap): (usize, usize, usize), - rng: &mut RNG, - max_queries: usize, - ) -> crate::Result + pub fn new(max_stmts: usize, rng: RNG) -> Self where RNG: Rng, { - Ok(Self { - conn_params: Vector::with_capacity(DFLT_PARAMS_LEN)?, - ftb: Vector::new(), - nb: PartitionedFilledBuffer::_with_capacity(network_buffer_cap)?, - rb: Vector::with_capacity(records_buffer_cap)?, - stmts: Statements::new(max_queries, rng), - tb: HashMap::new(), - vb: Vector::with_capacity(values_buffer_cap)?, - }) + Self { + cp: HashMap::new(), + nb: PartitionedFilledBuffer::new(), + rb: Vector::new(), + stmts: Statements::new(max_stmts, rng), + vb: Vector::new(), + } } /// With default capacity. #[inline] - pub fn with_default_params(rng: &mut RNG) -> crate::Result + pub fn with_capacity( + (columns_cap, network_buffer_cap, rows_cap, stmts_cap): (usize, usize, usize, usize), + max_stmts: usize, + rng: &mut RNG, + ) -> crate::Result where RNG: Rng, { Ok(Self { - conn_params: Vector::with_capacity(DFLT_PARAMS_LEN)?, - ftb: Vector::new(), - nb: PartitionedFilledBuffer::default(), - rb: Vector::with_capacity(DFLT_RECORDS_LEN)?, - stmts: Statements::with_default_params(rng), - tb: HashMap::new(), - vb: Vector::with_capacity(DFLT_VALUES_LEN)?, + cp: HashMap::with_capacity(4), + nb: PartitionedFilledBuffer::_with_capacity(network_buffer_cap)?, + rb: Vector::with_capacity(rows_cap)?, + stmts: Statements::with_capacity(columns_cap, max_stmts, rng, stmts_cap)?, + vb: Vector::with_capacity(rows_cap.saturating_mul(columns_cap))?, }) } - pub(crate) fn _empty() -> Self { - Self { - conn_params: Vector::new(), - ftb: Vector::new(), - nb: PartitionedFilledBuffer::new(), - rb: Vector::new(), - stmts: Statements::_empty(), - tb: HashMap::new(), - vb: Vector::new(), - } - } - /// Should be used in a new instance. + #[inline] pub(crate) fn clear(&mut self) { - let Self { conn_params, ftb, nb, rb, stmts, tb, vb } = self; - conn_params.clear(); - ftb.clear(); + let Self { cp, nb, rb, stmts, vb } = self; + cp.clear(); nb._clear(); rb.clear(); stmts.clear(); - tb.clear(); vb.clear(); } /// Should be called before executing commands. + #[inline] pub(crate) fn clear_cmd_buffers( nb: &mut PartitionedFilledBuffer, rb: &mut Vector, @@ -105,9 +78,10 @@ impl ExecutorBuffer { vb.clear(); } + #[inline] pub(crate) fn parts_mut(&mut self) -> ExecutorBufferPartsMut<'_> { ExecutorBufferPartsMut { - conn_params: &mut self.conn_params, + cp: &mut self.cp, nb: &mut self.nb, rb: &mut self.rb, stmts: &mut self.stmts, @@ -131,7 +105,7 @@ impl LeaseMut for ExecutorBuffer { } pub(crate) struct ExecutorBufferPartsMut<'eb> { - pub(crate) conn_params: &'eb mut Vector<(Identifier, Identifier)>, + pub(crate) cp: &'eb mut HashMap, pub(crate) nb: &'eb mut PartitionedFilledBuffer, pub(crate) rb: &'eb mut Vector, pub(crate) stmts: &'eb mut Statements, diff --git a/wtx/src/database/client/postgres/integration_tests.rs b/wtx/src/database/client/postgres/integration_tests.rs index 49f32e23..bbba3e08 100644 --- a/wtx/src/database/client/postgres/integration_tests.rs +++ b/wtx/src/database/client/postgres/integration_tests.rs @@ -20,7 +20,7 @@ async fn conn_scram_tls() { let mut rng = Xorshift64::from(simple_seed()); let _executor = Executor::::connect_encrypted( &Config::from_uri(&uri).unwrap(), - ExecutorBuffer::with_default_params(&mut rng).unwrap(), + ExecutorBuffer::new(usize::MAX, &mut rng), TcpStream::connect(uri.hostname_with_implied_port()).await.unwrap(), &mut rng, |stream| async { @@ -436,7 +436,7 @@ async fn executor() -> Executor { let mut rng = Xorshift64::from(simple_seed()); Executor::connect( &Config::from_uri(&uri).unwrap(), - ExecutorBuffer::with_default_params(&mut rng).unwrap(), + ExecutorBuffer::new(usize::MAX, &mut rng), &mut rng, TcpStream::connect(uri.hostname_with_implied_port()).await.unwrap(), ) diff --git a/wtx/src/database/client/postgres/message.rs b/wtx/src/database/client/postgres/message.rs index 3f95e629..b4abfe79 100644 --- a/wtx/src/database/client/postgres/message.rs +++ b/wtx/src/database/client/postgres/message.rs @@ -42,7 +42,7 @@ pub(crate) enum MessageTy<'bytes> { /// Notification response. NotificationResponse, /// Parameters of a query. - ParameterDescription(&'bytes [u8]), + ParameterDescription(u16, &'bytes [u8]), /// Parameter status report. ParameterStatus(&'bytes [u8], &'bytes [u8]), /// Parse request was successful. @@ -52,7 +52,7 @@ pub(crate) enum MessageTy<'bytes> { /// Backend is ready to process another query. ReadyForQuery, /// Single row data. - RowDescription(&'bytes [u8]), + RowDescription(u16, &'bytes [u8]), } impl<'bytes> TryFrom<(&mut ConnectionState, &'bytes [u8])> for MessageTy<'bytes> { @@ -100,13 +100,17 @@ impl<'bytes> TryFrom<(&mut ConnectionState, &'bytes [u8])> for MessageTy<'bytes> let (name, value) = rslt().ok_or(PostgresError::UnexpectedDatabaseMessageBytes)?; Self::ParameterStatus(name, value) } - [b'T', _, _, _, _, _a, _b, rest @ ..] => Self::RowDescription(rest), + [b'T', _, _, _, _, a, b, rest @ ..] => { + Self::RowDescription(u16::from_be_bytes([*a, *b]), rest) + } [b'Z', _, _, _, _, _] => Self::ReadyForQuery, [b'c', ..] => Self::CopyDone, [b'd', ..] => Self::CopyData, [b'n', ..] => Self::NoData, [b's', ..] => Self::PortalSuspended, - [b't', _, _, _, _, _a, _b, rest @ ..] => Self::ParameterDescription(rest), + [b't', _, _, _, _, a, b, rest @ ..] => { + Self::ParameterDescription(u16::from_be_bytes([*a, *b]), rest) + } _ => { return Err( PostgresError::UnexpectedValueFromBytes { expected: type_name::() }.into(), diff --git a/wtx/src/database/client/postgres/postgres_error.rs b/wtx/src/database/client/postgres/postgres_error.rs index 6a21d203..461ab604 100644 --- a/wtx/src/database/client/postgres/postgres_error.rs +++ b/wtx/src/database/client/postgres/postgres_error.rs @@ -7,6 +7,8 @@ pub enum PostgresError { DecodingError, /// There are no bytes left to build a `DbError` InsufficientDbErrorBytes, + /// Invalid IP format + InvalidIpFormat, /// JSONB is the only supported JSON format InvalidJsonFormat, /// Postgres does not support large unsigned integers. For example, `u8` can only be stored diff --git a/wtx/src/database/client/postgres/record.rs b/wtx/src/database/client/postgres/record.rs index ab44278e..9940bc49 100644 --- a/wtx/src/database/client/postgres/record.rs +++ b/wtx/src/database/client/postgres/record.rs @@ -52,7 +52,8 @@ impl<'exec, E> Record<'exec, E> { let initial_value_offset = bytes_range.start; let mut curr_value_offset = bytes_range.start; - match (bytes.get(bytes_range), values_len) { + let local_bytes = bytes.get(bytes_range); + match (local_bytes, values_len) { (Some([a, b, c, d, rest @ ..]), 1..=u16::MAX) => { bytes = rest; fun(&mut curr_value_offset, [*a, *b, *c, *d])?; @@ -114,7 +115,7 @@ where None } else { let begin = range.start.wrapping_sub(self.initial_value_offset); - let column = match self.stmt.columns.get(idx) { + let column = match self.stmt.column(idx) { None => return _unlikely_elem(None), Some(elem) => elem, }; @@ -131,7 +132,7 @@ where impl<'exec, E> ValueIdent> for str { #[inline] fn idx(&self, input: &Record<'exec, E>) -> Option { - input.stmt.columns.iter().position(|column| column.name.as_str() == self) + input.stmt.columns().position(|column| column.name.as_str() == self) } } @@ -177,7 +178,7 @@ mod tests { client::postgres::{ statements::Statement, tests::{column0, column1, column2}, - DecodeValue, Record, + DecodeValue, Record, Ty, }, Record as _, }, @@ -188,9 +189,9 @@ mod tests { #[test] fn returns_correct_values() { let bytes = &[0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 3, 0, 0, 0, 1, 4]; - let columns = &[column0(), column1(), column2()]; + let values = &[(column0(), Ty::Any), (column1(), Ty::Any), (column2(), Ty::Any)]; let mut values_bytes_offsets = Vector::new(); - let stmt = Statement::new(columns, &[]); + let stmt = Statement::new(3, 0, values); let record = Record::::parse( bytes, 0..bytes.len(), diff --git a/wtx/src/database/client/postgres/records.rs b/wtx/src/database/client/postgres/records.rs index d15bb7f4..107de3cd 100644 --- a/wtx/src/database/client/postgres/records.rs +++ b/wtx/src/database/client/postgres/records.rs @@ -91,7 +91,7 @@ mod tests { client::postgres::{ statements::Statement, tests::{column0, column1, column2}, - DecodeValue, Record, Records, + DecodeValue, Record, Records, Ty, }, Record as _, Records as _, }, @@ -101,8 +101,8 @@ mod tests { #[test] fn returns_correct_values() { let bytes = &[0, 0, 0, 2, 1, 2, 0, 0, 0, 2, 3, 4, 9, 9, 9, 0, 1, 0, 0, 0, 4, 5, 6, 7, 8]; - let columns = &[column0(), column1(), column2()]; - let stmt = Statement::new(columns, &[]); + let values = &[(column0(), Ty::Any), (column1(), Ty::Any), (column2(), Ty::Any)]; + let stmt = Statement::new(3, 0, values); let mut records_values_offsets = Vector::new(); let mut values_bytes_offsets = Vector::new(); assert_eq!( diff --git a/wtx/src/database/client/postgres/statements.rs b/wtx/src/database/client/postgres/statements.rs index df94ef55..03a404fe 100644 --- a/wtx/src/database/client/postgres/statements.rs +++ b/wtx/src/database/client/postgres/statements.rs @@ -1,378 +1,186 @@ +mod column; +mod statement; +mod statement_builder; +mod statements_misc; + use crate::{ - database::{client::postgres::ty::Ty, Identifier}, - misc::{Rng, _random_state, _unreachable}, + database::client::postgres::ty::Ty, + misc::{BlocksDeque, Rng, _random_state}, }; -use alloc::collections::VecDeque; +pub(crate) use column::Column; use foldhash::fast::FixedState; use hashbrown::HashMap; - -const AVG_STMT_COLUMNS_LEN: usize = 4; -const AVG_STMT_PARAMS_LEN: usize = 4; -const DFLT_MAX_STMTS: usize = 128; -const INITIAL_ELEMENTS_CAP: usize = 8; -const NUM_OF_ELEMENTS_TO_REMOVE_WHEN_FULL: u8 = 8; +pub(crate) use statement::Statement; +pub(crate) use statement_builder::StatementBuilder; +pub(crate) use statements_misc::StatementsMisc; /// Statements #[derive(Debug)] pub struct Statements { - columns: VecDeque, - columns_start: usize, - info_by_cmd_hash: HashMap, - info_by_cmd_hash_start: usize, - info: VecDeque, max_stmts: usize, - num_of_elements_to_remove_when_full: u8, - params: VecDeque, - params_start: usize, rs: FixedState, + stmts: BlocksDeque<(Column, Ty), StatementsMisc>, + stmts_indcs: HashMap, } impl Statements { + #[inline] pub(crate) fn new(max_stmts: usize, rng: RNG) -> Self where RNG: Rng, { Self { - columns: VecDeque::with_capacity(INITIAL_ELEMENTS_CAP.saturating_mul(AVG_STMT_COLUMNS_LEN)), - columns_start: 0, - info_by_cmd_hash: HashMap::with_capacity(INITIAL_ELEMENTS_CAP), - info_by_cmd_hash_start: 0, - info: VecDeque::with_capacity(INITIAL_ELEMENTS_CAP), - max_stmts, - num_of_elements_to_remove_when_full: NUM_OF_ELEMENTS_TO_REMOVE_WHEN_FULL, - params: VecDeque::with_capacity(INITIAL_ELEMENTS_CAP.saturating_mul(AVG_STMT_PARAMS_LEN)), - params_start: 0, + max_stmts: max_stmts.max(1), rs: _random_state(rng), + stmts: BlocksDeque::new(), + stmts_indcs: HashMap::new(), } } - pub(crate) fn _empty() -> Self { - Self { - columns: VecDeque::new(), - columns_start: 0, - info: VecDeque::new(), - info_by_cmd_hash: HashMap::new(), - info_by_cmd_hash_start: 0, - rs: FixedState::with_seed(0), - max_stmts: 0, - num_of_elements_to_remove_when_full: 0, - params: VecDeque::new(), - params_start: 0, - } - } - - pub(crate) fn with_default_params(rng: &mut RNG) -> Self + #[inline] + pub(crate) fn with_capacity( + columns: usize, + max_stmts: usize, + rng: RNG, + stmts: usize, + ) -> crate::Result where RNG: Rng, { - Self::new(DFLT_MAX_STMTS, rng) - } - - pub(crate) fn clear(&mut self) { - let Self { - columns, - columns_start, - rs: _, - info_by_cmd_hash, - info_by_cmd_hash_start, - info, - max_stmts: _, - num_of_elements_to_remove_when_full: _, - params, - params_start, - } = self; - columns.clear(); - *columns_start = 0; - info_by_cmd_hash.clear(); - *info_by_cmd_hash_start = 0; - info.clear(); - params.clear(); - *params_start = 0; - } - - pub(crate) fn get_by_stmt_hash(&self, stmt_hash: u64) -> Option> { - let mut info_idx = *self.info_by_cmd_hash.get(&stmt_hash)?; - info_idx = info_idx.wrapping_sub(self.info_by_cmd_hash_start); - let info_slice_opt = self.info.as_slices().0.get(..=info_idx); - let (columns_range, params_range) = match info_slice_opt { - None | Some([]) => _unreachable(), - Some([a]) => ( - 0..a.columns_offset.wrapping_sub(self.columns_start), - 0..a.params_offset.wrapping_sub(self.params_start), - ), - Some([.., a, b]) => ( - { - let start = a.columns_offset.wrapping_sub(self.columns_start); - let end = b.columns_offset.wrapping_sub(self.columns_start); - start..end - }, - { - let start = a.params_offset.wrapping_sub(self.params_start); - let end = b.params_offset.wrapping_sub(self.params_start); - start..end - }, - ), - }; - let columns = self.columns.as_slices().0; - let params = self.params.as_slices().0; - if let (Some(a), Some(b)) = (columns.get(columns_range), params.get(params_range)) { - Some(Statement::new(a, b)) - } else { - _unreachable(); - } - } - - pub(crate) fn hasher_mut(&mut self) -> &mut FixedState { - &mut self.rs + Ok(Self { + max_stmts: max_stmts.max(1), + rs: _random_state(rng), + stmts: BlocksDeque::with_capacity(stmts, columns)?, + stmts_indcs: HashMap::with_capacity(stmts), + }) } - pub(crate) fn push(&mut self, stmt_hash: u64) -> PushRslt<'_> { - if self.info_by_cmd_hash.get(&stmt_hash).is_some() { - #[expect(clippy::unwrap_used, reason = "borrow checker woes")] - return PushRslt::Stmt(self.get_by_stmt_hash(stmt_hash).unwrap()); - } - if self.info.len() >= self.max_stmts { - let remove = usize::from(self.num_of_elements_to_remove_when_full).min(self.max_stmts / 2); - for _ in 0..remove { - self.remove_first_stmt(); + #[inline] + pub(crate) fn builder(&mut self) -> StatementBuilder<'_> { + if self.stmts.blocks_len() >= self.max_stmts { + let to_remove = (self.max_stmts / 2).max(1); + for _ in 0..to_remove { + let _ = self.stmts.pop_front(); } + self.stmts_indcs.retain(|_, value| { + if *value < to_remove { + return false; + } + *value = value.wrapping_sub(to_remove); + true + }) } - PushRslt::Builder(StatementBuilder { columns_len: 0, params_len: 0, stmt_hash, stmts: self }) - } - - fn remove_first_stmt(&mut self) { - let Some(info) = self.info.pop_front() else { - return; - }; - - let columns_len = info.columns_offset.wrapping_sub(self.columns_start); - for _ in 0..columns_len { - let _ = self.columns.pop_front(); - } - self.columns_start = self.columns_start.wrapping_add(columns_len); - - let params_len = info.params_offset.wrapping_sub(self.params_start); - for _ in 0..params_len { - let _ = self.params.pop_front(); - } - self.params_start = self.params_start.wrapping_add(params_len); - - let _ = self.info_by_cmd_hash.remove(&info.stmt_hash); - self.info_by_cmd_hash_start = self.info_by_cmd_hash_start.wrapping_add(1); + StatementBuilder::new(&mut self.stmts, &mut self.stmts_indcs) } -} -#[derive(Debug)] -pub(crate) enum PushRslt<'stmts> { - Builder(StatementBuilder<'stmts>), - Stmt(Statement<'stmts>), -} - -#[cfg_attr(test, derive(Clone))] -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct Column { - pub(crate) name: Identifier, - pub(crate) ty: Ty, -} - -#[derive(Clone, Debug, Default, Eq, PartialEq)] -pub(crate) struct Statement<'stmts> { - pub(crate) columns: &'stmts [Column], - pub(crate) params: &'stmts [Ty], -} - -impl<'stmts> Statement<'stmts> { - pub(crate) const fn new(columns: &'stmts [Column], params: &'stmts [Ty]) -> Self { - Self { columns, params } + #[inline] + pub(crate) fn clear(&mut self) { + let Self { max_stmts: _, rs: _, stmts, stmts_indcs } = self; + stmts.clear(); + stmts_indcs.clear(); } -} - -#[derive(Debug)] -pub(crate) struct StatementBuilder<'stmts> { - columns_len: usize, - params_len: usize, - stmt_hash: u64, - stmts: &'stmts mut Statements, -} -impl<'stmts> StatementBuilder<'stmts> { - // Returning `&'stmts mut Statements` because of borrow checker limitations. - pub(crate) fn finish(self) -> &'stmts mut Statements { - let (last_columns_offset, last_params_offset) = self - .stmts - .info - .as_slices() - .0 - .last() - .map_or((self.stmts.columns_start, self.stmts.params_start), |el| { - (el.columns_offset, el.params_offset) - }); - let _ = self.stmts.info_by_cmd_hash.insert( - self.stmt_hash, - self.stmts.info_by_cmd_hash_start.wrapping_add(self.stmts.info.len()), - ); - self.stmts.info.push_back(StatementInfo { - columns_offset: last_columns_offset.wrapping_add(self.columns_len), - params_offset: last_params_offset.wrapping_add(self.params_len), - stmt_hash: self.stmt_hash, - }); - self.stmts + #[inline] + pub(crate) fn get_by_idx(&self, idx: usize) -> Option> { + let stmt = self.stmts.get(idx)?; + Some(Statement::new(stmt.misc.columns_len, stmt.misc.types_len, stmt.data)) } - pub(crate) fn push_column(&mut self, column: Column) { - self.stmts.columns.push_back(column); - self.columns_len = self.columns_len.wrapping_add(1); + #[inline] + pub(crate) fn get_by_stmt_hash(&self, stmt_hash: u64) -> Option> { + self.get_by_idx(*self.stmts_indcs.get(&stmt_hash)?) } - pub(crate) fn push_param(&mut self, param: Ty) { - self.stmts.params.push_back(param); - self.params_len = self.params_len.wrapping_add(1); + #[inline] + pub(crate) fn hasher_mut(&mut self) -> &mut FixedState { + &mut self.rs } } -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct StatementInfo { - pub(crate) columns_offset: usize, - pub(crate) params_offset: usize, - pub(crate) stmt_hash: u64, -} - #[cfg(test)] mod tests { use crate::{ database::client::postgres::{ - statements::{Column, PushRslt, Statement}, + statements::StatementsMisc, tests::{column0, column1, column2, column3}, ty::Ty, Statements, }, - misc::{simple_seed, Vector, Xorshift64}, + misc::{simple_seed, Xorshift64}, }; - #[test] - fn stmt_if_duplicated() { - let stmt_hash = 123; - let mut stmts = Statements::new(100, &mut Xorshift64::from(simple_seed())); - let PushRslt::Builder(builder) = stmts.push(stmt_hash) else { panic!() }; - let _ = builder.finish(); - let PushRslt::Stmt(_) = stmts.push(stmt_hash) else { panic!() }; - } - + // FIXME(MIRI): The modification of the vector's length makes MIRI think that there is an + // invalid pointer using stacked borrows. + // + // | A | B | | <- Push back one block of 2 elements. Length is 2 + // | A | B | C | <- Push back one block of 1 element. Length is 3 + // | | | C | <- Pop front one block. Length is 1 + // + // Such behaviour does not occur with "miri-tree-borrows". + #[cfg_attr(miri, ignore)] #[test] fn two_statements() { let mut stmts = Statements::new(2, &mut Xorshift64::from(simple_seed())); - stmts.num_of_elements_to_remove_when_full = 1; let stmt_id0 = 123; - let PushRslt::Builder(mut builder) = stmts.push(stmt_id0) else { panic!() }; - builder.push_column(column0()); - builder.push_column(column1()); - builder.push_param(Ty::Int2); - let _ = builder.finish(); - assert_stmts( - AssertStatements { - columns: &[column0(), column1()], - columns_offset_start: 0, - info: &[(2, 1)], - info_by_cmd_hash: &[0], - params: &[Ty::Int2], - params_offset_start: 0, - }, - &stmts, - ); - assert_eq!( - stmts.get_by_stmt_hash(stmt_id0), - Some(Statement::new(&[column0(), column1()], &[Ty::Int2])) - ); + let mut builder = stmts.builder(); + let _ = builder.expand(2).unwrap(); + builder.inserted_elements()[0] = (column0(), Ty::Int2); + builder.inserted_elements()[1] = (column1(), Ty::Int2); + let _ = builder.build(stmt_id0, StatementsMisc::new(2, 1)).unwrap(); + { + let stmt = stmts.get_by_stmt_hash(stmt_id0).unwrap(); + assert_eq!(stmt.columns().count(), 2); + assert_eq!(stmt.column(0).unwrap(), &column0()); + assert_eq!(stmt.column(1).unwrap(), &column1()); + assert_eq!(stmt.tys().count(), 1); + assert_eq!(stmt.ty(0).unwrap(), Ty::Int2); + } let stmt_id1 = 456; - let PushRslt::Builder(mut builder) = stmts.push(stmt_id1) else { panic!() }; - builder.push_column(column2()); - builder.push_param(Ty::Int4); - let _ = builder.finish(); - assert_stmts( - AssertStatements { - columns: &[column0(), column1(), column2()], - columns_offset_start: 0, - info: &[(2, 1), (3, 2)], - info_by_cmd_hash: &[0, 1], - params: &[Ty::Int2, Ty::Int4], - params_offset_start: 0, - }, - &stmts, - ); - assert_eq!( - stmts.get_by_stmt_hash(stmt_id0), - Some(Statement::new(&[column0(), column1()], &[Ty::Int2])) - ); - assert_eq!(stmts.get_by_stmt_hash(stmt_id1), Some(Statement::new(&[column2()], &[Ty::Int4]))); + let mut builder = stmts.builder(); + let _ = builder.expand(1).unwrap(); + builder.inserted_elements()[0] = (column2(), Ty::Int4); + let _ = builder.build(stmt_id1, StatementsMisc::new(1, 1)).unwrap(); + { + let stmt = stmts.get_by_stmt_hash(stmt_id0).unwrap(); + assert_eq!(stmt.columns().count(), 2); + assert_eq!(stmt.column(0).unwrap(), &column0()); + assert_eq!(stmt.column(1).unwrap(), &column1()); + assert_eq!(stmt.tys().count(), 1); + assert_eq!(stmt.ty(0).unwrap(), Ty::Int2); + } + { + let stmt = stmts.get_by_stmt_hash(stmt_id1).unwrap(); + assert_eq!(stmt.columns().count(), 1); + assert_eq!(stmt.column(0).unwrap(), &column2()); + assert_eq!(stmt.tys().count(), 1); + assert_eq!(stmt.ty(0).unwrap(), Ty::Int4); + } let stmt_id2 = 789; - let PushRslt::Builder(mut builder) = stmts.push(stmt_id2) else { panic!() }; - builder.push_column(column3()); - let _ = builder.finish(); - assert_stmts( - AssertStatements { - columns: &[column2(), column3()], - columns_offset_start: 2, - info: &[(3, 2), (4, 2)], - info_by_cmd_hash: &[1, 2], - params: &[Ty::Int4], - params_offset_start: 1, - }, - &stmts, - ); + let mut builder = stmts.builder(); + let _ = builder.expand(1).unwrap(); + builder.inserted_elements()[0].0 = column3(); + let _ = builder.build(stmt_id2, StatementsMisc::new(1, 0)).unwrap(); assert_eq!(stmts.get_by_stmt_hash(stmt_id0), None); - assert_eq!(stmts.get_by_stmt_hash(stmt_id1), Some(Statement::new(&[column2()], &[Ty::Int4]))); - assert_eq!(stmts.get_by_stmt_hash(stmt_id2), Some(Statement::new(&[column3()], &[]))); + { + let stmt = stmts.get_by_stmt_hash(stmt_id1).unwrap(); + assert_eq!(stmt.columns().count(), 1); + assert_eq!(stmt.column(0).unwrap(), &column2()); + assert_eq!(stmt.tys().count(), 1); + assert_eq!(stmt.ty(0).unwrap(), Ty::Int4); + } + { + let stmt = stmts.get_by_stmt_hash(stmt_id2).unwrap(); + assert_eq!(stmt.columns().count(), 1); + assert_eq!(stmt.column(0).unwrap(), &column3()); + assert_eq!(stmt.tys().count(), 0); + } stmts.clear(); - assert_stmts( - AssertStatements { - columns: &[], - columns_offset_start: 0, - info: &[], - info_by_cmd_hash: &[], - params: &[], - params_offset_start: 0, - }, - &stmts, - ); assert_eq!(stmts.get_by_stmt_hash(stmt_id0), None); assert_eq!(stmts.get_by_stmt_hash(stmt_id1), None); assert_eq!(stmts.get_by_stmt_hash(stmt_id2), None); } - - #[track_caller] - fn assert_stmts(cs: AssertStatements<'_>, stmts: &Statements) { - assert_eq!(stmts.columns.as_slices().0, cs.columns); - assert_eq!(stmts.columns.as_slices().1, &[]); - assert_eq!(stmts.columns_start, cs.columns_offset_start); - assert_eq!( - Vector::from_iter(stmts.info.iter().map(|el| (el.columns_offset, el.params_offset))) - .unwrap() - .as_slice(), - cs.info - ); - assert_eq!( - &{ - let mut vec = Vector::from_iter(stmts.info_by_cmd_hash.iter().map(|el| *el.1)).unwrap(); - vec.sort(); - vec - }[..], - cs.info_by_cmd_hash - ); - assert_eq!(stmts.params.as_slices().0, cs.params); - assert_eq!(stmts.params.as_slices().1, &[]); - assert_eq!(stmts.params_start, cs.params_offset_start); - } - - struct AssertStatements<'data> { - columns: &'data [Column], - columns_offset_start: usize, - info: &'data [(usize, usize)], - info_by_cmd_hash: &'data [usize], - params: &'data [Ty], - params_offset_start: usize, - } } diff --git a/wtx/src/database/client/postgres/statements/column.rs b/wtx/src/database/client/postgres/statements/column.rs new file mode 100644 index 00000000..be66e15a --- /dev/null +++ b/wtx/src/database/client/postgres/statements/column.rs @@ -0,0 +1,14 @@ +use crate::database::{client::postgres::Ty, Identifier}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct Column { + pub(crate) name: Identifier, + pub(crate) ty: Ty, +} + +impl Column { + #[inline] + pub(crate) fn new(name: Identifier, ty: Ty) -> Self { + Self { name, ty } + } +} diff --git a/wtx/src/database/client/postgres/statements/statement.rs b/wtx/src/database/client/postgres/statements/statement.rs new file mode 100644 index 00000000..4ba23ed7 --- /dev/null +++ b/wtx/src/database/client/postgres/statements/statement.rs @@ -0,0 +1,48 @@ +use crate::database::client::postgres::{statements::column::Column, Ty}; + +/// ```sql +/// SELECT a,b,c,d FROM table WHERE e = $1 AND f = $2 +/// ``` +/// +/// The columns are "a", "b", "c", "d" and the types are "$1" and "$2". +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub(crate) struct Statement<'stmts> { + columns_len: usize, + tys_len: usize, + values: &'stmts [(Column, Ty)], +} + +impl<'stmts> Statement<'stmts> { + #[inline] + pub(crate) const fn new( + columns_len: usize, + tys_len: usize, + values: &'stmts [(Column, Ty)], + ) -> Self { + Self { columns_len, tys_len, values } + } + + #[inline] + pub(crate) fn column(&self, idx: usize) -> Option<&Column> { + let columns = self.values.get(..self.columns_len)?; + Some(&columns.get(idx)?.0) + } + + #[inline] + pub(crate) fn columns(&self) -> impl Iterator { + let columns = self.values.get(..self.columns_len).unwrap_or_default(); + columns.iter().map(|el| &el.0) + } + + #[cfg(test)] + #[inline] + pub(crate) fn ty(&self, idx: usize) -> Option { + Some(self.values.get(..self.tys_len)?.get(idx)?.1) + } + + #[cfg(test)] + #[inline] + pub(crate) fn tys(&self) -> impl Iterator { + self.values.get(..self.tys_len).unwrap_or_default().iter().map(|el| &el.1) + } +} diff --git a/wtx/src/database/client/postgres/statements/statement_builder.rs b/wtx/src/database/client/postgres/statements/statement_builder.rs new file mode 100644 index 00000000..dc1c6b4a --- /dev/null +++ b/wtx/src/database/client/postgres/statements/statement_builder.rs @@ -0,0 +1,53 @@ +use crate::{ + database::{ + client::postgres::{ + statements::{column::Column, StatementsMisc}, + Ty, + }, + Identifier, + }, + misc::{BlocksDeque, BlocksDequeBuilder, BufferMode}, +}; +use hashbrown::HashMap; + +#[derive(Debug)] +pub(crate) struct StatementBuilder<'stmts> { + pub(crate) builder: BlocksDequeBuilder<'stmts, (Column, Ty), StatementsMisc, true>, + pub(crate) curr_len: usize, + pub(crate) indcs: &'stmts mut HashMap, +} + +impl<'stmts> StatementBuilder<'stmts> { + #[inline] + pub(crate) fn new( + stmts: &'stmts mut BlocksDeque<(Column, Ty), StatementsMisc>, + stmts_indcs: &'stmts mut HashMap, + ) -> Self { + let curr_len = stmts.blocks_len(); + Self { builder: stmts.builder_back(), curr_len, indcs: stmts_indcs } + } + + #[inline] + pub(crate) fn build(mut self, hash: u64, mut sm: StatementsMisc) -> crate::Result { + let len = self.builder.inserted_elements().len(); + sm.columns_len = sm.columns_len.min(len); + sm.types_len = sm.types_len.min(len); + let _ = self.indcs.insert(hash, self.curr_len); + self.builder.build(sm)?; + Ok(self.curr_len) + } + + #[inline] + pub(crate) fn expand(&mut self, additional: usize) -> crate::Result<&mut Self> { + let _ = self.builder.expand( + BufferMode::Additional(additional), + (Column::new(Identifier::new(), Ty::Any), Ty::Any), + )?; + Ok(self) + } + + #[inline] + pub(crate) fn inserted_elements(&mut self) -> &mut [(Column, Ty)] { + self.builder.inserted_elements() + } +} diff --git a/wtx/src/database/client/postgres/statements/statements_misc.rs b/wtx/src/database/client/postgres/statements/statements_misc.rs new file mode 100644 index 00000000..dc114ea6 --- /dev/null +++ b/wtx/src/database/client/postgres/statements/statements_misc.rs @@ -0,0 +1,12 @@ +#[derive(Debug)] +pub(crate) struct StatementsMisc { + pub(crate) columns_len: usize, + pub(crate) types_len: usize, +} + +impl StatementsMisc { + #[inline] + pub(crate) fn new(columns_len: usize, types_len: usize) -> Self { + Self { columns_len, types_len } + } +} diff --git a/wtx/src/database/client/postgres/struct_encoder.rs b/wtx/src/database/client/postgres/struct_encoder.rs index d718dd68..d26b5a02 100644 --- a/wtx/src/database/client/postgres/struct_encoder.rs +++ b/wtx/src/database/client/postgres/struct_encoder.rs @@ -63,6 +63,7 @@ impl<'buffer, 'fbw, 'vec, E> Drop for StructEncoder<'buffer, 'fbw, 'vec, E> { } } +#[inline] fn write_len(ev: &mut EncodeValue<'_, '_>, start: usize, len: u32) { let Some([a, b, c, d, ..]) = ev.fbw()._curr_bytes_mut().get_mut(start..) else { return; diff --git a/wtx/src/database/client/postgres/tys.rs b/wtx/src/database/client/postgres/tys.rs index 84650eed..14d2170a 100644 --- a/wtx/src/database/client/postgres/tys.rs +++ b/wtx/src/database/client/postgres/tys.rs @@ -1,7 +1,7 @@ -macro_rules! proptest { +macro_rules! kani { ($name:ident, $ty:ty) => { - #[cfg(all(feature = "_proptest", test))] - #[test_strategy::proptest] + #[cfg(kani)] + #[kani::proof] fn $name(instance: $ty) { let mut vec = &mut crate::misc::FilledBuffer::_new(); { @@ -296,7 +296,10 @@ mod collections { { #[inline] fn decode(dv: &DecodeValue<'_>) -> Result { - Ok(from_utf8_basic(dv.bytes()).map_err(crate::Error::from)?.into()) + match from_utf8_basic(dv.bytes()).map_err(crate::Error::from) { + Ok(elem) => Ok(elem.into()), + Err(err) => Err(err.into()), + } } } impl Encode> for String @@ -315,12 +318,12 @@ mod collections { { const TY: Ty = Ty::Text; } - proptest!(string, String); + kani!(string, String); } mod ip { use crate::database::{ - client::postgres::{DecodeValue, EncodeValue, Postgres, Ty}, + client::postgres::{DecodeValue, EncodeValue, Postgres, PostgresError, Ty}, Decode, Encode, Typed, }; use core::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -334,7 +337,7 @@ mod ip { Ok(match dv.bytes() { [2, ..] => IpAddr::V4(Ipv4Addr::decode(dv)?), [3, ..] => IpAddr::V6(Ipv6Addr::decode(dv)?), - _ => panic!(), + _ => return Err(E::from(PostgresError::InvalidIpFormat.into())), }) } } @@ -366,7 +369,7 @@ mod ip { #[inline] fn decode(dv: &DecodeValue<'exec>) -> Result { let [2, 32, 0, 4, e, f, g, h] = dv.bytes() else { - panic!(); + return Err(E::from(PostgresError::InvalidIpFormat.into())); }; Ok(Ipv4Addr::from([*e, *f, *g, *h])) } @@ -396,7 +399,7 @@ mod ip { #[inline] fn decode(dv: &DecodeValue<'exec>) -> Result { let [3, 128, 0, 16, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t] = dv.bytes() else { - panic!(); + return Err(E::from(PostgresError::InvalidIpFormat.into())); }; Ok(Ipv6Addr::from([*e, *f, *g, *h, *i, *j, *k, *l, *m, *n, *o, *p, *q, *r, *s, *t])) } @@ -588,8 +591,8 @@ mod primitives { const TY: Ty = Ty::Bool; } - proptest!(bool_true, bool); - proptest!(bool_false, bool); + kani!(bool_true, bool); + kani!(bool_false, bool); macro_rules! impl_integer_from_array { ($instance:expr, [$($elem:ident),+], ($signed:ident, $signed_pg_ty:expr), ($unsigned:ident, $unsigned_pg_ty:expr)) => { @@ -789,7 +792,7 @@ mod rust_decimal { const TY: Ty = Ty::Numeric; } - proptest!(rust_decimal, Decimal); + kani!(rust_decimal, Decimal); } #[cfg(feature = "serde_json")] diff --git a/wtx/src/database/schema_manager/integration_tests.rs b/wtx/src/database/schema_manager/integration_tests.rs index d2d05cd7..5ab4ddce 100644 --- a/wtx/src/database/schema_manager/integration_tests.rs +++ b/wtx/src/database/schema_manager/integration_tests.rs @@ -47,7 +47,7 @@ macro_rules! create_integration_tests { let mut rng = Xorshift64::from(crate::misc::simple_seed()); crate::database::client::postgres::Executor::connect( &config, - crate::database::client::postgres::ExecutorBuffer::with_default_params(&mut rng).unwrap(), + crate::database::client::postgres::ExecutorBuffer::new(usize::MAX, &mut rng), &mut rng, stream, ).await.unwrap() diff --git a/wtx/src/database/schema_manager/misc.rs b/wtx/src/database/schema_manager/misc.rs index 725e3daf..c5535d26 100644 --- a/wtx/src/database/schema_manager/misc.rs +++ b/wtx/src/database/schema_manager/misc.rs @@ -28,9 +28,8 @@ use { misc::{ArrayVector, FromRadix10, Vector}, }, alloc::string::String, - core::cmp::Ordering, + core::{cmp::Ordering, fmt::Write}, std::{ - fmt::Write, fs::{read_to_string, DirEntry, File}, io::Read, path::{Path, PathBuf}, diff --git a/wtx/src/error.rs b/wtx/src/error.rs index d4d99e22..f1c50b02 100644 --- a/wtx/src/error.rs +++ b/wtx/src/error.rs @@ -1,5 +1,5 @@ use crate::misc::{ - ArrayStringError, ArrayVectorError, BlocksQueueError, FromRadix10Error, QueueError, VectorError, + ArrayStringError, ArrayVectorError, BlocksDequeError, DequeueError, FromRadix10Error, VectorError, }; #[allow(unused_imports, reason = "Depends on the selection of features")] use alloc::boxed::Box; @@ -131,7 +131,7 @@ pub enum Error { // ArrayStringError(ArrayStringError), ArrayVectorError(ArrayVectorError), - BlocksQueueError(BlocksQueueError), + BlocksQueueError(BlocksDequeError), #[cfg(feature = "client-api-framework")] ClientApiFrameworkError(crate::client_api_framework::ClientApiFrameworkError), #[cfg(feature = "database")] @@ -147,7 +147,7 @@ pub enum Error { Http2ErrorReset(crate::http2::Http2ErrorCode, Option, u32), #[cfg(feature = "postgres")] PostgresError(crate::database::client::postgres::PostgresError), - QueueError(QueueError), + QueueError(DequeueError), #[cfg(feature = "schema-manager")] SchemaManagerError(crate::database::schema_manager::SchemaManagerError), VectorError(VectorError), @@ -443,9 +443,9 @@ impl From for Error { } } -impl From for Error { +impl From for Error { #[inline] - fn from(from: BlocksQueueError) -> Self { + fn from(from: BlocksDequeError) -> Self { Self::BlocksQueueError(from) } } @@ -497,9 +497,9 @@ impl From for Error { } } -impl From for Error { +impl From for Error { #[inline] - fn from(from: QueueError) -> Self { + fn from(from: DequeueError) -> Self { Self::QueueError(from) } } diff --git a/wtx/src/grpc.rs b/wtx/src/grpc.rs index c79217ab..ba2f701e 100644 --- a/wtx/src/grpc.rs +++ b/wtx/src/grpc.rs @@ -9,7 +9,7 @@ mod grpc_status_code; use crate::{data_transformation::dnsn::Serialize, misc::Vector}; pub use client::Client; pub use grpc_manager::GrpcManager; -pub use grpc_res_middleware::GrpcResMiddleware; +pub use grpc_res_middleware::GrpcMiddleware; pub use grpc_status_code::GrpcStatusCode; #[inline] diff --git a/wtx/src/grpc/grpc_res_middleware.rs b/wtx/src/grpc/grpc_res_middleware.rs index 8892628d..fa7cf514 100644 --- a/wtx/src/grpc/grpc_res_middleware.rs +++ b/wtx/src/grpc/grpc_res_middleware.rs @@ -1,26 +1,36 @@ +use core::ops::ControlFlow; + use crate::{ grpc::GrpcManager, http::{ - server_framework::ResMiddleware, Header, KnownHeaderName, Mime, ReqResBuffer, ReqResDataMut, - Response, + server_framework::Middleware, Header, KnownHeaderName, Mime, ReqResBuffer, ReqResDataMut, + Request, Response, StatusCode, }, }; /// Applies gRPC headers #[derive(Debug)] -pub struct GrpcResMiddleware; +pub struct GrpcMiddleware; -impl ResMiddleware> for GrpcResMiddleware +impl Middleware> for GrpcMiddleware where E: From, { + type Aux = (); + + #[inline] + fn aux(&self) -> Self::Aux { + () + } + #[inline] - async fn apply_res_middleware( + async fn req( &self, _: &mut CA, - res: Response<&mut ReqResBuffer>, + _: &mut Self::Aux, + res: &mut Request, sa: &mut GrpcManager, - ) -> Result<(), E> { + ) -> Result, E> { res.rrd.headers_mut().push_from_iter_many([ Header::from_name_and_value( KnownHeaderName::ContentType.into(), @@ -33,6 +43,17 @@ where value: [sa.status_code_mut().number_as_str().as_bytes()].into_iter(), }, ])?; - Ok(()) + Ok(ControlFlow::Continue(())) + } + + #[inline] + async fn res( + &self, + _: &mut CA, + _: &mut Self::Aux, + _: Response<&mut ReqResBuffer>, + _: &mut GrpcManager, + ) -> Result, E> { + Ok(ControlFlow::Continue(())) } } diff --git a/wtx/src/http/cookie.rs b/wtx/src/http/cookie.rs index e52c5a99..3a509b57 100644 --- a/wtx/src/http/cookie.rs +++ b/wtx/src/http/cookie.rs @@ -22,14 +22,11 @@ static FMT4: &str = "%a, %d-%b-%Y %H:%M:%S GMT"; #[inline] pub(crate) fn decrypt( buffer: &mut Vector, - key: &[u8], + key: &[u8; 32], (name, value): (&[u8], &[u8]), ) -> crate::Result<()> { use crate::misc::BufferMode; - use aes_gcm::{ - aead::{generic_array::GenericArray, AeadInPlace}, - Aes256Gcm, Tag, - }; + use aes_gcm::{aead::AeadInPlace, aes::cipher::Array, Aes256Gcm}; use base64::{engine::general_purpose::STANDARD, Engine}; let start = buffer.len(); @@ -56,13 +53,12 @@ pub(crate) fn decrypt( }; rslt }; - ::new(GenericArray::from_slice(key)) - .decrypt_in_place_detached( - GenericArray::from_slice(&nonce), - name, - content, - Tag::from_slice(&tag), - )?; + ::new(&Array(*key)).decrypt_in_place_detached( + &Array(nonce), + name, + content, + &Array(tag), + )?; let idx = start.wrapping_sub(TAG_LEN); let _ = _shift_copyable_chunks(0, buffer, [NONCE_LEN..idx]); buffer.truncate(idx.wrapping_sub(NONCE_LEN)); @@ -73,7 +69,7 @@ pub(crate) fn decrypt( #[inline] pub(crate) fn encrypt( buffer: &mut Vector, - key: &[u8], + key: &[u8; 32], (name, value): (&[u8], &[u8]), mut rng: RNG, ) -> crate::Result<()> @@ -81,10 +77,7 @@ where RNG: Rng, { use crate::misc::BufferMode; - use aes_gcm::{ - aead::{generic_array::GenericArray, AeadInPlace}, - Aes256Gcm, - }; + use aes_gcm::{aead::AeadInPlace, aes::cipher::Array, Aes256Gcm}; use base64::{engine::general_purpose::STANDARD, Engine}; let start = buffer.len(); @@ -120,9 +113,9 @@ where *a9 = c9; *a10 = c10; *a11 = c11; - let aes = ::new(GenericArray::from_slice(key)); + let aes = ::new(&Array(*key)); let nonce = [*a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7, *a8, *a9, *a10, *a11]; - let tag = aes.encrypt_in_place_detached(GenericArray::from_slice(&nonce), name, content)?; + let tag = aes.encrypt_in_place_detached(&Array(nonce), name, content)?; let [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15] = tag.into(); *b0 = d0; *b1 = d1; @@ -141,9 +134,8 @@ where *b14 = d14; *b15 = d15; }; - let Some((base64, content)) = - buffer.get_mut(start..).and_then(|el| el.split_at_mut_checked(base64_len)) - else { + let slice_mut = buffer.get_mut(start..).and_then(|el| el.split_at_mut_checked(base64_len)); + let Some((base64, content)) = slice_mut else { return Ok(()); }; let base64_idx = STANDARD.encode_slice(content, base64)?; diff --git a/wtx/src/http/http_error.rs b/wtx/src/http/http_error.rs index b8da2461..af4fb17a 100644 --- a/wtx/src/http/http_error.rs +++ b/wtx/src/http/http_error.rs @@ -3,6 +3,12 @@ use crate::http::{KnownHeaderName, Method}; /// Http error #[derive(Debug)] pub enum HttpError { + /// Client requrested a CORS header that isn't allowed + ForbiddenCorsHeader, + /// Client requrested a CORS method that isn't allowed + ForbiddenCorsMethod, + /// Client requrested a CORS origin that isn't allowed + ForbiddenCorsOrigin, /// The length of a header field must be within a threshold. HeaderFieldIsTooLarge, /// Invalid HTTP/2 or HTTP/3 header diff --git a/wtx/src/http/method.rs b/wtx/src/http/method.rs index da2aac5a..c2689d1e 100644 --- a/wtx/src/http/method.rs +++ b/wtx/src/http/method.rs @@ -23,6 +23,23 @@ _create_enum! { } } +impl Method { + /// An array containing the whole set of variants + pub const ALL: [Self; 9] = [ + Self::Connect, + Self::Delete, + Self::Get, + Self::Head, + Self::Options, + Self::Patch, + Self::Post, + Self::Put, + Self::Trace, + ]; + /// The number of variants + pub const VARIANTS: u8 = 9; +} + #[cfg(feature = "serde")] mod serde { use crate::http::Method; diff --git a/wtx/src/http/server_framework.rs b/wtx/src/http/server_framework.rs index 17c6d3de..422f0e61 100644 --- a/wtx/src/http/server_framework.rs +++ b/wtx/src/http/server_framework.rs @@ -25,7 +25,7 @@ use alloc::sync::Arc; pub use conn_aux::ConnAux; pub use cors_middleware::CorsMiddleware; pub use endpoint::Endpoint; -pub use middleware::{ReqMiddleware, ResMiddleware}; +pub use middleware::Middleware; pub use param_wrappers::*; pub use path_management::PathManagement; pub use path_params::PathParams; @@ -39,23 +39,22 @@ pub use stream_aux::StreamAux; /// Server #[derive(Debug)] -pub struct ServerFramework { +pub struct ServerFramework { _ca_cb: CAC, _cp: ConnParams, _sa_cb: SAC, - _router: Arc>, + _router: Arc>, } -impl ServerFramework +impl ServerFramework where E: From, + M: Middleware, P: PathManagement, - REQM: ReqMiddleware, - RESM: ResMiddleware, SA: StreamAux, { async fn _auto( - mut _as: AutoStream SA::Init, Arc>)>, + mut _as: AutoStream SA::Init, Arc>)>, ) -> Result, E> { let mut stream_aux = SA::req_aux(_as.stream_aux.0(), &mut _as.req)?; #[cfg(feature = "matchit")] diff --git a/wtx/src/http/server_framework/cors_middleware.rs b/wtx/src/http/server_framework/cors_middleware.rs index 06e7b4c9..82e13f98 100644 --- a/wtx/src/http/server_framework/cors_middleware.rs +++ b/wtx/src/http/server_framework/cors_middleware.rs @@ -1,100 +1,169 @@ use crate::{ http::{ - server_framework::{ConnAux, ResMiddleware}, - Header, KnownHeaderName, Method, ReqResBuffer, ReqResDataMut, Response, + server_framework::{ConnAux, Middleware}, + Header, Headers, HttpError, KnownHeaderName, Method, ReqResBuffer, Request, Response, + StatusCode, }, - misc::ArrayVector, + misc::{bytes_split1, ArrayVector, Intersperse, Vector}, }; +use core::ops::ControlFlow; +use hashbrown::HashSet; + +const MAX_HEADEARS: usize = 8; +const MAX_METHODS: usize = Method::VARIANTS as usize; +const MAX_ORIGINS: usize = 2; + +type AllowHeaders = (bool, ArrayVector<&'static str, MAX_HEADEARS>); +type AllowMethods = (bool, ArrayVector); +type AllowOrigins = (bool, ArrayVector<&'static str, MAX_ORIGINS>); +type ExposeHeaders = (bool, ArrayVector<&'static str, MAX_HEADEARS>); /// Cross-origin resource sharing #[derive(Debug)] pub struct CorsMiddleware { allow_credentials: bool, - allow_headers: ArrayVector<&'static str, 8>, - allow_methods: (bool, ArrayVector), - allow_origin: Option<&'static str>, - expose_headers: ArrayVector<&'static str, 8>, + // Many local options, many request/response options. + allow_headers: AllowHeaders, + // Many local options, many request/response options. + allow_methods: AllowMethods, + // Many local options, single request/response option. + allow_origins: AllowOrigins, + // Many local options, many request/response options. + expose_headers: ExposeHeaders, max_age: Option, } +impl ConnAux for CorsMiddleware { + type Init = Self; + + #[inline] + fn conn_aux(init: Self::Init) -> crate::Result { + Ok(init) + } +} + impl CorsMiddleware { /// New empty instance #[inline] pub const fn new() -> Self { Self { allow_credentials: false, - allow_headers: ArrayVector::new(), + allow_headers: (false, ArrayVector::new()), allow_methods: (false, ArrayVector::new()), - allow_origin: None, - expose_headers: ArrayVector::new(), + allow_origins: (false, ArrayVector::new()), + expose_headers: (false, ArrayVector::new()), max_age: None, } } - /// * All request headers allowed. + /// * Credentials are allowed + /// * All request headers allowed (wildcard). /// * All methods are allowed. - /// * All origins are allowed. - /// * All headers are exposed. + /// * All origins are allowed (wildcard). + /// * All headers are exposed (wildcard). + /// * No caching #[inline] #[must_use] pub fn permissive() -> Self { - Self::new().allow_headers(["*"]).allow_methods(true, []).allow_origin("*").expose_headers(["*"]) - } - - /// Like [`Self::permissive`] with the additional feature of allowing credentials. - #[inline] - #[must_use] - pub fn unrestricted() -> Self { - Self::new().allow_credentials().allow_headers(["*"]).allow_methods(true, []).allow_origin("*") + Self { + allow_credentials: true, + allow_headers: (true, ArrayVector::new()), + allow_methods: (false, ArrayVector::from_array(Method::ALL.into())), + allow_origins: (true, ArrayVector::new()), + expose_headers: (true, ArrayVector::new()), + max_age: None, + } } /// #[inline] #[must_use] - pub const fn allow_credentials(mut self) -> Self { + pub fn allow_credentials(mut self) -> Self { self.allow_credentials = true; self } /// + /// + /// Wildcard is only allowed in requests without credentials. #[inline] #[must_use] - pub fn allow_headers(mut self, elem: impl IntoIterator) -> Self { - let _rslt = self.allow_headers.extend_from_iter(elem); + pub fn allow_headers( + mut self, + is_wildcard: bool, + specifics: impl IntoIterator, + ) -> Self { + if is_wildcard { + self.allow_headers.0 = true; + } else { + self.allow_headers.0 = false; + self.allow_headers.1.clear(); + let iter = specifics.into_iter().take(MAX_HEADEARS); + let _rslt = self.allow_headers.1.extend_from_iter(iter); + } self } /// + /// + /// Wildcard is only allowed in requests without credentials. #[inline] #[must_use] pub fn allow_methods( mut self, - is_all: bool, + is_wildcard: bool, specifics: impl IntoIterator, ) -> Self { - if is_all { + if is_wildcard { self.allow_methods.0 = true; } else { self.allow_methods.0 = false; self.allow_methods.1.clear(); - let _rslt = self.allow_methods.1.extend_from_iter(specifics); + let iter = specifics.into_iter().take(MAX_METHODS); + let _rslt = self.allow_methods.1.extend_from_iter(iter); } self } /// + /// + /// Wildcard is only allowed in requests without credentials. #[inline] #[must_use] - pub const fn allow_origin(mut self, elem: &'static str) -> Self { - self.allow_origin = Some(elem); + pub fn allow_origins( + mut self, + is_wildcard: bool, + specifics: impl IntoIterator, + ) -> Self { + if is_wildcard { + self.allow_origins.0 = true; + } else { + self.allow_origins.0 = false; + self.allow_origins.1.clear(); + let iter = specifics.into_iter().take(MAX_ORIGINS); + let _rslt = self.allow_origins.1.extend_from_iter(iter); + } self } /// + /// + /// Wildcard is only allowed in requests without credentials. #[inline] #[must_use] - pub fn expose_headers(mut self, elem: impl IntoIterator) -> Self { - let _rslt = self.expose_headers.extend_from_iter(elem); + pub fn expose_headers( + mut self, + is_wildcard: bool, + specifics: impl IntoIterator, + ) -> Self { + if is_wildcard { + self.expose_headers.0 = true; + } else { + self.expose_headers.0 = false; + self.expose_headers.1.clear(); + let iter = specifics.into_iter().take(MAX_HEADEARS); + let _rslt = self.expose_headers.1.extend_from_iter(iter); + } self } @@ -105,80 +174,282 @@ impl CorsMiddleware { self.max_age = Some(elem); self } -} - -impl ConnAux for CorsMiddleware { - type Init = CorsMiddleware; #[inline] - fn conn_aux(init: Self::Init) -> crate::Result { - Ok(init) + fn allowed_origin<'this>(&'this self, origin: &[u8]) -> Option<&'static str> { + self.allow_origins.1.iter().find(|el| el.as_bytes() == origin).copied() } -} -impl ResMiddleware for CorsMiddleware -where - E: From, -{ #[inline] - async fn apply_res_middleware( - &self, - _: &mut CA, - res: Response<&mut ReqResBuffer>, - _: &mut SA, - ) -> Result<(), E> { - let Self { - allow_credentials, - allow_headers, - allow_methods, - allow_origin, - expose_headers, - max_age, - } = self; - if *allow_credentials { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( + fn apply_allow_credentials(allow_credentials: bool, headers: &mut Headers) -> crate::Result<()> { + if allow_credentials { + headers.push_from_iter(Header::from_name_and_value( KnownHeaderName::AccessControlAllowCredentials.into(), ["true".as_bytes()], ))?; } + Ok(()) + } + + #[inline] + fn apply_allow_headers(allow_headers: &[u8], headers: &mut Headers) -> crate::Result<()> { if !allow_headers.is_empty() { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( + headers.push_from_iter(Header::from_name_and_value( KnownHeaderName::AccessControlAllowHeaders.into(), - allow_headers.iter().map(|el| el.as_bytes()), + [allow_headers], ))?; } - if allow_methods.0 { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( + Ok(()) + } + + #[inline] + fn apply_allow_methods( + (is_wildcard, specifics): &AllowMethods, + headers: &mut Headers, + ) -> crate::Result<()> { + if *is_wildcard { + headers.push_from_iter(Header::from_name_and_value( KnownHeaderName::AccessControlAllowMethods.into(), ["*".as_bytes()], ))?; - } else if !allow_methods.1.is_empty() { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( + } else if !specifics.is_empty() { + headers.push_from_iter(Header::from_name_and_value( KnownHeaderName::AccessControlAllowMethods.into(), - allow_methods.1.iter().map(|el| el.strings().custom[0].as_bytes()), + Intersperse::new( + specifics.iter().map(|el| { + let [_, name] = el.strings().custom; + name.as_bytes() + }), + b",", + ), ))?; } else { } - if let Some(elem) = allow_origin { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( - KnownHeaderName::AccessControlAllowOrigin.into(), - [elem.as_bytes()], + Ok(()) + } + + #[inline] + fn apply_allow_origin(origin: &[u8], headers: &mut Headers) -> crate::Result<()> { + headers.push_from_iter(Header::from_name_and_value( + KnownHeaderName::AccessControlAllowOrigin.into(), + [origin], + ))?; + Ok(()) + } + + #[inline] + fn apply_expose_headers( + (is_wildcard, specifics): &ExposeHeaders, + headers: &mut Headers, + ) -> crate::Result<()> { + if *is_wildcard { + headers.push_from_iter(Header::from_name_and_value( + KnownHeaderName::AccessControlExposeHeaders.into(), + ["*".as_bytes()], ))?; - } - if !expose_headers.is_empty() { - res.rrd.headers_mut().push_from_iter(Header::from_name_and_value( - KnownHeaderName::AccessControlAllowHeaders.into(), - expose_headers.iter().map(|el| el.as_bytes()), + } else if !specifics.is_empty() { + headers.push_from_iter(Header::from_name_and_value( + KnownHeaderName::AccessControlExposeHeaders.into(), + Intersperse::new(specifics.iter().map(|el| el.as_bytes()), b","), ))?; + } else { } + Ok(()) + } + + #[inline] + fn apply_max_age(max_age: Option, headers: &mut Headers) -> crate::Result<()> { if let Some(elem) = max_age { - res.rrd.headers_mut().push_from_fmt(Header::from_name_and_value( + headers.push_from_fmt(Header::from_name_and_value( KnownHeaderName::AccessControlMaxAge.into(), format_args!("{elem}"), ))?; } Ok(()) } + + #[inline] + async fn apply_normal_response( + &self, + origin: Option<&str>, + headers: &mut Headers, + ) -> crate::Result<()> { + let Self { + allow_credentials, + allow_headers: _, + allow_methods: _, + allow_origins: _, + expose_headers, + max_age: _, + } = self; + Self::apply_allow_credentials(*allow_credentials, headers)?; + if let Some(elem) = origin { + Self::apply_allow_origin(elem.as_bytes(), headers)?; + } + Self::apply_expose_headers(expose_headers, headers)?; + Ok(()) + } + + #[inline] + async fn apply_preflight_response( + &self, + evaluated_allow_headers: &[u8], + evaluated_allow_origin: &[u8], + headers: &mut Headers, + ) -> crate::Result<()> { + let Self { + allow_credentials, + allow_headers: _, + allow_methods, + allow_origins: _, + expose_headers: _, + max_age, + } = self; + Self::apply_allow_credentials(*allow_credentials, headers)?; + Self::apply_allow_headers(evaluated_allow_headers, headers)?; + Self::apply_allow_methods(allow_methods, headers)?; + Self::apply_allow_origin(evaluated_allow_origin, headers)?; + Self::apply_max_age(*max_age, headers)?; + Ok(()) + } + + #[inline] + fn extract_origin<'any>( + opt: Option>, + ) -> crate::Result> { + Ok(opt.ok_or_else(|| HttpError::MissingHeader(KnownHeaderName::Origin))?) + } + + #[inline] + fn manage_preflight_headers<'bytes>( + &self, + acrh: Header<'_, &'bytes [u8]>, + body: &mut Vector, + ) -> crate::Result<()> { + if self.allow_headers.0 { + body.extend_from_copyable_slice(acrh.value).map_err(crate::Error::from)?; + return Ok(()); + } + let mut uniques = HashSet::new(); + for sub_header in bytes_split1(acrh.value, b',') { + let _ = uniques.insert(sub_header.trim_ascii()); + } + let mut matched_headers: usize = 0; + for allow_header in self.allow_headers.1.iter() { + if uniques.contains(allow_header.as_bytes()) { + matched_headers = matched_headers.wrapping_add(1); + } + } + if matched_headers != uniques.len() { + return Err(HttpError::ForbiddenCorsHeader.into()); + } + let mut iter = uniques.iter(); + if let Some(elem) = iter.next() { + body.extend_from_copyable_slice(elem).map_err(crate::Error::from)?; + } + for elem in iter { + let slices = [",".as_bytes(), elem]; + let _ = body.extend_from_copyable_slices(slices).map_err(crate::Error::from)?; + } + Ok(()) + } + + #[inline] + fn manage_preflight_methods(&self, acrm: Header<'_, &[u8]>) -> crate::Result<()> { + if self.allow_methods.0 { + return Ok(()); + } + if !self.allow_methods.1.iter().any(|method| { + let strings = method.strings(); + let [a, b] = strings.custom; + strings.ident.as_bytes() == acrm.value + || a.as_bytes() == acrm.value + || b.as_bytes() == acrm.value + }) { + return Err(HttpError::ForbiddenCorsMethod.into()); + } + Ok(()) + } + + #[inline] + fn manage_preflight_origin( + &self, + body: &mut Vector, + origin: Header<'_, &[u8]>, + ) -> crate::Result<()> { + let actual_origin = if self.allow_origins.0 { + origin.value + } else if let Some(allowed_origin) = self.allowed_origin(origin.value) { + allowed_origin.as_bytes() + } else { + return Err(HttpError::ForbiddenCorsOrigin.into()); + }; + body.extend_from_copyable_slice(actual_origin).map_err(crate::Error::from)?; + Ok(()) + } +} + +impl Middleware for CorsMiddleware +where + E: From, +{ + type Aux = Option<&'static str>; + + #[inline] + fn aux(&self) -> Self::Aux { + None + } + + #[inline] + async fn req( + &self, + _: &mut CA, + mw_aux: &mut Self::Aux, + req: &mut Request, + _: &mut SA, + ) -> Result, E> { + let origin_opt = if req.method == Method::Options { + let [acrh_opt, acrm_opt, origin_opt] = req.rrd.headers.get_many_by_name([ + KnownHeaderName::AccessControlRequestHeaders.into(), + KnownHeaderName::AccessControlRequestMethod.into(), + KnownHeaderName::Origin.into(), + ]); + if let (Some(acrh), Some(acrm)) = (acrh_opt, acrm_opt) { + req.rrd.body.clear(); + self.manage_preflight_headers(acrh, &mut req.rrd.body)?; + self.manage_preflight_methods(acrm)?; + let idx = req.rrd.body.len(); + self.manage_preflight_origin(&mut req.rrd.body, Self::extract_origin(origin_opt)?)?; + let (headers, origin) = req.rrd.body.split_at_checked(idx).unwrap_or_default(); + req.rrd.headers.clear(); + self.apply_preflight_response(headers, origin, &mut req.rrd.headers).await?; + req.rrd.body.clear(); + return Ok(ControlFlow::Break(StatusCode::Ok)); + } else { + origin_opt + } + } else { + req.rrd.headers.get_by_name(KnownHeaderName::Origin.into()) + }; + if self.allow_origins.0 { + *mw_aux = Some("*"); + } else if let Some(origin) = self.allowed_origin(Self::extract_origin(origin_opt)?.value) { + *mw_aux = Some(origin); + } + Ok(ControlFlow::Continue(())) + } + + #[inline] + async fn res( + &self, + _: &mut CA, + mw_aux: &mut Self::Aux, + res: Response<&mut ReqResBuffer>, + _: &mut SA, + ) -> Result, E> { + self.apply_normal_response(*mw_aux, &mut res.rrd.headers).await?; + Ok(ControlFlow::Continue(())) + } } impl Default for CorsMiddleware { diff --git a/wtx/src/http/server_framework/middleware.rs b/wtx/src/http/server_framework/middleware.rs index 90f2eed1..5b3d7e4d 100644 --- a/wtx/src/http/server_framework/middleware.rs +++ b/wtx/src/http/server_framework/middleware.rs @@ -1,161 +1,32 @@ -use crate::{ - http::{ReqResBuffer, Request, Response}, - misc::{FnFut, FnFutWrapper}, -}; -use core::future::Future; +use crate::http::{ReqResBuffer, Request, Response, StatusCode}; +use core::{future::Future, ops::ControlFlow}; /// Request middleware -pub trait ReqMiddleware +pub trait Middleware where E: From, { - /// Modifies or halts requests. - fn apply_req_middleware( - &self, - conn_aux: &mut CA, - req: &mut Request, - stream_aux: &mut SA, - ) -> impl Future>; -} + /// Auxiliary structure + type Aux; -impl ReqMiddleware for &T -where - E: From, - T: for<'any> FnFut< - (&'any mut CA, &'any mut SA, &'any mut Request), - Result = Result<(), E>, - >, -{ - #[inline] - async fn apply_req_middleware( - &self, - conn_aux: &mut CA, - req: &mut Request, - stream_aux: &mut SA, - ) -> Result<(), E> { - self.call((conn_aux, stream_aux, req)).await?; - Ok(()) - } -} + /// Auxiliary structure + fn aux(&self) -> Self::Aux; -impl ReqMiddleware for [T] -where - E: From, - T: for<'any> FnFut< - (&'any mut CA, &'any mut SA, &'any mut Request), - Result = Result<(), E>, - >, -{ - #[inline] - async fn apply_req_middleware( - &self, - conn_aux: &mut CA, - req: &mut Request, - stream_aux: &mut SA, - ) -> Result<(), E> { - for elem in self { - elem.call((conn_aux, stream_aux, req)).await?; - } - Ok(()) - } -} - -impl ReqMiddleware - for FnFutWrapper<(&mut CA, &mut SA, &mut Request), F> -where - F: for<'any> FnFut< - (&'any mut CA, &'any mut SA, &'any mut Request), - Result = Result<(), E>, - >, - E: From, -{ - #[inline] - async fn apply_req_middleware( + /// Modifies or halts requests. + fn req( &self, conn_aux: &mut CA, + mw_aux: &mut Self::Aux, req: &mut Request, stream_aux: &mut SA, - ) -> Result<(), E> { - self.0.call((conn_aux, stream_aux, req)).await?; - Ok(()) - } -} + ) -> impl Future, E>>; -/// Response middleware -pub trait ResMiddleware -where - E: From, -{ /// Modifies or halts responses. - fn apply_res_middleware( - &self, - conn_aux: &mut CA, - res: Response<&mut ReqResBuffer>, - stream_aux: &mut SA, - ) -> impl Future>; -} - -impl ResMiddleware for &T -where - E: From, - T: for<'any> FnFut< - (&'any mut CA, &'any mut SA, Response<&'any mut ReqResBuffer>), - Result = Result<(), E>, - >, -{ - #[inline] - async fn apply_res_middleware( - &self, - conn_aux: &mut CA, - res: Response<&mut ReqResBuffer>, - stream_aux: &mut SA, - ) -> Result<(), E> { - self.call((conn_aux, stream_aux, res)).await?; - Ok(()) - } -} - -impl ResMiddleware for [T] -where - E: From, - T: for<'any> FnFut< - (&'any mut CA, &'any mut SA, Response<&'any mut ReqResBuffer>), - Result = Result<(), E>, - >, -{ - #[inline] - async fn apply_res_middleware( - &self, - conn_aux: &mut CA, - res: Response<&mut ReqResBuffer>, - stream_aux: &mut SA, - ) -> Result<(), E> { - for elem in self { - let local_res = - Response { rrd: &mut *res.rrd, status_code: res.status_code, version: res.version }; - elem.call((conn_aux, stream_aux, local_res)).await?; - } - Ok(()) - } -} - -impl ResMiddleware - for FnFutWrapper<(&mut CA, &mut SA, Response<&mut ReqResBuffer>), F> -where - F: for<'any> FnFut< - (&'any mut CA, &'any mut SA, Response<&'any mut ReqResBuffer>), - Result = Result<(), E>, - >, - E: From, -{ - #[inline] - async fn apply_res_middleware( + fn res( &self, conn_aux: &mut CA, + mw_aux: &mut Self::Aux, res: Response<&mut ReqResBuffer>, stream_aux: &mut SA, - ) -> Result<(), E> { - self.0.call((conn_aux, stream_aux, res)).await?; - Ok(()) - } + ) -> impl Future, E>>; } diff --git a/wtx/src/http/server_framework/router.rs b/wtx/src/http/server_framework/router.rs index 758e2ee9..1bee0964 100644 --- a/wtx/src/http/server_framework/router.rs +++ b/wtx/src/http/server_framework/router.rs @@ -1,38 +1,36 @@ use crate::{ http::{ - server_framework::{PathManagement, ReqMiddleware, ResMiddleware}, + server_framework::{Middleware, PathManagement}, ReqResBuffer, Request, Response, StatusCode, }, misc::{ArrayVector, Vector}, }; -use core::marker::PhantomData; +use core::{marker::PhantomData, ops::ControlFlow}; /// Redirects requests to specific asynchronous functions based on the set of inner URIs. #[derive(Debug)] -pub struct Router { +pub struct Router { + pub(crate) middlewares: M, pub(crate) paths: P, pub(crate) phantom: PhantomData<(CA, E, SA)>, - pub(crate) req_middlewares: REQM, - pub(crate) res_middlewares: RESM, #[cfg(feature = "matchit")] pub(crate) router: matchit::Router>, } -impl Router +impl Router where E: From, P: PathManagement, { /// Creates a new instance with paths and middlewares. #[inline] - pub fn new(paths: P, req_middlewares: REQM, res_middlewares: RESM) -> crate::Result { + pub fn new(paths: P, middlewares: M) -> crate::Result { #[cfg(feature = "matchit")] let router = Self::router(&paths)?; Ok(Self { + middlewares, paths, phantom: PhantomData, - req_middlewares, - res_middlewares, #[cfg(feature = "matchit")] router, }) @@ -54,7 +52,7 @@ where } } -impl Router +impl Router where E: From, P: PathManagement, @@ -65,22 +63,20 @@ where #[cfg(feature = "matchit")] let router = Self::router(&paths)?; Ok(Self { + middlewares: (), paths, phantom: PhantomData, - req_middlewares: (), - res_middlewares: (), #[cfg(feature = "matchit")] router, }) } } -impl PathManagement for Router +impl PathManagement for Router where E: From, + M: Middleware, P: PathManagement, - REQM: ReqMiddleware, - RESM: ResMiddleware, { const IS_ROUTER: bool = true; @@ -92,10 +88,15 @@ where req: &mut Request, stream_aux: &mut SA, ) -> Result { - self.req_middlewares.apply_req_middleware(conn_aux, req, stream_aux).await?; + let mw_aux = &mut self.middlewares.aux(); + if let ControlFlow::Break(el) = self.middlewares.req(conn_aux, mw_aux, req, stream_aux).await? { + return Ok(el); + } let status_code = self.paths.manage_path(conn_aux, path_defs, req, stream_aux).await?; let res = Response { rrd: &mut req.rrd, status_code, version: req.version }; - self.res_middlewares.apply_res_middleware(conn_aux, res, stream_aux).await?; + if let ControlFlow::Break(el) = self.middlewares.res(conn_aux, mw_aux, res, stream_aux).await? { + return Ok(el); + } Ok(status_code) } diff --git a/wtx/src/http/server_framework/server_framework_builder.rs b/wtx/src/http/server_framework/server_framework_builder.rs index 3a740ca1..f2c806e0 100644 --- a/wtx/src/http/server_framework/server_framework_builder.rs +++ b/wtx/src/http/server_framework/server_framework_builder.rs @@ -6,19 +6,19 @@ use alloc::sync::Arc; /// Server #[derive(Debug)] -pub struct ServerFrameworkBuilder { +pub struct ServerFrameworkBuilder { cp: ConnParams, - router: Arc>, + router: Arc>, } -impl ServerFrameworkBuilder +impl ServerFrameworkBuilder where CA: ConnAux, SA: StreamAux, { /// New instance with default connection values. #[inline] - pub fn new(router: Router) -> Self { + pub fn new(router: Router) -> Self { Self { cp: ConnParams::default(), router: Arc::new(router) } } @@ -28,7 +28,7 @@ where self, ca_cb: CAC, ra_cb: SAC, - ) -> ServerFramework + ) -> ServerFramework where CAC: Fn() -> CA::Init, SAC: Fn() -> SA::Init, @@ -38,9 +38,7 @@ where /// Fills the initialization structures for all auxiliaries with default values. #[inline] - pub fn with_dflt_aux( - self, - ) -> ServerFramework CA::Init, E, P, REQM, RESM, SA, fn() -> SA::Init> + pub fn with_dflt_aux(self) -> ServerFramework CA::Init, E, M, P, SA, fn() -> SA::Init> where CA::Init: Default, SA::Init: Default, @@ -55,24 +53,21 @@ where } } -impl ServerFrameworkBuilder<(), E, P, REQM, RESM, ()> { +impl ServerFrameworkBuilder<(), E, M, P, ()> { /// Build without state #[inline] - pub fn without_aux(self) -> ServerFramework<(), fn() -> (), E, P, REQM, RESM, (), fn() -> ()> { + pub fn without_aux(self) -> ServerFramework<(), fn() -> (), E, M, P, (), fn() -> ()> { ServerFramework { _ca_cb: nothing, _cp: self.cp, _sa_cb: nothing, _router: self.router } } } -impl ServerFrameworkBuilder +impl ServerFrameworkBuilder where CA: ConnAux, { /// Sets the initializing strut for `CAA` and sets the request auxiliary to `()`. #[inline] - pub fn with_conn_aux( - self, - ca_cb: CAC, - ) -> ServerFramework ()> + pub fn with_conn_aux(self, ca_cb: CAC) -> ServerFramework ()> where CAC: Fn() -> CA::Init, { @@ -80,16 +75,13 @@ where } } -impl ServerFrameworkBuilder<(), E, P, REQM, RESM, SA> +impl ServerFrameworkBuilder<(), E, M, P, SA> where SA: StreamAux, { /// Sets the initializing strut for `SA` and sets the connection auxiliary to `()`. #[inline] - pub fn with_req_aux( - self, - ra_cb: SAC, - ) -> ServerFramework<(), fn() -> (), E, P, REQM, RESM, SA, SAC> + pub fn with_req_aux(self, ra_cb: SAC) -> ServerFramework<(), fn() -> (), E, M, P, SA, SAC> where SAC: Fn() -> SA::Init, { diff --git a/wtx/src/http/server_framework/tokio.rs b/wtx/src/http/server_framework/tokio.rs index 655a8441..1a110abe 100644 --- a/wtx/src/http/server_framework/tokio.rs +++ b/wtx/src/http/server_framework/tokio.rs @@ -1,30 +1,28 @@ use crate::{ http::{ - server_framework::{ - ConnAux, PathManagement, ReqMiddleware, ResMiddleware, Router, ServerFramework, StreamAux, - }, + server_framework::{ConnAux, Middleware, PathManagement, Router, ServerFramework, StreamAux}, ManualServerStreamTokio, OptionedServer, ReqResBuffer, StreamMode, }, http2::Http2Buffer, misc::Rng, }; -use std::sync::Arc; +use alloc::sync::Arc; use tokio::net::tcp::OwnedWriteHalf; -impl ServerFramework +impl ServerFramework where CA: Clone + ConnAux + Send + 'static, CAC: Clone + Fn() -> CA::Init + Send + 'static, E: From + Send + 'static, + M: Middleware + Send + 'static, + M::Aux: Send + 'static, P: PathManagement + Send + 'static, - REQM: ReqMiddleware + Send + 'static, - RESM: ResMiddleware + Send + 'static, SA: StreamAux + Send + 'static, SAC: Clone + Fn() -> SA::Init + Send + 'static, - Arc>: Send, - Router: Send, - for<'any> &'any Arc>: Send, - for<'any> &'any Router: Send, + Arc>: Send, + Router: Send, + for<'any> &'any Arc>: Send, + for<'any> &'any Router: Send, { /// Starts listening to incoming requests based on the given `host`. #[inline] @@ -91,7 +89,7 @@ where _: ManualServerStreamTokio< CA, Http2Buffer, - (impl Fn() -> SA::Init, Arc>), + (impl Fn() -> SA::Init, Arc>), (), OwnedWriteHalf, >, @@ -105,7 +103,7 @@ where _: ManualServerStreamTokio< CA, Http2Buffer, - (impl Fn() -> SA::Init, Arc>), + (impl Fn() -> SA::Init, Arc>), (), tokio::io::WriteHalf>, >, diff --git a/wtx/src/http/session.rs b/wtx/src/http/session.rs index c8f617e2..2f641516 100644 --- a/wtx/src/http/session.rs +++ b/wtx/src/http/session.rs @@ -2,169 +2,48 @@ mod session_builder; mod session_decoder; mod session_enforcer; mod session_error; +mod session_manager; mod session_state; mod session_store; -use crate::{ - http::{ - cookie::{encrypt, CookieGeneric}, - server_framework::ConnAux, - Header, KnownHeaderName, ReqResBuffer, ReqResDataMut, - }, - misc::{GenericTime, Lease, LeaseMut, Lock, Rng, Vector}, -}; -use chrono::DateTime; -use core::marker::PhantomData; -use serde::Serialize; +use crate::http::server_framework::ConnAux; pub use session_builder::SessionBuilder; pub use session_decoder::SessionDecoder; pub use session_enforcer::SessionEnforcer; pub use session_error::SessionError; +pub use session_manager::{SessionManager, SessionManagerInner}; pub use session_state::SessionState; pub use session_store::SessionStore; type SessionId = [u8; 16]; -type SessionKey = [u8; 16]; +type SessionKey = [u8; 32]; /// [`Session`] backed by `tokio` #[cfg(feature = "tokio")] -pub type SessionTokio = - Session>>, SS>; +pub type SessionTokio = + Session>>, S>; /// Allows the management of state across requests within a connection. #[derive(Clone, Debug)] -pub struct Session { - /// Content - pub content: L, +pub struct Session { + /// Manager + pub manager: SessionManager, /// Store - pub store: SS, + pub store: S, } -impl Session -where - E: From, - L: Lock>, - SS: SessionStore, -{ +impl Session { /// Allows the specification of custom parameters. #[inline] - pub fn builder(store: SS) -> SessionBuilder { + pub fn builder(store: S) -> SessionBuilder { SessionBuilder::new(store) } - - /// Removes the session from the store and also modifies headers. - #[inline] - pub async fn delete_session_cookie(&mut self, rrd: &mut RRD) -> Result<(), E> - where - RRD: ReqResDataMut, - { - let SessionInner { cookie_def, phantom: _, key: _, state } = &mut *self.content.lock().await; - if let Some(elem) = state.take() { - self.store.delete(&elem.id).await?; - } - cookie_def.expire = Some(DateTime::from_timestamp_nanos(0)); - cookie_def.value.clear(); - rrd.headers_mut().push_from_fmt(Header::from_name_and_value( - KnownHeaderName::SetCookie.into(), - format_args!("{cookie_def}"), - ))?; - Ok(()) - } - - /// Saves the session in the store and also modifies headers. - #[inline] - pub async fn set_session_cookie( - &mut self, - custom_state: CS, - rng: RNG, - rrd: &mut RRD, - ) -> Result<(), E> - where - CS: Serialize, - RNG: Rng, - RRD: LeaseMut, - { - let SessionInner { cookie_def, phantom: _, key, state } = &mut *self.content.lock().await; - cookie_def.value.clear(); - let id = GenericTime::timestamp().map_err(Into::into)?.as_nanos().to_be_bytes(); - let local_state = if let Some(elem) = cookie_def.expire { - let local_state = SessionState { custom_state, expire: Some(elem), id }; - self.store.create(&local_state).await?; - local_state - } else { - SessionState { custom_state, expire: None, id } - }; - let idx = rrd.lease().body.len(); - serde_json::to_writer(&mut rrd.lease_mut().body, &local_state).map_err(Into::into)?; - *state = Some(local_state); - let rslt = encrypt( - &mut cookie_def.value, - key, - (cookie_def.name, rrd.lease().body.get(idx..).unwrap_or_default()), - rng, - ); - rrd.lease_mut().body.truncate(idx); - rslt?; - rrd.lease_mut().headers.push_from_fmt(Header::from_name_and_value( - KnownHeaderName::SetCookie.into(), - format_args!("{}", &cookie_def), - ))?; - Ok(()) - } } -impl ConnAux for Session -where - L: Lock>, -{ - type Init = Session; +impl ConnAux for Session { + type Init = Self; #[inline] fn conn_aux(init: Self::Init) -> crate::Result { Ok(init) } } - -impl Lease for Session { - #[inline] - fn lease(&self) -> &Self { - self - } -} - -impl LeaseMut for Session { - #[inline] - fn lease_mut(&mut self) -> &mut Self { - self - } -} - -impl Lease> for (Session, A) { - #[inline] - fn lease(&self) -> &Session { - &self.0 - } -} - -impl LeaseMut> for (Session, A) { - #[inline] - fn lease_mut(&mut self) -> &mut Session { - &mut self.0 - } -} - -/// Allows the management of state across requests within a connection. -#[derive(Debug)] -pub struct SessionInner { - cookie_def: CookieGeneric<&'static [u8], Vector>, - key: SessionKey, - phantom: PhantomData, - state: Option>, -} - -impl SessionInner { - /// State saved in the store or in the current session. - #[inline] - pub fn state(&self) -> &Option> { - &self.state - } -} diff --git a/wtx/src/http/session/session_builder.rs b/wtx/src/http/session/session_builder.rs index bd788efe..95088506 100644 --- a/wtx/src/http/session/session_builder.rs +++ b/wtx/src/http/session/session_builder.rs @@ -1,8 +1,8 @@ use crate::{ http::{ cookie::{CookieGeneric, SameSite}, - session::{SessionInner, SessionKey}, - Session, SessionStore, + session::{SessionKey, SessionManagerInner}, + Session, SessionManager, SessionStore, }, misc::{sleep, Lock, Rng, Vector}, }; @@ -46,17 +46,17 @@ impl SessionBuilder { /// If the backing store already has a system that automatically removes outdated sessions like /// SQL triggers, then the [`Future`] can be ignored. #[inline] - pub fn build_generating_key( + pub fn build_generating_key( self, rng: &mut RNG, - ) -> (impl Future>, Session) + ) -> (impl Future>, Session) where E: From, - L: Lock>, + I: Lock>, RNG: Rng, SS: Clone + SessionStore, { - let mut key = [0; 16]; + let mut key = [0; 32]; rng.fill_slice(&mut key); Self::build_with_key(self, key) } @@ -69,13 +69,13 @@ impl SessionBuilder { /// If the backing store already has a system that automatically removes outdated sessions like /// SQL triggers, then the [`Future`] can be ignored. #[inline] - pub fn build_with_key( + pub fn build_with_key( self, key: SessionKey, - ) -> (impl Future>, Session) + ) -> (impl Future>, Session) where E: From, - L: Lock>, + I: Lock>, SS: Clone + SessionStore, { let Self { cookie_def, inspection_interval, store } = self; @@ -88,7 +88,9 @@ impl SessionBuilder { } }, Session { - content: L::new(SessionInner { cookie_def, phantom: PhantomData, key, state: None }), + manager: SessionManager { + inner: I::new(SessionManagerInner { cookie_def, phantom: PhantomData, key, state: None }), + }, store, }, ) diff --git a/wtx/src/http/session/session_decoder.rs b/wtx/src/http/session/session_decoder.rs index b6f1b859..a3f5ac64 100644 --- a/wtx/src/http/session/session_decoder.rs +++ b/wtx/src/http/session/session_decoder.rs @@ -1,14 +1,15 @@ use crate::{ http::{ cookie::{decrypt, CookieBytes}, - server_framework::ReqMiddleware, - KnownHeaderName, ReqResBuffer, Request, Session, SessionError, SessionInner, SessionState, - SessionStore, + server_framework::Middleware, + KnownHeaderName, ReqResBuffer, Request, Response, Session, SessionError, SessionManagerInner, + SessionState, SessionStore, StatusCode, }, misc::{GenericTime, LeaseMut, Lock}, + pool::{Pool, ResourceManager}, }; use chrono::DateTime; -use core::marker::PhantomData; +use core::ops::ControlFlow; use serde::de::DeserializeOwned; /// Decodes cookies received from requests and manages them. @@ -16,45 +17,55 @@ use serde::de::DeserializeOwned; /// The use of this structure without [`Session`] or used after the applicability of [`Session`] /// is a NO-OP. #[derive(Debug)] -pub struct SessionDecoder { - phantom: PhantomData<(L, SS)>, +pub struct SessionDecoder { + session: Session, } -impl SessionDecoder { +impl SessionDecoder { /// New instance #[inline] - pub fn new() -> Self { - Self { phantom: PhantomData } + pub fn new(session: Session) -> Self { + Self { session } } } -impl ReqMiddleware for SessionDecoder +impl Middleware for SessionDecoder where - CA: LeaseMut>, CS: DeserializeOwned + PartialEq, E: From, - L: Lock>, - SS: SessionStore, + I: Lock>, + S: Pool, + for<'any> S::GetElem<'any>: LeaseMut, + RM: ResourceManager, + RM::Resource: SessionStore, { + type Aux = (); + + #[inline] + fn aux(&self) -> Self::Aux { + () + } + #[inline] - async fn apply_req_middleware( + async fn req( &self, - conn_aux: &mut CA, + _: &mut CA, + _: &mut Self::Aux, req: &mut Request, _: &mut SA, - ) -> Result<(), E> { - let Session { content, store } = conn_aux.lease_mut(); - let SessionInner { cookie_def, phantom: _, key, state } = &mut *content.lock().await; + ) -> Result, E> { + let SessionManagerInner { cookie_def, key, state, .. } = + &mut *self.session.manager.inner.lock().await; if let Some(elem) = state { if let Some(expire) = &elem.expire { let millis = i64::try_from(GenericTime::timestamp()?.as_millis()).unwrap_or_default(); let date_time = DateTime::from_timestamp_millis(millis).unwrap_or_default(); if expire >= &date_time { - let _rslt = store.delete(&elem.id).await; + let _rslt = self.session.store.get(&(), &()).await?.lease_mut().delete(&elem.id).await; return Err(crate::Error::from(SessionError::ExpiredSession).into()); } } - return Ok(()); + return Ok(ControlFlow::Continue(())); } let lease = req.rrd.lease_mut(); let (vector, headers) = (&mut lease.body, &mut lease.headers); @@ -79,24 +90,29 @@ where let rslt_des = serde_json::from_slice(&cookie_def.value).map_err(Into::into); cookie_def.value.clear(); let state_des: SessionState = rslt_des?; - let state_db_opt = store.read(&state_des.id).await?; + let state_db_opt = + self.session.store.get(&(), &()).await?.lease_mut().read(&state_des.id).await?; let Some(state_db) = state_db_opt else { return Err(crate::Error::from(SessionError::MissingStoredSession).into()); }; if state_db != state_des { - store.delete(&state_des.id).await?; + self.session.store.get(&(), &()).await?.lease_mut().delete(&state_des.id).await?; return Err(crate::Error::from(SessionError::InvalidStoredSession).into()); } *state = Some(state_des); break; } - Ok(()) + Ok(ControlFlow::Continue(())) } -} -impl Default for SessionDecoder { #[inline] - fn default() -> Self { - Self::new() + async fn res( + &self, + _: &mut CA, + _: &mut Self::Aux, + _: Response<&mut ReqResBuffer>, + _: &mut SA, + ) -> Result, E> { + Ok(ControlFlow::Continue(())) } } diff --git a/wtx/src/http/session/session_enforcer.rs b/wtx/src/http/session/session_enforcer.rs index a58cca1b..df0a2d3d 100644 --- a/wtx/src/http/session/session_enforcer.rs +++ b/wtx/src/http/session/session_enforcer.rs @@ -1,51 +1,68 @@ -use core::marker::PhantomData; - use crate::{ http::{ - server_framework::ReqMiddleware, ReqResBuffer, ReqResData, Request, Session, SessionError, - SessionInner, + server_framework::Middleware, ReqResBuffer, ReqResData, Request, Response, Session, + SessionError, SessionManagerInner, StatusCode, }, - misc::{Lease, Lock}, + misc::Lock, }; +use core::ops::ControlFlow; /// Enforces stored session in all requests. /// /// #[derive(Debug)] -pub struct SessionEnforcer { +pub struct SessionEnforcer { denied: [&'static str; N], - phantom: PhantomData<(L, SS)>, + session: Session, } -impl SessionEnforcer { +impl SessionEnforcer { /// Creates a new instance with paths that are not taken into consideration. #[inline] - pub fn new(denied: [&'static str; N]) -> Self { - Self { denied, phantom: PhantomData } + pub fn new(denied: [&'static str; N], session: Session) -> Self { + Self { denied, session } } } -impl ReqMiddleware for SessionEnforcer +impl Middleware for SessionEnforcer where - CA: Lease>, E: From, - L: Lock>, + I: Lock>, { + type Aux = (); + + #[inline] + fn aux(&self) -> Self::Aux { + () + } + #[inline] - async fn apply_req_middleware( + async fn req( &self, - conn_aux: &mut CA, + _: &mut CA, + _: &mut Self::Aux, req: &mut Request, _: &mut SA, - ) -> Result<(), E> { + ) -> Result, E> { let uri = req.rrd.uri(); let path = uri.path(); if self.denied.iter().all(|elem| *elem != path) { - return Ok(()); + return Ok(ControlFlow::Continue(())); } - if conn_aux.lease().content.lock().await.state().is_none() { + if self.session.manager.inner.lock().await.state().is_none() { return Err(crate::Error::from(SessionError::RequiredSessionInPath).into()); } - Ok(()) + Ok(ControlFlow::Continue(())) + } + + #[inline] + async fn res( + &self, + _: &mut CA, + _: &mut Self::Aux, + _: Response<&mut ReqResBuffer>, + _: &mut SA, + ) -> Result, E> { + Ok(ControlFlow::Continue(())) } } diff --git a/wtx/src/http/session/session_manager.rs b/wtx/src/http/session/session_manager.rs new file mode 100644 index 00000000..398f2abe --- /dev/null +++ b/wtx/src/http/session/session_manager.rs @@ -0,0 +1,111 @@ +use crate::{ + http::{ + cookie::{encrypt, CookieGeneric}, + session::SessionKey, + Header, KnownHeaderName, ReqResBuffer, ReqResDataMut, SessionState, SessionStore, + }, + misc::{GenericTime, Lease, LeaseMut, Lock, Rng, Vector}, +}; +use chrono::DateTime; +use core::marker::PhantomData; +use serde::Serialize; + +/// Manages sessions +#[derive(Clone, Debug)] +pub struct SessionManager { + /// Inner content + pub inner: I, +} + +impl SessionManager +where + E: From, + I: Lock>, +{ + /// Removes the session from the store and also modifies headers. + #[inline] + pub async fn delete_session_cookie( + &mut self, + rrd: &mut RRD, + store: &mut S, + ) -> Result<(), E> + where + RRD: ReqResDataMut, + S: SessionStore, + { + let SessionManagerInner { cookie_def, phantom: _, key: _, state } = + &mut *self.inner.lock().await; + if let Some(elem) = state.take() { + store.delete(&elem.id).await?; + } + cookie_def.expire = Some(DateTime::from_timestamp_nanos(0)); + cookie_def.value.clear(); + rrd.headers_mut().push_from_fmt(Header::from_name_and_value( + KnownHeaderName::SetCookie.into(), + format_args!("{cookie_def}"), + ))?; + Ok(()) + } + + /// Saves the session in the store and also modifies headers. + /// + /// The `rrd` body is used as a temporary buffer but no existing content is erased. + #[inline] + pub async fn set_session_cookie( + &mut self, + custom_state: CS, + rng: RNG, + rrd: &mut RRD, + store: &mut S, + ) -> Result<(), E> + where + CS: Serialize, + RNG: Rng, + RRD: LeaseMut, + S: SessionStore, + { + let SessionManagerInner { cookie_def, phantom: _, key, state } = &mut *self.inner.lock().await; + cookie_def.value.clear(); + let id = GenericTime::timestamp().map_err(Into::into)?.as_nanos().to_be_bytes(); + let local_state = if let Some(elem) = cookie_def.expire { + let local_state = SessionState { custom_state, expire: Some(elem), id }; + store.create(&local_state).await?; + local_state + } else { + SessionState { custom_state, expire: None, id } + }; + let idx = rrd.lease().body.len(); + serde_json::to_writer(&mut rrd.lease_mut().body, &local_state).map_err(Into::into)?; + *state = Some(local_state); + let rslt = encrypt( + &mut cookie_def.value, + key, + (cookie_def.name, rrd.lease().body.get(idx..).unwrap_or_default()), + rng, + ); + rrd.lease_mut().body.truncate(idx); + rslt?; + rrd.lease_mut().headers.push_from_fmt(Header::from_name_and_value( + KnownHeaderName::SetCookie.into(), + format_args!("{}", &cookie_def), + ))?; + Ok(()) + } +} + +/// Allows the management of state across requests within a connection. +#[derive(Debug)] +pub struct SessionManagerInner { + pub(crate) cookie_def: CookieGeneric<&'static [u8], Vector>, + pub(crate) key: SessionKey, + pub(crate) phantom: PhantomData, + pub(crate) state: Option>, +} + +impl SessionManagerInner { + /// State saved in the store or in the current session. + #[inline] + pub fn state(&self) -> &Option> { + &self.state + } +} diff --git a/wtx/src/http2/hpack_headers.rs b/wtx/src/http2/hpack_headers.rs index 1507ec48..c4897df2 100644 --- a/wtx/src/http2/hpack_headers.rs +++ b/wtx/src/http2/hpack_headers.rs @@ -1,8 +1,8 @@ -use crate::misc::{Block, BlocksQueue}; +use crate::misc::{Block, BlocksDeque}; #[derive(Debug)] pub(crate) struct HpackHeaders { - bq: BlocksQueue>, + bq: BlocksDeque>, max_bytes: usize, } @@ -12,7 +12,7 @@ where { #[inline] pub(crate) const fn new(max_bytes: usize) -> Self { - Self { bq: BlocksQueue::new(), max_bytes } + Self { bq: BlocksDeque::new(), max_bytes } } #[inline] @@ -63,7 +63,7 @@ where return Ok(()); } self.remove_until_max_bytes(local_len, cb); - self.bq.push_front( + self.bq.push_front_from_coyable_data( [name].into_iter().chain(iter), Metadata { is_sensitive, misc, name_len: name.len() }, )?; diff --git a/wtx/src/http2/huffman.rs b/wtx/src/http2/huffman.rs index 67eb0a81..a5f720cf 100644 --- a/wtx/src/http2/huffman.rs +++ b/wtx/src/http2/huffman.rs @@ -166,16 +166,17 @@ mod bench { } } -#[cfg(all(feature = "_proptest", test))] -mod proptest { +#[cfg(kani)] +mod kani { use crate::{ http::_HeaderValueBuffer, http2::huffman::{huffman_decode, huffman_encode}, misc::Vector, }; - #[test_strategy::proptest] + #[kani::proof] fn encode_and_decode(data: Vector) { + let data = kani::any(); let mut encoded = Vector::with_capacity(data.len()).unwrap(); huffman_encode(&data, &mut encoded).unwrap(); let mut decoded = _HeaderValueBuffer::default(); diff --git a/wtx/src/http2/index_map.rs b/wtx/src/http2/index_map.rs index b05e3e21..e17e0c68 100644 --- a/wtx/src/http2/index_map.rs +++ b/wtx/src/http2/index_map.rs @@ -1,4 +1,4 @@ -use alloc::collections::VecDeque; +use crate::misc::Deque; use core::{borrow::Borrow, hash::Hash}; use hashbrown::HashMap; @@ -6,7 +6,7 @@ use hashbrown::HashMap; pub(crate) struct IndexMap { cursor: usize, elements: HashMap, - keys: VecDeque, + keys: Deque, } impl IndexMap @@ -15,7 +15,7 @@ where { #[inline] pub(crate) fn new() -> Self { - Self { cursor: 0, elements: HashMap::new(), keys: VecDeque::new() } + Self { cursor: 0, elements: HashMap::new(), keys: Deque::new() } } #[inline] @@ -35,7 +35,7 @@ where if self.cursor >= self.elements.len() { return None; } - let key = self.keys.front()?; + let key = self.keys.get(0)?; let value = self.elements.get_mut(key)?; Some(value) } @@ -54,7 +54,7 @@ where pub(crate) fn push_back(&mut self, key: K, value: V) -> Option { let prev_value = self.elements.insert(key.clone(), value); if prev_value.is_none() { - self.keys.push_back(key); + self.keys.push_back(key).ok()?; } prev_value } diff --git a/wtx/src/http2/misc.rs b/wtx/src/http2/misc.rs index f7fac57f..d17836c0 100644 --- a/wtx/src/http2/misc.rs +++ b/wtx/src/http2/misc.rs @@ -17,7 +17,7 @@ use crate::{ }, misc::{ AtomicWaker, LeaseMut, Lock, PartitionedFilledBuffer, RefCounter, StreamReader, StreamWriter, - Usize, _read_until, + Usize, _read_header, _read_payload, }, }; use core::{ @@ -191,9 +191,10 @@ where let mut fut = pin!(async move { for _ in 0.._max_frames_mismatches!() { pfb._clear_if_following_is_empty(); + pfb._reserve(9)?; let mut read = pfb._following_len(); let buffer = pfb._following_rest_mut(); - let array = _read_until::<9, _>(buffer, &mut read, 0, stream_reader).await?; + let array = _read_header::<0, 9, _>(buffer, &mut read, stream_reader).await?; let (fi_opt, data_len) = FrameInit::from_array(array); if data_len > max_frame_len { return Err(crate::Error::Http2ErrorGoAway( @@ -201,7 +202,7 @@ where Some(Http2Error::LargeArbitraryFrameLen), )); } - let frame_len = *Usize::from_u32(data_len.wrapping_add(9)); + let data_len_usize = *Usize::from_u32(data_len); let Some(fi) = fi_opt else { if IS_HEADER_BLOCK { return Err(protocol_err(Http2Error::UnexpectedContinuationFrame)); @@ -209,6 +210,7 @@ where if data_len > 32 { return Err(protocol_err(Http2Error::LargeIgnorableFrameLen)); } + let frame_len = data_len_usize.wrapping_add(9); let (antecedent_len, following_len) = if let Some(to_read) = frame_len.checked_sub(read) { stream_reader.read_skip(to_read).await?; (pfb._buffer().len(), 0) @@ -219,25 +221,7 @@ where continue; }; _trace!("Received frame: {fi:?}"); - let mut is_fulfilled = false; - pfb._reserve(*Usize::from(data_len))?; - for _ in 0..=data_len { - if read >= frame_len { - is_fulfilled = true; - break; - } - read = read.wrapping_add( - stream_reader.read(pfb._following_rest_mut().get_mut(read..).unwrap_or_default()).await?, - ); - } - if !is_fulfilled { - return Err(crate::Error::UnexpectedBufferState); - } - pfb._set_indices( - pfb._current_end_idx().wrapping_add(9), - *Usize::from(data_len), - read.wrapping_sub(frame_len), - )?; + _read_payload((9, data_len_usize), pfb, &mut read, stream_reader).await?; return Ok(fi); } Err(protocol_err(Http2Error::VeryLargeAmountOfFrameMismatches)) diff --git a/wtx/src/macros.rs b/wtx/src/macros.rs index dd5fa85e..bd26b1d8 100644 --- a/wtx/src/macros.rs +++ b/wtx/src/macros.rs @@ -307,10 +307,10 @@ macro_rules! _max_frames_mismatches { macro_rules! _simd { ( - 512 => $_512:expr, - 256 => $_256:expr, + fallback => $fallback:expr, 128 => $_128:expr, - _ => $fallback:expr $(,)? + 256 => $_256:expr, + 512 => $_512:expr $(,)? ) => {{ #[cfg(target_feature = "avx512f")] let rslt = $_512; @@ -342,6 +342,52 @@ macro_rules! _simd { }}; } +macro_rules! _simd_bytes { + ( + ($align:ident, $bytes:expr), + (|$bytes_ident_a:ident| $bytes_expr_a:expr, |$bytes_ident_b:ident| $bytes_expr_b:expr), + |$_16:ident| $_128:expr, + |$_32:ident| $_256:expr, + |$_64:ident| $_512:expr $(,)? + ) => {{ + // SAFETY: Changing a sequence of `u8` should be fine + let (_prefix, _chunks, _suffix) = unsafe { $bytes.$align() }; + _simd! { + fallback => { + let $bytes_ident_a = _prefix; + $bytes_expr_a; + } + 128 => { + let $bytes_ident_a = _prefix; + $bytes_expr_a; + let _: [[u8; 64]] = *_chunks; + let $_16 = _chunks; + $_128 + let $bytes_ident_b = _suffix; + $bytes_expr_b; + }, + 256 => { + let $bytes_ident_a = _prefix; + $bytes_expr_a; + let _: [[u8; 32]] = *_chunks; + let $_16 = _chunks; + $_128 + let $bytes_ident_b = _suffix; + $bytes_expr_b; + }, + 512 => { + let $bytes_ident_a = _prefix; + $bytes_expr_a; + let _: [[u8; 16]] = *_chunks; + let $_16 = _chunks; + $_128 + let $bytes_ident_b = _suffix; + $bytes_expr_b; + }, + } + }}; +} + macro_rules! _trace { ($($tt:tt)+) => { #[cfg(feature = "tracing")] diff --git a/wtx/src/misc.rs b/wtx/src/misc.rs index 7db9a1c6..8e5ac52b 100644 --- a/wtx/src/misc.rs +++ b/wtx/src/misc.rs @@ -3,10 +3,11 @@ mod array_chunks; mod array_string; mod array_vector; -mod blocks_queue; +mod blocks_deque; mod buffer_mode; mod bytes_fmt; mod connection_state; +mod deque; mod either; mod enum_var_strings; mod filled_buffer; @@ -15,6 +16,7 @@ mod fn_fut; mod from_radix_10; mod generic_time; mod incomplete_utf8_char; +mod interspace; mod iter_wrapper; mod lease; mod lock; @@ -23,7 +25,6 @@ mod noop_waker; mod optimization; mod partitioned_filled_buffer; mod query_writer; -mod queue; mod ref_counter; mod rng; mod role; @@ -44,11 +45,12 @@ pub use self::tokio_rustls::{TokioRustlsAcceptor, TokioRustlsConnector}; pub use array_chunks::{ArrayChunks, ArrayChunksMut}; pub use array_string::{ArrayString, ArrayStringError}; pub use array_vector::{ArrayVector, ArrayVectorError, IntoIter}; -pub use blocks_queue::{Block, BlocksQueue, BlocksQueueError}; +pub use blocks_deque::{Block, BlocksDeque, BlocksDequeBuilder, BlocksDequeError}; pub use buffer_mode::BufferMode; pub use bytes_fmt::BytesFmt; pub use connection_state::ConnectionState; use core::{any::type_name, fmt::Write, ops::Range, time::Duration}; +pub use deque::{Deque, DequeueError}; pub use either::Either; pub use enum_var_strings::EnumVarStrings; pub use filled_buffer_writer::FilledBufferWriter; @@ -56,13 +58,13 @@ pub use fn_fut::*; pub use from_radix_10::{FromRadix10, FromRadix10Error}; pub use generic_time::GenericTime; pub use incomplete_utf8_char::{CompletionErr, IncompleteUtf8Char}; +pub use interspace::Intersperse; pub use iter_wrapper::IterWrapper; pub use lease::{Lease, LeaseMut}; pub use lock::{Lock, SyncLock}; pub use noop_waker::NOOP_WAKER; pub use optimization::*; pub use query_writer::QueryWriter; -pub use queue::{Queue, QueueError}; pub use ref_counter::RefCounter; pub use rng::*; pub use role::Role; @@ -85,27 +87,28 @@ pub(crate) use { /// Hashes a password using the `argon2` algorithm. #[cfg(feature = "argon2")] #[inline] -pub fn argon2_pwd(pwd: &[u8], salt: &[u8]) -> crate::Result<[u8; 32]> { +pub fn argon2_pwd( + blocks: &mut Vector, + pwd: &[u8], + salt: &[u8], +) -> crate::Result<[u8; N]> { use argon2::{Algorithm, Argon2, Params, Version}; - const OUT_LEN: usize = 32; - const PARAMS: Params = { + let params = const { + let output_len = Some(N); let Ok(elem) = Params::new( Params::DEFAULT_M_COST, Params::DEFAULT_T_COST, Params::DEFAULT_P_COST, - Some(OUT_LEN), + output_len, ) else { panic!(); }; elem }; - let mut out = [0; OUT_LEN]; - Argon2::new(Algorithm::Argon2id, Version::V0x13, PARAMS).hash_password_into_with_memory( - pwd, - salt, - &mut out, - &mut [argon2::Block::new(); PARAMS.block_count()], - )?; + blocks.expand(BufferMode::Len(params.block_count()), argon2::Block::new())?; + let mut out = [0; N]; + Argon2::new(Algorithm::Argon2id, Version::V0x13, params) + .hash_password_into_with_memory(pwd, salt, &mut out, blocks)?; Ok(out) } @@ -119,6 +122,7 @@ pub fn into_rslt(opt: Option) -> crate::Result { /// Similar to `collect_seq` of `serde` but expects a `Result`. #[cfg(feature = "serde")] +#[inline] pub fn serde_collect_seq_rslt(ser: S, into_iter: I) -> Result<(), E> where E: From, @@ -161,6 +165,7 @@ pub async fn sleep(duration: Duration) -> crate::Result<()> { /// A tracing register with optioned parameters. #[cfg(feature = "_tracing-tree")] +#[inline] pub fn tracing_tree_init( fallback_opt: Option<&str>, ) -> Result<(), tracing_subscriber::util::TryInitError> { @@ -249,62 +254,71 @@ pub(crate) fn _conservative_size_hint_len(size_hint: (usize, Option)) -> } } -#[inline] -pub(crate) fn _interspace( - write: &mut W, - mut iter: impl Iterator, - mut cb: impl for<'args> FnMut(&mut W, T) -> Result<(), E>, - mut interspace: impl FnMut(&mut W) -> Result<(), E>, -) -> Result<(), E> -where - E: From, - W: Write, -{ - if let Some(elem) = iter.next() { - cb(write, elem)?; - } - for elem in iter { - interspace(write)?; - cb(write, elem)?; - } - Ok(()) -} - #[cfg(feature = "std")] +#[inline] pub(crate) fn _number_or_available_parallelism(n: Option) -> crate::Result { Ok(if let Some(elem) = n { elem } else { usize::from(std::thread::available_parallelism()?) }) } #[cfg(feature = "foldhash")] +#[inline] pub(crate) fn _random_state(mut rng: impl Rng) -> foldhash::fast::FixedState { let [a, b, c, d, e, f, g, h] = rng.u8_8(); foldhash::fast::FixedState::with_seed(u64::from_ne_bytes([a, b, c, d, e, f, g, h])) } -pub(crate) async fn _read_until( +#[inline] +pub(crate) async fn _read_header( buffer: &mut [u8], read: &mut usize, - start: usize, stream_reader: &mut SR, ) -> crate::Result<[u8; LEN]> where [u8; LEN]: Default, SR: StreamReader, { - let until = start.wrapping_add(LEN); - for _ in 0..LEN { - let has_enough_data = *read >= until; - if has_enough_data { + loop { + let (lhs, rhs) = buffer.split_at_mut_checked(*read).unwrap_or_default(); + if let Some(slice) = lhs.get(BEGIN..BEGIN.wrapping_add(LEN)) { + return Ok(slice.try_into().unwrap_or_default()); + } + let local_read = stream_reader.read(rhs).await?; + if local_read == 0 { + return Err(crate::Error::ClosedConnection); + } + *read = read.wrapping_add(local_read); + } +} + +#[inline] +pub(crate) async fn _read_payload( + (header_len, payload_len): (usize, usize), + network_buffer: &mut PartitionedFilledBuffer, + read: &mut usize, + stream: &mut S, +) -> crate::Result<()> +where + S: StreamReader, +{ + let frame_len = header_len.wrapping_add(payload_len); + network_buffer._reserve(frame_len)?; + loop { + if *read >= frame_len { break; } - let actual_buffer = buffer.get_mut(*read..).unwrap_or_default(); - let local_read = stream_reader.read(actual_buffer).await?; + let local_buffer = network_buffer._following_rest_mut().get_mut(*read..).unwrap_or_default(); + let local_read = stream.read(local_buffer).await?; if local_read == 0 { - return Err(crate::Error::UnexpectedStreamReadEOF); + return Err(crate::Error::ClosedConnection); } *read = read.wrapping_add(local_read); } - Ok(buffer.get(start..until).and_then(|el| el.try_into().ok()).unwrap_or_else(_unlikely_dflt)) + network_buffer._set_indices( + network_buffer._current_end_idx().wrapping_add(header_len), + payload_len, + read.wrapping_sub(frame_len), + )?; + Ok(()) } #[cold] @@ -338,6 +352,7 @@ pub(crate) const fn _unreachable() -> ! { panic!("Entered in a branch that should be impossible, which is likely a programming error"); } +#[inline] pub(crate) fn _usize_range_from_u32_range(range: Range) -> Range { *Usize::from(range.start)..*Usize::from(range.end) } diff --git a/wtx/src/misc/blocks_queue.rs b/wtx/src/misc/blocks_deque.rs similarity index 67% rename from wtx/src/misc/blocks_queue.rs rename to wtx/src/misc/blocks_deque.rs index 961d9941..3bab8a7e 100644 --- a/wtx/src/misc/blocks_queue.rs +++ b/wtx/src/misc/blocks_deque.rs @@ -25,11 +25,16 @@ macro_rules! get_mut { } } +mod block; +mod blocks_deque_builder; +mod metadata; #[cfg(test)] mod tests; -use crate::misc::Queue; -use core::{ops::Range, ptr}; +use crate::misc::Deque; +pub use block::Block; +pub use blocks_deque_builder::BlocksDequeBuilder; +use core::ptr; /// [`Block`] composed by references. type BlockRef<'bq, D, M> = Block<&'bq [D], &'bq M>; @@ -38,7 +43,7 @@ type BlockMut<'bq, D, M> = Block<&'bq mut [D], &'bq mut M>; /// Errors of [`BlocksQueue`]. #[derive(Debug)] -pub enum BlocksQueueError { +pub enum BlocksDequeError { #[doc = doc_single_elem_cap_overflow!()] PushOverflow, #[doc = doc_reserve_overflow!()] @@ -47,40 +52,39 @@ pub enum BlocksQueueError { WithCapacityOverflow, } -/// A circular buffer where elements are added in one-way blocks that will never intersect -/// boundaries. +/// A circular buffer where elements are added in blocks that will never intersect boundaries. #[derive(Debug)] -pub struct BlocksQueue { - data: Queue, - metadata: Queue>, +pub struct BlocksDeque { + data: Deque, + metadata: Deque>, } -impl BlocksQueue { +impl BlocksDeque { /// Creates a new empty instance. #[inline] pub const fn new() -> Self { - Self { data: Queue::new(), metadata: Queue::new() } + Self { data: Deque::new(), metadata: Deque::new() } } /// Constructs a new, empty instance with at least the specified capacity. #[inline] - pub fn with_capacity(blocks: usize, elements: usize) -> Result { + pub fn with_capacity(blocks: usize, elements: usize) -> Result { Ok(Self { - data: Queue::with_capacity(elements) - .map_err(|_err| BlocksQueueError::WithCapacityOverflow)?, - metadata: Queue::with_capacity(blocks) - .map_err(|_err| BlocksQueueError::WithCapacityOverflow)?, + data: Deque::with_capacity(elements) + .map_err(|_err| BlocksDequeError::WithCapacityOverflow)?, + metadata: Deque::with_capacity(blocks) + .map_err(|_err| BlocksDequeError::WithCapacityOverflow)?, }) } /// Constructs a new, empty instance with the exact specified capacity. #[inline] - pub fn with_exact_capacity(blocks: usize, elements: usize) -> Result { + pub fn with_exact_capacity(blocks: usize, elements: usize) -> Result { Ok(Self { - data: Queue::with_exact_capacity(elements) - .map_err(|_err| BlocksQueueError::WithCapacityOverflow)?, - metadata: Queue::with_exact_capacity(blocks) - .map_err(|_err| BlocksQueueError::WithCapacityOverflow)?, + data: Deque::with_exact_capacity(elements) + .map_err(|_err| BlocksDequeError::WithCapacityOverflow)?, + metadata: Deque::with_exact_capacity(blocks) + .map_err(|_err| BlocksDequeError::WithCapacityOverflow)?, }) } @@ -102,6 +106,18 @@ impl BlocksQueue { self.metadata.len() } + /// See [`BlocksDequeBuilder`]. + #[inline] + pub fn builder_back(&mut self) -> BlocksDequeBuilder<'_, D, M, true> { + BlocksDequeBuilder::new(self) + } + + /// See [`BlocksDequeBuilder`]. + #[inline] + pub fn builder_front(&mut self) -> BlocksDequeBuilder<'_, D, M, false> { + BlocksDequeBuilder::new(self) + } + /// Clears the queue, removing all values. #[inline] pub fn clear(&mut self) { @@ -151,12 +167,6 @@ impl BlocksQueue { .map(move |elem| do_get!(BlockMut, elem, data.as_ptr_mut(), slice_from_raw_parts_mut, &mut)) } - /// Returns the last block. - #[inline] - pub fn last(&self) -> Option> { - self.get(self.data.len().checked_sub(1)?) - } - /// Removes the last element from the queue and returns it, or `None` if it is empty. #[inline] pub fn pop_back(&mut self) -> Option { @@ -173,9 +183,37 @@ impl BlocksQueue { Some(metadata.misc) } - /// Prepends an block to the queue. + /// Appends a block to the end of the queue. #[inline] - pub fn push_front<'data, I>(&mut self, data: I, misc: M) -> Result<(), BlocksQueueError> + pub fn push_back_from_copyable_data<'data, I>( + &mut self, + data: I, + misc: M, + ) -> Result<(), BlocksDequeError> + where + D: Copy + 'data, + I: IntoIterator, + I::IntoIter: Clone, + { + let total_data_len = self + .data + .extend_back_from_copyable_slices(data) + .map_err(|_err| BlocksDequeError::PushOverflow)?; + let begin = self.data.tail().wrapping_sub(total_data_len); + self + .metadata + .push_back(metadata::Metadata { begin, len: total_data_len, misc }) + .map_err(|_err| BlocksDequeError::PushOverflow)?; + Ok(()) + } + + /// Prepends a block to the queue. + #[inline] + pub fn push_front_from_coyable_data<'data, I>( + &mut self, + data: I, + misc: M, + ) -> Result<(), BlocksDequeError> where D: Copy + 'data, I: IntoIterator, @@ -184,24 +222,25 @@ impl BlocksQueue { let (total_data_len, head_shift) = self .data .extend_front_from_copyable_slices(data) - .map_err(|_err| BlocksQueueError::PushOverflow)?; + .map_err(|_err| BlocksDequeError::PushOverflow)?; self .metadata - .push_front(Metadata { begin: self.data.head(), len: total_data_len, misc }) - .map_err(|_err| BlocksQueueError::PushOverflow)?; + .push_front(metadata::Metadata { begin: self.data.head(), len: total_data_len, misc }) + .map_err(|_err| BlocksDequeError::PushOverflow)?; self.adjust_metadata(head_shift, 1); Ok(()) } /// Reserves capacity for at least additional more elements to be inserted in the given queue. #[inline(always)] - pub fn reserve_front(&mut self, blocks: usize, elements: usize) -> Result<(), BlocksQueueError> { - let _ = self.metadata.reserve_front(blocks).map_err(|_er| BlocksQueueError::ReserveOverflow)?; - let n = self.data.reserve_front(elements).map_err(|_err| BlocksQueueError::ReserveOverflow)?; + pub fn reserve_front(&mut self, blocks: usize, elements: usize) -> Result<(), BlocksDequeError> { + let _ = self.metadata.reserve_front(blocks).map_err(|_er| BlocksDequeError::ReserveOverflow)?; + let n = self.data.reserve_front(elements).map_err(|_err| BlocksDequeError::ReserveOverflow)?; self.adjust_metadata(n, 0); Ok(()) } + // Only used in front operations #[inline] fn adjust_metadata(&mut self, head_shift: usize, skip: usize) { if head_shift > 0 { @@ -212,27 +251,9 @@ impl BlocksQueue { } } -impl Default for BlocksQueue { +impl Default for BlocksDeque { #[inline] fn default() -> Self { Self::new() } } - -/// Block -#[derive(Debug, PartialEq)] -pub struct Block { - /// Opaque data - pub data: D, - /// Miscellaneous - pub misc: M, - /// Range - pub range: Range, -} - -#[derive(Clone, Copy, Debug)] -struct Metadata { - begin: usize, - len: usize, - misc: M, -} diff --git a/wtx/src/misc/blocks_deque/block.rs b/wtx/src/misc/blocks_deque/block.rs new file mode 100644 index 00000000..dfa0ea2c --- /dev/null +++ b/wtx/src/misc/blocks_deque/block.rs @@ -0,0 +1,12 @@ +use core::ops::Range; + +/// Block +#[derive(Debug, PartialEq)] +pub struct Block { + /// Opaque data + pub data: D, + /// Miscellaneous + pub misc: M, + /// Range + pub range: Range, +} diff --git a/wtx/src/misc/blocks_deque/blocks_deque_builder.rs b/wtx/src/misc/blocks_deque/blocks_deque_builder.rs new file mode 100644 index 00000000..b1e8963d --- /dev/null +++ b/wtx/src/misc/blocks_deque/blocks_deque_builder.rs @@ -0,0 +1,92 @@ +use core::slice; + +use crate::misc::{blocks_deque::metadata::Metadata, BlocksDeque, BlocksDequeError, BufferMode}; + +/// Allows the construction of a single block through the insertion of indivial elements. +#[derive(Debug)] +pub struct BlocksDequeBuilder<'db, D, M, const IS_BACK: bool> { + bd: &'db mut BlocksDeque, + inserted: usize, + was_built: bool, +} + +impl<'db, D, M, const IS_BACK: bool> BlocksDequeBuilder<'db, D, M, IS_BACK> { + #[inline] + pub(crate) fn new(bd: &'db mut BlocksDeque) -> Self { + Self { bd, inserted: 0, was_built: false } + } + + /// Finishes the construction of the block + #[inline] + pub fn build(mut self, misc: M) -> Result<(), BlocksDequeError> { + self.was_built = true; + let rslt = if IS_BACK { + let begin = self.bd.data.tail().wrapping_sub(self.inserted); + let metadata = Metadata { begin, len: self.inserted, misc }; + self.bd.metadata.push_back(metadata) + } else { + let metadata = Metadata { begin: self.bd.data.head(), len: self.inserted, misc }; + self.bd.metadata.push_front(metadata) + }; + rslt.map_err(|_err| BlocksDequeError::PushOverflow) + } + + /// Appends or prepends elements so that the current length is equal to `bp`. + #[inline] + pub fn expand(&mut self, bm: BufferMode, value: D) -> Result<&mut Self, BlocksDequeError> + where + D: Clone, + { + let additional = if IS_BACK { + let rslt = self.bd.data.expand_back(bm, value); + rslt.map_err(|_err| BlocksDequeError::PushOverflow)? + } else { + let rslt = self.bd.data.expand_front(bm, value); + let (additional, head_shift) = rslt.map_err(|_err| BlocksDequeError::PushOverflow)?; + self.bd.adjust_metadata(head_shift, 1); + additional + }; + self.inserted = self.inserted.wrapping_add(additional); + Ok(self) + } + + /// The elements inserted so far by this builder + #[inline] + pub fn inserted_elements(&mut self) -> &mut [D] { + let ptr = self.bd.data.as_ptr_mut(); + let shifted_ptr = if IS_BACK { + let begin = self.bd.data.tail().wrapping_sub(self.inserted); + // SAFETY: We are in a "back-only" environment so the tail index will never be less than + // the number of inserted elements + unsafe { ptr.add(begin) } + } else { + let begin = self.bd.data.head(); + // SAFETY: We are in a "front-only" environment so there will always be `inserted` elements + // after the starting head. + unsafe { ptr.add(begin) } + }; + // SAFETY: the above checks ensure valid memory + unsafe { slice::from_raw_parts_mut(shifted_ptr, self.inserted) } + } +} + +// The `was_built` parameter is used to enforce a valid instance in case of an external error. +// +// ``` +// builder.expand(...); +// some_fallible_operation(...)?; +// builder.build(); +// ``` +impl<'db, D, M, const IS_BACK: bool> Drop for BlocksDequeBuilder<'db, D, M, IS_BACK> { + #[inline] + fn drop(&mut self) { + if !self.was_built { + let previous_len = self.bd.elements_len().wrapping_sub(self.inserted); + if IS_BACK { + self.bd.data.truncate_back(previous_len); + } else { + self.bd.data.truncate_front(previous_len); + } + } + } +} diff --git a/wtx/src/misc/blocks_deque/metadata.rs b/wtx/src/misc/blocks_deque/metadata.rs new file mode 100644 index 00000000..0291efb0 --- /dev/null +++ b/wtx/src/misc/blocks_deque/metadata.rs @@ -0,0 +1,6 @@ +#[derive(Clone, Copy, Debug)] +pub(crate) struct Metadata { + pub(crate) begin: usize, + pub(crate) len: usize, + pub(crate) misc: M, +} diff --git a/wtx/src/misc/blocks_queue/tests.rs b/wtx/src/misc/blocks_deque/tests.rs similarity index 81% rename from wtx/src/misc/blocks_queue/tests.rs rename to wtx/src/misc/blocks_deque/tests.rs index 57eead99..a3f388e6 100644 --- a/wtx/src/misc/blocks_queue/tests.rs +++ b/wtx/src/misc/blocks_deque/tests.rs @@ -5,7 +5,7 @@ // RO = Right Occupied // T = Tail (Exclusive) -use crate::misc::{blocks_queue::BlockRef, BlocksQueue}; +use crate::misc::{blocks_deque::BlockRef, BlocksDeque}; // [. . . . . . . .]: Empty - (LF=8, LO=0,RF=0, RO=0) - (H=0, T=0) // [. . . . . . . H]: Push front - (LF=7, LO=0, RF=0, RO=1) - (H=7, T=8) @@ -27,25 +27,25 @@ use crate::misc::{blocks_queue::BlockRef, BlocksQueue}; // [. . . . . . . .]: Pop back - (LF=8, LO=0, RF=0, RO=0) - (H=0, T=0) #[test] fn pop_back() { - let mut bq = BlocksQueue::with_exact_capacity(4, 8).unwrap(); + let mut bq = BlocksDeque::with_exact_capacity(4, 8).unwrap(); check_state(&bq, 0, 0, &[], &[]); - bq.push_front([&[1][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[1][..]], ()).unwrap(); check_state(&bq, 1, 1, &[1], &[]); - bq.push_front([&[2, 3][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[2, 3][..]], ()).unwrap(); check_state(&bq, 2, 3, &[2, 3, 1], &[]); - bq.push_front([&[4, 5], &[6][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[4, 5], &[6][..]], ()).unwrap(); check_state(&bq, 3, 6, &[4, 5, 6, 2, 3, 1], &[]); - bq.push_front([&[7, 8][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[7, 8][..]], ()).unwrap(); check_state(&bq, 4, 8, &[7, 8, 4, 5, 6, 2, 3, 1], &[]); let _ = bq.pop_back(); check_state(&bq, 3, 7, &[7, 8, 4, 5, 6, 2, 3], &[]); - bq.push_front([&[9][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[9][..]], ()).unwrap(); check_state(&bq, 4, 8, &[9], &[7, 8, 4, 5, 6, 2, 3]); let _ = bq.pop_back(); @@ -54,10 +54,10 @@ fn pop_back() { let _ = bq.pop_back(); check_state(&bq, 2, 3, &[9], &[7, 8]); - bq.push_front([&[10], &[11, 12][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[10], &[11, 12][..]], ()).unwrap(); check_state(&bq, 3, 6, &[10, 11, 12, 9], &[7, 8]); - bq.push_front([&[13, 14][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[13, 14][..]], ()).unwrap(); check_state(&bq, 4, 8, &[13, 14, 10, 11, 12, 9], &[7, 8]); let _ = bq.pop_back(); @@ -69,10 +69,10 @@ fn pop_back() { let _ = bq.pop_back(); check_state(&bq, 1, 2, &[13, 14], &[]); - bq.push_front([&[15][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[15][..]], ()).unwrap(); check_state(&bq, 2, 3, &[15, 13, 14], &[]); - bq.push_front([&[16][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[16][..]], ()).unwrap(); check_state(&bq, 3, 4, &[16, 15, 13, 14], &[]); let _ = bq.pop_back(); @@ -92,13 +92,13 @@ fn pop_back() { // [. . . . . . . .]: Pop back - (LF=8, LO=0, RF=0, RO=0) - (H=0, T=0) #[test] fn pop_front() { - let mut bq = BlocksQueue::with_exact_capacity(2, 8).unwrap(); + let mut bq = BlocksDeque::with_exact_capacity(2, 8).unwrap(); check_state(&bq, 0, 0, &[], &[]); - bq.push_front([&[1, 2, 3][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[1, 2, 3][..]], ()).unwrap(); check_state(&bq, 1, 3, &[1, 2, 3], &[]); - bq.push_front([&[4, 5], &[6, 7, 8][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[4, 5], &[6, 7, 8][..]], ()).unwrap(); check_state(&bq, 2, 8, &[4, 5, 6, 7, 8, 1, 2, 3], &[]); let _ = bq.pop_front(); @@ -112,14 +112,14 @@ fn pop_front() { // [H * * *]: Push front - (LF=0, LO=0, RF=0, RO=4) - (H=0, T=4) #[test] fn push_reserve_and_push() { - let mut bq = BlocksQueue::new(); + let mut bq = BlocksDeque::new(); bq.reserve_front(1, 4).unwrap(); - bq.push_front([&[0, 1, 2, 3][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[0, 1, 2, 3][..]], ()).unwrap(); check_state(&bq, 1, 4, &[0, 1, 2, 3], &[]); assert_eq!(bq.get(0), Some(BlockRef { data: &[0, 1, 2, 3], misc: &(), range: 0..4 })); assert_eq!(bq.get(1), None); bq.reserve_front(1, 6).unwrap(); - bq.push_front([&[4, 5, 6, 7, 8, 9][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[4, 5, 6, 7, 8, 9][..]], ()).unwrap(); check_state(&bq, 2, 10, &[4, 5, 6, 7, 8, 9, 0, 1, 2, 3], &[]); assert_eq!(bq.get(0), Some(BlockRef { data: &[4, 5, 6, 7, 8, 9], misc: &(), range: 0..6 })); assert_eq!(bq.get(1), Some(BlockRef { data: &[0, 1, 2, 3], misc: &(), range: 6..10 })); @@ -157,11 +157,11 @@ fn wrap_pop_front() { // [. . H * * * * * ]: Push front - (LF=2, LO=0, RF=0, RO=6) // [. . H * . . . . ]: Pop back - (LF=2, LO=0, RF=4, RO=0) // [. . . H * * * * ]: Push front - (LF=3, LO=0, RF=0, RO=5) -fn wrap_initial() -> BlocksQueue { - let mut bq = BlocksQueue::with_exact_capacity(6, 8).unwrap(); +fn wrap_initial() -> BlocksDeque { + let mut bq = BlocksDeque::with_exact_capacity(6, 8).unwrap(); check_state(&bq, 0, 0, &[], &[]); for _ in 0..6 { - bq.push_front([&[0][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[0][..]], ()).unwrap(); } check_state(&bq, 6, 6, &[0, 0, 0, 0, 0, 0], &[]); for idx in 0..6 { @@ -174,7 +174,7 @@ fn wrap_initial() -> BlocksQueue { check_state(&bq, 2, 2, &[0, 0], &[]); assert_eq!(bq.get(0).unwrap().data, &[0]); assert_eq!(bq.get(1).unwrap().data, &[0]); - bq.push_front([&[1, 2, 3][..]], ()).unwrap(); + bq.push_front_from_coyable_data([&[1, 2, 3][..]], ()).unwrap(); check_state(&bq, 3, 5, &[1, 2, 3, 0, 0], &[]); assert_eq!(bq.get(0).unwrap().data, &[1, 2, 3]); assert_eq!(bq.get(1).unwrap().data, &[0]); @@ -184,7 +184,7 @@ fn wrap_initial() -> BlocksQueue { #[track_caller] fn check_state( - bq: &BlocksQueue, + bq: &BlocksDeque, blocks_len: usize, elements_len: usize, front: &[i32], diff --git a/wtx/src/misc/queue.rs b/wtx/src/misc/deque.rs similarity index 60% rename from wtx/src/misc/queue.rs rename to wtx/src/misc/deque.rs index d793a5f0..6963e90e 100644 --- a/wtx/src/misc/queue.rs +++ b/wtx/src/misc/deque.rs @@ -1,3 +1,38 @@ +// 1. Valid instances +// +// In a double ended queue it is possible to store elements in 2 logical ways. +// +// 1.1. Contiguous +// +// No boundary intersections +// +// | | | A | B | C | | | | | +// +// 1.2. Wrapping +// +// The order doesn't matter, front elements will always stay at the right-hand-side +// and back elements will always stay at the left-hand-side. +// +// 1.2.1 Pushing an element to the back of the queue. +// +// | | | | | | | | A | B | +// ------------------------------------- +// | C | | | | | | | A | B | +// +// 1.2.2 Prepending an element to the front of the queue. +// +// | A | B | | | | | | | | +// ------------------------------------- +// | B | C | | | | | | | A | +// +// 2. Invalid instances +// +// It is impossible to exist a wrapping non-contiguous queue like in the following examples. +// +// | B | C | D | | | | A | | | +// +// | | A | B | | C | | | | | + macro_rules! as_slices { ($empty:expr, $ptr:ident, $slice:ident, $this:expr, $($ref:tt)*) => {{ let capacity = $this.data.capacity(); @@ -22,21 +57,21 @@ macro_rules! as_slices { }} } -#[cfg(all(feature = "_proptest", test))] -mod proptest; +#[cfg(kani)] +mod kani; #[cfg(test)] mod tests; -use crate::misc::Vector; +use crate::misc::{BufferMode, Vector}; use core::{ fmt::{Debug, Formatter}, mem::needs_drop, ptr, slice, }; -/// Errors of [Queue]. +/// Errors of [Deque]. #[derive(Debug)] -pub enum QueueError { +pub enum DequeueError { #[doc = doc_single_elem_cap_overflow!()] ExtendFromSliceOverflow, #[doc = doc_single_elem_cap_overflow!()] @@ -47,14 +82,25 @@ pub enum QueueError { WithCapacityOverflow, } -/// A circular buffer. -pub struct Queue { +/// A double-ended queue implemented with a growable ring buffer. +// +// # Illustration +// +// | | | A | B | C | D | | | | | +// | | |--> data.capacity() +// | | +// | |------------------> tail +// | +// |----------------------------------> head +// +// The vector length is a shorcut for the sum of head of tail elements. +pub struct Deque { data: Vector, head: usize, tail: usize, } -impl Queue { +impl Deque { const NEEDS_DROP: bool = needs_drop::(); /// Creates a new empty instance. @@ -65,9 +111,9 @@ impl Queue { /// Constructs a new, empty instance with at least the specified capacity. #[inline] - pub fn with_capacity(cap: usize) -> Result { + pub fn with_capacity(cap: usize) -> Result { Ok(Self { - data: Vector::with_capacity(cap).map_err(|_err| QueueError::WithCapacityOverflow)?, + data: Vector::with_capacity(cap).map_err(|_err| DequeueError::WithCapacityOverflow)?, head: 0, tail: 0, }) @@ -75,9 +121,9 @@ impl Queue { /// Constructs a new, empty instance with at least the specified capacity. #[inline] - pub fn with_exact_capacity(cap: usize) -> Result { + pub fn with_exact_capacity(cap: usize) -> Result { Ok(Self { - data: Vector::with_capacity(cap).map_err(|_err| QueueError::WithCapacityOverflow)?, + data: Vector::with_capacity(cap).map_err(|_err| DequeueError::WithCapacityOverflow)?, head: 0, tail: 0, }) @@ -100,7 +146,7 @@ impl Queue { /// Returns a pair of slices which contain, in order, the contents of the queue. /// /// ```rust - /// let mut queue = wtx::misc::Queue::with_capacity(8).unwrap(); + /// let mut queue = wtx::misc::Deque::with_capacity(8).unwrap(); /// queue.push_front(3).unwrap(); /// queue.push_back(1).unwrap(); /// queue.push_back(2).unwrap(); @@ -114,7 +160,7 @@ impl Queue { /// Mutable version of [`Self::as_slices`]. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_front(3); /// queue.push_back(1); /// queue.push_back(2); @@ -140,10 +186,102 @@ impl Queue { *tail = 0; } + /// Appends elements to the back of the instance so that the current length is equal to `bp`. + /// + /// Does nothing if the calculated length is equal or less than the current length. + /// + /// ```rust + /// let mut queue = wtx::misc::Deque::new(); + /// queue.expand_back(wtx::misc::BufferMode::Len(4), 1); + /// assert_eq!(queue.as_slices(), (&[1, 1, 1, 1][..], &[][..])); + /// ``` + #[inline(always)] + pub fn expand_back(&mut self, bm: BufferMode, value: T) -> Result + where + T: Clone, + { + let len = self.data.len(); + let Some((additional, new_len)) = bm.params(len) else { + return Ok(0); + }; + let rr = self.prolong_back(additional)?; + // SAFETY: Elements were allocated + unsafe { + self.expand(additional, rr.begin, new_len, value); + } + Ok(additional) + } + + /// Prepends elements to the front of the instance so that the current length is equal to `bp`. + /// + /// Does nothing if the calculated length is equal or less than the current length. + /// + /// ```rust + /// let mut queue = wtx::misc::Deque::new(); + /// queue.expand_front(wtx::misc::BufferMode::Len(4), 1); + /// assert_eq!(queue.as_slices(), (&[1, 1, 1, 1][..], &[][..])); + /// ``` + #[inline(always)] + pub fn expand_front(&mut self, bp: BufferMode, value: T) -> Result<(usize, usize), DequeueError> + where + T: Clone, + { + let len = self.data.len(); + let Some((additional, new_len)) = bp.params(len) else { + return Ok((0, 0)); + }; + let rr = self.prolong_front(additional)?; + // SAFETY: Elements were allocated + unsafe { + self.expand(additional, rr.begin, new_len, value); + } + Ok((additional, rr.head_shift)) + } + + /// Appends all elements of the iterator. + /// + /// ```rust + /// let mut queue = wtx::misc::Deque::new(); + /// queue.extend_back_from_iter([1, 2]); + /// assert_eq!(queue.len(), 2); + /// ``` + #[inline] + pub fn extend_back_from_iter( + &mut self, + ii: impl IntoIterator, + ) -> Result<(), DequeueError> { + let iter = ii.into_iter(); + let _ = self.reserve_back(iter.size_hint().0)?; + for elem in iter { + self.push_back(elem)?; + } + Ok(()) + } + + /// Prepends all elements of the iterator. + /// + /// ```rust + /// let mut queue = wtx::misc::Deque::new(); + /// queue.extend_front_from_iter([1, 2]); + /// assert_eq!(queue.len(), 2); + /// ``` + #[inline] + pub fn extend_front_from_iter( + &mut self, + ii: impl IntoIterator, + ) -> Result<(), DequeueError> { + let iter = ii.into_iter(); + let _ = self.reserve_front(iter.size_hint().0)?; + for elem in iter { + self.push_front(elem)?; + } + Ok(()) + } + /// Provides a reference to the element at the given index. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_back(3); /// assert_eq!(queue.get(0), Some(&1)); @@ -153,7 +291,7 @@ impl Queue { if idx >= self.data.len() { return None; } - idx = wrap_add(self.data.capacity(), self.head, idx); + idx = wrap_add_idx(self.data.capacity(), self.head, idx); // SAFETY: `idx` points to valid memory let rslt = unsafe { self.data.as_ptr().add(idx) }; // SAFETY: `idx` points to valid memory @@ -163,7 +301,7 @@ impl Queue { /// Mutable version of [`Self::get`]. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_back(3); /// assert_eq!(queue.get_mut(0), Some(&mut 1)); @@ -173,7 +311,7 @@ impl Queue { if idx >= self.data.len() { return None; } - idx = wrap_add(self.data.capacity(), self.head, idx); + idx = wrap_add_idx(self.data.capacity(), self.head, idx); // SAFETY: `idx` points to valid memory let rslt = unsafe { self.data.as_ptr_mut().add(idx) }; // SAFETY: `idx` points to valid memory @@ -183,7 +321,7 @@ impl Queue { /// Returns a front-to-back iterator. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_front(3); /// let mut iter = queue.iter(); @@ -200,7 +338,7 @@ impl Queue { /// Mutable version of [`Self::iter`]. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_front(3); /// let mut iter = queue.iter_mut(); @@ -229,7 +367,7 @@ impl Queue { /// Removes the last element from the queue and returns it, or `None` if it is empty. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_back(3); /// queue.pop_back(); @@ -238,7 +376,7 @@ impl Queue { #[inline] pub fn pop_back(&mut self) -> Option { let new_len = self.data.len().checked_sub(1)?; - let curr_tail = wrap_sub(self.data.capacity(), self.tail, 1); + let curr_tail = wrap_sub_idx(self.data.capacity(), self.tail, 1); self.tail = curr_tail; // SAFETY: is within bounds unsafe { @@ -253,7 +391,7 @@ impl Queue { /// Removes the first element and returns it, or [`Option::None`] if the queue is empty. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_back(3); /// queue.pop_front(); @@ -263,7 +401,7 @@ impl Queue { pub fn pop_front(&mut self) -> Option { let new_len = self.data.len().checked_sub(1)?; let prev_head = self.head; - self.head = wrap_add(self.data.capacity(), prev_head, 1); + self.head = wrap_add_idx(self.data.capacity(), prev_head, 1); // SAFETY: is within bounds unsafe { self.data.set_len(new_len); @@ -277,17 +415,17 @@ impl Queue { /// Appends an element to the back of the queue. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_back(1); /// queue.push_back(3); /// assert_eq!(queue.as_slices(), (&[1, 3][..], &[][..])); /// ``` #[inline] - pub fn push_back(&mut self, value: T) -> Result<(), QueueError> { - let _ = self.reserve_back(1).map_err(|_err| QueueError::PushFrontOverflow)?; + pub fn push_back(&mut self, value: T) -> Result<(), DequeueError> { + let _ = self.reserve_back(1).map_err(|_err| DequeueError::PushFrontOverflow)?; let len = self.data.len(); let tail = self.tail; - self.tail = wrap_add(self.data.capacity(), tail, 1); + self.tail = wrap_add_idx(self.data.capacity(), tail, 1); // SAFETY: `idx` is within bounds let dst = unsafe { self.data.as_ptr_mut().add(tail) }; // SAFETY: `dst` points to valid memory @@ -304,16 +442,16 @@ impl Queue { /// Prepends an element to the queue. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_front(1); /// queue.push_front(3); /// assert_eq!(queue.as_slices(), (&[3, 1][..], &[][..])); /// ``` #[inline] - pub fn push_front(&mut self, value: T) -> Result<(), QueueError> { - let _ = self.reserve_front(1).map_err(|_err| QueueError::PushFrontOverflow)?; + pub fn push_front(&mut self, value: T) -> Result<(), DequeueError> { + let _ = self.reserve_front(1).map_err(|_err| DequeueError::PushFrontOverflow)?; let len = self.data.len(); - self.head = wrap_sub(self.data.capacity(), self.head, 1); + self.head = wrap_sub_idx(self.data.capacity(), self.head, 1); // SAFETY: `self.head` points to valid memory let dst = unsafe { self.data.as_ptr_mut().add(self.head) }; // SAFETY: `dst` points to valid memory @@ -330,23 +468,23 @@ impl Queue { /// Reserves capacity for at least additional more elements to be inserted at the back of the /// queue. #[inline(always)] - pub fn reserve_back(&mut self, additional: usize) -> Result { - let tuple = reserve::<_, true>(additional, &mut self.data, &mut self.head, &mut self.tail)?; - Ok(tuple.2) + pub fn reserve_back(&mut self, additional: usize) -> Result { + let rr = reserve::<_, true>(additional, &mut self.data, &mut self.head, &mut self.tail)?; + Ok(rr.head_shift) } /// Reserves capacity for at least additional more elements to be inserted at the front of the /// queue. #[inline(always)] - pub fn reserve_front(&mut self, additional: usize) -> Result { - let tuple = reserve::<_, false>(additional, &mut self.data, &mut self.head, &mut self.tail)?; - Ok(tuple.2) + pub fn reserve_front(&mut self, additional: usize) -> Result { + let rr = reserve::<_, false>(additional, &mut self.data, &mut self.head, &mut self.tail)?; + Ok(rr.head_shift) } /// Shortens the queue, keeping the first `new_len` elements. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_front(1); /// queue.push_front(3); /// queue.push_back(5); @@ -405,7 +543,7 @@ impl Queue { /// Shortens the queue, keeping the last `new_len` elements. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_front(1); /// queue.push_front(3); /// queue.push_back(5); @@ -460,17 +598,110 @@ impl Queue { self.data.set_len(new_len); } } + + #[inline] + pub(crate) fn head(&self) -> usize { + self.head + } + + #[inline] + pub(crate) fn tail(&self) -> usize { + self.tail + } + + #[inline] + unsafe fn expand(&mut self, additional: usize, begin: usize, new_len: usize, value: T) + where + T: Clone, + { + // SAFETY: it is up to the caller to pass valid elements and enough allocated capacity + let ptr = unsafe { self.data.as_ptr_mut().add(begin) }; + // SAFETY: it is up to the caller to pass valid elements and enough allocated capacity + unsafe { + slice::from_raw_parts_mut(ptr, additional).fill(value); + } + // SAFETY: it is up to the caller to pass valid elements and enough allocated capacity + unsafe { + self.data.set_len(new_len); + } + } + + #[inline] + fn prolong_back(&mut self, additional: usize) -> Result { + let rr = reserve::<_, true>(additional, &mut self.data, &mut self.head, &mut self.tail)?; + self.tail = rr.begin.wrapping_add(additional); + Ok(rr) + } + + #[inline] + fn prolong_front(&mut self, additional: usize) -> Result { + let rr = reserve::<_, false>(additional, &mut self.data, &mut self.head, &mut self.tail)?; + self.head = rr.begin; + Ok(rr) + } + + #[inline] + fn slices_len<'iter>(iter: impl Iterator) -> usize + where + T: 'iter, + { + let mut len: usize = 0; + for other in iter { + len = len.wrapping_add(other.len()); + } + len + } } -impl Queue +impl Deque where T: Copy, { + /// Iterates over the `others` slices, copies each element, and then appends + /// them to this instance. `others` are traversed in-order. + /// + /// ```rust + /// let mut queue = wtx::misc::Deque::new(); + /// queue.push_back(4); + /// queue.extend_back_from_copyable_slices([&[2, 3][..]]); + /// queue.extend_back_from_copyable_slices([&[0, 1][..], &[1][..]]); + /// assert_eq!(queue.as_slices(), (&[4, 2, 3, 0, 1, 1][..], &[][..])); + /// ``` + #[inline] + pub fn extend_back_from_copyable_slices<'iter, I>( + &mut self, + others: I, + ) -> Result + where + I: IntoIterator, + I::IntoIter: Clone, + T: 'iter, + { + let iter = others.into_iter(); + let others_len = Self::slices_len(iter.clone()); + let rr = self.prolong_back(others_len)?; + let mut shift = rr.begin; + for other in iter { + // SAFETY: `self.head` points to valid memory + let dst = unsafe { self.data.as_ptr_mut().add(shift) }; + // SAFETY: `dst` points to valid memory + unsafe { + ptr::copy_nonoverlapping(other.as_ptr(), dst, other.len()); + } + shift = shift.wrapping_add(other.len()); + } + // SAFETY: is within bounds + unsafe { + self.data.set_len(self.data.len().wrapping_add(others_len)); + } + Ok(others_len) + } + /// Iterates over the `others` slices, copies each element, and then prepends - /// it to this vector. The `others` slices are traversed in-order. + /// them to this instance. `others` are traversed in-order. /// /// ```rust - /// let mut queue = wtx::misc::Queue::new(); + /// let mut queue = wtx::misc::Deque::new(); /// queue.push_front(4); /// queue.extend_front_from_copyable_slices([&[2, 3][..]]); /// queue.extend_front_from_copyable_slices([&[0, 1][..], &[1][..]]); @@ -480,62 +711,65 @@ where pub fn extend_front_from_copyable_slices<'iter, I>( &mut self, others: I, - ) -> Result<(usize, usize), QueueError> + ) -> Result<(usize, usize), DequeueError> where I: IntoIterator, I::IntoIter: Clone, T: 'iter, { - let mut others_len: usize = 0; let iter = others.into_iter(); - for other in iter.clone() { - let Some(curr_len) = others_len.checked_add(other.len()) else { - return Err(QueueError::ExtendFromSliceOverflow); - }; - others_len = curr_len; - } - let tuple = reserve::<_, false>(others_len, &mut self.data, &mut self.head, &mut self.tail)?; - let mut head = tuple.0; - self.head = head; + let others_len = Self::slices_len(iter.clone()); + let rr = self.prolong_front(others_len)?; + let mut shift = rr.begin; for other in iter { // SAFETY: `self.head` points to valid memory - let dst = unsafe { self.data.as_ptr_mut().add(head) }; + let dst = unsafe { self.data.as_ptr_mut().add(shift) }; // SAFETY: `dst` points to valid memory unsafe { ptr::copy_nonoverlapping(other.as_ptr(), dst, other.len()); } - head = head.wrapping_add(other.len()); + shift = shift.wrapping_add(other.len()); } // SAFETY: is within bounds unsafe { self.data.set_len(self.data.len().wrapping_add(others_len)); } - Ok((others_len, tuple.2)) - } - - pub(crate) fn head(&self) -> usize { - self.head + Ok((others_len, rr.head_shift)) } } -impl Debug for Queue +impl Debug for Deque where T: Debug, { #[inline] fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { let (front, back) = self.as_slices(); - f.debug_struct("Queue").field("front", &front).field("back", &back).finish() + f.debug_struct("Deque").field("front", &front).field("back", &back).finish() } } -impl Default for Queue { +impl Default for Deque { #[inline] fn default() -> Self { Self::new() } } +struct ReserveRslt { + /// Starting indexwhere the `additional` number of elements can be inserted. + begin: usize, + /// The number os places the head must be shift. + head_shift: usize, +} + +impl ReserveRslt { + #[inline] + fn new(begin: usize, head_shift: usize) -> Self { + Self { begin, head_shift } + } +} + #[inline] unsafe fn drop_elements(len: usize, offset: usize, ptr: *mut T) { // SAFETY: It is up to the caller to provide a valid pointer with a valid index @@ -579,35 +813,35 @@ fn is_wrapping(head: usize, len: usize, tail: usize) -> bool { } } -/// Returns the starting and ending index where the `additional` number of elements -/// can be inserted. +/// Allocates `additional` capacity for the contiguous insertion of back or front elements. This +/// also means that the free capacity of intersections is not considered. #[inline(always)] fn reserve( additional: usize, data: &mut Vector, head: &mut usize, tail: &mut usize, -) -> Result<(usize, usize, usize), QueueError> { +) -> Result { let len = data.len(); let prev_cap = data.capacity(); - data.reserve(additional).map_err(|_err| QueueError::ReserveOverflow)?; + data.reserve(additional).map_err(|_err| DequeueError::ReserveOverflow)?; let curr_cap = data.capacity(); let prev_head = prev_cap.min(*head); let prev_tail = prev_cap.min(*tail); if len == 0 { return Ok(if IS_BACK { - (0, additional, 0) + ReserveRslt::new(0, 0) } else { - (curr_cap.wrapping_sub(additional), curr_cap, 0) + ReserveRslt::new(curr_cap.wrapping_sub(additional), 0) }); } if is_wrapping(prev_head, len, prev_tail) { let free_slots = prev_head.wrapping_sub(prev_tail); if free_slots >= additional { return Ok(if IS_BACK { - (prev_tail, prev_tail.wrapping_add(additional), 0) + ReserveRslt::new(prev_tail, 0) } else { - (prev_head.wrapping_sub(additional), prev_head, 0) + ReserveRslt::new(prev_head.wrapping_sub(additional), 0) }); } let front_len = prev_cap.wrapping_sub(prev_head); @@ -621,20 +855,20 @@ fn reserve( ptr::copy(src, dst, front_len); } *head = curr_head; - if IS_BACK { - Ok((prev_tail, prev_tail.wrapping_add(additional), 0)) + Ok(if IS_BACK { + ReserveRslt::new(prev_tail, 0) } else { - Ok((curr_head.wrapping_sub(additional), curr_head, curr_cap.wrapping_sub(prev_cap))) - } + ReserveRslt::new(curr_head.wrapping_sub(additional), curr_cap.wrapping_sub(prev_cap)) + }) } else { let left_free = prev_head; let right_free = curr_cap.wrapping_sub(prev_tail); if IS_BACK { if right_free >= additional { - return Ok((prev_tail, prev_tail.wrapping_add(additional), 0)); + return Ok(ReserveRslt::new(prev_tail, 0)); } if right_free == 0 && left_free >= additional { - return Ok((0, additional, 0)); + return Ok(ReserveRslt::new(0, 0)); } // SAFETY: `prev_head` is equal or less than the current capacity let src = unsafe { data.as_ptr_mut().add(prev_head) }; @@ -645,13 +879,13 @@ fn reserve( let curr_tail = len; *head = 0; *tail = curr_tail; - Ok((curr_tail, curr_tail.wrapping_add(additional), 0)) + Ok(ReserveRslt::new(curr_tail, 0)) } else { if left_free >= additional { - return Ok((prev_head.wrapping_sub(additional), prev_head, 0)); + return Ok(ReserveRslt::new(prev_head.wrapping_sub(additional), 0)); } if left_free == 0 && right_free >= additional { - return Ok((curr_cap.wrapping_sub(additional), curr_cap, 0)); + return Ok(ReserveRslt::new(curr_cap.wrapping_sub(additional), 0)); } let curr_head = curr_cap.wrapping_sub(len); // SAFETY: `prev_head` is equal or less than the current capacity @@ -664,14 +898,14 @@ fn reserve( } *head = curr_head; *tail = curr_cap; - Ok((curr_head.wrapping_sub(additional), curr_head, right_free)) + Ok(ReserveRslt::new(curr_head.wrapping_sub(additional), right_free)) } } } #[inline] -fn wrap_add(capacity: usize, idx: usize, value: usize) -> usize { - wrap_idx(idx.wrapping_add(value), capacity) +fn wrap_add_idx(capacity: usize, idx: usize, offset: usize) -> usize { + wrap_idx(idx.wrapping_add(offset), capacity) } #[inline] @@ -680,10 +914,6 @@ fn wrap_idx(idx: usize, cap: usize) -> usize { } #[inline] -fn wrap_sub(capacity: usize, idx: usize, value: usize) -> usize { - #[inline] - fn wrap_idx(idx: usize, cap: usize) -> usize { - idx.checked_sub(cap).unwrap_or(idx) - } - wrap_idx(idx.wrapping_sub(value).wrapping_add(capacity), capacity) +fn wrap_sub_idx(capacity: usize, idx: usize, offset: usize) -> usize { + wrap_idx(idx.wrapping_sub(offset).wrapping_add(capacity), capacity) } diff --git a/wtx/src/misc/queue/proptest.rs b/wtx/src/misc/deque/kani.rs similarity index 87% rename from wtx/src/misc/queue/proptest.rs rename to wtx/src/misc/deque/kani.rs index 0616fd1c..5e62de6f 100644 --- a/wtx/src/misc/queue/proptest.rs +++ b/wtx/src/misc/deque/kani.rs @@ -1,9 +1,10 @@ -use crate::misc::Queue; -use alloc::{collections::VecDeque, vec::Vec}; +use crate::misc::Deque; +use alloc::collections::VecDeque; -#[test_strategy::proptest] -fn queue(bytes: Vec) { - let mut queue = Queue::with_capacity(bytes.len()).unwrap(); +#[kani::proof] +fn queue() { + let bytes = kani::vec::any_vec::(); + let mut queue = Deque::with_capacity(bytes.len()).unwrap(); let mut vec_deque = VecDeque::with_capacity(bytes.len()); for byte in bytes.iter().copied() { diff --git a/wtx/src/misc/queue/tests.rs b/wtx/src/misc/deque/tests.rs similarity index 81% rename from wtx/src/misc/queue/tests.rs rename to wtx/src/misc/deque/tests.rs index 2c541d49..4c6dfb06 100644 --- a/wtx/src/misc/queue/tests.rs +++ b/wtx/src/misc/deque/tests.rs @@ -1,8 +1,8 @@ -use crate::misc::{queue::is_wrapping, Queue}; +use crate::misc::{deque::is_wrapping, Deque}; #[test] fn as_slices() { - let mut queue = Queue::with_capacity(4).unwrap(); + let mut queue = Deque::with_capacity(4).unwrap(); queue.push_front(1).unwrap(); queue.push_back(5).unwrap(); assert_eq!(queue.as_slices(), (&[1][..], &[5][..])); @@ -10,7 +10,7 @@ fn as_slices() { #[test] fn clear() { - let mut queue = Queue::with_capacity(1).unwrap(); + let mut queue = Deque::with_capacity(1).unwrap(); assert_eq!(queue.len(), 0); queue.push_front(1).unwrap(); assert_eq!(queue.len(), 1); @@ -20,7 +20,7 @@ fn clear() { #[test] fn get() { - let mut queue = Queue::with_capacity(1).unwrap(); + let mut queue = Deque::with_capacity(1).unwrap(); assert_eq!(queue.get(0), None); assert_eq!(queue.get_mut(0), None); queue.push_front(1).unwrap(); @@ -43,7 +43,7 @@ fn heads_tails_and_slices() { fn impossible_instances() { // . . . . H (4-5) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_back(1).unwrap(); queue.push_back(1).unwrap(); queue.push_back(1).unwrap(); @@ -57,7 +57,7 @@ fn impossible_instances() { } // H * * * T (0-5) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_back(1).unwrap(); queue.push_back(2).unwrap(); queue.push_back(3).unwrap(); @@ -124,7 +124,7 @@ fn pop_front() { #[test] fn push_front() { - let mut queue = Queue::with_capacity(1).unwrap(); + let mut queue = Deque::with_capacity(1).unwrap(); assert_eq!(queue.len(), 0); queue.push_front(1).unwrap(); assert_eq!(queue.len(), 1); @@ -132,7 +132,7 @@ fn push_front() { #[test] fn push_when_full() { - let mut bq = Queue::with_capacity(5).unwrap(); + let mut bq = Deque::with_capacity(5).unwrap(); bq.push_front(0).unwrap(); bq.push_front(1).unwrap(); bq.push_front(2).unwrap(); @@ -147,7 +147,7 @@ fn push_when_full() { #[test] fn reserve() { - let mut queue = Queue::::new(); + let mut queue = Deque::::new(); assert_eq!(queue.capacity(), 0); let _ = queue.reserve_back(10).unwrap(); assert!(queue.capacity() >= 10); @@ -156,29 +156,29 @@ fn reserve() { } fn instances( - single_begin: impl FnOnce(&mut Queue) -> (usize, usize, &'static [i32], &'static [i32]), - single_end: impl FnOnce(&mut Queue) -> (usize, usize, &'static [i32], &'static [i32]), - single_both_sides: impl FnOnce(&mut Queue) -> (usize, usize, &'static [i32], &'static [i32]), - full_begin: impl FnOnce(&mut Queue) -> (usize, usize, &'static [i32], &'static [i32]), - full_end: impl FnOnce(&mut Queue) -> (usize, usize, &'static [i32], &'static [i32]), + single_begin: impl FnOnce(&mut Deque) -> (usize, usize, &'static [i32], &'static [i32]), + single_end: impl FnOnce(&mut Deque) -> (usize, usize, &'static [i32], &'static [i32]), + single_both_sides: impl FnOnce(&mut Deque) -> (usize, usize, &'static [i32], &'static [i32]), + full_begin: impl FnOnce(&mut Deque) -> (usize, usize, &'static [i32], &'static [i32]), + full_end: impl FnOnce(&mut Deque) -> (usize, usize, &'static [i32], &'static [i32]), ) { // H . . . . (0-1) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_back(1).unwrap(); let (head, tail, front, back) = single_begin(&mut queue); verify_instance(&queue, head, tail, front, back); } // . . . . H (4-0) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_front(1).unwrap(); let (head, tail, front, back) = single_end(&mut queue); verify_instance(&queue, head, tail, front, back); } // T . . . H (4-1) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_back(2).unwrap(); queue.push_front(1).unwrap(); let (head, tail, front, back) = single_both_sides(&mut queue); @@ -187,7 +187,7 @@ fn instances( } // H * * * T (0-0) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_front(5).unwrap(); queue.push_front(4).unwrap(); queue.push_front(3).unwrap(); @@ -202,7 +202,7 @@ fn instances( } // * * * T H (4-4) { - let mut queue = Queue::with_exact_capacity(5).unwrap(); + let mut queue = Deque::with_exact_capacity(5).unwrap(); queue.push_front(1).unwrap(); queue.push_back(2).unwrap(); queue.push_back(3).unwrap(); @@ -214,7 +214,7 @@ fn instances( } #[track_caller] -fn verify_instance(queue: &Queue, head: usize, tail: usize, front: &[i32], back: &[i32]) { +fn verify_instance(queue: &Deque, head: usize, tail: usize, front: &[i32], back: &[i32]) { assert_eq!((queue.head, queue.tail, queue.as_slices()), (head, tail, (front, back))); assert_eq!(queue.len(), front.len() + back.len()); if is_wrapping(queue.head, queue.data.len(), queue.tail) { diff --git a/wtx/src/misc/filled_buffer.rs b/wtx/src/misc/filled_buffer.rs index d629e3b0..3a51f9e8 100644 --- a/wtx/src/misc/filled_buffer.rs +++ b/wtx/src/misc/filled_buffer.rs @@ -185,12 +185,13 @@ impl std::io::Write for FilledBuffer { } } -#[cfg(all(feature = "_proptest", test))] -mod proptest { +#[cfg(kani)] +mod kani { use crate::misc::FilledBuffer; - #[test_strategy::proptest] - fn reserve_is_allocation(reserve: u8) { + #[kani::proof] + fn reserve_is_allocation() { + let reserve: u8 = kani::any(); let mut vec = FilledBuffer::_new(); vec._reserve(reserve.into()).unwrap(); assert!(vec._capacity() >= reserve.into()); diff --git a/wtx/src/misc/interspace.rs b/wtx/src/misc/interspace.rs new file mode 100644 index 00000000..20f180f6 --- /dev/null +++ b/wtx/src/misc/interspace.rs @@ -0,0 +1,119 @@ +// FIXME(stable): iter_intersperse + +use core::iter::{Fuse, FusedIterator}; + +/// An iterator adapter that places a separator between all elements. +#[derive(Clone, Debug)] +pub struct Intersperse +where + I: Iterator, + I::Item: Clone, +{ + started: bool, + separator: I::Item, + next_item: Option, + iter: Fuse, +} + +impl Intersperse +where + I: Iterator, + I::Item: Clone, +{ + /// Creates a new iterator which places a copy of separator between adjacent items of + /// the original iterator. + #[inline] + pub fn new(iter: I, separator: I::Item) -> Self { + Self { started: false, separator, next_item: None, iter: iter.fuse() } + } +} + +impl FusedIterator for Intersperse +where + I: FusedIterator, + I::Item: Clone, +{ +} + +impl Iterator for Intersperse +where + I: Iterator, + I::Item: Clone, +{ + type Item = I::Item; + + #[inline] + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + let separator = self.separator; + intersperse_fold(self.iter, init, f, move || separator.clone(), self.started, self.next_item) + } + + #[inline] + fn next(&mut self) -> Option { + if self.started { + if let Some(v) = self.next_item.take() { + Some(v) + } else { + let next_item = self.iter.next(); + next_item.is_some().then(|| { + self.next_item = next_item; + self.separator.clone() + }) + } + } else { + self.started = true; + self.iter.next() + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + intersperse_size_hint(&self.iter, self.started, self.next_item.is_some()) + } +} + +#[inline] +fn intersperse_fold( + mut iter: I, + init: B, + mut f: F, + mut separator: G, + started: bool, + mut next_item: Option, +) -> B +where + I: Iterator, + F: FnMut(B, I::Item) -> B, + G: FnMut() -> I::Item, +{ + let mut accum = init; + + let first = if started { next_item.take() } else { iter.next() }; + if let Some(x) = first { + accum = f(accum, x); + } + + iter.fold(accum, |mut elem, x| { + elem = f(elem, separator()); + elem = f(elem, x); + elem + }) +} + +#[inline] +fn intersperse_size_hint(iter: &I, started: bool, next_is_some: bool) -> (usize, Option) +where + I: Iterator, +{ + let (lo, hi) = iter.size_hint(); + ( + lo.saturating_sub((!started).into()).saturating_add(next_is_some.into()).saturating_add(lo), + hi.and_then(|elem| { + elem.saturating_sub((!started).into()).saturating_add(next_is_some.into()).checked_add(elem) + }), + ) +} diff --git a/wtx/src/misc/mem_transfer.rs b/wtx/src/misc/mem_transfer.rs index 2200daff..17a3bb56 100644 --- a/wtx/src/misc/mem_transfer.rs +++ b/wtx/src/misc/mem_transfer.rs @@ -58,22 +58,18 @@ where unsafe { slice.get_unchecked_mut(..new_len) } } -#[cfg(all(feature = "_proptest", test))] -mod proptest { - use crate::misc::{Vector, _shift_copyable_chunks}; - use core::ops::Range; +#[cfg(kani)] +mod kani { + use crate::misc::_shift_copyable_chunks; + use alloc::vec::Vec; - #[test_strategy::proptest] - fn shift_bytes(mut data: Vector, range: Range) { - let mut begin: usize = range.start.into(); - let mut end: usize = range.end.into(); - let mut data_clone = data.clone(); - begin = begin.min(data.len()); - end = end.min(data.len()); - let rslt = _shift_copyable_chunks(0, &mut data, [begin..end]); - data_clone.rotate_left(begin); - data_clone.truncate(rslt.len()); - assert_eq!(rslt, data_clone.as_ref()); + #[kani::proof] + fn shift_bytes() { + let begin = kani::any(); + let tuples = kani::vec::any_vec::<(usize, usize), 128>(); + let ranges: Vec<_> = tuples.into_iter().map(|el| el.0..el.1).collect(); + let mut slice = kani::vec::any_vec::(); + let _ = _shift_copyable_chunks(begin, &mut slice, ranges.into_iter()); } } diff --git a/wtx/src/misc/optimization.rs b/wtx/src/misc/optimization.rs index 02e4c11a..a2dae00a 100644 --- a/wtx/src/misc/optimization.rs +++ b/wtx/src/misc/optimization.rs @@ -44,7 +44,7 @@ pub fn bytes_rsplit1(bytes: &[u8], elem: u8) -> impl Iterator { /// Internally uses `memchr` if the feature is active. #[inline] -pub fn bytes_split1(bytes: &[u8], elem: u8) -> impl Iterator { +pub fn bytes_split1(bytes: &[u8], elem: u8) -> impl Clone + Iterator { #[cfg(feature = "memchr")] return memchr::memchr_iter(elem, bytes).chain(core::iter::once(bytes.len())).scan( 0, diff --git a/wtx/src/misc/stream/tokio.rs b/wtx/src/misc/stream/tokio.rs index cbd5138b..5aa476c3 100644 --- a/wtx/src/misc/stream/tokio.rs +++ b/wtx/src/misc/stream/tokio.rs @@ -31,6 +31,14 @@ impl StreamReader for TcpStream { } } +#[cfg(unix)] +impl StreamReader for tokio::net::UnixStream { + #[inline] + async fn read(&mut self, bytes: &mut [u8]) -> crate::Result { + Ok(::read(self, bytes).await?) + } +} + impl StreamWriter for OwnedWriteHalf { #[inline] async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { @@ -75,3 +83,18 @@ impl StreamWriter for TcpStream { Ok(()) } } + +#[cfg(unix)] +impl StreamWriter for tokio::net::UnixStream { + #[inline] + async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> { + ::write_all(self, bytes).await?; + Ok(()) + } + + #[inline] + async fn write_all_vectored(&mut self, bytes: &[&[u8]]) -> crate::Result<()> { + _local_write_all_vectored!(bytes, self, |io_slices| self.write_vectored(io_slices).await); + Ok(()) + } +} diff --git a/wtx/src/misc/tuple_impls.rs b/wtx/src/misc/tuple_impls.rs index a477566f..8405abd5 100644 --- a/wtx/src/misc/tuple_impls.rs +++ b/wtx/src/misc/tuple_impls.rs @@ -46,11 +46,12 @@ macro_rules! impl_0_16 { mod http_server_framework { use crate::{ http::{ - HttpError, Request, ReqResBuffer, Response, StatusCode, - server_framework::{ConnAux, StreamAux, ReqMiddleware, ResMiddleware, PathManagement, PathParams} + HttpError, Response, Request, ReqResBuffer, StatusCode, + server_framework::{ConnAux, Middleware, StreamAux, PathManagement, PathParams} }, misc::{ArrayVector, Vector} }; + use core::ops::ControlFlow; $( impl<$($T,)*> ConnAux for ($($T,)*) @@ -77,34 +78,53 @@ macro_rules! impl_0_16 { } } - impl<$($T,)* CA, ERR, SA> ReqMiddleware for ($($T,)*) + impl<$($T,)* CA, ERR, SA> Middleware for ($($T,)*) where - $($T: ReqMiddleware,)* + $($T: Middleware,)* ERR: From { + type Aux = ($($T::Aux,)*); + #[inline] - async fn apply_req_middleware(&self, _conn_aux: &mut CA, _req: &mut Request, _stream_aux: &mut SA) -> Result<(), ERR> { - $( self.$N.apply_req_middleware(_conn_aux, _req, _stream_aux).await?; )* - Ok(()) + fn aux(&self) -> Self::Aux { + ($(self.$N.aux(),)*) } - } - impl<$($T,)* CA, ERR, SA> ResMiddleware for ($($T,)*) - where - $($T: ResMiddleware,)* - ERR: From - { #[inline] - async fn apply_res_middleware(&self, _conn_aux: &mut CA, mut _res: Response<&mut ReqResBuffer>, _stream_aux: &mut SA) -> Result<(), ERR> { + async fn req( + &self, + _conn_aux: &mut CA, + _mx_aux: &mut Self::Aux, + _req: &mut Request, + _stream_aux: &mut SA, + ) -> Result, ERR> { + $({ + if let ControlFlow::Break(status_code) = self.$N.req(_conn_aux, &mut _mx_aux.$N, _req, _stream_aux).await? { + return Ok(ControlFlow::Break(status_code)); + } + })* + Ok(ControlFlow::Continue(())) + } + + #[inline] + async fn res( + &self, + _conn_aux: &mut CA, + _mx_aux: &mut Self::Aux, + _res: Response<&mut ReqResBuffer>, + _stream_aux: &mut SA, + ) -> Result, ERR> { $({ let local_res = Response { rrd: &mut *_res.rrd, status_code: _res.status_code, version: _res.version, }; - self.$N.apply_res_middleware(_conn_aux, local_res, _stream_aux).await?; + if let ControlFlow::Break(status_code) = self.$N.res(_conn_aux, &mut _mx_aux.$N, local_res, _stream_aux).await? { + return Ok(ControlFlow::Break(status_code)); + } })* - Ok(()) + Ok(ControlFlow::Continue(())) } } diff --git a/wtx/src/misc/uri.rs b/wtx/src/misc/uri.rs index 996d2908..80112cc9 100644 --- a/wtx/src/misc/uri.rs +++ b/wtx/src/misc/uri.rs @@ -19,6 +19,14 @@ pub type UriString = Uri; /// ```txt /// foo://user:password@hostname:80/path?query=value#hash /// ``` +// +// foo:// | user:password@hostname:80 | /path |?query=value#hash +// | | | +// | | |-> query_start +// | | +// | |---------> href_start +// | +// |-------------------------------------> authority_start #[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] pub struct Uri where @@ -112,6 +120,17 @@ where (self.hostname(), self.port().unwrap_or_default()) } + /// Returns the number of characters. + /// + /// ```rust + /// let uri = wtx::misc::Uri::new("foo://user:password@hostname:80/path?query=value#hash"); + /// assert_eq!(uri.len(), 53); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.uri.lease().len() + } + /// /// /// ```rust diff --git a/wtx/src/misc/vector.rs b/wtx/src/misc/vector.rs index fcfbff03..71818ede 100644 --- a/wtx/src/misc/vector.rs +++ b/wtx/src/misc/vector.rs @@ -46,8 +46,7 @@ impl From for u8 { impl core::error::Error for VectorError {} /// A wrapper around the std's vector. -#[cfg_attr(feature = "test-strategy", derive(test_strategy::Arbitrary))] -#[cfg_attr(feature = "test-strategy", arbitrary(bound(T: proptest::arbitrary::Arbitrary + 'static)))] +//#[cfg_attr(kani, derive(kani::Arbitrary))] #[derive(Clone, Eq, PartialEq)] #[repr(transparent)] pub struct Vector { @@ -195,7 +194,7 @@ impl Vector { self.data.drain(range) } - /// Clones and appends all elements in the iterator. + /// Appends all elements of the iterator. /// /// ```rust /// let mut vec = wtx::misc::Vector::new(); @@ -711,24 +710,35 @@ mod cl_aux { } } -#[cfg(all(feature = "_proptest", test))] -mod _proptest { +#[cfg(kani)] +mod kani { use crate::misc::Vector; - use alloc::vec::Vec; - #[test_strategy::proptest] - fn insert(elem: u8, idx: usize, mut vec: Vec) { + #[kani::proof] + fn extend_from_iter() { + let mut from = Vector::from_vec(kani::vec::any_vec::()); + let to = kani::vec::any_vec::(); + from.extend_from_iter(to.into_iter()).unwrap(); + } + + #[kani::proof] + fn insert() { + let elem = kani::any(); + let idx = kani::any(); + let mut vec = kani::vec::any_vec::(); let mut vector = Vector::from_vec(vec.clone()); if idx > vec.len() { - return Ok(()); + return; } vec.insert(idx, elem); vector.insert(idx, elem).unwrap(); assert_eq!(vec.as_slice(), vector.as_slice()); } - #[test_strategy::proptest] - fn push(elem: u8, mut vec: Vec) { + #[kani::proof] + fn push() { + let elem = kani::any(); + let mut vec = kani::vec::any_vec::(); let mut vector = Vector::from_vec(vec.clone()); vec.push(elem); vector.push(elem).unwrap(); diff --git a/wtx/src/pool/resource_manager.rs b/wtx/src/pool/resource_manager.rs index 615744c3..fd239068 100644 --- a/wtx/src/pool/resource_manager.rs +++ b/wtx/src/pool/resource_manager.rs @@ -102,6 +102,7 @@ pub(crate) mod database { pub struct PostgresRM { _certs: Option<&'static [u8]>, error: PhantomData E>, + max_stmts: usize, rng: Xorshift64Sync, stream: PhantomData, uri: String, @@ -121,7 +122,7 @@ pub(crate) mod database { use crate::{ database::{ client::postgres::{Executor, ExecutorBuffer}, - Executor as _, + Executor as _, DEFAULT_MAX_STMTS, }, misc::{simple_seed, Xorshift64Sync}, pool::{PostgresRM, ResourceManager}, @@ -137,6 +138,7 @@ pub(crate) mod database { Self { _certs: None, error: PhantomData, + max_stmts: DEFAULT_MAX_STMTS, rng: Xorshift64Sync::from(simple_seed()), stream: PhantomData, uri, @@ -156,10 +158,9 @@ pub(crate) mod database { #[inline] async fn create(&self, _: &Self::CreateAux) -> Result { executor!(&self.uri, |config, uri| { - let eb = ExecutorBuffer::with_default_params(&mut &self.rng)?; Executor::connect( &config, - eb, + ExecutorBuffer::new(self.max_stmts, &mut &self.rng), &mut &self.rng, TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(Into::into)?, ) @@ -177,7 +178,7 @@ pub(crate) mod database { _: &Self::RecycleAux, resource: &mut Self::Resource, ) -> Result<(), Self::Error> { - let mut buffer = ExecutorBuffer::_empty(); + let mut buffer = ExecutorBuffer::new(self.max_stmts, &mut &self.rng); mem::swap(&mut buffer, &mut resource.eb); *resource = executor!(&self.uri, |config, uri| { Executor::connect( @@ -197,7 +198,7 @@ pub(crate) mod database { use crate::{ database::{ client::postgres::{Executor, ExecutorBuffer}, - Executor as _, + Executor as _, DEFAULT_MAX_STMTS, }, misc::{simple_seed, TokioRustlsConnector, Xorshift64Sync}, pool::{PostgresRM, ResourceManager}, @@ -214,6 +215,7 @@ pub(crate) mod database { Self { _certs: certs, error: PhantomData, + max_stmts: DEFAULT_MAX_STMTS, rng: Xorshift64Sync::from(simple_seed()), stream: PhantomData, uri, @@ -235,7 +237,7 @@ pub(crate) mod database { executor!(&self.uri, |config, uri| { Executor::connect_encrypted( &config, - ExecutorBuffer::with_default_params(&mut &self.rng)?, + ExecutorBuffer::new(self.max_stmts, &mut &self.rng), TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(Into::into)?, &mut &self.rng, |stream| async { @@ -260,12 +262,12 @@ pub(crate) mod database { _: &Self::RecycleAux, resource: &mut Self::Resource, ) -> Result<(), Self::Error> { - let mut buffer = ExecutorBuffer::_empty(); + let mut buffer = ExecutorBuffer::new(self.max_stmts, &mut &self.rng); mem::swap(&mut buffer, &mut resource.eb); *resource = executor!(&self.uri, |config, uri| { Executor::connect_encrypted( &config, - ExecutorBuffer::with_default_params(&mut &self.rng)?, + buffer, TcpStream::connect(uri.hostname_with_implied_port()).await.map_err(Into::into)?, &mut &self.rng, |stream| async { diff --git a/wtx/src/pool/simple_pool.rs b/wtx/src/pool/simple_pool.rs index d8f1a568..d6dec128 100644 --- a/wtx/src/pool/simple_pool.rs +++ b/wtx/src/pool/simple_pool.rs @@ -89,6 +89,15 @@ where } } +#[cfg(feature = "http-server-framework")] +impl crate::http::server_framework::ConnAux for SimplePool { + type Init = Self; + + #[inline] + fn conn_aux(init: Self::Init) -> crate::Result { + Ok(init) + } +} #[cfg(feature = "http-server-framework")] impl crate::http::server_framework::StreamAux for SimplePool { type Init = Self; diff --git a/wtx/src/web_socket.rs b/wtx/src/web_socket.rs index 69c896e7..004e2efd 100644 --- a/wtx/src/web_socket.rs +++ b/wtx/src/web_socket.rs @@ -91,15 +91,13 @@ where { /// Creates a new instance from a stream that supposedly has already completed the handshake. #[inline] - pub fn new( + pub const fn new( nc: NC, no_masking: bool, rng: Xorshift64, stream: S, - mut wsb: WSB, + wsb: WSB, ) -> crate::Result { - wsb.lease_mut().network_buffer._clear_if_following_is_empty(); - wsb.lease_mut().network_buffer._reserve(MAX_HEADER_LEN_USIZE)?; Ok(Self { connection_state: ConnectionState::Open, curr_payload: PayloadTy::None, diff --git a/wtx/src/web_socket/read_frame_info.rs b/wtx/src/web_socket/read_frame_info.rs index c5283adf..ea72c98b 100644 --- a/wtx/src/web_socket/read_frame_info.rs +++ b/wtx/src/web_socket/read_frame_info.rs @@ -1,5 +1,5 @@ use crate::{ - misc::{PartitionedFilledBuffer, Stream, _read_until}, + misc::{PartitionedFilledBuffer, Stream, _read_header}, web_socket::{ compression::NegotiatedCompression, misc::{has_masked_frame, op_code}, @@ -85,26 +85,37 @@ impl ReadFrameInfo { S: Stream, { let buffer = network_buffer._following_rest_mut(); - let first_two = _read_until::<2, S>(buffer, read, 0, stream).await?; + let first_two = _read_header::<0, 2, S>(buffer, read, stream).await?; let tuple = Self::manage_first_two_bytes(first_two, nc)?; let (fin, length_code, masked, op_code, should_decompress) = tuple; - let (mut header_len, payload_len) = match length_code { + let mut mask = None; + let (header_len, payload_len) = match length_code { 126 => { - let payload_len = _read_until::<2, S>(buffer, read, 2, stream).await?; - (4u8, u16::from_be_bytes(payload_len).into()) + let payload_len = _read_header::<2, 2, S>(buffer, read, stream).await?; + if Self::manage_mask::(masked, no_masking)? { + mask = Some(_read_header::<4, 4, S>(buffer, read, stream).await?); + (8, u16::from_be_bytes(payload_len).into()) + } else { + (4, u16::from_be_bytes(payload_len).into()) + } } 127 => { - let payload_len = _read_until::<8, S>(buffer, read, 2, stream).await?; - (10, u64::from_be_bytes(payload_len).try_into()?) + let payload_len = _read_header::<2, 8, S>(buffer, read, stream).await?; + if Self::manage_mask::(masked, no_masking)? { + mask = Some(_read_header::<10, 4, S>(buffer, read, stream).await?); + (14, u64::from_be_bytes(payload_len).try_into()?) + } else { + (10, u64::from_be_bytes(payload_len).try_into()?) + } + } + _ => { + if Self::manage_mask::(masked, no_masking)? { + mask = Some(_read_header::<2, 4, S>(buffer, read, stream).await?); + (6, length_code.into()) + } else { + (2, length_code.into()) + } } - _ => (2, length_code.into()), - }; - let mask = if Self::manage_mask::(masked, no_masking)? { - let rslt = _read_until::<4, S>(buffer, read, header_len.into(), stream).await?; - header_len = header_len.wrapping_add(4); - Some(rslt) - } else { - None }; Self::manage_final_params(fin, op_code, max_payload_len, payload_len)?; Ok(ReadFrameInfo { fin, header_len, mask, op_code, payload_len, should_decompress }) diff --git a/wtx/src/web_socket/unmask.rs b/wtx/src/web_socket/unmask.rs index 15e78e5c..3802799a 100644 --- a/wtx/src/web_socket/unmask.rs +++ b/wtx/src/web_socket/unmask.rs @@ -1,43 +1,52 @@ #[doc = _internal_doc!()] #[inline] pub(crate) fn unmask(bytes: &mut [u8], mut mask: [u8; 4]) { - let (is_128, unmask_chunks_slice) = _simd!( - 512 => (false, _unmask_chunks_slice_512), - 256 => (false, _unmask_chunks_slice_256), - 128 => (true, _unmask_chunks_slice_128), - _ => (false, _unmask_chunks_slice_fallback) + let unmask_chunks_slice = _simd!( + fallback => _unmask_chunks_slice_fallback, + 128 => _unmask_chunks_slice_128, + 256 => _unmask_chunks_slice_256, + 512 => _unmask_chunks_slice_512 ); - // SAFETY: Changing a sequence of `u8` should be fine let (prefix, chunks, suffix) = unsafe { bytes.align_to_mut() }; unmask_u8_slice(prefix, mask, 0); mask.rotate_left(prefix.len() % 4); unmask_chunks_slice(chunks, mask); - unmask_u8_slice(suffix, mask, if is_128 { (chunks.len() % 2).wrapping_mul(2) } else { 0 }); + unmask_u8_slice(suffix, mask, 0); } #[inline] -fn _unmask_chunks_slice_512(bytes: &mut [u64], [a, b, c, d]: [u8; 4]) { - let mask = u64::from_be_bytes([d, c, b, a, d, c, b, a]); - for elem in bytes { - *elem ^= mask; +fn _unmask_chunks_slice_512(slice: &mut [[u8; 64]], [a, b, c, d]: [u8; 4]) { + let mask = [ + a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, + a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, + ]; + for array in slice { + for (array_elem, mask_elem) in array.iter_mut().zip(mask) { + *array_elem ^= mask_elem; + } } } #[inline] -fn _unmask_chunks_slice_256(bytes: &mut [u32], [a, b, c, d]: [u8; 4]) { - let mask = u32::from_be_bytes([d, c, b, a]); - for elem in bytes { - *elem ^= mask; +fn _unmask_chunks_slice_256(slice: &mut [[u8; 32]], [a, b, c, d]: [u8; 4]) { + let mask = [ + a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d, + ]; + for array in slice { + for (array_elem, mask_elem) in array.iter_mut().zip(mask) { + *array_elem ^= mask_elem; + } } } -#[expect(clippy::indexing_slicing, reason = "index will always be in-bounds")] #[inline] -fn _unmask_chunks_slice_128(bytes: &mut [u16], [a, b, c, d]: [u8; 4]) { - let mask = [u16::from_be_bytes([b, a]), u16::from_be_bytes([d, c])]; - for (idx, elem) in bytes.iter_mut().enumerate() { - *elem ^= mask[idx & 1]; +fn _unmask_chunks_slice_128(slice: &mut [[u8; 16]], [a, b, c, d]: [u8; 4]) { + let mask = [a, b, c, d, a, b, c, d, a, b, c, d, a, b, c, d]; + for array in slice { + for (array_elem, mask_elem) in array.iter_mut().zip(mask) { + *array_elem ^= mask_elem; + } } } @@ -65,12 +74,14 @@ mod bench { } } -#[cfg(all(feature = "_proptest", test))] -mod proptest { +#[cfg(kani)] +mod kani { use crate::misc::Vector; - #[test_strategy::proptest] - fn unmask(mut payload: Vector, mask: [u8; 4]) { + #[kani::proof] + fn unmask() { + let mask = kani::any(); + let mut payload = Vector::from(kani::vec::any_vec::()); payload.fill(0); crate::web_socket::unmask::unmask(&mut payload, mask); let expected = Vector::from_iter((0..payload.len()).map(|idx| mask[idx & 3])).unwrap(); diff --git a/wtx/src/web_socket/web_socket_reader.rs b/wtx/src/web_socket/web_socket_reader.rs index 05c6fc99..d09c2f27 100644 --- a/wtx/src/web_socket/web_socket_reader.rs +++ b/wtx/src/web_socket/web_socket_reader.rs @@ -7,11 +7,13 @@ use crate::{ misc::{ from_utf8_basic, from_utf8_ext, BufferMode, CompletionErr, ConnectionState, ExtUtf8Error, FnMutFut, IncompleteUtf8Char, LeaseMut, PartitionedFilledBuffer, Rng, Stream, Vector, + _read_payload, }, web_socket::{ compression::NegotiatedCompression, fill_with_close_code, payload_ty::PayloadTy, read_frame_info::ReadFrameInfo, unmask::unmask, web_socket_writer::manage_normal_frame, CloseCode, Frame, FrameMut, OpCode, WebSocketError, MAX_CONTROL_PAYLOAD_LEN, + MAX_HEADER_LEN_USIZE, }, }; @@ -232,9 +234,7 @@ where RNG: Rng, S: Stream, { - reader_buffer_second.clear(); let first_rfi = loop { - network_buffer._clear_if_following_is_empty(); reader_buffer_first.clear(); let rfi = fetch_frame_from_stream::<_, _, IS_CLIENT>( max_payload_len, @@ -282,6 +282,7 @@ where return Ok((Frame::new(true, rfi.op_code, borrow_checker, nc.rsv1()), payload_ty)); } }; + reader_buffer_second.clear(); if first_rfi.should_decompress { read_continuation_frames::<_, _, _, IS_CLIENT>( connection_state, @@ -332,7 +333,6 @@ fn copy_from_arbitrary_nb_to_rb1( let current_mut = network_buffer._current_mut(); unmask_nb::(current_mut, no_masking, rfi)?; reader_buffer_first.extend_from_copyable_slice(current_mut)?; - network_buffer._clear_if_following_is_empty(); Ok(()) } @@ -431,6 +431,8 @@ where NC: NegotiatedCompression, S: Stream, { + network_buffer._clear_if_following_is_empty(); + network_buffer._reserve(MAX_HEADER_LEN_USIZE)?; let mut read = network_buffer._following_len(); let rfi = ReadFrameInfo::from_stream::<_, _, IS_CLIENT>( max_payload_len, @@ -441,40 +443,11 @@ where stream, ) .await?; - let frame_len = rfi.payload_len.wrapping_add(rfi.header_len.into()); - fetch_payload_from_stream(frame_len, network_buffer, &mut read, stream).await?; - network_buffer._set_indices( - network_buffer._current_end_idx().wrapping_add(rfi.header_len.into()), - rfi.payload_len, - read.wrapping_sub(frame_len), - )?; + let header_len = rfi.header_len.into(); + _read_payload((header_len, rfi.payload_len), network_buffer, &mut read, stream).await?; Ok(rfi) } -#[inline] -async fn fetch_payload_from_stream( - frame_len: usize, - network_buffer: &mut PartitionedFilledBuffer, - read: &mut usize, - stream: &mut S, -) -> crate::Result<()> -where - S: Stream, -{ - network_buffer._reserve(frame_len)?; - for _ in 0..=frame_len { - if *read >= frame_len { - return Ok(()); - } - *read = read.wrapping_add( - stream - .read(network_buffer._following_rest_mut().get_mut(*read..).unwrap_or_default()) - .await?, - ); - } - Err(crate::Error::UnexpectedBufferState) -} - #[inline] async fn read_continuation_frames( connection_state: &mut ConnectionState,