diff --git a/.gitmodules b/.gitmodules index 65fcd3bd..47ebad0b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,6 @@ [submodule "submodules/h3"] path = submodules/h3 url = git@github.com:junkurihara/h3.git -[submodule "submodules/quinn"] - path = submodules/quinn - url = git@github.com:junkurihara/quinn.git -[submodule "submodules/s2n-quic"] - path = submodules/s2n-quic - url = git@github.com:junkurihara/s2n-quic.git [submodule "submodules/rusty-http-cache-semantics"] path = submodules/rusty-http-cache-semantics url = git@github.com:junkurihara/rusty-http-cache-semantics.git diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 36c53b18..1848e5e3 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "cache"] +default = ["http3-s2n", "cache"] http3-quinn = ["rpxy-lib/http3-quinn"] http3-s2n = ["rpxy-lib/http3-s2n"] cache = ["rpxy-lib/cache"] diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 7f10e603..b4b475dc 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "sticky-cookie", "cache"] +default = ["http3-s2n", "sticky-cookie", "cache"] http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] sticky-cookie = ["base64", "sha2", "chrono"] @@ -25,7 +25,7 @@ rustc-hash = "1.1.0" bytes = "1.5.0" derive_builder = "0.12.0" futures = { version = "0.3.29", features = ["alloc", "async-await"] } -tokio = { version = "1.33.0", default-features = false, features = [ +tokio = { version = "1.34.0", default-features = false, features = [ "net", "rt-multi-thread", "time", @@ -41,12 +41,10 @@ anyhow = "1.0.75" thiserror = "1.0.50" # http and tls -hyper = { version = "0.14.27", default-features = false, features = [ - "server", - "http1", - "http2", - "stream", -] } +http = "1.0.0" +http-body-util = "0.1.0" +hyper = { version = "1.0.1", default-features = false } +hyper-util = { version = "0.1.0", features = ["full"] } hyper-rustls = { version = "0.24.2", default-features = false, features = [ "tokio-runtime", "webpki-tokio", @@ -54,7 +52,7 @@ hyper-rustls = { version = "0.24.2", default-features = false, features = [ "http2", ] } tokio-rustls = { version = "0.24.1", features = ["early-data"] } -rustls = { version = "0.21.8", default-features = false } +rustls = { version = "0.21.9", default-features = false } webpki = "0.22.4" x509-parser = "0.15.1" @@ -62,18 +60,16 @@ x509-parser = "0.15.1" tracing = { version = "0.1.40" } # http/3 -# quinn = { version = "0.9.3", optional = true } -quinn = { path = "../submodules/quinn/quinn", optional = true } # Tentative to support rustls-0.21 +quinn = { version = "0.10.2", optional = true } h3 = { path = "../submodules/h3/h3/", optional = true } -# h3-quinn = { path = "./h3/h3-quinn/", optional = true } -h3-quinn = { path = "../submodules/h3-quinn/", optional = true } # Tentative to support rustls-0.21 -# for UDP socket wit SO_REUSEADDR when h3 with quinn -socket2 = { version = "0.5.5", features = ["all"], optional = true } -s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [ +h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +s2n-quic = { version = "1.31.0", default-features = false, features = [ "provider-tls-rustls", ], optional = true } -s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", optional = true } -s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls/", optional = true } +s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +s2n-quic-rustls = { version = "0.31.0", optional = true } +# for UDP socket wit SO_REUSEADDR when h3 with quinn +socket2 = { version = "0.5.5", features = ["all"], optional = true } # cache http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } @@ -90,3 +86,4 @@ sha2 = { version = "0.10.8", default-features = false, optional = true } [dev-dependencies] +# http and tls diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index d1c01306..02605a60 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -33,6 +33,9 @@ where /// Shared context - Async task runtime handler pub runtime_handle: tokio::runtime::Handle, + + /// Shared context - Notify object to stop async tasks + pub term_notify: Option>, } /// Configuration parameters for proxy transport and request handlers diff --git a/rpxy-lib/src/handler/error.rs b/rpxy-lib/src/handler/error.rs new file mode 100644 index 00000000..8fb9d79d --- /dev/null +++ b/rpxy-lib/src/handler/error.rs @@ -0,0 +1,16 @@ +use http::StatusCode; +use thiserror::Error; + +pub type HttpResult = std::result::Result; + +/// Describes things that can go wrong in the handler +#[derive(Debug, Error)] +pub enum HttpError {} + +impl From for StatusCode { + fn from(e: HttpError) -> StatusCode { + match e { + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/rpxy-lib/src/handler/handler_main.rs b/rpxy-lib/src/handler/handler_main.rs index 8b13dc75..2720c2fe 100644 --- a/rpxy-lib/src/handler/handler_main.rs +++ b/rpxy-lib/src/handler/handler_main.rs @@ -1,9 +1,10 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy use super::{ - forwarder::{ForwardRequest, Forwarder}, + error::*, + // forwarder::{ForwardRequest, Forwarder}, utils_headers::*, utils_request::*, - utils_synth_response::*, + // utils_synth_response::*, HandlerContext, }; use crate::{ @@ -16,365 +17,368 @@ use crate::{ utils::ServerNameBytesExp, }; use derive_builder::Builder; -use hyper::{ - client::connect::Connect, +use http::{ header::{self, HeaderValue}, - http::uri::Scheme, - Body, Request, Response, StatusCode, Uri, Version, + uri::Scheme, + Request, Response, StatusCode, Uri, Version, }; +use hyper::body::Incoming; +use hyper_util::client::legacy::connect::Connect; use std::{net::SocketAddr, sync::Arc}; use tokio::{io::copy_bidirectional, time::timeout}; #[derive(Clone, Builder)] /// HTTP message handler for requests from clients and responses from backend applications, /// responsible to manipulate and forward messages to upstream backends and downstream clients. -pub struct HttpMessageHandler +// pub struct HttpMessageHandler +pub struct HttpMessageHandler where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, { - forwarder: Arc>, + // forwarder: Arc>, globals: Arc>, } -impl HttpMessageHandler +impl HttpMessageHandler where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, { - /// Return with an arbitrary status code of error and log message - fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result> { - log_data.status_code(&status_code).output(); - http_error(status_code) - } + // /// Return with an arbitrary status code of error and log message + // fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result> { + // log_data.status_code(&status_code).output(); + // http_error(status_code) + // } /// Handle incoming request message from a client pub async fn handle_request( &self, - mut req: Request, + mut req: Request, client_addr: SocketAddr, // アクセス制御用 listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> Result> { + ) -> Result>> { //////// let mut log_data = MessageLog::from(&req); log_data.client_addr(&client_addr); ////// - // Here we start to handle with server_name - let server_name = if let Ok(v) = req.parse_host() { - ServerNameBytesExp::from(v) - } else { - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - // check consistency of between TLS SNI and HOST/Request URI Line. - #[allow(clippy::collapsible_if)] - if tls_enabled && self.globals.proxy_config.sni_consistency { - if server_name != tls_server_name.unwrap_or_default() { - return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); - } - } - // Find backend application for given server_name, and drop if incoming request is invalid as request. - let backend = match self.globals.backends.apps.get(&server_name) { - Some(be) => be, - None => { - let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - }; - debug!("Serving by default app"); - self.globals.backends.apps.get(default_server_name).unwrap() - } - }; + // // Here we start to handle with server_name + // let server_name = if let Ok(v) = req.parse_host() { + // ServerNameBytesExp::from(v) + // } else { + // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + // }; + // // check consistency of between TLS SNI and HOST/Request URI Line. + // #[allow(clippy::collapsible_if)] + // if tls_enabled && self.globals.proxy_config.sni_consistency { + // if server_name != tls_server_name.unwrap_or_default() { + // return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); + // } + // } + // // Find backend application for given server_name, and drop if incoming request is invalid as request. + // let backend = match self.globals.backends.apps.get(&server_name) { + // Some(be) => be, + // None => { + // let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // }; + // debug!("Serving by default app"); + // self.globals.backends.apps.get(default_server_name).unwrap() + // } + // }; - // Redirect to https if !tls_enabled and redirect_to_https is true - if !tls_enabled && backend.https_redirection.unwrap_or(false) { - debug!("Redirect to secure connection: {}", &backend.server_name); - log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); - return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); - } + // // Redirect to https if !tls_enabled and redirect_to_https is true + // if !tls_enabled && backend.https_redirection.unwrap_or(false) { + // debug!("Redirect to secure connection: {}", &backend.server_name); + // log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); + // return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); + // } - // Find reverse proxy for given path and choose one of upstream host - // Longest prefix match - let path = req.uri().path(); - let Some(upstream_group) = backend.reverse_proxy.get(path) else { - return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data) - }; + // // Find reverse proxy for given path and choose one of upstream host + // // Longest prefix match + // let path = req.uri().path(); + // let Some(upstream_group) = backend.reverse_proxy.get(path) else { + // return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data); + // }; - // Upgrade in request header - let upgrade_in_request = extract_upgrade(req.headers()); - let request_upgraded = req.extensions_mut().remove::(); + // // Upgrade in request header + // let upgrade_in_request = extract_upgrade(req.headers()); + // let request_upgraded = req.extensions_mut().remove::(); - // Build request from destination information - let _context = match self.generate_request_forwarded( - &client_addr, - &listen_addr, - &mut req, - &upgrade_in_request, - upstream_group, - tls_enabled, - ) { - Err(e) => { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - Ok(v) => v, - }; - debug!("Request to be forwarded: {:?}", req); - log_data.xff(&req.headers().get("x-forwarded-for")); - log_data.upstream(req.uri()); - ////// + // // Build request from destination information + // let _context = match self.generate_request_forwarded( + // &client_addr, + // &listen_addr, + // &mut req, + // &upgrade_in_request, + // upstream_group, + // tls_enabled, + // ) { + // Err(e) => { + // error!("Failed to generate destination uri for reverse proxy: {}", e); + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // } + // Ok(v) => v, + // }; + // debug!("Request to be forwarded: {:?}", req); + // log_data.xff(&req.headers().get("x-forwarded-for")); + // log_data.upstream(req.uri()); + // ////// - // Forward request to a chosen backend - let mut res_backend = { - let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { - return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); - }; - match result { - Ok(res) => res, - Err(e) => { - error!("Failed to get response from backend: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - } - }; + // // Forward request to a chosen backend + // let mut res_backend = { + // let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { + // return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); + // }; + // match result { + // Ok(res) => res, + // Err(e) => { + // error!("Failed to get response from backend: {}", e); + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // } + // } + // }; - // Process reverse proxy context generated during the forwarding request generation. - #[cfg(feature = "sticky-cookie")] - if let Some(context_from_lb) = _context.context_lb { - let res_headers = res_backend.headers_mut(); - if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { - error!("Failed to append context to the response given from backend: {}", e); - return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); - } - } + // // Process reverse proxy context generated during the forwarding request generation. + // #[cfg(feature = "sticky-cookie")] + // if let Some(context_from_lb) = _context.context_lb { + // let res_headers = res_backend.headers_mut(); + // if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + // error!("Failed to append context to the response given from backend: {}", e); + // return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + // } + // } - if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { - // Generate response to client - if self.generate_response_forwarded(&mut res_backend, backend).is_err() { - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - log_data.status_code(&res_backend.status()).output(); - return Ok(res_backend); - } + // if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { + // // Generate response to client + // if self.generate_response_forwarded(&mut res_backend, backend).is_err() { + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // log_data.status_code(&res_backend.status()).output(); + // return Ok(res_backend); + // } - // Handle StatusCode::SWITCHING_PROTOCOLS in response - let upgrade_in_response = extract_upgrade(res_backend.headers()); - let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) - { - u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() - } else { - false - }; - if !should_upgrade { - error!( - "Backend tried to switch to protocol {:?} when {:?} was requested", - upgrade_in_response, upgrade_in_request - ); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - let Some(request_upgraded) = request_upgraded else { - error!("Request does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - error!("Response does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - }; + // // Handle StatusCode::SWITCHING_PROTOCOLS in response + // let upgrade_in_response = extract_upgrade(res_backend.headers()); + // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) + // { + // u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() + // } else { + // false + // }; + // if !should_upgrade { + // error!( + // "Backend tried to switch to protocol {:?} when {:?} was requested", + // upgrade_in_response, upgrade_in_request + // ); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // let Some(request_upgraded) = request_upgraded else { + // error!("Request does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + // }; + // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + // error!("Response does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // }; - self.globals.runtime_handle.spawn(async move { - let mut response_upgraded = onupgrade.await.map_err(|e| { - error!("Failed to upgrade response: {}", e); - RpxyError::Hyper(e) - })?; - let mut request_upgraded = request_upgraded.await.map_err(|e| { - error!("Failed to upgrade request: {}", e); - RpxyError::Hyper(e) - })?; - copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - .await - .map_err(|e| { - error!("Coping between upgraded connections failed: {}", e); - RpxyError::Io(e) - })?; - Ok(()) as Result<()> - }); - log_data.status_code(&res_backend.status()).output(); - Ok(res_backend) + // self.globals.runtime_handle.spawn(async move { + // let mut response_upgraded = onupgrade.await.map_err(|e| { + // error!("Failed to upgrade response: {}", e); + // RpxyError::Hyper(e) + // })?; + // let mut request_upgraded = request_upgraded.await.map_err(|e| { + // error!("Failed to upgrade request: {}", e); + // RpxyError::Hyper(e) + // })?; + // copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + // .await + // .map_err(|e| { + // error!("Coping between upgraded connections failed: {}", e); + // RpxyError::Io(e) + // })?; + // Ok(()) as Result<()> + // }); + // log_data.status_code(&res_backend.status()).output(); + // Ok(res_backend) + todo!() } //////////////////////////////////////////////////// // Functions to generate messages //////////////////////////////////////////////////// - /// Manipulate a response message sent from a backend application to forward downstream to a client. - fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> - where - B: core::fmt::Debug, - { - let headers = response.headers_mut(); - remove_connection_header(headers); - remove_hop_header(headers); - add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; + // /// Manipulate a response message sent from a backend application to forward downstream to a client. + // fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> + // where + // B: core::fmt::Debug, + // { + // let headers = response.headers_mut(); + // remove_connection_header(headers); + // remove_hop_header(headers); + // add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; - #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] - { - // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled - // TODO: This is a workaround for avoiding a client authentication in HTTP/3 - if self.globals.proxy_config.http3 - && chosen_backend - .crypto_source - .as_ref() - .is_some_and(|v| !v.is_mutual_tls()) - { - if let Some(port) = self.globals.proxy_config.https_port { - add_header_entry_overwrite_if_exist( - headers, - header::ALT_SVC.as_str(), - format!( - "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age - ), - )?; - } - } else { - // remove alt-svc to disallow requests via http3 - headers.remove(header::ALT_SVC.as_str()); - } - } - #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] - { - if let Some(port) = self.globals.proxy_config.https_port { - headers.remove(header::ALT_SVC.as_str()); - } - } + // #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + // { + // // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled + // // TODO: This is a workaround for avoiding a client authentication in HTTP/3 + // if self.globals.proxy_config.http3 + // && chosen_backend + // .crypto_source + // .as_ref() + // .is_some_and(|v| !v.is_mutual_tls()) + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // add_header_entry_overwrite_if_exist( + // headers, + // header::ALT_SVC.as_str(), + // format!( + // "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + // port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age + // ), + // )?; + // } + // } else { + // // remove alt-svc to disallow requests via http3 + // headers.remove(header::ALT_SVC.as_str()); + // } + // } + // #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // headers.remove(header::ALT_SVC.as_str()); + // } + // } - Ok(()) - } + // Ok(()) + // } - #[allow(clippy::too_many_arguments)] - /// Manipulate a request message sent from a client to forward upstream to a backend application - fn generate_request_forwarded( - &self, - client_addr: &SocketAddr, - listen_addr: &SocketAddr, - req: &mut Request, - upgrade: &Option, - upstream_group: &UpstreamGroup, - tls_enabled: bool, - ) -> Result { - debug!("Generate request to be forwarded"); + // #[allow(clippy::too_many_arguments)] + // /// Manipulate a request message sent from a client to forward upstream to a backend application + // fn generate_request_forwarded( + // &self, + // client_addr: &SocketAddr, + // listen_addr: &SocketAddr, + // req: &mut Request, + // upgrade: &Option, + // upstream_group: &UpstreamGroup, + // tls_enabled: bool, + // ) -> Result { + // debug!("Generate request to be forwarded"); - // Add te: trailer if contained in original request - let contains_te_trailers = { - if let Some(te) = req.headers().get(header::TE) { - te.as_bytes() - .split(|v| v == &b',' || v == &b' ') - .any(|x| x == "trailers".as_bytes()) - } else { - false - } - }; + // // Add te: trailer if contained in original request + // let contains_te_trailers = { + // if let Some(te) = req.headers().get(header::TE) { + // te.as_bytes() + // .split(|v| v == &b',' || v == &b' ') + // .any(|x| x == "trailers".as_bytes()) + // } else { + // false + // } + // }; - let uri = req.uri().to_string(); - let headers = req.headers_mut(); - // delete headers specified in header.connection - remove_connection_header(headers); - // delete hop headers including header.connection - remove_hop_header(headers); - // X-Forwarded-For - add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; + // let uri = req.uri().to_string(); + // let headers = req.headers_mut(); + // // delete headers specified in header.connection + // remove_connection_header(headers); + // // delete hop headers including header.connection + // remove_hop_header(headers); + // // X-Forwarded-For + // add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; - // Add te: trailer if te_trailer - if contains_te_trailers { - headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); - } + // // Add te: trailer if te_trailer + // if contains_te_trailers { + // headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); + // } - // add "host" header of original server_name if not exist (default) - if req.headers().get(header::HOST).is_none() { - let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); - req - .headers_mut() - .insert(header::HOST, HeaderValue::from_str(&org_host)?); - }; + // // add "host" header of original server_name if not exist (default) + // if req.headers().get(header::HOST).is_none() { + // let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); + // req + // .headers_mut() + // .insert(header::HOST, HeaderValue::from_str(&org_host)?); + // }; - ///////////////////////////////////////////// - // Fix unique upstream destination since there could be multiple ones. - #[cfg(feature = "sticky-cookie")] - let (upstream_chosen_opt, context_from_lb) = { - let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { - takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? - } else { - None - }; - upstream_group.get(&context_to_lb) - }; - #[cfg(not(feature = "sticky-cookie"))] - let (upstream_chosen_opt, _) = upstream_group.get(&None); + // ///////////////////////////////////////////// + // // Fix unique upstream destination since there could be multiple ones. + // #[cfg(feature = "sticky-cookie")] + // let (upstream_chosen_opt, context_from_lb) = { + // let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + // takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + // } else { + // None + // }; + // upstream_group.get(&context_to_lb) + // }; + // #[cfg(not(feature = "sticky-cookie"))] + // let (upstream_chosen_opt, _) = upstream_group.get(&None); - let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; - let context = HandlerContext { - #[cfg(feature = "sticky-cookie")] - context_lb: context_from_lb, - #[cfg(not(feature = "sticky-cookie"))] - context_lb: None, - }; - ///////////////////////////////////////////// + // let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + // let context = HandlerContext { + // #[cfg(feature = "sticky-cookie")] + // context_lb: context_from_lb, + // #[cfg(not(feature = "sticky-cookie"))] + // context_lb: None, + // }; + // ///////////////////////////////////////////// - // apply upstream-specific headers given in upstream_option - let headers = req.headers_mut(); - apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; + // // apply upstream-specific headers given in upstream_option + // let headers = req.headers_mut(); + // apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; - // update uri in request - if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { - return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); - }; - let new_uri = Uri::builder() - .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) - .authority(upstream_chosen.uri.authority().unwrap().as_str()); - let org_pq = match req.uri().path_and_query() { - Some(pq) => pq.to_string(), - None => "/".to_string(), - } - .into_bytes(); + // // update uri in request + // if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { + // return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); + // }; + // let new_uri = Uri::builder() + // .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) + // .authority(upstream_chosen.uri.authority().unwrap().as_str()); + // let org_pq = match req.uri().path_and_query() { + // Some(pq) => pq.to_string(), + // None => "/".to_string(), + // } + // .into_bytes(); - // replace some parts of path if opt_replace_path is enabled for chosen upstream - let new_pq = match &upstream_group.replace_path { - Some(new_path) => { - let matched_path: &[u8] = upstream_group.path.as_ref(); - if matched_path.is_empty() || org_pq.len() < matched_path.len() { - return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); - }; - let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); - new_pq.extend_from_slice(new_path.as_ref()); - new_pq.extend_from_slice(&org_pq[matched_path.len()..]); - new_pq - } - None => org_pq, - }; - *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; + // // replace some parts of path if opt_replace_path is enabled for chosen upstream + // let new_pq = match &upstream_group.replace_path { + // Some(new_path) => { + // let matched_path: &[u8] = upstream_group.path.as_ref(); + // if matched_path.is_empty() || org_pq.len() < matched_path.len() { + // return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); + // }; + // let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); + // new_pq.extend_from_slice(new_path.as_ref()); + // new_pq.extend_from_slice(&org_pq[matched_path.len()..]); + // new_pq + // } + // None => org_pq, + // }; + // *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; - // upgrade - if let Some(v) = upgrade { - req.headers_mut().insert(header::UPGRADE, v.parse()?); - req - .headers_mut() - .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); - } + // // upgrade + // if let Some(v) = upgrade { + // req.headers_mut().insert(header::UPGRADE, v.parse()?); + // req + // .headers_mut() + // .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); + // } - // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 - if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { - // Change version to http/1.1 when destination scheme is http - debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); - *req.version_mut() = Version::HTTP_11; - } else if req.version() == Version::HTTP_3 { - // HTTP/3 is always https - debug!("HTTP/3 is currently unsupported for request to upstream."); - *req.version_mut() = Version::HTTP_2; - } + // // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 + // if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { + // // Change version to http/1.1 when destination scheme is http + // debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); + // *req.version_mut() = Version::HTTP_11; + // } else if req.version() == Version::HTTP_3 { + // // HTTP/3 is always https + // debug!("HTTP/3 is currently unsupported for request to upstream."); + // *req.version_mut() = Version::HTTP_2; + // } - apply_upstream_options_to_request_line(req, upstream_group)?; + // apply_upstream_options_to_request_line(req, upstream_group)?; - Ok(context) - } + // Ok(context) + // } } diff --git a/rpxy-lib/src/handler/mod.rs b/rpxy-lib/src/handler/mod.rs index 84e02261..2ae5aba6 100644 --- a/rpxy-lib/src/handler/mod.rs +++ b/rpxy-lib/src/handler/mod.rs @@ -1,17 +1,15 @@ #[cfg(feature = "cache")] -mod cache; -mod forwarder; +// mod cache; +mod error; +// mod forwarder; mod handler_main; mod utils_headers; mod utils_request; -mod utils_synth_response; +// mod utils_synth_response; #[cfg(feature = "sticky-cookie")] use crate::backend::LbContext; -pub use { - forwarder::Forwarder, - handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}, -}; +pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; #[allow(dead_code)] #[derive(Debug)] diff --git a/rpxy-lib/src/hyper_executor.rs b/rpxy-lib/src/hyper_executor.rs new file mode 100644 index 00000000..152bbe92 --- /dev/null +++ b/rpxy-lib/src/hyper_executor.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; +use tokio::runtime::Handle; + +use crate::{globals::Globals, CryptoSource}; + +#[derive(Clone)] +/// Executor for hyper +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + pub fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} + +/// build connection builder shared with proxy instances +pub(crate) fn build_http_server(globals: &Arc>) -> ConnectionBuilder +where + T: CryptoSource, +{ + let executor = LocalExecutor::new(globals.runtime_handle.clone()); + let mut http_server = server::conn::auto::Builder::new(executor); + http_server + .http1() + .keep_alive(globals.proxy_config.keepalive) + .pipeline_flush(true); + http_server + .http2() + .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); + http_server +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index fd242c53..7f7ade27 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -4,20 +4,16 @@ mod constants; mod error; mod globals; mod handler; +mod hyper_executor; mod log; mod proxy; mod utils; -use crate::{ - error::*, - globals::Globals, - handler::{Forwarder, HttpMessageHandlerBuilder}, - log::*, - proxy::ProxyBuilder, -}; +use crate::{error::*, globals::Globals, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder}; use futures::future::select_all; +use hyper_executor::build_http_server; // use hyper_trust_dns::TrustDnsResolver; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; pub use crate::{ certs::{CertsAndKeys, CryptoSource}, @@ -76,16 +72,19 @@ where backends: app_config_list.clone().try_into()?, request_count: Default::default(), runtime_handle: runtime_handle.clone(), + term_notify: term_notify.clone(), }); // build message handler including a request forwarder let msg_handler = Arc::new( HttpMessageHandlerBuilder::default() - .forwarder(Arc::new(Forwarder::new(&globals).await)) + // .forwarder(Arc::new(Forwarder::new(&globals).await)) .globals(globals.clone()) .build()?, ); + let http_server = Arc::new(build_http_server(&globals)); + let addresses = globals.proxy_config.listen_sockets.clone(); let futures = select_all(addresses.into_iter().map(|addr| { let mut tls_enabled = false; @@ -97,16 +96,17 @@ where .globals(globals.clone()) .listening_on(addr) .tls_enabled(tls_enabled) + .http_server(http_server.clone()) .msg_handler(msg_handler.clone()) .build() .unwrap(); - globals.runtime_handle.spawn(proxy.start(term_notify.clone())) + globals.runtime_handle.spawn(async move { proxy.start().await }) })); // wait for all future if let (Ok(Err(e)), _, _) = futures.await { - error!("Some proxy services are down: {:?}", e); + error!("Some proxy services are down: {}", e); }; Ok(()) diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 0551b626..c89c3942 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -10,4 +10,33 @@ mod proxy_quic_s2n; mod proxy_tls; mod socket; +use crate::error::*; +use http::{Response, StatusCode}; +use http_body_util::{combinators, BodyExt, Either, Empty}; +use hyper::body::{Bytes, Incoming}; + pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; + +/// Type for synthetic boxed body +type BoxBody = combinators::BoxBody; +/// Type for either passthrough body or synthetic body +type EitherBody = Either; + +/// helper function to build http response with passthrough body +fn passthrough_response(response: Response) -> Result> { + Ok(response.map(EitherBody::Left)) +} + +/// build http response with status code of 4xx and 5xx +fn synthetic_error_response(status_code: StatusCode) -> Result> { + let res = Response::builder() + .status(status_code) + .body(EitherBody::Right(BoxBody::new(empty()))) + .unwrap(); + Ok(res) +} + +/// helper function to build a empty body +fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index fd07521e..699938b7 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,17 +1,21 @@ use super::Proxy; use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; use bytes::{Buf, Bytes}; +use futures::Stream; #[cfg(feature = "http3-quinn")] use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -use hyper::{client::connect::Connect, Body, Request, Response}; +use http::{Request, Response}; +use http_body_util::{BodyExt, BodyStream, StreamBody}; +use hyper::body::{Body, Incoming}; +use hyper_util::client::legacy::connect::Connect; #[cfg(feature = "http3-s2n")] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; use std::net::SocketAddr; use tokio::time::{timeout, Duration}; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn connection_serve_h3( @@ -89,18 +93,36 @@ where S: BidiStream + Send + 'static, >::RecvStream: Send, { + println!("stream_serve_h3"); let (req_parts, _) = req.into_parts(); // split stream and async body handling let (mut send_stream, mut recv_stream) = stream.split(); - // generate streamed body with trailers using channel - let (body_sender, req_body) = Body::channel(); + // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX); + // // if max > max_body_size as u64 { + // // return Err(HttpError::TooLargeRequestBody); + // // } + // let new_req = Request::from_parts(req_parts, body_stream); + + //////////////////// + // TODO: TODO: TODO: TODO: + // TODO: Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. + // Thus, we need to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of + // Either as body. + // Also, the downstream from the backend handler could be Incoming, but will be wrapped as Either as well due to H3. + // Result, E> type includes E as HttpError to generate the status code and related Response. + // Thus to handle synthetic error messages in BoxBody, the serve() function outputs Response, BoxBody>>>. + //////////////////// + + // // generate streamed body with trailers using channel + // let (body_sender, req_body) = Incoming::channel(); // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. let max_body_size = self.globals.proxy_config.h3_request_max_body_size; self.globals.runtime_handle.spawn(async move { - let mut sender = body_sender; + // let mut sender = body_sender; let mut size = 0usize; while let Some(mut body) = recv_stream.recv_data().await? { debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); @@ -113,51 +135,52 @@ where return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); } // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - sender.send_data(body.copy_to_bytes(body.remaining())).await?; + // sender.send_data(body.copy_to_bytes(body.remaining())).await?; } // trailers: use inner for work around. (directly get trailer) let trailers = recv_stream.as_mut().recv_trailers().await?; if trailers.is_some() { debug!("HTTP/3 incoming request trailers"); - sender.send_trailers(trailers.unwrap()).await?; + // sender.send_trailers(trailers.unwrap()).await?; } Ok(()) }); - let new_req: Request = Request::from_parts(req_parts, req_body); - let res = self - .msg_handler - .clone() - .handle_request( - new_req, - client_addr, - self.listening_on, - self.tls_enabled, - Some(tls_server_name), - ) - .await?; + // let new_req: Request = Request::from_parts(req_parts, req_body); + // let res = self + // .msg_handler + // .clone() + // .handle_request( + // new_req, + // client_addr, + // self.listening_on, + // self.tls_enabled, + // Some(tls_server_name), + // ) + // .await?; - let (new_res_parts, new_body) = res.into_parts(); - let new_res = Response::from_parts(new_res_parts, ()); + // let (new_res_parts, new_body) = res.into_parts(); + // let new_res = Response::from_parts(new_res_parts, ()); - match send_stream.send_response(new_res).await { - Ok(_) => { - debug!("HTTP/3 response to connection successful"); - // aggregate body without copying - let mut body_data = hyper::body::aggregate(new_body).await?; + // match send_stream.send_response(new_res).await { + // Ok(_) => { + // debug!("HTTP/3 response to connection successful"); + // // aggregate body without copying + // let body_data = new_body.collect().await?.aggregate(); - // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - send_stream - .send_data(body_data.copy_to_bytes(body_data.remaining())) - .await?; + // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes + // send_stream + // .send_data(body_data.copy_to_bytes(body_data.remaining())) + // .await?; - // TODO: needs handling trailer? should be included in body from handler. - } - Err(err) => { - error!("Unable to send response to connection peer: {:?}", err); - } - } - Ok(send_stream.finish().await?) + // // TODO: needs handling trailer? should be included in body from handler. + // } + // Err(err) => { + // error!("Unable to send response to connection peer: {:?}", err); + // } + // } + // Ok(send_stream.finish().await?) + todo!() } } diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index bd52ea95..ec1008a3 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,78 +1,70 @@ -use super::socket::bind_tcp_socket; +use super::{passthrough_response, socket::bind_tcp_socket, synthetic_error_response, EitherBody}; use crate::{ - certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp, + certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, hyper_executor::LocalExecutor, log::*, + utils::ServerNameBytesExp, }; use derive_builder::{self, Builder}; -use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; -use std::{net::SocketAddr, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - runtime::Handle, - sync::Notify, - time::{timeout, Duration}, +use http::{Request, StatusCode}; +use hyper::{ + body::Incoming, + rt::{Read, Write}, + service::service_fn, }; - -#[derive(Clone)] -pub struct LocalExecutor { - runtime_handle: Handle, -} - -impl LocalExecutor { - fn new(runtime_handle: Handle) -> Self { - LocalExecutor { runtime_handle } - } -} - -impl hyper::rt::Executor for LocalExecutor -where - F: std::future::Future + Send + 'static, - F::Output: Send, -{ - fn execute(&self, fut: F) { - self.runtime_handle.spawn(fut); - } -} +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::time::{timeout, Duration}; #[derive(Clone, Builder)] -pub struct Proxy +/// Proxy main object +pub struct Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub listening_on: SocketAddr, pub tls_enabled: bool, // TCP待受がTLSかどうか - pub msg_handler: Arc>, + /// hyper server receiving http request + pub http_server: Arc>, + // pub msg_handler: Arc>, + pub msg_handler: Arc>, pub globals: Arc>, } -impl Proxy +/// Wrapper function to handle request +async fn serve_request( + req: Request, + // handler: Arc>, + handler: Arc>, + client_addr: SocketAddr, + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, +) -> Result> where - T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone + Sync + Send, + U: CryptoSource + Clone + Sync + Send + 'static, { - /// Wrapper function to handle request - async fn serve( - handler: Arc>, - req: Request, - client_addr: SocketAddr, - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> Result> { - handler - .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) - .await + match handler + .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) + .await? + { + Ok(res) => passthrough_response(res), + Err(e) => synthetic_error_response(StatusCode::from(e)), } +} +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send, +{ /// Serves requests from clients - pub(super) fn client_serve( - self, + pub(super) fn serve_connection( + &self, stream: I, - server: Http, peer_addr: SocketAddr, tls_server_name: Option, ) where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + I: Read + Write + Send + Unpin + 'static, { let request_count = self.globals.request_count.clone(); if request_count.increment() > self.globals.proxy_config.max_clients { @@ -81,24 +73,27 @@ where } debug!("Request incoming: current # {}", request_count.current()); + let server_clone = self.http_server.clone(); + let msg_handler_clone = self.msg_handler.clone(); + let timeout_sec = self.globals.proxy_config.proxy_timeout; + let tls_enabled = self.tls_enabled; + let listening_on = self.listening_on; self.globals.runtime_handle.clone().spawn(async move { timeout( - self.globals.proxy_config.proxy_timeout + Duration::from_secs(1), - server - .serve_connection( - stream, - service_fn(move |req: Request| { - Self::serve( - self.msg_handler.clone(), - req, - peer_addr, - self.listening_on, - self.tls_enabled, - tls_server_name.clone(), - ) - }), - ) - .with_upgrades(), + timeout_sec + Duration::from_secs(1), + server_clone.serve_connection_with_upgrades( + stream, + service_fn(move |req: Request| { + serve_request( + req, + msg_handler_clone.clone(), + peer_addr, + listening_on, + tls_enabled, + tls_server_name.clone(), + ) + }), + ), ) .await .ok(); @@ -109,13 +104,13 @@ where } /// Start without TLS (HTTP cleartext) - async fn start_without_tls(self, server: Http) -> Result<()> { + async fn start_without_tls(&self) -> Result<()> { let listener_service = async { let tcp_socket = bind_tcp_socket(&self.listening_on)?; let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTP request for configured host names"); - while let Ok((stream, _client_addr)) = tcp_listener.accept().await { - self.clone().client_serve(stream, server.clone(), _client_addr, None); + while let Ok((stream, client_addr)) = tcp_listener.accept().await { + self.serve_connection(TokioIo::new(stream), client_addr, None); } Ok(()) as Result<()> }; @@ -124,32 +119,23 @@ where } /// Entrypoint for HTTP/1.1 and HTTP/2 servers - pub async fn start(self, term_notify: Option>) -> Result<()> { - let mut server = Http::new(); - server.http1_keep_alive(self.globals.proxy_config.keepalive); - server.http2_max_concurrent_streams(self.globals.proxy_config.max_concurrent_streams); - server.pipeline_flush(true); - let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); - let server = server.with_executor(executor); - - let listening_on = self.listening_on; - + pub async fn start(&self) -> Result<()> { let proxy_service = async { if self.tls_enabled { - self.start_with_tls(server).await + self.start_with_tls().await } else { - self.start_without_tls(server).await + self.start_without_tls().await } }; - match term_notify { + match &self.globals.term_notify { Some(term) => { tokio::select! { _ = proxy_service => { warn!("Proxy service got down"); } _ = term.notified() => { - info!("Proxy service listening on {} receives term signal", listening_on); + info!("Proxy service listening on {} receives term signal", self.listening_on); } } } @@ -159,8 +145,6 @@ where } } - // proxy_service.await?; - Ok(()) } } diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs index fb08420b..1828e5f7 100644 --- a/rpxy-lib/src/proxy/proxy_quic_quinn.rs +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -5,14 +5,14 @@ use super::{ }; use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; use rustls::ServerConfig; use std::sync::Arc; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn listener_service_h3( diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs index e0c41a5f..d1d15807 100644 --- a/rpxy-lib/src/proxy/proxy_quic_s2n.rs +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -4,13 +4,13 @@ use super::{ }; use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use s2n_quic::provider; use std::sync::Arc; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn listener_service_h3( @@ -29,7 +29,7 @@ where // event loop loop { tokio::select! { - v = self.serve_connection(&server_crypto) => { + v = self.listener_service_h3_inner(&server_crypto) => { if let Err(e) = v { error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); break; @@ -64,7 +64,7 @@ where }) } - async fn serve_connection(&self, server_crypto: &Option>) -> Result<()> { + async fn listener_service_h3_inner(&self, server_crypto: &Option>) -> Result<()> { // setup UDP socket let io = provider::io::tokio::Builder::default() .with_receive_address(self.listening_on)? @@ -110,9 +110,9 @@ where while let Some(new_conn) = server.accept().await { debug!("New QUIC connection established"); let Ok(Some(new_server_name)) = new_conn.server_name() else { - warn!("HTTP/3 no SNI is given"); - continue; - }; + warn!("HTTP/3 no SNI is given"); + continue; + }; debug!("HTTP/3 connection incoming (SNI {:?})", new_server_name); let self_clone = self.clone(); diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs index 7c5d601b..6ed62126 100644 --- a/rpxy-lib/src/proxy/proxy_tls.rs +++ b/rpxy-lib/src/proxy/proxy_tls.rs @@ -1,25 +1,21 @@ use super::{ crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, - proxy_main::{LocalExecutor, Proxy}, + proxy_main::Proxy, socket::bind_tcp_socket, }; use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; use hot_reload::{ReloaderReceiver, ReloaderService}; -use hyper::{client::connect::Connect, server::conn::Http}; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; use std::sync::Arc; use tokio::time::{timeout, Duration}; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { // TCP Listener Service, i.e., http/2 and http/1.1 - async fn listener_service( - &self, - server: Http, - mut server_crypto_rx: ReloaderReceiver, - ) -> Result<()> { + async fn listener_service(&self, mut server_crypto_rx: ReloaderReceiver) -> Result<()> { let tcp_socket = bind_tcp_socket(&self.listening_on)?; let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTPS request for configured host names"); @@ -33,7 +29,6 @@ where } let (raw_stream, client_addr) = tcp_cnx.unwrap(); let sc_map_inner = server_crypto_map.clone(); - let server_clone = server.clone(); let self_inner = self.clone(); // spawns async handshake to avoid blocking thread by sequential handshake. @@ -55,30 +50,27 @@ where return Err(RpxyError::Proxy(format!("No TLS serving app for {:?}", server_name.unwrap()))); } let stream = match start.into_stream(server_crypto.unwrap().clone()).await { - Ok(s) => s, + Ok(s) => TokioIo::new(s), Err(e) => { return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); } }; - self_inner.client_serve(stream, server_clone, client_addr, server_name_in_bytes); + self_inner.serve_connection(stream, client_addr, server_name_in_bytes); Ok(()) }; self.globals.runtime_handle.spawn( async move { // timeout is introduced to avoid get stuck here. - match timeout( + let Ok(v) = timeout( Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), handshake_fut - ).await { - Ok(a) => { - if let Err(e) = a { - error!("{}", e); - } - }, - Err(e) => { - error!("Timeout to handshake TLS: {}", e); - } + ).await else { + error!("Timeout to handshake TLS"); + return; }; + if let Err(e) = v { + error!("{}", e); + } }); } _ = server_crypto_rx.changed() => { @@ -99,7 +91,7 @@ where Ok(()) as Result<()> } - pub async fn start_with_tls(self, server: Http) -> Result<()> { + pub async fn start_with_tls(&self) -> Result<()> { let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( &self.globals.clone(), CERTS_WATCH_DELAY_SECS, @@ -114,7 +106,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx) => { + _ = self.listener_service(cert_reloader_rx) => { error!("TCP proxy service for TLS exited"); }, else => { @@ -131,7 +123,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx.clone()) => { + _ = self.listener_service(cert_reloader_rx.clone()) => { error!("TCP proxy service for TLS exited"); }, _= self.listener_service_h3(cert_reloader_rx) => { @@ -148,7 +140,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx) => { + _ = self.listener_service(cert_reloader_rx) => { error!("TCP proxy service for TLS exited"); }, else => { diff --git a/submodules/h3 b/submodules/h3 index b86df122..5c161952 160000 --- a/submodules/h3 +++ b/submodules/h3 @@ -1 +1 @@ -Subproject commit b86df1220775d13b89cead99e787944b55991b1e +Subproject commit 5c161952b02e663f31f9b83829bafa7a047b6627 diff --git a/submodules/h3-quinn/Cargo.toml b/submodules/h3-quinn/Cargo.toml deleted file mode 100644 index abbb21e4..00000000 --- a/submodules/h3-quinn/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "h3-quinn" -version = "0.0.4" -rust-version = "1.63" -authors = ["Jean-Christophe BEGUE "] -edition = "2018" -documentation = "https://docs.rs/h3-quinn" -repository = "https://github.com/hyperium/h3" -readme = "../README.md" -description = "QUIC transport implementation based on Quinn." -keywords = ["http3", "quic", "h3"] -categories = ["network-programming", "web-programming"] -license = "MIT" - -[dependencies] -h3 = { version = "0.0.3", path = "../h3/h3" } -bytes = "1" -quinn = { path = "../quinn/quinn/", default-features = false, features = [ - "futures-io", -] } -quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } -tokio-util = { version = "0.7.9" } -futures = { version = "0.3.28" } -tokio = { version = "1.33.0", features = ["io-util"], default-features = false } diff --git a/submodules/h3-quinn/src/lib.rs b/submodules/h3-quinn/src/lib.rs deleted file mode 100644 index 78696dec..00000000 --- a/submodules/h3-quinn/src/lib.rs +++ /dev/null @@ -1,740 +0,0 @@ -//! QUIC Transport implementation with Quinn -//! -//! This module implements QUIC traits with Quinn. -#![deny(missing_docs)] - -use std::{ - convert::TryInto, - fmt::{self, Display}, - future::Future, - pin::Pin, - sync::Arc, - task::{self, Poll}, -}; - -use bytes::{Buf, Bytes, BytesMut}; - -use futures::{ - ready, - stream::{self, BoxStream}, - StreamExt, -}; -use quinn::ReadDatagram; -pub use quinn::{ - self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, -}; - -use h3::{ - ext::Datagram, - quic::{self, Error, StreamId, WriteBuf}, -}; -use tokio_util::sync::ReusableBoxFuture; - -/// A QUIC connection backed by Quinn -/// -/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`]. -pub struct Connection { - conn: quinn::Connection, - incoming_bi: BoxStream<'static, as Future>::Output>, - opening_bi: Option as Future>::Output>>, - incoming_uni: BoxStream<'static, as Future>::Output>, - opening_uni: Option as Future>::Output>>, - datagrams: BoxStream<'static, as Future>::Output>, -} - -impl Connection { - /// Create a [`Connection`] from a [`quinn::NewConnection`] - pub fn new(conn: quinn::Connection) -> Self { - Self { - conn: conn.clone(), - incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_bi().await, conn)) - })), - opening_bi: None, - incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_uni().await, conn)) - })), - opening_uni: None, - datagrams: Box::pin(stream::unfold(conn, |conn| async { - Some((conn.read_datagram().await, conn)) - })), - } - } -} - -/// The error type for [`Connection`] -/// -/// Wraps reasons a Quinn connection might be lost. -#[derive(Debug)] -pub struct ConnectionError(quinn::ConnectionError); - -impl std::error::Error for ConnectionError {} - -impl fmt::Display for ConnectionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl Error for ConnectionError { - fn is_timeout(&self) -> bool { - matches!(self.0, quinn::ConnectionError::TimedOut) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for ConnectionError { - fn from(e: quinn::ConnectionError) -> Self { - Self(e) - } -} - -/// Types of errors when sending a datagram. -#[derive(Debug)] -pub enum SendDatagramError { - /// Datagrams are not supported by the peer - UnsupportedByPeer, - /// Datagrams are locally disabled - Disabled, - /// The datagram was too large to be sent. - TooLarge, - /// Network error - ConnectionLost(Box), -} - -impl fmt::Display for SendDatagramError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), - SendDatagramError::Disabled => write!(f, "datagram support disabled"), - SendDatagramError::TooLarge => write!(f, "datagram too large"), - SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), - } - } -} - -impl std::error::Error for SendDatagramError {} - -impl Error for SendDatagramError { - fn is_timeout(&self) -> bool { - false - } - - fn err_code(&self) -> Option { - match self { - Self::ConnectionLost(err) => err.err_code(), - _ => None, - } - } -} - -impl From for SendDatagramError { - fn from(value: quinn::SendDatagramError) -> Self { - match value { - quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, - quinn::SendDatagramError::Disabled => Self::Disabled, - quinn::SendDatagramError::TooLarge => Self::TooLarge, - quinn::SendDatagramError::ConnectionLost(err) => { - Self::ConnectionLost(ConnectionError::from(err).into()) - } - } - } -} - -impl quic::Connection for Connection -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - type BidiStream = BidiStream; - type OpenStreams = OpenStreams; - type Error = ConnectionError; - - fn poll_accept_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - }))) - } - - fn poll_accept_recv( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) - } - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.clone().open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn opener(&self) -> Self::OpenStreams { - OpenStreams { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl quic::SendDatagramExt for Connection -where - B: Buf, -{ - type Error = SendDatagramError; - - fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { - // TODO investigate static buffer from known max datagram size - let mut buf = BytesMut::new(); - data.encode(&mut buf); - self.conn.send_datagram(buf.freeze())?; - - Ok(()) - } -} - -impl quic::RecvDatagramExt for Connection { - type Buf = Bytes; - - type Error = ConnectionError; - - #[inline] - fn poll_accept_datagram( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - match ready!(self.datagrams.poll_next_unpin(cx)) { - Some(Ok(x)) => Poll::Ready(Ok(Some(x))), - Some(Err(e)) => Poll::Ready(Err(e.into())), - None => Poll::Ready(Ok(None)), - } - } -} - -/// Stream opener backed by a Quinn connection -/// -/// Implements [`quic::OpenStreams`] using [`quinn::Connection`], -/// [`quinn::OpenBi`], [`quinn::OpenUni`]. -pub struct OpenStreams { - conn: quinn::Connection, - opening_bi: Option as Future>::Output>>, - opening_uni: Option as Future>::Output>>, -} - -impl quic::OpenStreams for OpenStreams -where - B: Buf, -{ - type RecvStream = RecvStream; - type SendStream = SendStream; - type BidiStream = BidiStream; - type Error = ConnectionError; - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl Clone for OpenStreams { - fn clone(&self) -> Self { - Self { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } -} - -/// Quinn-backed bidirectional stream -/// -/// Implements [`quic::BidiStream`] which allows the stream to be split -/// into two structs each implementing one direction. -pub struct BidiStream -where - B: Buf, -{ - send: SendStream, - recv: RecvStream, -} - -impl quic::BidiStream for BidiStream -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - - fn split(self) -> (Self::SendStream, Self::RecvStream) { - (self.send, self.recv) - } -} - -impl quic::RecvStream for BidiStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - self.recv.poll_data(cx) - } - - fn stop_sending(&mut self, error_code: u64) { - self.recv.stop_sending(error_code) - } - - fn recv_id(&self) -> StreamId { - self.recv.recv_id() - } -} - -impl quic::SendStream for BidiStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_ready(cx) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_finish(cx) - } - - fn reset(&mut self, reset_code: u64) { - self.send.reset(reset_code) - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - self.send.send_data(data) - } - - fn send_id(&self) -> StreamId { - self.send.send_id() - } -} -impl quic::SendStreamUnframed for BidiStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - self.send.poll_send(cx, buf) - } -} - -/// Quinn-backed receive stream -/// -/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`]. -pub struct RecvStream { - stream: Option, - read_chunk_fut: ReadChunkFuture, -} - -type ReadChunkFuture = ReusableBoxFuture< - 'static, - ( - quinn::RecvStream, - Result, quinn::ReadError>, - ), ->; - -impl RecvStream { - fn new(stream: quinn::RecvStream) -> Self { - Self { - stream: Some(stream), - // Should only allocate once the first time it's used - read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::RecvStream for RecvStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - if let Some(mut stream) = self.stream.take() { - self.read_chunk_fut.set(async move { - let chunk = stream.read_chunk(usize::MAX, true).await; - (stream, chunk) - }) - }; - - let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); - self.stream = Some(stream); - Poll::Ready(Ok(chunk?.map(|c| c.bytes))) - } - - fn stop_sending(&mut self, error_code: u64) { - self.stream - .as_mut() - .unwrap() - .stop(VarInt::from_u64(error_code).expect("invalid error_code")) - .ok(); - } - - fn recv_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -/// The error type for [`RecvStream`] -/// -/// Wraps errors that occur when reading from a receive stream. -#[derive(Debug)] -pub struct ReadError(quinn::ReadError); - -impl From for std::io::Error { - fn from(value: ReadError) -> Self { - value.0.into() - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.0.source() - } -} - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl From for Arc { - fn from(e: ReadError) -> Self { - Arc::new(e) - } -} - -impl From for ReadError { - fn from(e: quinn::ReadError) -> Self { - Self(e) - } -} - -impl Error for ReadError { - fn is_timeout(&self) -> bool { - matches!( - self.0, - quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) - ) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( - quinn_proto::ApplicationClose { error_code, .. }, - )) => Some(error_code.into_inner()), - quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), - _ => None, - } - } -} - -/// Quinn-backed send stream -/// -/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`]. -pub struct SendStream { - stream: Option, - writing: Option>, - write_fut: WriteFuture, -} - -type WriteFuture = - ReusableBoxFuture<'static, (quinn::SendStream, Result)>; - -impl SendStream -where - B: Buf, -{ - fn new(stream: quinn::SendStream) -> SendStream { - Self { - stream: Some(stream), - writing: None, - write_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::SendStream for SendStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - if let Some(ref mut data) = self.writing { - while data.has_remaining() { - if let Some(mut stream) = self.stream.take() { - let chunk = data.chunk().to_owned(); // FIXME - avoid copy - self.write_fut.set(async move { - let ret = stream.write(&chunk).await; - (stream, ret) - }); - } - - let (stream, res) = ready!(self.write_fut.poll(cx)); - self.stream = Some(stream); - match res { - Ok(cnt) => data.advance(cnt), - Err(err) => { - return Poll::Ready(Err(SendStreamError::Write(err))); - } - } - } - } - self.writing = None; - Poll::Ready(Ok(())) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.stream - .as_mut() - .unwrap() - .poll_finish(cx) - .map_err(Into::into) - } - - fn reset(&mut self, reset_code: u64) { - let _ = self - .stream - .as_mut() - .unwrap() - .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - if self.writing.is_some() { - return Err(Self::Error::NotReady); - } - self.writing = Some(data.into()); - Ok(()) - } - - fn send_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -impl quic::SendStreamUnframed for SendStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - if self.writing.is_some() { - // This signifies a bug in implementation - panic!("poll_send called while send stream is not ready") - } - - let s = Pin::new(self.stream.as_mut().unwrap()); - - let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); - match res { - Ok(written) => { - buf.advance(written); - Poll::Ready(Ok(written)) - } - Err(err) => { - // We are forced to use AsyncWrite for now because we cannot store - // the result of a call to: - // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. - // - // This is why we have to unpack the error from io::Error instead of having it - // returned directly. This should not panic as long as quinn's AsyncWrite impl - // doesn't change. - let err = err - .into_inner() - .expect("write stream returned an empty error") - .downcast::() - .expect("write stream returned an error which type is not WriteError"); - - Poll::Ready(Err(SendStreamError::Write(*err))) - } - } - } -} - -/// The error type for [`SendStream`] -/// -/// Wraps errors that can happen writing to or polling a send stream. -#[derive(Debug)] -pub enum SendStreamError { - /// Errors when writing, wrapping a [`quinn::WriteError`] - Write(WriteError), - /// Error when the stream is not ready, because it is still sending - /// data from a previous call - NotReady, -} - -impl From for std::io::Error { - fn from(value: SendStreamError) -> Self { - match value { - SendStreamError::Write(err) => err.into(), - SendStreamError::NotReady => { - std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") - } - } - } -} - -impl std::error::Error for SendStreamError {} - -impl Display for SendStreamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl From for SendStreamError { - fn from(e: WriteError) -> Self { - Self::Write(e) - } -} - -impl Error for SendStreamError { - fn is_timeout(&self) -> bool { - matches!( - self, - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::TimedOut - )) - ) - } - - fn err_code(&self) -> Option { - match self { - Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }), - )) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for Arc { - fn from(e: SendStreamError) -> Self { - Arc::new(e) - } -} diff --git a/submodules/quinn b/submodules/quinn deleted file mode 160000 index 6d80efee..00000000 --- a/submodules/quinn +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d80efeeae60b96ff330ae6a70e8cc9291fcc615 diff --git a/submodules/s2n-quic b/submodules/s2n-quic deleted file mode 160000 index 30027eea..00000000 --- a/submodules/s2n-quic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 30027eeacc7b620da62fc4825b94afd57ab0c7be diff --git a/submodules/s2n-quic-h3/Cargo.toml b/submodules/s2n-quic-h3/Cargo.toml new file mode 100644 index 00000000..fecfd10c --- /dev/null +++ b/submodules/s2n-quic-h3/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "s2n-quic-h3" +# this in an unpublished internal crate so the version should not be changed +version = "0.1.0" +authors = ["AWS s2n"] +edition = "2021" +rust-version = "1.63" +license = "Apache-2.0" +# this contains an http3 implementation for testing purposes and should not be published +publish = false + +[dependencies] +bytes = { version = "1", default-features = false } +futures = { version = "0.3", default-features = false } +h3 = { path = "../h3/h3/" } +s2n-quic = "1.31.0" +s2n-quic-core = "0.31.0" diff --git a/submodules/s2n-quic-h3/README.md b/submodules/s2n-quic-h3/README.md new file mode 100644 index 00000000..aed94754 --- /dev/null +++ b/submodules/s2n-quic-h3/README.md @@ -0,0 +1,10 @@ +# s2n-quic-h3 + +This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly. + +## License + +This project is licensed under the [Apache-2.0 License][license-url]. + +[license-badge]: https://img.shields.io/badge/license-apache-blue.svg +[license-url]: https://aws.amazon.com/apache-2-0/ diff --git a/submodules/s2n-quic-h3/src/lib.rs b/submodules/s2n-quic-h3/src/lib.rs new file mode 100644 index 00000000..c85f197f --- /dev/null +++ b/submodules/s2n-quic-h3/src/lib.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod s2n_quic; + +pub use self::s2n_quic::*; +pub use h3; diff --git a/submodules/s2n-quic-h3/src/s2n_quic.rs b/submodules/s2n-quic-h3/src/s2n_quic.rs new file mode 100644 index 00000000..dffa19b2 --- /dev/null +++ b/submodules/s2n-quic-h3/src/s2n_quic.rs @@ -0,0 +1,506 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use bytes::{Buf, Bytes}; +use futures::ready; +use h3::quic::{self, Error, StreamId, WriteBuf}; +use s2n_quic::stream::{BidirectionalStream, ReceiveStream}; +use s2n_quic_core::varint::VarInt; +use std::{ + convert::TryInto, + fmt::{self, Display}, + sync::Arc, + task::{self, Poll}, +}; + +pub struct Connection { + conn: s2n_quic::connection::Handle, + bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor, + recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor, +} + +impl Connection { + pub fn new(new_conn: s2n_quic::Connection) -> Self { + let (handle, acceptor) = new_conn.split(); + let (bidi, recv) = acceptor.split(); + + Self { + conn: handle, + bidi_acceptor: bidi, + recv_acceptor: recv, + } + } +} + +#[derive(Debug)] +pub struct ConnectionError(s2n_quic::connection::Error); + +impl std::error::Error for ConnectionError {} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. }) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::connection::Error::Application { error, .. } => Some(error.into()), + _ => None, + } + } +} + +impl From for ConnectionError { + fn from(e: s2n_quic::connection::Error) -> Self { + Self(e) + } +} + +impl quic::Connection for Connection +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type OpenStreams = OpenStreams; + type Error = ConnectionError; + + fn poll_accept_recv( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? { + Some(x) => x, + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) + } + + fn poll_accept_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? { + Some(x) => x.split(), + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + }))) + } + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams { + conn: self.conn.clone(), + } + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } +} + +pub struct OpenStreams { + conn: s2n_quic::connection::Handle, +} + +impl quic::OpenStreams for OpenStreams +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type Error = ConnectionError; + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .unwrap_or_else(|_| VarInt::MAX.into()), + ); + } +} + +impl Clone for OpenStreams { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + +pub struct BidiStream +where + B: Buf, +{ + send: SendStream, + recv: RecvStream, +} + +impl quic::BidiStream for BidiStream +where + B: Buf, +{ + type SendStream = SendStream; + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } +} + +impl quic::RecvStream for BidiStream +where + B: Buf, +{ + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> StreamId { + self.recv.stream.id().try_into().expect("invalid stream id") + } +} + +impl quic::SendStream for BidiStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn send_id(&self) -> StreamId { + self.send.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for BidiStream +where + B: Buf, +{ + fn from(bidi: BidirectionalStream) -> Self { + let (recv, send) = bidi.split(); + BidiStream { + send: send.into(), + recv: recv.into(), + } + } +} + +pub struct RecvStream { + stream: s2n_quic::stream::ReceiveStream, +} + +impl RecvStream { + fn new(stream: s2n_quic::stream::ReceiveStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let buf = ready!(self.stream.poll_receive(cx))?; + Ok(buf).into() + } + + fn stop_sending(&mut self, error_code: u64) { + let _ = self.stream.stop_sending( + s2n_quic::application::Error::new(error_code) + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } + + fn recv_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for RecvStream { + fn from(recv: ReceiveStream) -> Self { + RecvStream::new(recv) + } +} + +#[derive(Debug)] +pub struct ReadError(s2n_quic::stream::Error); + +impl std::error::Error for ReadError {} + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for Arc { + fn from(e: ReadError) -> Self { + Arc::new(e) + } +} + +impl From for ReadError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self(e) + } +} + +impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!( + self.0, + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + } + ) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + } => Some(error.into()), + s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()), + _ => None, + } + } +} + +pub struct SendStream { + stream: s2n_quic::stream::SendStream, + chunk: Option, + buf: Option>, // TODO: Replace with buf: PhantomData + // after https://github.com/hyperium/h3/issues/78 is resolved +} + +impl SendStream +where + B: Buf, +{ + fn new(stream: s2n_quic::stream::SendStream) -> SendStream { + Self { + stream, + chunk: None, + buf: Default::default(), + } + } +} + +impl quic::SendStream for SendStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + loop { + // try to flush the current chunk if we have one + if let Some(chunk) = self.chunk.as_mut() { + ready!(self.stream.poll_send(chunk, cx))?; + + // s2n-quic will take the whole chunk on send, even if it exceeds the limits + debug_assert!(chunk.is_empty()); + self.chunk = None; + } + + // try to take the next chunk from the WriteBuf + if let Some(ref mut data) = self.buf { + let len = data.chunk().len(); + + // if the write buf is empty, then clear it and break + if len == 0 { + self.buf = None; + break; + } + + // copy the first chunk from WriteBuf and prepare it to flush + let chunk = data.copy_to_bytes(len); + self.chunk = Some(chunk); + + // loop back around to flush the chunk + continue; + } + + // if we didn't have either a chunk or WriteBuf, then we're ready + break; + } + + Poll::Ready(Ok(())) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // self.available_bytes = ready!(self.stream.poll_send_ready(cx))?; + // Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + if self.buf.is_some() { + return Err(Self::Error::NotReady); + } + self.buf = Some(data.into()); + Ok(()) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // let mut data = data.into(); + // while self.available_bytes > 0 && data.has_remaining() { + // let len = data.chunk().len(); + // let chunk = data.copy_to_bytes(len); + // self.stream.send_data(chunk)?; + // self.available_bytes = self.available_bytes.saturating_sub(len); + // } + // Ok(()) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + // ensure all chunks are flushed to the QUIC stream before finishing + ready!(self.poll_ready(cx))?; + self.stream.finish()?; + Ok(()).into() + } + + fn reset(&mut self, reset_code: u64) { + let _ = self + .stream + .reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into())); + } + + fn send_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for SendStream +where + B: Buf, +{ + fn from(send: s2n_quic::stream::SendStream) -> Self { + SendStream::new(send) + } +} + +#[derive(Debug)] +pub enum SendStreamError { + Write(s2n_quic::stream::Error), + NotReady, +} + +impl std::error::Error for SendStreamError {} + +impl Display for SendStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From for SendStreamError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self::Write(e) + } +} + +impl Error for SendStreamError { + fn is_timeout(&self) -> bool { + matches!( + self, + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + }) + ) + } + + fn err_code(&self) -> Option { + match self { + Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => { + Some((*error).into()) + } + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + }) => Some((*error).into()), + _ => None, + } + } +} + +impl From for Arc { + fn from(e: SendStreamError) -> Self { + Arc::new(e) + } +}