diff --git a/Cargo.lock b/Cargo.lock index f275d392..9e3c6a08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2272,7 +2272,6 @@ dependencies = [ "rama-http-types", "rama-utils", "rand", - "serde", "serde_json", "slab", "smallvec", @@ -2307,6 +2306,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", + "smallvec", "sync_wrapper", "tokio", "tokio-test", diff --git a/rama-cli/src/cmd/fp/data.rs b/rama-cli/src/cmd/fp/data.rs index e51a89ca..516f340e 100644 --- a/rama-cli/src/cmd/fp/data.rs +++ b/rama-cli/src/cmd/fp/data.rs @@ -2,10 +2,10 @@ use super::State; use rama::{ error::{BoxError, ErrorContext}, http::{ - core::h2::{PseudoHeader, PseudoHeaderOrder}, - dep::http::request::Parts, + dep::http::{request::Parts, Extensions}, headers::Forwarded, - Request, + proto::{h1::Http1HeaderMap, h2::PseudoHeaderOrder}, + HeaderMap, }, net::{http::RequestContext, stream::SocketInfo}, tls::types::{ @@ -187,27 +187,24 @@ pub(super) async fn get_request_info( #[derive(Debug, Clone, Serialize)] pub(super) struct HttpInfo { pub(super) headers: Vec<(String, String)>, - pub(super) pseudo_headers: Option>, + pub(super) pseudo_headers: Option>, } -pub(super) fn get_http_info(req: &Request) -> HttpInfo { - // TODO: get in correct order - // TODO: get in correct case - let headers = req - .headers() - .iter() +pub(super) fn get_http_info(headers: HeaderMap, ext: &mut Extensions) -> HttpInfo { + let headers: Vec<_> = Http1HeaderMap::new(headers, Some(ext)) + .into_iter() .map(|(name, value)| { ( - name.as_str().to_owned(), - value.to_str().map(|v| v.to_owned()).unwrap_or_default(), + name.to_string(), + std::str::from_utf8(value.as_bytes()) + .map(|s| s.to_owned()) + .unwrap_or_else(|_| format!("0x{:x?}", value.as_bytes())), ) }) .collect(); - - let pseudo_headers: Option> = req - .extensions() + let pseudo_headers: Option> = ext .get::() - .map(|o| o.iter().collect()); + .map(|o| o.iter().map(|p| p.to_string()).collect()); HttpInfo { headers, diff --git a/rama-cli/src/cmd/fp/endpoints.rs b/rama-cli/src/cmd/fp/endpoints.rs index e90eca31..d169ab37 100644 --- a/rama-cli/src/cmd/fp/endpoints.rs +++ b/rama-cli/src/cmd/fp/endpoints.rs @@ -79,9 +79,7 @@ pub(super) async fn get_report( mut ctx: Context>, req: Request, ) -> Result { - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -95,6 +93,8 @@ pub(super) async fn get_report( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + let http_info = get_http_info(parts.headers, &mut parts.extensions); + let head = r#""#.to_owned(); let mut tables = vec![ @@ -110,14 +110,7 @@ pub(super) async fn get_report( if let Some(pseudo) = http_info.pseudo_headers { tables.push(Table { title: "🚗 H2 Pseudo Headers".to_owned(), - rows: vec![( - "order".to_owned(), - pseudo - .into_iter() - .map(|h| h.as_str()) - .collect::>() - .join(", "), - )], + rows: vec![("order".to_owned(), pseudo.join(", "))], }); } @@ -174,9 +167,7 @@ pub(super) async fn get_api_fetch_number( mut ctx: Context>, req: Request, ) -> Result, Response> { - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -190,6 +181,8 @@ pub(super) async fn get_api_fetch_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + let http_info = get_http_info(parts.headers, &mut parts.extensions); + let tls_info = get_tls_display_info(&ctx); Ok(Json(json!({ @@ -208,9 +201,7 @@ pub(super) async fn post_api_fetch_number( mut ctx: Context>, req: Request, ) -> Result, Response> { - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -224,6 +215,8 @@ pub(super) async fn post_api_fetch_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + let http_info = get_http_info(parts.headers, &mut parts.extensions); + let tls_info = get_tls_display_info(&ctx); Ok(Json(json!({ @@ -241,9 +234,7 @@ pub(super) async fn get_api_xml_http_request_number( mut ctx: Context>, req: Request, ) -> Result, Response> { - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -257,6 +248,8 @@ pub(super) async fn get_api_xml_http_request_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + let http_info = get_http_info(parts.headers, &mut parts.extensions); + Ok(Json(json!({ "number": ctx.state().counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel), "fp": { @@ -272,9 +265,7 @@ pub(super) async fn post_api_xml_http_request_number( mut ctx: Context>, req: Request, ) -> Result, Response> { - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -288,6 +279,8 @@ pub(super) async fn post_api_xml_http_request_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + let http_info = get_http_info(parts.headers, &mut parts.extensions); + let tls_info = get_tls_display_info(&ctx); Ok(Json(json!({ @@ -308,10 +301,7 @@ pub(super) async fn post_api_xml_http_request_number( pub(super) async fn form(mut ctx: Context>, req: Request) -> Result { // TODO: get TLS Info (for https access only) // TODO: support HTTP1, HTTP2 and AUTO (for now we are only doing auto) - - let http_info = get_http_info(&req); - - let (parts, _) = req.into_parts(); + let (mut parts, _) = req.into_parts(); let user_agent_info = get_user_agent_info(&ctx).await; @@ -325,6 +315,8 @@ pub(super) async fn form(mut ctx: Context>, req: Request) -> Result🏠 Back to Home..."##); @@ -354,6 +346,13 @@ pub(super) async fn form(mut ctx: Context>, req: Request) -> Result()` will return a map with: -/// -/// ```text -/// HeaderCaseMap({ -/// "x-bread": ["x-Bread", "X-BREAD", "x-bread"], -/// }) -/// ``` -/// -/// [`preserve_header_case`]: /client/struct.Client.html#method.preserve_header_case -#[derive(Clone, Debug)] -pub(crate) struct HeaderCaseMap(HeaderMap); - -impl HeaderCaseMap { - /// Returns a view of all spellings associated with that header name, - /// in the order they were found. - pub(crate) fn get_all<'a>( - &'a self, - name: &HeaderName, - ) -> impl Iterator + 'a> + 'a { - self.get_all_internal(name) - } - - /// Returns a view of all spellings associated with that header name, - /// in the order they were found. - pub(crate) fn get_all_internal(&self, name: &HeaderName) -> ValueIter<'_, Bytes> { - self.0.get_all(name).into_iter() - } - - pub(crate) fn default() -> Self { - Self(Default::default()) - } - - #[allow(dead_code)] - pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { - self.0.insert(name, orig); - } - - pub(crate) fn append(&mut self, name: N, orig: Bytes) - where - N: IntoHeaderName, - { - self.0.append(name, orig); - } -} - -#[derive(Clone, Debug, Default)] -/// Hashmap -pub struct OriginalHeaderOrder { - /// Stores how many entries a Headername maps to. This is used - /// for accounting. - num_entries: HashMap, - /// Stores the ordering of the headers. ex: `vec[i] = (headerName, idx)`, - /// The vector is ordered such that the ith element - /// represents the ith header that came in off the line. - /// The `HeaderName` and `idx` are then used elsewhere to index into - /// the multi map that stores the header values. - entry_order: Vec<(HeaderName, usize)>, -} - -impl OriginalHeaderOrder { - pub fn insert(&mut self, name: HeaderName) { - if !self.num_entries.contains_key(&name) { - let idx = 0; - self.num_entries.insert(name.clone(), 1); - self.entry_order.push((name, idx)); - } - // Replacing an already existing element does not - // change ordering, so we only care if its the first - // header name encountered - } - - pub fn append(&mut self, name: N) - where - N: IntoHeaderName + Into + Clone, - { - let name: HeaderName = name.into(); - let idx; - if self.num_entries.contains_key(&name) { - idx = self.num_entries[&name]; - *self.num_entries.get_mut(&name).unwrap() += 1; - } else { - idx = 0; - self.num_entries.insert(name.clone(), 1); - } - self.entry_order.push((name, idx)); - } - - /// This returns an iterator that provides header names and indexes - /// in the original order received. - /// - /// # Examples - /// - /// ``` - /// use rama_http_core::ext::OriginalHeaderOrder; - /// use rama_http_types::header::{HeaderName, HeaderValue, HeaderMap}; - /// - /// let mut h_order = OriginalHeaderOrder::default(); - /// let mut h_map = HeaderMap::new(); - /// - /// let name1 = HeaderName::try_from("Set-CookiE").expect("valid Set-CookiE header name"); - /// let value1 = HeaderValue::from_static("a=b"); - /// h_map.append(name1.clone(), value1); - /// h_order.append(name1); - /// - /// let name2 = HeaderName::try_from("Content-Encoding").expect("valid Content-Encoding header name"); - /// let value2 = HeaderValue::from_static("gzip"); - /// h_map.append(name2.clone(), value2); - /// h_order.append(name2); - /// - /// let name3 = HeaderName::try_from("SET-COOKIE").expect("valid SET-COOKIE header name"); - /// let value3 = HeaderValue::from_static("c=d"); - /// h_map.append(name3.clone(), value3); - /// h_order.append(name3); - /// - /// let mut iter = h_order.get_in_order(); - /// - /// let (name, idx) = iter.next().unwrap(); - /// assert_eq!("a=b", h_map.get_all(name).iter().nth(*idx).expect("get set-cookie header value")); - /// - /// let (name, idx) = iter.next().unwrap(); - /// assert_eq!("gzip", h_map.get_all(name).iter().nth(*idx).expect("get content-encoding header value")); - /// - /// let (name, idx) = iter.next().unwrap(); - /// assert_eq!("c=d", h_map.get_all(name).iter().nth(*idx).expect("get SET-COOKIE header value")); - /// ``` - pub fn get_in_order(&self) -> impl Iterator { - self.entry_order.iter() - } -} diff --git a/rama-http-core/src/h2/client.rs b/rama-http-core/src/h2/client.rs index 44229e26..bee457b5 100644 --- a/rama-http-core/src/h2/client.rs +++ b/rama-http-core/src/h2/client.rs @@ -138,10 +138,11 @@ use crate::h2::codec::{Codec, SendError, UserError}; use crate::h2::ext::Protocol; use crate::h2::frame::{Headers, Pseudo, Reason, Settings, StreamId}; use crate::h2::proto::{self, Error}; -use crate::h2::{FlowControl, PingPong, PseudoHeaderOrder, RecvStream, SendStream}; +use crate::h2::{FlowControl, PingPong, RecvStream, SendStream}; use bytes::{Buf, Bytes}; use rama_http_types::dep::http::{request, uri}; +use rama_http_types::proto::h2::PseudoHeaderOrder; use rama_http_types::{HeaderMap, Method, Request, Response, Version}; use std::fmt; use std::future::Future; diff --git a/rama-http-core/src/h2/frame/headers.rs b/rama-http-core/src/h2/frame/headers.rs index 27b63479..4b85f37f 100644 --- a/rama-http-core/src/h2/frame/headers.rs +++ b/rama-http-core/src/h2/frame/headers.rs @@ -4,17 +4,15 @@ use crate::h2::frame::{Error, Frame, Head, Kind}; use crate::h2::hpack::{self, BytesStr}; use rama_http_types::dep::http::uri; +use rama_http_types::proto::h2::{PseudoHeader, PseudoHeaderOrder, PseudoHeaderOrderIter}; use rama_http_types::{ header, HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, Uri, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use serde::{de::Error as _, Deserialize, Serialize}; -use smallvec::SmallVec; use std::fmt; use std::io::Cursor; -use std::str::FromStr; type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; @@ -104,155 +102,6 @@ impl PartialEq for Pseudo { impl Eq for Pseudo {} -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] -#[repr(u8)] -/// Defined in function of being able to communicate the used or desired -/// order in which the pseudo headers are in the h2 request. -/// -/// Used mainly in [`PseudoHeaderOrder`]. -pub enum PseudoHeader { - Method = 0b1000_0000, - Scheme = 0b0100_0000, - Authority = 0b0010_0000, - Path = 0b0001_0000, - Protocol = 0b0000_1000, - Status = 0b0000_0100, -} - -impl PseudoHeader { - pub fn as_str(&self) -> &'static str { - match self { - PseudoHeader::Method => ":method", - PseudoHeader::Scheme => ":scheme", - PseudoHeader::Authority => ":authority", - PseudoHeader::Path => ":path", - PseudoHeader::Protocol => ":protocol", - PseudoHeader::Status => ":status", - } - } -} - -impl fmt::Display for PseudoHeader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -rama_utils::macros::error::static_str_error! { - #[doc = "pseudo header string is invalid"] - pub struct InvalidPseudoHeaderStr; -} - -impl FromStr for PseudoHeader { - type Err = InvalidPseudoHeaderStr; - - fn from_str(s: &str) -> Result { - let s = s.trim(); - let s = s.strip_prefix(':').unwrap_or(s); - - if s.eq_ignore_ascii_case("method") { - Ok(Self::Method) - } else if s.eq_ignore_ascii_case("scheme") { - Ok(Self::Scheme) - } else if s.eq_ignore_ascii_case("authority") { - Ok(Self::Authority) - } else if s.eq_ignore_ascii_case("path") { - Ok(Self::Path) - } else if s.eq_ignore_ascii_case("protocol") { - Ok(Self::Protocol) - } else if s.eq_ignore_ascii_case("status") { - Ok(Self::Status) - } else { - Err(InvalidPseudoHeaderStr) - } - } -} - -impl Serialize for PseudoHeader { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.as_str().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for PseudoHeader { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = <&'de str>::deserialize(deserializer)?; - s.parse().map_err(D::Error::custom) - } -} - -const PSEUDO_HEADERS_STACK_SIZE: usize = 5; - -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct PseudoHeaderOrder { - headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, - mask: u8, -} - -impl PseudoHeaderOrder { - pub fn new() -> Self { - Self::default() - } - - pub fn push(&mut self, header: PseudoHeader) { - if self.mask & (header as u8) == 0 { - self.mask |= header as u8; - self.headers.push(header); - } else { - tracing::trace!("ignore duplicate psuedo header: {header:?}") - } - } - - pub fn extend(&mut self, iter: impl IntoIterator) { - for header in iter { - self.push(header); - } - } - - pub fn iter(&self) -> PseudoHeaderOrderIter { - self.clone().into_iter() - } - - pub fn is_empty(&self) -> bool { - self.headers.is_empty() - } - - pub fn len(&self) -> usize { - self.headers.len() - } -} - -impl IntoIterator for PseudoHeaderOrder { - type Item = PseudoHeader; - type IntoIter = PseudoHeaderOrderIter; - - fn into_iter(self) -> Self::IntoIter { - let PseudoHeaderOrder { mut headers, .. } = self; - headers.reverse(); - PseudoHeaderOrderIter { headers } - } -} - -#[derive(Debug)] -/// Iterator over a copy of [`PseudoHeaderOrder`]. -pub struct PseudoHeaderOrderIter { - headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, -} - -impl Iterator for PseudoHeaderOrderIter { - type Item = PseudoHeader; - - fn next(&mut self) -> Option { - self.headers.pop() - } -} - #[derive(Debug)] struct Iter { /// Pseudo headers diff --git a/rama-http-core/src/h2/frame/mod.rs b/rama-http-core/src/h2/frame/mod.rs index 2c98f1ea..a9e04d9a 100644 --- a/rama-http-core/src/h2/frame/mod.rs +++ b/rama-http-core/src/h2/frame/mod.rs @@ -52,8 +52,7 @@ pub use self::data::Data; pub use self::go_away::GoAway; pub use self::head::{Head, Kind}; pub use self::headers::{ - parse_u64, Continuation, Headers, InvalidPseudoHeaderStr, Pseudo, PseudoHeader, - PseudoHeaderOrder, PseudoHeaderOrderIter, PushPromise, PushPromiseHeaderError, + parse_u64, Continuation, Headers, Pseudo, PushPromise, PushPromiseHeaderError, }; pub use self::ping::Ping; pub use self::priority::{Priority, StreamDependency}; diff --git a/rama-http-core/src/h2/mod.rs b/rama-http-core/src/h2/mod.rs index 47d85cf9..43c7291d 100644 --- a/rama-http-core/src/h2/mod.rs +++ b/rama-http-core/src/h2/mod.rs @@ -118,9 +118,6 @@ mod frame; #[allow(missing_docs)] pub mod frame; -#[doc(inline)] -pub use frame::{InvalidPseudoHeaderStr, PseudoHeader, PseudoHeaderOrder, PseudoHeaderOrderIter}; - pub mod client; pub mod ext; pub mod server; diff --git a/rama-http-core/src/h2/server.rs b/rama-http-core/src/h2/server.rs index 9b90ad7c..8627dfec 100644 --- a/rama-http-core/src/h2/server.rs +++ b/rama-http-core/src/h2/server.rs @@ -118,9 +118,10 @@ use crate::h2::codec::{Codec, UserError}; use crate::h2::frame::{self, Pseudo, PushPromiseHeaderError, Reason, Settings, StreamId}; use crate::h2::proto::{self, Config, Error, Prioritized}; -use crate::h2::{FlowControl, PingPong, PseudoHeaderOrder, RecvStream, SendStream}; +use crate::h2::{FlowControl, PingPong, RecvStream, SendStream}; use bytes::{Buf, Bytes}; +use rama_http_types::proto::h2::PseudoHeaderOrder; use rama_http_types::{HeaderMap, Method, Request, Response}; use std::future::Future; use std::pin::Pin; diff --git a/rama-http-core/src/proto/h1/conn.rs b/rama-http-core/src/proto/h1/conn.rs index d9432967..89d703fb 100644 --- a/rama-http-core/src/proto/h1/conn.rs +++ b/rama-http-core/src/proto/h1/conn.rs @@ -19,6 +19,7 @@ use super::io::Buffered; use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; use crate::body::DecodedLength; use crate::headers; +use crate::proto::h1::EncodeHead; use crate::proto::{BodyLength, MessageHead}; const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; @@ -47,7 +48,6 @@ where io: Buffered::new(io), state: State { allow_half_close: false, - cached_headers: None, error: None, keep_alive: KA::Busy, method: None, @@ -199,7 +199,6 @@ where let msg = match self.io.parse::( cx, ParseContext { - cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, h1_parser_config: self.state.h1_parser_config.clone(), h1_max_headers: self.state.h1_max_headers, @@ -566,7 +565,12 @@ where let buf = self.io.headers_buf(); match super::role::encode_headers::( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body, keep_alive: self.state.wants_keep_alive(), req_method: &mut self.state.method, @@ -575,13 +579,7 @@ where }, buf, ) { - Ok(encoder) => { - debug_assert!(self.state.cached_headers.is_none()); - debug_assert!(head.headers.is_empty()); - self.state.cached_headers = Some(head.headers); - - Some(encoder) - } + Ok(encoder) => Some(encoder), Err(err) => { self.state.error = Some(err); self.state.writing = Writing::Closed; @@ -750,9 +748,6 @@ where return Err(crate::Error::new_version_h2()); } if let Some(msg) = T::on_error(&err) { - // Drop the cached headers so as to not trigger a debug - // assert in `write_head`... - self.state.cached_headers.take(); self.write_head(msg, None); self.state.error = Some(err); return Ok(()); @@ -848,8 +843,6 @@ impl Unpin for Conn {} struct State { allow_half_close: bool, - /// Re-usable HeaderMap to reduce allocating new ones. - cached_headers: Option, /// If an error occurs when there wasn't a direct way to return it /// back to the user, this is set. error: Option, diff --git a/rama-http-core/src/proto/h1/io.rs b/rama-http-core/src/proto/h1/io.rs index 696552f7..be10bf2f 100644 --- a/rama-http-core/src/proto/h1/io.rs +++ b/rama-http-core/src/proto/h1/io.rs @@ -176,7 +176,6 @@ where &mut self.read_buf, self.partial_len, ParseContext { - cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, h1_parser_config: parse_ctx.h1_parser_config.clone(), h1_max_headers: parse_ctx.h1_max_headers, @@ -687,7 +686,6 @@ mod tests { // Rather, this `poll_fn` will wrap the `Poll` result. std::future::poll_fn(|cx| { let parse_ctx = ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, diff --git a/rama-http-core/src/proto/h1/mod.rs b/rama-http-core/src/proto/h1/mod.rs index a6d6c5ee..a13d81c3 100644 --- a/rama-http-core/src/proto/h1/mod.rs +++ b/rama-http-core/src/proto/h1/mod.rs @@ -1,6 +1,7 @@ use bytes::BytesMut; use httparse::ParserConfig; -use rama_http_types::{HeaderMap, Method}; +use rama_http_types::dep::http; +use rama_http_types::{HeaderMap, Method, Version}; use crate::body::DecodedLength; use crate::proto::{BodyLength, MessageHead}; @@ -63,16 +64,26 @@ pub(crate) struct ParsedMessage { } pub(crate) struct ParseContext<'a> { - cached_headers: &'a mut Option, req_method: &'a mut Option, h1_parser_config: ParserConfig, h1_max_headers: Option, h09_responses: bool, } +struct EncodeHead<'a, S> { + /// HTTP version of the message. + pub(crate) version: Version, + /// Subject (request line or status line) of Incoming message. + pub(crate) subject: S, + /// Headers of the Incoming message. + pub(crate) headers: HeaderMap, + /// Extensions. + extensions: &'a mut http::Extensions, +} + /// Passed to Http1Transaction::encode pub(crate) struct Encode<'a, T> { - head: &'a mut MessageHead, + head: EncodeHead<'a, T>, body: Option, keep_alive: bool, req_method: &'a mut Option, diff --git a/rama-http-core/src/proto/h1/role.rs b/rama-http-core/src/proto/h1/role.rs index ed925337..865ec4c2 100644 --- a/rama-http-core/src/proto/h1/role.rs +++ b/rama-http-core/src/proto/h1/role.rs @@ -6,8 +6,8 @@ use bytes::Bytes; use bytes::BytesMut; use rama_http_types::dep::http; use rama_http_types::header::Entry; -use rama_http_types::header::ValueIter; -use rama_http_types::header::{self, HeaderMap, HeaderName, HeaderValue}; +use rama_http_types::header::{self, HeaderMap, HeaderValue}; +use rama_http_types::proto::h1::{Http1HeaderMap, Http1HeaderName}; use rama_http_types::{Method, StatusCode, Version}; use smallvec::{smallvec, smallvec_inline, SmallVec}; use tracing::{debug, error, trace, trace_span, warn}; @@ -15,30 +15,18 @@ use tracing::{debug, error, trace, trace_span, warn}; use crate::body::DecodedLength; use crate::common::date; use crate::error::Parse; -use crate::ext::HeaderCaseMap; -use crate::ext::OriginalHeaderOrder; use crate::headers; use crate::proto::h1::{ Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, }; -use crate::proto::RequestHead; use crate::proto::{BodyLength, MessageHead, RequestLine}; +use super::EncodeHead; + pub(crate) const DEFAULT_MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; -macro_rules! header_name { - ($bytes:expr) => {{ - { - match HeaderName::from_bytes($bytes) { - Ok(name) => name, - Err(e) => maybe_panic!(e), - } - } - }}; -} - macro_rules! header_value { ($bytes:expr) => {{ { @@ -47,18 +35,6 @@ macro_rules! header_value { }}; } -macro_rules! maybe_panic { - ($($arg:tt)*) => ({ - let _err = ($($arg)*); - if cfg!(debug_assertions) { - panic!("{:?}", _err); - } else { - error!("Internal rama_http_core error, please report {:?}", _err); - return Err(Parse::Internal) - } - }) -} - pub(super) fn parse_headers( bytes: &mut BytesMut, prev_len: Option, @@ -220,20 +196,19 @@ impl Http1Transaction for Server { let mut is_te_chunked = false; let mut wants_upgrade = subject.0 == Method::CONNECT; - let mut header_case_map = HeaderCaseMap::default(); - let mut header_order = OriginalHeaderOrder::default(); - - let mut headers = ctx.cached_headers.take().unwrap_or_default(); - - headers.reserve(headers_len); + let mut headers = Http1HeaderMap::with_capacity(headers_len); for header in &headers_indices[..headers_len] { // SAFETY: array is valid up to `headers_len` let header = unsafe { header.assume_init_ref() }; - let name = header_name!(&slice[header.name.0..header.name.1]); + let name = Http1HeaderName::try_copy_from_slice(&slice[header.name.0..header.name.1]) + .inspect_err(|err| { + tracing::debug!("invalid http1 header: {err:?}"); + }) + .map_err(|_| crate::error::Parse::Internal)?; let value = header_value!(slice.slice(header.value.0..header.value.1)); - match name { + match *name.header_name() { header::TRANSFER_ENCODING => { // https://tools.ietf.org/html/rfc7230#section-3.3.3 // If Transfer-Encoding header is present, and 'chunked' is @@ -295,9 +270,6 @@ impl Http1Transaction for Server { _ => (), } - header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); - header_order.append(&name); - headers.append(name, value); } @@ -308,8 +280,7 @@ impl Http1Transaction for Server { let mut extensions = http::Extensions::default(); - extensions.insert(header_case_map); - extensions.insert(header_order); + let headers = headers.consume(&mut extensions); *ctx.req_method = Some(subject.0.clone()); @@ -350,7 +321,6 @@ impl Http1Transaction for Server { (Ok(()), true) } else if msg.head.subject.is_informational() { warn!("response with 1xx status code not supported"); - *msg.head = MessageHead::default(); msg.head.subject = StatusCode::INTERNAL_SERVER_ERROR; msg.body = None; (Err(crate::Error::new_user_unsupported_status_code()), true) @@ -404,28 +374,9 @@ impl Http1Transaction for Server { extend(dst, b"\r\n"); } - let orig_headers; - let extensions = std::mem::take(&mut msg.head.extensions); - let orig_headers = match extensions.get::() { - None if msg.title_case_headers => { - orig_headers = HeaderCaseMap::default(); - Some(&orig_headers) - } - orig_headers => orig_headers, - }; - let encoder = if let Some(orig_headers) = orig_headers { - Self::encode_headers_with_original_case( - msg, - dst, - is_last, - orig_len, - wrote_len, - orig_headers, - )? - } else { - Self::encode_headers_with_lower_case(msg, dst, is_last, orig_len, wrote_len)? - }; - + let mut extensions = std::mem::take(msg.head.extensions); + let encoder = + Self::encode_h1_headers(msg, &mut extensions, dst, is_last, orig_len, wrote_len)?; ret.map(|()| encoder) } @@ -485,119 +436,68 @@ impl Server { Server::can_have_content_length(method, status) && method != &Some(Method::HEAD) } - fn encode_headers_with_lower_case( - msg: Encode<'_, StatusCode>, - dst: &mut Vec, - is_last: bool, - orig_len: usize, - wrote_len: bool, - ) -> crate::Result { - struct LowercaseWriter; - - impl HeaderNameWriter for LowercaseWriter { - #[inline] - fn write_full_header_line( - &mut self, - dst: &mut Vec, - line: &str, - _: (HeaderName, &str), - ) { - extend(dst, line.as_bytes()) - } - - #[inline] - fn write_header_name_with_colon( - &mut self, - dst: &mut Vec, - name_with_colon: &str, - _: HeaderName, - ) { - extend(dst, name_with_colon.as_bytes()) - } - - #[inline] - fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName) { - extend(dst, name.as_str().as_bytes()) - } - } - - Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, LowercaseWriter) - } - #[cold] #[inline(never)] - fn encode_headers_with_original_case( + fn encode_h1_headers( msg: Encode<'_, StatusCode>, + ext: &mut http::Extensions, dst: &mut Vec, is_last: bool, orig_len: usize, wrote_len: bool, - orig_headers: &HeaderCaseMap, ) -> crate::Result { - struct OrigCaseWriter<'map> { - map: &'map HeaderCaseMap, - current: Option<(HeaderName, ValueIter<'map, Bytes>)>, + struct OrigCaseWriter { title_case_headers: bool, } - impl HeaderNameWriter for OrigCaseWriter<'_> { + impl HeaderNameWriter for OrigCaseWriter { #[inline] fn write_full_header_line( &mut self, dst: &mut Vec, - _: &str, - (name, rest): (HeaderName, &str), + (name, rest): (Http1HeaderName, &str), ) { self.write_header_name(dst, &name); extend(dst, rest.as_bytes()); } #[inline] - fn write_header_name_with_colon( - &mut self, - dst: &mut Vec, - _: &str, - name: HeaderName, - ) { - self.write_header_name(dst, &name); + fn write_header_name_with_colon(&mut self, dst: &mut Vec, name: &Http1HeaderName) { + self.write_header_name(dst, name); extend(dst, b": "); } #[inline] - fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName) { - let Self { - map, - ref mut current, - title_case_headers, - } = *self; - if current.as_ref().map_or(true, |(last, _)| last != name) { - *current = None; - } - let (_, values) = - current.get_or_insert_with(|| (name.clone(), map.get_all_internal(name))); + fn write_header_name(&mut self, dst: &mut Vec, name: &Http1HeaderName) { + let Self { title_case_headers } = *self; - if let Some(orig_name) = values.next() { - extend(dst, orig_name); - } else if title_case_headers { - title_case(dst, name.as_str().as_bytes()); + if title_case_headers { + title_case(dst, name.as_bytes()); } else { - extend(dst, name.as_str().as_bytes()); + extend(dst, name.as_bytes()); } } } let header_name_writer = OrigCaseWriter { - map: orig_headers, - current: None, title_case_headers: msg.title_case_headers, }; - Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, header_name_writer) + Self::encode_headers( + msg, + ext, + dst, + is_last, + orig_len, + wrote_len, + header_name_writer, + ) } #[inline] fn encode_headers( msg: Encode<'_, StatusCode>, + ext: &mut http::Extensions, dst: &mut Vec, mut is_last: bool, orig_len: usize, @@ -617,7 +517,6 @@ impl Server { let mut encoder = Encoder::length(0); let mut allowed_trailer_fields: Option> = None; let mut wrote_date = false; - let mut cur_name = None; let mut is_name_written = false; let mut must_write_chunked = false; let mut prev_con_len = None; @@ -641,14 +540,13 @@ impl Server { }}; } - 'headers: for (opt_name, value) in msg.head.headers.drain() { - if let Some(n) = opt_name { - cur_name = Some(n); - handle_is_name_written!(); - is_name_written = false; - } - let name = cur_name.as_ref().expect("current header name"); - match *name { + let h1_headers = Http1HeaderMap::new(msg.head.headers, Some(ext)); + + 'headers: for (name, value) in h1_headers { + handle_is_name_written!(); + is_name_written = false; + + match *name.header_name() { header::CONTENT_LENGTH => { if wrote_len && !is_name_written { warn!("unexpected content-length found, canceling"); @@ -680,11 +578,7 @@ impl Server { if !is_name_written { encoder = Encoder::length(known_len); - header_name_writer.write_header_name_with_colon( - dst, - "content-length: ", - header::CONTENT_LENGTH, - ); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); wrote_len = true; is_name_written = true; @@ -712,11 +606,7 @@ impl Server { } else { // we haven't written content-length yet! encoder = Encoder::length(len); - header_name_writer.write_header_name_with_colon( - dst, - "content-length: ", - header::CONTENT_LENGTH, - ); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); wrote_len = true; is_name_written = true; @@ -771,11 +661,7 @@ impl Server { if !is_name_written { encoder = Encoder::chunked(); is_name_written = true; - header_name_writer.write_header_name_with_colon( - dst, - "transfer-encoding: ", - header::TRANSFER_ENCODING, - ); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); } else { extend(dst, b", "); @@ -789,11 +675,7 @@ impl Server { } if !is_name_written { is_name_written = true; - header_name_writer.write_header_name_with_colon( - dst, - "connection: ", - header::CONNECTION, - ); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); } else { extend(dst, b", "); @@ -814,11 +696,7 @@ impl Server { if !is_name_written { is_name_written = true; - header_name_writer.write_header_name_with_colon( - dst, - "trailer: ", - header::TRAILER, - ); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); } else { extend(dst, b", "); @@ -847,8 +725,7 @@ impl Server { "{:?} set is_name_written and didn't continue loop", name, ); - header_name_writer.write_header_name(dst, name); - extend(dst, b": "); + header_name_writer.write_header_name_with_colon(dst, &name); extend(dst, value.as_bytes()); extend(dst, b"\r\n"); } @@ -865,8 +742,7 @@ impl Server { } else { header_name_writer.write_full_header_line( dst, - "transfer-encoding: chunked\r\n", - (header::TRANSFER_ENCODING, ": chunked\r\n"), + (header::TRANSFER_ENCODING.into(), ": chunked\r\n"), ); Encoder::chunked() } @@ -876,11 +752,8 @@ impl Server { msg.req_method, msg.head.subject, ) { - header_name_writer.write_full_header_line( - dst, - "content-length: 0\r\n", - (header::CONTENT_LENGTH, ": 0\r\n"), - ) + header_name_writer + .write_full_header_line(dst, (header::CONTENT_LENGTH.into(), ": 0\r\n")) } Encoder::length(0) } @@ -888,11 +761,8 @@ impl Server { if !Server::can_have_content_length(msg.req_method, msg.head.subject) { Encoder::length(0) } else { - header_name_writer.write_header_name_with_colon( - dst, - "content-length: ", - header::CONTENT_LENGTH, - ); + header_name_writer + .write_header_name_with_colon(dst, &header::CONTENT_LENGTH.into()); extend(dst, ::itoa::Buffer::new().format(len).as_bytes()); extend(dst, b"\r\n"); Encoder::length(len) @@ -914,7 +784,7 @@ impl Server { // don't force the write if disabled if !wrote_date && msg.date_header { dst.reserve(date::DATE_VALUE_LENGTH + 8); - header_name_writer.write_header_name_with_colon(dst, "date: ", header::DATE); + header_name_writer.write_header_name_with_colon(dst, &header::DATE.into()); date::extend(dst); extend(dst, b"\r\n\r\n"); } else { @@ -944,16 +814,10 @@ trait HeaderNameWriter { fn write_full_header_line( &mut self, dst: &mut Vec, - line: &str, - name_value_pair: (HeaderName, &str), - ); - fn write_header_name_with_colon( - &mut self, - dst: &mut Vec, - name_with_colon: &str, - name: HeaderName, + name_value_pair: (Http1HeaderName, &str), ); - fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName); + fn write_header_name_with_colon(&mut self, dst: &mut Vec, name: &Http1HeaderName); + fn write_header_name(&mut self, dst: &mut Vec, name: &Http1HeaderName); } impl Http1Transaction for Client { @@ -1034,21 +898,22 @@ impl Http1Transaction for Client { let slice = slice.freeze(); - let mut headers = ctx.cached_headers.take().unwrap_or_default(); - let mut keep_alive = version == Version::HTTP_11; - let mut header_case_map = HeaderCaseMap::default(); - let mut header_order = OriginalHeaderOrder::default(); + let mut headers = Http1HeaderMap::with_capacity(headers_len); - headers.reserve(headers_len); for header in &headers_indices[..headers_len] { // SAFETY: array is valid up to `headers_len` let header = unsafe { header.assume_init_ref() }; - let name = header_name!(&slice[header.name.0..header.name.1]); + let name = + Http1HeaderName::try_copy_from_slice(&slice[header.name.0..header.name.1]) + .inspect_err(|err| { + tracing::debug!("invalid http1 header: {err:?}"); + }) + .map_err(|_| crate::error::Parse::Internal)?; let value = header_value!(slice.slice(header.value.0..header.value.1)); - if let header::CONNECTION = name { + if header::CONNECTION == name.header_name() { // keep_alive was previously set to default for Version if keep_alive { // HTTP/1.1 @@ -1059,16 +924,12 @@ impl Http1Transaction for Client { } } - header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); - header_order.append(&name); - headers.append(name, value); } let mut extensions = http::Extensions::default(); - extensions.insert(header_case_map); - extensions.insert(header_order); + let headers = headers.consume(&mut extensions); if let Some(reason) = reason { // Safety: httparse ensures that only valid reason phrase bytes are present in this @@ -1103,7 +964,7 @@ impl Http1Transaction for Client { } } - fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec) -> crate::Result { + fn encode(mut msg: Encode<'_, Self::Outgoing>, dst: &mut Vec) -> crate::Result { trace!( "Client::encode method={:?}, body={:?}", msg.head.subject.0, @@ -1112,7 +973,7 @@ impl Http1Transaction for Client { *msg.req_method = Some(msg.head.subject.0.clone()); - let body = Client::set_length(msg.head, msg.body); + let body = Client::set_length(&mut msg.head, msg.body); let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; dst.reserve(init_cap); @@ -1133,21 +994,14 @@ impl Http1Transaction for Client { } extend(dst, b"\r\n"); - if let Some(orig_headers) = msg.head.extensions.get::() { - write_headers_original_case( - &msg.head.headers, - orig_headers, - dst, - msg.title_case_headers, - ); - } else if msg.title_case_headers { - write_headers_title_case(&msg.head.headers, dst); - } else { - write_headers(&msg.head.headers, dst); - } + write_h1_headers( + msg.head.headers, + msg.title_case_headers, + msg.head.extensions, + dst, + ); extend(dst, b"\r\n"); - msg.head.headers.clear(); //TODO: remove when switching to drain() Ok(body) } @@ -1229,7 +1083,7 @@ impl Client { Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) } } - fn set_length(head: &mut RequestHead, body: Option) -> Encoder { + fn set_length(head: &mut EncodeHead<'_, RequestLine>, body: Option) -> Encoder { let body = if let Some(body) = body { body } else { @@ -1513,37 +1367,27 @@ pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec) { } #[cold] -fn write_headers_original_case( - headers: &HeaderMap, - orig_case: &HeaderCaseMap, - dst: &mut Vec, +fn write_h1_headers( + headers: HeaderMap, title_case_headers: bool, + ext: &mut http::Extensions, + dst: &mut Vec, ) { - // For each header name/value pair, there may be a value in the casemap - // that corresponds to the HeaderValue. So, we iterator all the keys, - // and for each one, try to pair the originally cased name with the value. - // - // TODO: consider adding http::HeaderMap::entries() iterator - for name in headers.keys() { - let mut names = orig_case.get_all(name); - - for value in headers.get_all(name) { - if let Some(orig_name) = names.next() { - extend(dst, orig_name.as_ref()); - } else if title_case_headers { - title_case(dst, name.as_str().as_bytes()); - } else { - extend(dst, name.as_str().as_bytes()); - } + let h1_headers = Http1HeaderMap::new(headers, Some(ext)); + for (name, value) in h1_headers { + if title_case_headers { + title_case(dst, name.as_bytes()); + } else { + extend(dst, name.as_bytes()); + } - // Wanted for curl test cases that send `X-Custom-Header:\r\n` - if value.is_empty() { - extend(dst, b":\r\n"); - } else { - extend(dst, b": "); - extend(dst, value.as_bytes()); - extend(dst, b"\r\n"); - } + // Wanted for curl test cases that send `X-Custom-Header:\r\n` + if value.is_empty() { + extend(dst, b":\r\n"); + } else { + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); } } } @@ -1571,6 +1415,7 @@ fn extend(dst: &mut Vec, data: &[u8]) { #[cfg(test)] mod tests { use bytes::BytesMut; + use rama_http_types::proto::h1::headers::original::OriginalHttp1Headers; use super::*; @@ -1581,7 +1426,6 @@ mod tests { let msg = Server::parse( &mut raw, ParseContext { - cached_headers: &mut None, req_method: &mut method, h1_parser_config: Default::default(), h1_max_headers: None, @@ -1603,7 +1447,6 @@ mod tests { fn test_parse_response() { let mut raw = BytesMut::from("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -1621,7 +1464,6 @@ mod tests { fn test_parse_request_errors() { let mut raw = BytesMut::from("GET htt:p// HTTP/1.1\r\nHost: ramaproxy.org\r\n\r\n"); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, @@ -1636,7 +1478,6 @@ mod tests { fn test_parse_response_h09_allowed() { let mut raw = BytesMut::from(H09_RESPONSE); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -1653,7 +1494,6 @@ mod tests { fn test_parse_response_h09_rejected() { let mut raw = BytesMut::from(H09_RESPONSE); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -1674,7 +1514,6 @@ mod tests { let mut h1_parser_config = ParserConfig::default(); h1_parser_config.allow_spaces_after_header_name_in_responses(true); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config, h1_max_headers: None, @@ -1692,7 +1531,6 @@ mod tests { fn test_parse_reject_response_with_spaces_before_colons() { let mut raw = BytesMut::from(RESPONSE_WITH_WHITESPACE_BETWEEN_HEADER_NAME_AND_COLON); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -1706,30 +1544,21 @@ mod tests { let mut raw = BytesMut::from("GET / HTTP/1.1\r\nHost: ramaproxy.org\r\nX-PASTA: noodles\r\n\r\n"); let ctx = ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, h09_responses: false, }; let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); - let orig_headers = parsed_message + let mut orig_headers = parsed_message .head .extensions - .get::() - .unwrap(); - assert_eq!( - orig_headers - .get_all_internal(&HeaderName::from_static("host")) - .collect::>(), - vec![&Bytes::from("Host")] - ); - assert_eq!( - orig_headers - .get_all_internal(&HeaderName::from_static("x-pasta")) - .collect::>(), - vec![&Bytes::from("X-PASTA")] - ); + .get::() + .unwrap() + .clone() + .into_iter(); + assert_eq!("Host", orig_headers.next().unwrap().as_str()); + assert_eq!("X-PASTA", orig_headers.next().unwrap().as_str()); } #[test] @@ -1739,7 +1568,6 @@ mod tests { Server::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, @@ -1755,7 +1583,6 @@ mod tests { Server::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: None, @@ -1980,7 +1807,6 @@ mod tests { assert!(Client::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -1996,7 +1822,6 @@ mod tests { Client::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut Some(m), h1_parser_config: Default::default(), h1_max_headers: None, @@ -2012,7 +1837,6 @@ mod tests { Client::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -2314,7 +2138,12 @@ mod tests { let mut vec = Vec::new(); Client::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2331,7 +2160,7 @@ mod tests { #[test] fn test_client_request_encode_orig_case() { use crate::proto::BodyLength; - use rama_http_types::header::{HeaderValue, CONTENT_LENGTH}; + use rama_http_types::header::HeaderValue; let mut head = MessageHead::default(); head.headers @@ -2339,14 +2168,19 @@ mod tests { head.headers .insert("content-type", HeaderValue::from_static("application/json")); - let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + let mut orig_headers = OriginalHttp1Headers::default(); + orig_headers.push("CONTENT-LENGTH".parse().unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); Client::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2366,7 +2200,7 @@ mod tests { #[test] fn test_client_request_encode_orig_and_title_case() { use crate::proto::BodyLength; - use rama_http_types::header::{HeaderValue, CONTENT_LENGTH}; + use rama_http_types::header::HeaderValue; let mut head = MessageHead::default(); head.headers @@ -2374,14 +2208,19 @@ mod tests { head.headers .insert("content-type", HeaderValue::from_static("application/json")); - let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + let mut orig_headers = OriginalHttp1Headers::default(); + orig_headers.push("CONTENT-LENGTH".parse().unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); Client::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2406,7 +2245,12 @@ mod tests { let mut vec = Vec::new(); let encoder = Server::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: None, keep_alive: true, req_method: &mut Some(Method::CONNECT), @@ -2436,7 +2280,12 @@ mod tests { let mut vec = Vec::new(); Server::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2456,7 +2305,7 @@ mod tests { #[test] fn test_server_response_encode_orig_case() { use crate::proto::BodyLength; - use rama_http_types::header::{HeaderValue, CONTENT_LENGTH}; + use rama_http_types::header::HeaderValue; let mut head = MessageHead::default(); head.headers @@ -2464,14 +2313,19 @@ mod tests { head.headers .insert("content-type", HeaderValue::from_static("application/json")); - let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + let mut orig_headers = OriginalHttp1Headers::default(); + orig_headers.push("CONTENT-LENGTH".parse().unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); Server::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2491,7 +2345,7 @@ mod tests { #[test] fn test_server_response_encode_orig_and_title_case() { use crate::proto::BodyLength; - use rama_http_types::header::{HeaderValue, CONTENT_LENGTH}; + use rama_http_types::header::HeaderValue; let mut head = MessageHead::default(); head.headers @@ -2499,14 +2353,19 @@ mod tests { head.headers .insert("content-type", HeaderValue::from_static("application/json")); - let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + let mut orig_headers = OriginalHttp1Headers::default(); + orig_headers.push("CONTENT-LENGTH".parse().unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); Server::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2527,7 +2386,7 @@ mod tests { #[test] fn test_disabled_date_header() { use crate::proto::BodyLength; - use rama_http_types::header::{HeaderValue, CONTENT_LENGTH}; + use rama_http_types::header::HeaderValue; let mut head = MessageHead::default(); head.headers @@ -2535,14 +2394,19 @@ mod tests { head.headers .insert("content-type", HeaderValue::from_static("application/json")); - let mut orig_headers = HeaderCaseMap::default(); - orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + let mut orig_headers = OriginalHttp1Headers::default(); + orig_headers.push("CONTENT-LENGTH".parse().unwrap()); head.extensions.insert(orig_headers); let mut vec = Vec::new(); Server::encode( Encode { - head: &mut head, + head: EncodeHead { + version: head.version, + subject: head.subject, + headers: head.headers, + extensions: &mut head.extensions, + }, body: Some(BodyLength::Known(10)), keep_alive: true, req_method: &mut None, @@ -2565,7 +2429,6 @@ mod tests { let parsed = Client::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), h1_max_headers: None, @@ -2603,7 +2466,6 @@ mod tests { let result = Server::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: max_headers, @@ -2622,7 +2484,6 @@ mod tests { let result = Client::parse( &mut bytes, ParseContext { - cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), h1_max_headers: max_headers, @@ -2724,11 +2585,15 @@ mod tests { let mut headers = HeaderMap::new(); let name = http::header::HeaderName::from_static("x-empty"); headers.insert(&name, "".parse().expect("parse empty")); - let mut orig_cases = HeaderCaseMap::default(); - orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); + let mut orig_cases = OriginalHttp1Headers::default(); + orig_cases.push("X-EmptY".parse().unwrap()); + + let mut ext = http::Extensions::new(); + ext.insert(orig_cases); let mut dst = Vec::new(); - super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + super::write_h1_headers(headers, false, &mut ext, &mut dst); assert_eq!( dst, b"X-EmptY:\r\n", @@ -2743,12 +2608,16 @@ mod tests { headers.insert(&name, "a".parse().unwrap()); headers.append(&name, "b".parse().unwrap()); - let mut orig_cases = HeaderCaseMap::default(); - orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); - orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); + let mut orig_cases = OriginalHttp1Headers::default(); + orig_cases.push("X-Empty".parse().unwrap()); + orig_cases.push("X-EMPTY".parse().unwrap()); + + let mut ext = http::Extensions::new(); + ext.insert(orig_cases); let mut dst = Vec::new(); - super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); + + super::write_h1_headers(headers, false, &mut ext, &mut dst); assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); } diff --git a/rama-http-types/Cargo.toml b/rama-http-types/Cargo.toml index a007a7ed..60ede9a8 100644 --- a/rama-http-types/Cargo.toml +++ b/rama-http-types/Cargo.toml @@ -31,6 +31,7 @@ rama-utils = { version = "0.2.0-alpha.5", path = "../rama-utils" } serde = { workspace = true, features = ["derive"] } serde_html_form = { workspace = true } serde_json = { workspace = true } +smallvec = { workspace = true } sync_wrapper = { workspace = true } tracing = { workspace = true } diff --git a/rama-http-types/src/lib.rs b/rama-http-types/src/lib.rs index ea9d6faf..cf60b2e7 100644 --- a/rama-http-types/src/lib.rs +++ b/rama-http-types/src/lib.rs @@ -33,6 +33,8 @@ pub type Request = http::Request; pub mod response; pub use response::{IntoResponse, IntoResponseParts, Response}; +pub mod proto; + pub mod headers; pub mod dep { diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs new file mode 100644 index 00000000..47f17393 --- /dev/null +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -0,0 +1,494 @@ +use std::collections::{self, HashMap}; + +use super::{ + name::{IntoHttp1HeaderName, IntoSealed as _, TryIntoHttp1HeaderName}, + original::{self, OriginalHttp1Headers}, + Http1HeaderName, +}; + +use crate::{ + dep::http::Extensions, + header::{self, InvalidHeaderName}, + HeaderMap, HeaderName, HeaderValue, +}; + +#[derive(Debug, Clone, Default)] +pub struct Http1HeaderMap { + headers: HeaderMap, + original_headers: OriginalHttp1Headers, +} + +impl Http1HeaderMap { + pub fn with_capacity(size: usize) -> Self { + Self { + headers: HeaderMap::with_capacity(size), + original_headers: OriginalHttp1Headers::with_capacity(size), + } + } + + pub fn new(headers: HeaderMap, ext: Option<&mut Extensions>) -> Self { + let original_headers = ext.and_then(|ext| ext.remove()).unwrap_or_default(); + Self { + headers, + original_headers, + } + } + + pub fn into_headers(self) -> HeaderMap { + self.headers + } + + /// use [`Self::into_headers`] if you do not care about + /// the original headers. + pub fn consume(self, ext: &mut Extensions) -> HeaderMap { + ext.insert(self.original_headers); + self.headers + } + + pub fn append(&mut self, name: impl IntoHttp1HeaderName, value: HeaderValue) { + let original_header = name.into_http1_header_name(); + let header_name = original_header.header_name(); + self.headers.append(header_name, value); + self.original_headers.push(original_header); + } + + pub fn try_append( + &mut self, + name: impl TryIntoHttp1HeaderName, + value: HeaderValue, + ) -> Result<(), InvalidHeaderName> { + let original_header = name.try_into_http1_header_name()?; + let header_name = original_header.header_name(); + self.headers.append(header_name, value); + self.original_headers.push(original_header); + Ok(()) + } +} + +impl From for Http1HeaderMap { + fn from(value: HeaderMap) -> Self { + Self { + headers: value, + ..Default::default() + } + } +} + +impl From for HeaderMap { + fn from(value: Http1HeaderMap) -> Self { + value.headers + } +} + +impl FromIterator<(N, HeaderValue)> for Http1HeaderMap { + fn from_iter>(iter: T) -> Self { + let mut map: Self = Default::default(); + for (name, value) in iter { + map.append(name, value); + } + map + } +} + +impl IntoIterator for Http1HeaderMap { + type Item = (Http1HeaderName, HeaderValue); + type IntoIter = Http1HeaderMapIntoIter; + + fn into_iter(self) -> Self::IntoIter { + if self.original_headers.is_empty() { + return Http1HeaderMapIntoIter { + state: Http1HeaderMapIntoIterState::Rem( + HeaderMapValueRemover::from(self.headers).into_iter(), + ), + }; + } + + Http1HeaderMapIntoIter { + state: Http1HeaderMapIntoIterState::Original { + original_iter: self.original_headers.into_iter(), + headers: self.headers.into(), + }, + } + } +} + +#[derive(Debug)] +pub struct Http1HeaderMapIntoIter { + state: Http1HeaderMapIntoIterState, +} + +#[derive(Debug)] +enum Http1HeaderMapIntoIterState { + Original { + original_iter: original::IntoIter, + headers: HeaderMapValueRemover, + }, + Rem(HeaderMapValueRemoverIntoIter), + Empty, +} + +impl Iterator for Http1HeaderMapIntoIter { + type Item = (Http1HeaderName, HeaderValue); + + fn next(&mut self) -> Option { + match std::mem::replace(&mut self.state, Http1HeaderMapIntoIterState::Empty) { + Http1HeaderMapIntoIterState::Original { + mut original_iter, + mut headers, + } => loop { + match original_iter.next() { + Some(http1_header_name) => { + if let Some(value) = headers.remove(http1_header_name.header_name()) { + let next = Some((http1_header_name, value)); + self.state = Http1HeaderMapIntoIterState::Original { + original_iter, + headers, + }; + return next; + } + } + None => { + let mut it = headers.into_iter(); + let next = it.next(); + self.state = Http1HeaderMapIntoIterState::Rem(it); + return next; + } + } + }, + Http1HeaderMapIntoIterState::Rem(mut it) => { + let next = it.next()?; + self.state = Http1HeaderMapIntoIterState::Rem(it); + Some(next) + } + Http1HeaderMapIntoIterState::Empty => None, + } + } +} + +#[derive(Debug)] +struct HeaderMapValueRemover { + header_map: HeaderMap, + removed_values: Option>>, +} + +impl From for HeaderMapValueRemover { + fn from(value: HeaderMap) -> Self { + Self { + header_map: value, + removed_values: None, + } + } +} + +impl HeaderMapValueRemover { + fn remove(&mut self, header: &HeaderName) -> Option { + match self.header_map.entry(header) { + header::Entry::Occupied(occupied_entry) => { + let (k, mut values) = occupied_entry.remove_entry_mult(); + match values.next() { + Some(v) => { + let values: Vec<_> = values.collect(); + if !values.is_empty() { + self.removed_values + .get_or_insert_with(Default::default) + .insert(k, values.into_iter()); + } + Some(v) + } + None => None, + } + } + header::Entry::Vacant(_) => self + .removed_values + .as_mut() + .and_then(|m| m.get_mut(header)) + .and_then(|i| i.next()), + } + } +} + +impl IntoIterator for HeaderMapValueRemover { + type Item = (Http1HeaderName, HeaderValue); + type IntoIter = HeaderMapValueRemoverIntoIter; + + fn into_iter(self) -> Self::IntoIter { + let removed_headers = self.removed_values.map(|r| r.into_iter()); + let remaining_headers = self.header_map.into_iter().peekable(); + HeaderMapValueRemoverIntoIter { + cached_header_name: None, + cached_headers: None, + removed_headers, + remaining_headers, + } + } +} + +#[derive(Debug)] +struct HeaderMapValueRemoverIntoIter { + cached_header_name: Option, + cached_headers: Option>>, + removed_headers: + Option>>, + remaining_headers: std::iter::Peekable>, +} + +impl Iterator for HeaderMapValueRemoverIntoIter { + type Item = (Http1HeaderName, HeaderValue); + + fn next(&mut self) -> Option { + if let Some(mut it) = self.cached_headers.take() { + if let Some(value) = it.next() { + match if it.peek().is_some() { + self.cached_headers = Some(it); + self.cached_header_name.clone() + } else { + self.cached_header_name.take() + } { + Some(name) => { + return Some((name.into_http1_header_name(), value)); + } + None => { + if cfg!(debug_assertions) { + panic!("no http header name found for multi-value header"); + } + } + } + } + } + + if let Some(removed_headers) = self.removed_headers.as_mut() { + for removed_header in removed_headers { + let mut cached_headers = removed_header.1.peekable(); + if cached_headers.peek().is_some() { + self.cached_header_name = Some(removed_header.0); + self.cached_headers = Some(cached_headers); + return self.next(); + } + } + } + + loop { + let header = self.remaining_headers.next()?; + match (header.0, self.cached_header_name.take()) { + (Some(name), _) | (None, Some(name)) => { + if self + .remaining_headers + .peek() + .map(|h| h.0.is_none()) + .unwrap_or_default() + { + self.cached_header_name = Some(name.clone()); + } + return Some((name.into_http1_header_name(), header.1)); + } + (None, None) => { + if cfg!(debug_assertions) { + panic!("no http header name found for multi-value header"); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default() { + let mut drain = Http1HeaderMap::default().into_iter(); + assert!(drain.next().is_none()); + } + + macro_rules! _add_extra_headers { + ( + $map:expr, + {} + ) => { + { + let extra: Option> = None; + extra + } + }; + ( + $map:expr, + { + $($name:literal: $value:literal),* + $(,)? + } + ) => { + { + let mut extra = vec![]; + $( + $map.append($name.to_lowercase().parse::().unwrap(), $value.parse().unwrap()); + extra.push(format!("{}: {}", $name.to_lowercase(), $value)); + )* + Some(extra) + } + }; + } + + macro_rules! test_req { + ({$( + $name:literal: $value:literal + ),* $(,)?}, $extra_headers:tt) => { + { + let mut map = Http1HeaderMap::default(); + + $( + map.try_append( + $name, + HeaderValue::from_str($value).unwrap() + ).unwrap(); + )* + + let extra_headers = _add_extra_headers!(&mut map.headers, $extra_headers); + + let mut drain = map.into_iter(); + + let mut next = || { + drain.next().map(|(name, value)| { + let s = format!( + "{}: {}", + name, + String::from_utf8_lossy(value.as_bytes()), + ); + s + }) + }; + + $( + assert_eq!(Some(format!("{}: {}", $name, $value)), next()); + )* + + if let Some(extra_headers) = extra_headers { + for extra in extra_headers { + assert_eq!(Some(extra), next()) + } + } + + assert_eq!(None, next()) + } + }; + } + + #[test] + fn test_only_extra_1() { + test_req!({}, { + "content-type": "application/json", + }) + } + + #[test] + fn test_only_extra_2() { + test_req!({}, { + "content-type": "application/json", + "content-length": "9", + }) + } + + #[test] + fn test_only_extra_2_manual_core_type() { + let mut map = HeaderMap::new(); + map.insert(header::CONTENT_LENGTH, "123".parse().unwrap()); + map.insert(header::CONTENT_TYPE, "json".parse().unwrap()); + + let mut iter = map.into_iter().peekable(); + let _ = iter.peek(); + assert_eq!( + iter.next(), + Some((Some(header::CONTENT_LENGTH), "123".parse().unwrap())) + ); + let _ = iter.peek(); + assert_eq!( + iter.next(), + Some((Some(header::CONTENT_TYPE), "json".parse().unwrap())) + ); + assert!(iter.next().is_none()); + } + + #[test] + fn test_only_extra_2_manual_dummy_wrapper() { + let mut map = HeaderMap::new(); + map.insert(header::CONTENT_LENGTH, "123".parse().unwrap()); + map.insert(header::CONTENT_TYPE, "json".parse().unwrap()); + + let map: Http1HeaderMap = map.into(); + + let mut iter = map.into_iter(); + assert_eq!( + iter.next(), + Some(( + header::CONTENT_LENGTH.into_http1_header_name(), + "123".parse().unwrap() + )) + ); + assert_eq!( + iter.next(), + Some(( + header::CONTENT_TYPE.into_http1_header_name(), + "json".parse().unwrap() + )) + ); + assert!(iter.next().is_none()); + } + + #[test] + fn test_happy_case_perfect() { + test_req!({ + "User-Agent": "curl/7.16.3", + "Host": "curl/7.16.3", + "Accept-Language": "en-us", + "Connection": "Keep-Alive", + "Content-Type": "application/json", + "X-FOO": "BaR", + }, {}) + } + + #[test] + fn test_happy_case_perfect_extra_headers() { + test_req!({ + "User-Agent": "curl/7.16.3", + "Host": "curl/7.16.3", + "Accept-Language": "en-us", + "Connection": "Keep-Alive", + "Content-Type": "application/json", + "X-FOO": "BaR", + }, { + "x-Hello": "world", + }) + } + + #[test] + fn test_happy_case_with_repetition() { + test_req!({ + "User-Agent": "curl/7.16.3", + "Host": "curl/7.16.3", + "Accept-Language": "en-us", + "Connection": "Keep-Alive", + "Accept-LANGuage": "NL-be", + "Content-Type": "application/json", + "Cookie": "a=1", + "Cookie": "b=2", + "X-FOO": "BaR", + }, {}) + } + + #[test] + fn test_happy_case_with_repetition_and_extra() { + test_req!({ + "User-Agent": "curl/7.16.3", + "Host": "curl/7.16.3", + "Accept-Language": "en-us", + "Connection": "Keep-Alive", + "Accept-LANGuage": "NL-be", + "Content-Type": "application/json", + "Cookie": "a=1", + "Cookie": "b=2", + "X-FOO": "BaR", + }, { + "x-Hello": "world", + }) + } +} diff --git a/rama-http-types/src/proto/h1/headers/mod.rs b/rama-http-types/src/proto/h1/headers/mod.rs new file mode 100644 index 00000000..97be3ce3 --- /dev/null +++ b/rama-http-types/src/proto/h1/headers/mod.rs @@ -0,0 +1,15 @@ +//! types and functionality to preserve +//! http1* header casing and order. +//! +//! This is especially important for proxies and clients... +//! because out there... are wild servers that care +//! about header casing for reasons... You can think +//! of that what you want, but they do and we have to deal with it. + +mod name; +pub use name::{Http1HeaderName, IntoHttp1HeaderName, TryIntoHttp1HeaderName}; + +pub mod original; + +mod map; +pub use map::Http1HeaderMap; diff --git a/rama-http-types/src/proto/h1/headers/name.rs b/rama-http-types/src/proto/h1/headers/name.rs new file mode 100644 index 00000000..b9294b7f --- /dev/null +++ b/rama-http-types/src/proto/h1/headers/name.rs @@ -0,0 +1,177 @@ +use bytes::Bytes; +use serde::{de::Error, Deserialize, Serialize}; +use std::{fmt, str::FromStr}; + +use crate::{header::InvalidHeaderName, HeaderName}; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Http1HeaderName { + name: HeaderName, + raw: Option, +} + +impl From for Http1HeaderName { + #[inline] + fn from(value: HeaderName) -> Self { + value.into_http1_header_name() + } +} + +impl From for HeaderName { + fn from(value: Http1HeaderName) -> Self { + value.name + } +} + +impl FromStr for Http1HeaderName { + type Err = InvalidHeaderName; + + #[inline] + fn from_str(s: &str) -> Result { + Http1HeaderName::try_copy_from_str(s) + } +} + +impl Serialize for Http1HeaderName { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Http1HeaderName { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = <&'de str>::deserialize(deserializer)?; + Self::try_copy_from_str(s).map_err(D::Error::custom) + } +} + +impl fmt::Display for Http1HeaderName { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl Http1HeaderName { + #[inline] + pub fn try_copy_from_slice(b: &[u8]) -> Result { + let bytes = Bytes::copy_from_slice(b); + bytes.try_into_http1_header_name() + } + + #[inline] + pub fn try_copy_from_str(s: &str) -> Result { + let bytes = Bytes::copy_from_slice(s.as_bytes()); + bytes.try_into_http1_header_name() + } + + pub fn as_bytes(&self) -> &[u8] { + if let Some(ref raw) = self.raw { + return raw.as_ref(); + } + self.name.as_ref() + } + + pub fn as_str(&self) -> &str { + self.raw + .as_deref() + .and_then(|b| std::str::from_utf8(b).ok()) + .unwrap_or_else(|| self.name.as_str()) + } + + pub fn header_name(&self) -> &HeaderName { + &self.name + } +} + +pub trait TryIntoHttp1HeaderName: try_into::Sealed {} + +impl TryIntoHttp1HeaderName for T {} + +mod try_into { + use super::*; + + pub trait Sealed { + #[doc(hidden)] + fn try_into_http1_header_name(self) -> Result; + } + + impl Sealed for T { + fn try_into_http1_header_name(self) -> Result { + Ok(self.into_http1_header_name()) + } + } + + impl Sealed for Bytes { + fn try_into_http1_header_name(self) -> Result { + let b: &[u8] = self.as_ref(); + let name = b.try_into()?; + Ok(Http1HeaderName { + name, + raw: Some(self), + }) + } + } + + macro_rules! from_owned_into_bytes { + ($($t:ty),+ $(,)?) => { + $( + impl Sealed for $t { + #[inline] + fn try_into_http1_header_name(self) -> Result { + let bytes = Bytes::from(self); + bytes.try_into_http1_header_name() + } + } + )+ + }; + } + + from_owned_into_bytes! { + &'static [u8], + &'static str, + String, + Vec, + } +} + +#[allow(unused)] +pub(crate) use try_into::Sealed as TryIntoSealed; + +pub trait IntoHttp1HeaderName: into::Sealed {} + +impl IntoHttp1HeaderName for T {} + +mod into { + use super::*; + + pub trait Sealed { + #[doc(hidden)] + fn into_http1_header_name(self) -> Http1HeaderName; + } + + impl Sealed for Http1HeaderName { + fn into_http1_header_name(self) -> Http1HeaderName { + self + } + } + + impl Sealed for HeaderName { + fn into_http1_header_name(self) -> Http1HeaderName { + Http1HeaderName { + name: self, + raw: None, + } + } + } +} + +#[allow(unused)] +pub(crate) use into::Sealed as IntoSealed; diff --git a/rama-http-types/src/proto/h1/headers/original.rs b/rama-http-types/src/proto/h1/headers/original.rs new file mode 100644 index 00000000..db9f6346 --- /dev/null +++ b/rama-http-types/src/proto/h1/headers/original.rs @@ -0,0 +1,84 @@ +//! Original order and case tracking for h1 tracking... +//! +//! If somebody reads this that designs protocols please +//! ensure that your protocol in no way can have deterministic +//! ordering or makes use of capitals... *sigh* what a painful design mistake + +use super::{Http1HeaderName, IntoHttp1HeaderName}; + +#[derive(Debug, Clone)] +// Keeps track of the order and casing +// of the inserted header names, usually used in combination +// with [`crate::proto::h1::Http1HeaderMap`]. +pub struct OriginalHttp1Headers { + /// ordered by insert order + ordered_headers: Vec, +} + +impl OriginalHttp1Headers { + pub fn push(&mut self, name: Http1HeaderName) { + self.ordered_headers.push(name); + } + + #[inline] + pub fn len(&self) -> usize { + self.ordered_headers.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.ordered_headers.is_empty() + } +} + +impl OriginalHttp1Headers { + #[inline] + pub fn with_capacity(size: usize) -> Self { + Self { + ordered_headers: Vec::with_capacity(size), + } + } +} + +impl Default for OriginalHttp1Headers { + #[inline] + fn default() -> Self { + Self::with_capacity(12) + } +} + +impl IntoIterator for OriginalHttp1Headers { + type Item = Http1HeaderName; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { + headers_iter: self.ordered_headers.into_iter(), + } + } +} + +impl FromIterator for OriginalHttp1Headers { + fn from_iter>(iter: T) -> Self { + OriginalHttp1Headers { + ordered_headers: iter + .into_iter() + .map(|it| it.into_http1_header_name()) + .collect(), + } + } +} + +#[derive(Debug)] +pub struct IntoIter { + headers_iter: std::vec::IntoIter, +} + +impl Iterator for IntoIter { + type Item = Http1HeaderName; + + #[inline] + fn next(&mut self) -> Option { + self.headers_iter.next() + } +} diff --git a/rama-http-types/src/proto/h1/mod.rs b/rama-http-types/src/proto/h1/mod.rs new file mode 100644 index 00000000..460dadf7 --- /dev/null +++ b/rama-http-types/src/proto/h1/mod.rs @@ -0,0 +1,4 @@ +//! high-level h1 proto types and functionality + +pub mod headers; +pub use headers::{Http1HeaderMap, Http1HeaderName, IntoHttp1HeaderName, TryIntoHttp1HeaderName}; diff --git a/rama-http-types/src/proto/h2/mod.rs b/rama-http-types/src/proto/h2/mod.rs new file mode 100644 index 00000000..65543113 --- /dev/null +++ b/rama-http-types/src/proto/h2/mod.rs @@ -0,0 +1,6 @@ +//! high-level h2 proto types and functionality + +mod pseudo_header; +pub use pseudo_header::{ + InvalidPseudoHeaderStr, PseudoHeader, PseudoHeaderOrder, PseudoHeaderOrderIter, +}; diff --git a/rama-http-types/src/proto/h2/pseudo_header.rs b/rama-http-types/src/proto/h2/pseudo_header.rs new file mode 100644 index 00000000..44b43a87 --- /dev/null +++ b/rama-http-types/src/proto/h2/pseudo_header.rs @@ -0,0 +1,163 @@ +use serde::{de::Error, Deserialize, Serialize}; +use smallvec::SmallVec; +use std::{fmt, str::FromStr}; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(u8)] +/// Defined in function of being able to communicate the used or desired +/// order in which the pseudo headers are in the h2 request. +/// +/// Used mainly in [`PseudoHeaderOrder`]. +pub enum PseudoHeader { + Method = 0b1000_0000, + Scheme = 0b0100_0000, + Authority = 0b0010_0000, + Path = 0b0001_0000, + Protocol = 0b0000_1000, + Status = 0b0000_0100, +} + +impl PseudoHeader { + pub fn as_str(&self) -> &'static str { + match self { + PseudoHeader::Method => ":method", + PseudoHeader::Scheme => ":scheme", + PseudoHeader::Authority => ":authority", + PseudoHeader::Path => ":path", + PseudoHeader::Protocol => ":protocol", + PseudoHeader::Status => ":status", + } + } +} + +impl fmt::Display for PseudoHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +rama_utils::macros::error::static_str_error! { + #[doc = "pseudo header string is invalid"] + pub struct InvalidPseudoHeaderStr; +} + +impl FromStr for PseudoHeader { + type Err = InvalidPseudoHeaderStr; + + fn from_str(s: &str) -> Result { + let s = s.trim(); + let s = s.strip_prefix(':').unwrap_or(s); + + if s.eq_ignore_ascii_case("method") { + Ok(Self::Method) + } else if s.eq_ignore_ascii_case("scheme") { + Ok(Self::Scheme) + } else if s.eq_ignore_ascii_case("authority") { + Ok(Self::Authority) + } else if s.eq_ignore_ascii_case("path") { + Ok(Self::Path) + } else if s.eq_ignore_ascii_case("protocol") { + Ok(Self::Protocol) + } else if s.eq_ignore_ascii_case("status") { + Ok(Self::Status) + } else { + Err(InvalidPseudoHeaderStr) + } + } +} + +impl Serialize for PseudoHeader { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.as_str().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for PseudoHeader { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = <&'de str>::deserialize(deserializer)?; + s.parse().map_err(D::Error::custom) + } +} + +const PSEUDO_HEADERS_STACK_SIZE: usize = 5; + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct PseudoHeaderOrder { + headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, + mask: u8, +} + +impl PseudoHeaderOrder { + pub fn new() -> Self { + Self::default() + } + + pub fn push(&mut self, header: PseudoHeader) { + if self.mask & (header as u8) == 0 { + self.mask |= header as u8; + self.headers.push(header); + } else { + tracing::trace!("ignore duplicate psuedo header: {header:?}") + } + } + + pub fn extend(&mut self, iter: impl IntoIterator) { + for header in iter { + self.push(header); + } + } + + pub fn iter(&self) -> PseudoHeaderOrderIter { + self.clone().into_iter() + } + + pub fn is_empty(&self) -> bool { + self.headers.is_empty() + } + + pub fn len(&self) -> usize { + self.headers.len() + } +} + +impl IntoIterator for PseudoHeaderOrder { + type Item = PseudoHeader; + type IntoIter = PseudoHeaderOrderIter; + + fn into_iter(self) -> Self::IntoIter { + let PseudoHeaderOrder { mut headers, .. } = self; + headers.reverse(); + PseudoHeaderOrderIter { headers } + } +} + +#[derive(Debug)] +/// Iterator over a copy of [`PseudoHeaderOrder`]. +pub struct PseudoHeaderOrderIter { + headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, +} + +impl Iterator for PseudoHeaderOrderIter { + type Item = PseudoHeader; + + fn next(&mut self) -> Option { + self.headers.pop() + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.headers.len())) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.headers.len() + } +} diff --git a/rama-http-types/src/proto/mod.rs b/rama-http-types/src/proto/mod.rs new file mode 100644 index 00000000..b77f1f2d --- /dev/null +++ b/rama-http-types/src/proto/mod.rs @@ -0,0 +1,7 @@ +//! High level pertaining to the HTTP message protocol. +//! +//! For low-level proto details you can refer to the `proto` module +//! in the `rama-http-core` crate. + +pub mod h1; +pub mod h2; diff --git a/rama-http/src/io/request.rs b/rama-http/src/io/request.rs index b9ea1026..4d9998ba 100644 --- a/rama-http/src/io/request.rs +++ b/rama-http/src/io/request.rs @@ -4,6 +4,10 @@ use crate::{ }; use bytes::Bytes; use rama_core::error::BoxError; +use rama_http_types::proto::{ + h1::Http1HeaderMap, + h2::{PseudoHeader, PseudoHeaderOrder}, +}; use tokio::io::{AsyncWrite, AsyncWriteExt}; /// Write an HTTP request to a writer in std http format. @@ -17,7 +21,7 @@ where W: AsyncWrite + Unpin + Send + Sync + 'static, B: http_body::Body> + Send + Sync + 'static, { - let (parts, body) = req.into_parts(); + let (mut parts, body) = req.into_parts(); if write_headers { w.write_all( @@ -36,8 +40,51 @@ where ) .await?; - for (key, value) in parts.headers.iter() { - w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + if let Some(pseudo_headers) = parts.extensions.get::() { + for header in pseudo_headers.iter() { + match header { + PseudoHeader::Method => { + w.write_all(format!("[{}: {}]\r\n", header, parts.method).as_bytes()) + .await?; + } + PseudoHeader::Scheme => { + w.write_all( + format!( + "[{}: {}]\r\n", + header, + parts.uri.scheme_str().unwrap_or("?") + ) + .as_bytes(), + ) + .await?; + } + PseudoHeader::Authority => { + w.write_all( + format!( + "[{}: {}]\r\n", + header, + parts.uri.authority().map(|a| a.as_str()).unwrap_or("?") + ) + .as_bytes(), + ) + .await?; + } + PseudoHeader::Path => { + w.write_all(format!("[{}: {}]\r\n", header, parts.uri.path()).as_bytes()) + .await?; + } + PseudoHeader::Protocol => (), // TODO: move ext h2 protocol out of h2 proto core once we need this info + PseudoHeader::Status => (), // not expected in request + } + } + } + + let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions)); + // put a clone of this data back into parts as we don't really want to consume it, just trace it + parts.headers = header_map.clone().consume(&mut parts.extensions); + + for (name, value) in header_map { + w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes()) .await?; } } diff --git a/rama-http/src/io/response.rs b/rama-http/src/io/response.rs index 25d2d8c5..14a8bb10 100644 --- a/rama-http/src/io/response.rs +++ b/rama-http/src/io/response.rs @@ -4,6 +4,10 @@ use crate::{ }; use bytes::Bytes; use rama_core::error::BoxError; +use rama_http_types::proto::{ + h1::Http1HeaderMap, + h2::{PseudoHeader, PseudoHeaderOrder}, +}; use tokio::io::{AsyncWrite, AsyncWriteExt}; /// Write an HTTP response to a writer in std http format. @@ -17,7 +21,7 @@ where W: AsyncWrite + Unpin + Send + Sync + 'static, B: http_body::Body> + Send + Sync + 'static, { - let (parts, body) = res.into_parts(); + let (mut parts, body) = res.into_parts(); if write_headers { w.write_all( @@ -35,8 +39,40 @@ where ) .await?; - for (key, value) in parts.headers.iter() { - w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + if let Some(pseudo_headers) = parts.extensions.get::() { + for header in pseudo_headers.iter() { + match header { + PseudoHeader::Method + | PseudoHeader::Scheme + | PseudoHeader::Authority + | PseudoHeader::Path + | PseudoHeader::Protocol => (), // not expected in response + PseudoHeader::Status => { + w.write_all( + format!( + "[{}: {} {}]\r\n", + header, + parts.status.as_u16(), + parts + .status + .canonical_reason() + .map(|r| format!(" {}", r)) + .unwrap_or_default(), + ) + .as_bytes(), + ) + .await?; + } + } + } + } + + let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions)); + // put a clone of this data back into parts as we don't really want to consume it, just trace it + parts.headers = header_map.clone().consume(&mut parts.extensions); + + for (name, value) in header_map { + w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes()) .await?; } } diff --git a/rama-http/src/lib.rs b/rama-http/src/lib.rs index ba529831..5163b360 100644 --- a/rama-http/src/lib.rs +++ b/rama-http/src/lib.rs @@ -19,7 +19,7 @@ #[doc(inline)] pub use ::rama_http_types::{ - header, + header, proto, response::{self, IntoResponse, Response}, Body, BodyDataStream, BodyExtractExt, BodyLimit, HeaderMap, HeaderName, HeaderValue, Method, Request, Scheme, StatusCode, Uri, Version, diff --git a/src/cli/service/echo.rs b/src/cli/service/echo.rs index 8d45e849..77b711fc 100644 --- a/src/cli/service/echo.rs +++ b/src/cli/service/echo.rs @@ -7,8 +7,8 @@ use crate::{ cli::ForwardKind, - combinators::Either7, - error::BoxError, + combinators::{Either3, Either7}, + error::{BoxError, OpaqueError}, http::{ dep::http_body_util::BodyExt, headers::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp}, @@ -18,6 +18,8 @@ use crate::{ trace::TraceLayer, ua::{UserAgent, UserAgentClassifierLayer}, }, + proto::h1::Http1HeaderMap, + proto::h2::PseudoHeaderOrder, response::Json, server::HttpServer, IntoResponse, Request, Response, Version, @@ -30,8 +32,6 @@ use crate::{ rt::Executor, Context, Layer, Service, }; -use rama_core::{combinators::Either3, error::OpaqueError}; -use rama_http_core::{ext::OriginalHeaderOrder, h2::PseudoHeaderOrder}; use serde_json::json; use std::{convert::Infallible, time::Duration}; use tokio::net::TcpStream; @@ -322,44 +322,24 @@ impl Service<(), Request> for EchoService { let authority = request_context.authority.to_string(); let scheme = request_context.protocol.to_string(); - // TODO: get in correct order - // TODO: get in correct case - - // TODO: get cleaner API + also original casing - let headers: Vec<_> = match req.extensions().get::() { - Some(original) => original - .get_in_order() - .map(|(name, idx)| { - let value = req - .headers() - .get_all(name) - .iter() - .nth(*idx) - .and_then(|v| v.to_str().ok()) - .unwrap_or_default() - .to_owned(); - let name = name.as_str().to_owned(); - (name, value) - }) - .collect(), - None => req - .headers() - .iter() - .map(|(name, value)| { - ( - name.as_str().to_owned(), - value.to_str().map(|v| v.to_owned()).unwrap_or_default(), - ) - }) - .collect(), - }; - let pseudo_headers: Option> = req .extensions() .get::() .map(|o| o.iter().collect()); - let (parts, body) = req.into_parts(); + let (mut parts, body) = req.into_parts(); + + let headers: Vec<_> = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions)) + .into_iter() + .map(|(name, value)| { + ( + name, + std::str::from_utf8(value.as_bytes()) + .map(|s| s.to_owned()) + .unwrap_or_else(|_| format!("0x{:x?}", value.as_bytes())), + ) + }) + .collect(); let body = body.collect().await.unwrap().to_bytes(); let body = hex::encode(body.as_ref()); diff --git a/src/http.rs b/src/http.rs index 89ba5842..4e5ef3ea 100644 --- a/src/http.rs +++ b/src/http.rs @@ -5,7 +5,7 @@ #[doc(inline)] pub use ::rama_http::{ - dep, header, headers, io, matcher, + dep, header, headers, io, matcher, proto, response::{self, IntoResponse, Response}, service, Body, BodyDataStream, BodyExtractExt, BodyLimit, HeaderMap, HeaderName, HeaderValue, Method, Request, Scheme, StatusCode, Uri, Version,