From 3a4cd945e8905833df91809db4cf10f6c51eb8c3 Mon Sep 17 00:00:00 2001 From: rkrishn7 Date: Sun, 31 Mar 2024 19:20:42 -0700 Subject: [PATCH] replace hyper for axum --- .gitignore | 2 + Cargo.lock | 337 ++++++++++++++++++++++++++++++----------- src/kiwi/Cargo.toml | 7 +- src/kiwi/src/config.rs | 16 ++ src/kiwi/src/lib.rs | 1 + src/kiwi/src/main.rs | 11 +- src/kiwi/src/tls.rs | 81 ++++++++++ src/kiwi/src/ws.rs | 157 +++++++++++-------- 8 files changed, 447 insertions(+), 165 deletions(-) create mode 100644 src/kiwi/src/tls.rs diff --git a/.gitignore b/.gitignore index 1d5d251..f93b48f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ /target /*.wasm /*kiwi.yml +/*.crt +/*.key diff --git a/Cargo.lock b/Cargo.lock index 6436495..4e1ad4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "ambient-authority" version = "0.0.2" @@ -141,37 +150,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] -name = "axum" -version = "0.7.4" +name = "aws-lc-rs" +version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e" +checksum = "df33e4a55b03f8780ba55041bc7be91a2a8ec8c03517b0379d2d6c96d2c30d95" dependencies = [ - "async-trait", - "axum-core", - "bytes", - "futures-util", - "http 1.0.0", - "http-body 1.0.0", - "http-body-util", - "hyper 1.2.0", - "hyper-util", - "itoa", - "matchit", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", - "sync_wrapper", - "tokio", - "tower", - "tower-layer", - "tower-service", - "tracing", + "aws-lc-sys", + "mirai-annotations", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ede3d6e360a48436fee127cb81710834407b1ec0c48a001cc29dec9005f73e" +dependencies = [ + "bindgen", + "cmake", + "dunce", + "fs_extra", + "libc", + "paste", ] [[package]] @@ -192,30 +193,6 @@ dependencies = [ "sync_wrapper", "tower-layer", "tower-service", - "tracing", -] - -[[package]] -name = "axum-server" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad46c3ec4e12f4a4b6835e173ba21c25e484c9d02b49770bf006ce5367c036" -dependencies = [ - "arc-swap", - "bytes", - "futures-util", - "http 1.0.0", - "http-body 1.0.0", - "http-body-util", - "hyper 1.2.0", - "hyper-util", - "pin-project-lite", - "rustls", - "rustls-pemfile 2.1.1", - "tokio", - "tokio-rustls", - "tower", - "tower-service", ] [[package]] @@ -254,6 +231,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.4.1", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.48", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -378,6 +378,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -390,6 +399,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "clang-sys" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.4.16" @@ -707,6 +727,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "dunce" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" + [[package]] name = "either" version = "1.9.0" @@ -837,6 +863,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -989,6 +1021,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.3.24" @@ -1069,6 +1107,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" version = "0.2.12" @@ -1356,8 +1403,6 @@ dependencies = [ "arc-swap", "async-stream", "async-trait", - "axum", - "axum-server", "base64 0.21.5", "bytes", "clap", @@ -1377,12 +1422,14 @@ dependencies = [ "rdkafka", "reqwest", "ringbuf", + "rustls-pemfile 2.1.1", "serde", "serde_json", "serde_yaml", "tempfile", "thiserror", "tokio", + "tokio-rustls 0.26.0", "tokio-stream", "tracing", "tracing-subscriber", @@ -1438,6 +1485,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "leb128" version = "0.2.5" @@ -1450,6 +1503,16 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +dependencies = [ + "cfg-if", + "windows-targets 0.52.0", +] + [[package]] name = "libredox" version = "0.0.1" @@ -1510,12 +1573,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - [[package]] name = "maybe-owned" version = "0.3.4" @@ -1552,6 +1609,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -1573,6 +1636,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mirai-annotations" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" + [[package]] name = "nanoid" version = "0.4.0" @@ -1612,6 +1681,16 @@ dependencies = [ "libc", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "notify" version = "6.1.1" @@ -1819,6 +1898,16 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "prettyplease" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" +dependencies = [ + "proc-macro2", + "syn 2.0.48", +] + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -1971,6 +2060,35 @@ dependencies = [ "smallvec", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + [[package]] name = "reqwest" version = "0.11.26" @@ -2069,10 +2187,25 @@ checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4d6d8ad9f2492485e13453acbb291dd08f64441b6609c491f1c2cd2c6b4fe1" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2108,6 +2241,18 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -2214,16 +2359,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_path_to_error" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" -dependencies = [ - "itoa", - "serde", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2289,6 +2424,12 @@ dependencies = [ "dirs", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -2548,7 +2689,18 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.10", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.4", + "rustls-pki-types", "tokio", ] @@ -2604,22 +2756,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tower-layer" version = "0.3.2" @@ -2638,7 +2774,6 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -3299,9 +3434,9 @@ dependencies = [ "http-body 1.0.0", "http-body-util", "hyper 1.2.0", - "rustls", + "rustls 0.21.10", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tracing", "wasmtime", "wasmtime-wasi", @@ -3390,6 +3525,18 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "wiggle" version = "18.0.2" @@ -3800,6 +3947,12 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + [[package]] name = "zstd" version = "0.11.2+zstd.1.5.2" diff --git a/src/kiwi/Cargo.toml b/src/kiwi/Cargo.toml index 0ed6b67..c90de0a 100644 --- a/src/kiwi/Cargo.toml +++ b/src/kiwi/Cargo.toml @@ -38,9 +38,12 @@ http = "1.0.0" notify = "6.1.1" arc-swap = "1.7.0" fastwebsockets = { version = "0.7.0", features = ["with_axum", "upgrade"] } -axum = "0.7.4" http-body-util = "0.1.1" -axum-server = { version = "0.6.0", features = ["tls-rustls"] } +hyper = "1.2.0" +hyper-util = { version = "0.1.3", features = ["server", "http1", "http2"] } +tokio-rustls = "0.26.0" +rustls-pemfile = "2.1.1" +bytes = "1.5.0" [dev-dependencies] tempfile = "3" diff --git a/src/kiwi/src/config.rs b/src/kiwi/src/config.rs index 2379046..8b6788d 100644 --- a/src/kiwi/src/config.rs +++ b/src/kiwi/src/config.rs @@ -96,6 +96,14 @@ pub struct Hooks { pub struct Server { pub address: String, pub tls: Option, + #[serde(default = "Server::default_healthcheck_enabled")] + pub healthcheck: bool, +} + +impl Server { + fn default_healthcheck_enabled() -> bool { + true + } } /// TLS configuration @@ -675,6 +683,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -701,6 +710,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -717,6 +727,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: Some(Kafka { group_id_prefix: "kiwi-".into(), @@ -752,6 +763,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -803,6 +815,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -859,6 +872,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -890,6 +904,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), @@ -918,6 +933,7 @@ mod tests { server: Server { address: "127.0.0.1:8000".into(), tls: None, + healthcheck: false, }, kafka: None, subscriber: Subscriber::default(), diff --git a/src/kiwi/src/lib.rs b/src/kiwi/src/lib.rs index 2ff0862..f351d47 100644 --- a/src/kiwi/src/lib.rs +++ b/src/kiwi/src/lib.rs @@ -4,5 +4,6 @@ pub mod hook; pub mod protocol; pub mod source; pub mod subscription; +pub mod tls; pub mod util; pub mod ws; diff --git a/src/kiwi/src/main.rs b/src/kiwi/src/main.rs index 8c53d5a..1b1319c 100644 --- a/src/kiwi/src/main.rs +++ b/src/kiwi/src/main.rs @@ -87,19 +87,22 @@ async fn main() -> anyhow::Result<()> { tokio::select! { _ = term.recv() => { tracing::info!("Received SIGTERM, shutting down"); + Ok(()) } _ = tokio::signal::ctrl_c() => { tracing::info!("Received SIGINT, shutting down"); + Ok(()) } - _ = kiwi::ws::serve( + res = kiwi::ws::serve( &listen_addr, sources, intercept, authenticate, config.subscriber, config.server.tls, - ) => {} + config.server.healthcheck, + ) => { + res + } } - - Ok(()) } diff --git a/src/kiwi/src/tls.rs b/src/kiwi/src/tls.rs new file mode 100644 index 0000000..3d110f8 --- /dev/null +++ b/src/kiwi/src/tls.rs @@ -0,0 +1,81 @@ +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use std::sync::Arc; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio_rustls::server::TlsStream; +use tokio_rustls::{rustls, TlsAcceptor}; + +fn load_certs(path: impl AsRef) -> std::io::Result>> { + rustls_pemfile::certs(&mut BufReader::new(File::open(path)?)).collect() +} + +fn load_key(path: impl AsRef) -> anyhow::Result>> { + Ok(rustls_pemfile::private_key(&mut BufReader::new( + File::open(path)?, + ))?) +} + +pub fn tls_acceptor(cert: impl AsRef, key: impl AsRef) -> anyhow::Result { + let key = load_key(key)?.expect("no key found"); + let certs = load_certs(cert)?; + + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key)?; + + Ok(TlsAcceptor::from(Arc::new(config))) +} + +pub enum MaybeTlsStream { + Plain(S), + Tls(TlsStream), +} + +impl AsyncRead for MaybeTlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} diff --git a/src/kiwi/src/ws.rs b/src/kiwi/src/ws.rs index 5efb375..8859f9e 100644 --- a/src/kiwi/src/ws.rs +++ b/src/kiwi/src/ws.rs @@ -2,14 +2,13 @@ use std::collections::BTreeMap; use std::sync::Mutex; use std::{net::SocketAddr, sync::Arc}; +use anyhow::Context; use arc_swap::ArcSwapOption; -use axum::body::Body; -use axum::extract::{ConnectInfo, Request, State}; -use axum::{response::IntoResponse, routing::get, Router}; -use axum_server::tls_rustls::RustlsConfig; +use bytes::Bytes; use fastwebsockets::{upgrade, CloseCode, FragmentCollector, Frame, Payload, WebSocketError}; -use http::{Response, StatusCode}; +use http::{Request, Response, StatusCode}; use http_body_util::Empty; +use hyper::service::service_fn; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::connection::ConnectionManager; @@ -20,16 +19,10 @@ use crate::hook::intercept::types::{AuthCtx, ConnectionCtx, WebSocketConnectionC use crate::hook::intercept::types::Intercept; use crate::protocol::{Command, Message, ProtocolError}; use crate::source::{Source, SourceId}; +use crate::tls::{tls_acceptor, MaybeTlsStream}; type Sources = Arc>>>; -struct AppState { - sources: Sources, - intercept: Arc>, - authenticate: Arc>, - subscriber_config: crate::config::Subscriber, -} - /// Starts a WebSocket server with the specified configuration pub async fn serve( listen_addr: &SocketAddr, @@ -38,56 +31,85 @@ pub async fn serve( authenticate: Arc>, subscriber_config: crate::config::Subscriber, tls_config: Option, + healthcheck: bool, ) -> anyhow::Result<()> where I: Intercept + Send + Sync + 'static, A: Authenticate + Send + Sync + Unpin + 'static, { - let app = make_app(sources, intercept, authenticate, subscriber_config); - let svc = app.into_make_service_with_connect_info::(); - - tracing::info!("Server listening on: {}", listen_addr); - - if let Some(tls) = tls_config { - let config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?; - - axum_server::bind_rustls(*listen_addr, config) - .serve(svc) - .await?; + let acceptor = if let Some(tls) = tls_config { + Some(tls_acceptor(&tls.cert, &tls.key).context("Failed to build TLS acceptor")?) } else { - axum_server::bind(*listen_addr).serve(svc).await?; + None }; + let listener = tokio::net::TcpListener::bind(listen_addr).await?; + tracing::info!("Server listening on: {listen_addr}"); - Ok(()) -} + loop { + let (stream, addr) = listener.accept().await?; + tracing::debug!(addr = ?addr, "Accepted connection"); + let acceptor = acceptor.clone(); + let authenticate = Arc::clone(&authenticate); + let intercept = Arc::clone(&intercept); + let sources = Arc::clone(&sources); + let subscriber_config = subscriber_config.clone(); + + tokio::spawn(async move { + let io = if let Some(acceptor) = acceptor { + match acceptor.accept(stream).await { + Ok(stream) => hyper_util::rt::TokioIo::new(MaybeTlsStream::Tls(stream)), + Err(e) => { + tracing::error!(addr = ?addr, "Failed to accept TLS connection: {}", e); + return; + } + } + } else { + hyper_util::rt::TokioIo::new(MaybeTlsStream::Plain(stream)) + }; + + let builder = + hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let conn_fut = builder.serve_connection_with_upgrades( + io, + service_fn(move |req: Request| { + let authenticate = Arc::clone(&authenticate); + let sources = Arc::clone(&sources); + let intercept = Arc::clone(&intercept); + let subscriber_config = subscriber_config.clone(); + + async move { + if healthcheck && req.uri().path() == "/health" { + return Response::builder() + .status(StatusCode::OK) + .body(Empty::new()); + } -fn make_app( - sources: Sources, - intercept: Arc>, - authenticate: Arc>, - subscriber_config: crate::config::Subscriber, -) -> Router -where - I: Intercept + Send + Sync + 'static, - A: Authenticate + Send + Sync + Unpin + 'static, -{ - let state = AppState { - sources, - intercept, - authenticate, - subscriber_config, - }; + let response = handle_ws( + sources, + intercept, + authenticate, + subscriber_config, + addr, + req, + ) + .await; + + Ok(response) + } + }), + ); - Router::new() - .route("/", get(ws_handler)) - .route("/health", get(healthcheck)) - .with_state(Arc::new(state)) + if let Err(e) = conn_fut.await { + tracing::error!(addr = ?addr, "Error occurred while serving connection: {}", e); + } + }); + } } #[tracing::instrument(skip_all)] async fn load_auth_ctx( authenticate: Arc>, - request: Request, + request: Request, ) -> Result, ()> where A: Authenticate + Send + Sync + Unpin + 'static, @@ -114,18 +136,21 @@ where } } -async fn ws_handler( - ConnectInfo(addr): ConnectInfo, - State(state): State>>, - mut request: Request, -) -> impl IntoResponse +async fn handle_ws( + sources: Sources, + intercept: Arc>, + authenticate: Arc>, + subscriber_config: crate::config::Subscriber, + addr: SocketAddr, + mut request: Request, +) -> Response> where I: Intercept + Send + Sync + 'static, A: Authenticate + Send + Sync + Unpin + 'static, { - let (response, fut) = upgrade::upgrade(&mut request).expect("failed to build upgrade response"); + let (response, fut) = upgrade::upgrade(&mut request).expect("Failed to upgrade connection"); - let authenticate = Arc::clone(&state.authenticate); + let authenticate = Arc::clone(&authenticate); let auth_ctx = if let Ok(auth_ctx) = load_auth_ctx(authenticate, request).await { auth_ctx @@ -136,9 +161,6 @@ where .unwrap(); }; - let intercept = Arc::clone(&state.intercept); - let sources = Arc::clone(&state.sources); - let subscriber_config = state.subscriber_config.clone(); let connection_ctx = ConnectionCtx::WebSocket(WebSocketConnectionCtx { addr }); tokio::spawn(async move { @@ -159,16 +181,12 @@ where ); } - tracing::debug!(connection = ?connection_ctx, "WebSocket connection terminated"); + tracing::debug!(connection = ?connection_ctx, "WebSocket connection terminated normally"); }); response } -async fn healthcheck() -> impl IntoResponse { - "OK" -} - async fn handle_client( fut: upgrade::UpgradeFut, sources: Sources, @@ -209,16 +227,16 @@ where tokio::select! { biased; - Some(cmd) = recv_cmd(&mut ws) => { - match cmd { - Ok(cmd) => { + maybe_cmd = recv_cmd(&mut ws) => { + match maybe_cmd { + Some(Ok(cmd)) => { if cmd_tx.send(cmd).is_err() { // If the send failed, the channel is closed thus we should // terminate the connection break; } } - Err(e) => { + Some(Err(e)) => { let (close_code, reason) = match e { RecvError::WebSocket(e) => { match e { @@ -242,6 +260,10 @@ where ws.write_frame(frame).await?; break; } + None => { + // The connection has been closed + break; + } } }, msg = msg_rx.recv() => { @@ -298,6 +320,7 @@ where fastwebsockets::OpCode::Binary => Some(Err(RecvError::Protocol( ProtocolError::UnsupportedCommandForm, ))), - _ => None, + fastwebsockets::OpCode::Close => None, + _ => panic!("Received unexpected opcode"), } }