Skip to content

Commit

Permalink
initial attempt to keep h2 (field) headers also ordered
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Jan 3, 2025
1 parent 45c0fb7 commit 393ab89
Show file tree
Hide file tree
Showing 17 changed files with 211 additions and 131 deletions.
9 changes: 7 additions & 2 deletions rama-http-core/examples/h2_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
use rama_error::BoxError;
use rama_http_core::h2::client;
use rama_http_types::{HeaderMap, Request};
use rama_http_types::{
proto::h1::headers::original::OriginalHttp1Headers, HeaderMap, HeaderName, Request,
};

use tokio::net::TcpStream;

Expand All @@ -48,10 +50,13 @@ pub async fn main() -> Result<(), BoxError> {
let mut trailers = HeaderMap::new();
trailers.insert("zomg", "hello".parse().unwrap());

let mut trailer_order = OriginalHttp1Headers::new();
trailer_order.push(HeaderName::from_static("zomg").into());

let (response, mut stream) = client.send_request(request, false).unwrap();

// send trailers
stream.send_trailers(trailers).unwrap();
stream.send_trailers(trailers, trailer_order).unwrap();

// Spawn a task to run the conn...
tokio::spawn(async move {
Expand Down
14 changes: 11 additions & 3 deletions rama-http-core/src/h2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ use crate::h2::{FlowControl, PingPong, RecvStream, SendStream};

use bytes::{Buf, Bytes};
use rama_http_types::dep::http::{request, uri};
use rama_http_types::proto::h1::headers::original::OriginalHttp1Headers;
use rama_http_types::proto::h2::PseudoHeaderOrder;
use rama_http_types::{HeaderMap, Method, Request, Response, Version};
use std::fmt;
Expand Down Expand Up @@ -498,7 +499,7 @@ where
/// header::HeaderName::from_bytes(b"my-trailer").unwrap(),
/// header::HeaderValue::from_bytes(b"hello").unwrap());
///
/// send_stream.send_trailers(trailers).unwrap();
/// send_stream.send_trailers(trailers, Default::default()).unwrap();
///
/// let response = response.await.unwrap();
/// // Process the response
Expand Down Expand Up @@ -1614,6 +1615,8 @@ impl Peer {
}
}

let header_order: OriginalHttp1Headers = extensions.remove().unwrap_or_default();

if pseudo.scheme.is_none() {
// If the scheme is not set, then there are a two options.
//
Expand Down Expand Up @@ -1643,7 +1646,7 @@ impl Peer {
}

// Create the HEADERS frame
let mut frame = Headers::new(id, pseudo, headers);
let mut frame = Headers::new(id, pseudo, headers, header_order);

if end_of_stream {
frame.set_end_stream()
Expand Down Expand Up @@ -1671,6 +1674,7 @@ impl proto::Peer for Peer {
fn convert_poll_message(
pseudo: Pseudo,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
stream_id: StreamId,
) -> Result<Self::Poll, Error> {
let mut b = Response::builder();
Expand All @@ -1691,7 +1695,11 @@ impl proto::Peer for Peer {
};

if !pseudo.order.is_empty() {
response.extensions_mut().insert(pseudo.order.clone());
response.extensions_mut().insert(pseudo.order);
}

if !field_order.is_empty() {
response.extensions_mut().insert(field_order);
}

*response.headers_mut() = fields;
Expand Down
141 changes: 64 additions & 77 deletions rama-http-core/src/h2/frame/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use crate::h2::frame::{Error, Frame, Head, Kind};
use crate::h2::hpack::{self, BytesStr};

use rama_http_types::dep::http::uri;
use rama_http_types::proto::h1::headers::original::OriginalHttp1Headers;
use rama_http_types::proto::h1::headers::Http1HeaderMapIntoIter;
use rama_http_types::proto::h1::Http1HeaderMap;
use rama_http_types::proto::h2::{PseudoHeader, PseudoHeaderOrder, PseudoHeaderOrderIter};
use rama_http_types::{
header, HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, Uri,
};
use rama_http_types::{header, HeaderMap, HeaderName, Method, Request, StatusCode, Uri};

use bytes::{Buf, BufMut, Bytes, BytesMut};

Expand Down Expand Up @@ -111,17 +112,20 @@ struct Iter {
pseudo_order: PseudoHeaderOrderIter,

/// Header fields
fields: header::IntoIter<HeaderValue>,
fields: Http1HeaderMapIntoIter,
}

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug)]
struct HeaderBlock {
/// The decoded header fields
fields: HeaderMap,

/// Precomputed size of all of our header fields, for perf reasons
field_size: usize,

/// Keeps track of header fields
field_order: OriginalHttp1Headers,

/// Set to true if decoding went over the max header list size.
is_over_size: bool,

Expand All @@ -130,6 +134,24 @@ struct HeaderBlock {
pseudo: Pseudo,
}

impl PartialEq for HeaderBlock {
fn eq(&self, other: &Self) -> bool {
(
&self.fields,
&self.field_size,
&self.is_over_size,
&self.pseudo,
) == (
&other.fields,
&other.field_size,
&other.is_over_size,
&other.pseudo,
)
}
}

impl Eq for HeaderBlock {}

#[derive(Debug)]
struct EncodingHeaderBlock {
hpack: Bytes,
Expand All @@ -145,21 +167,31 @@ const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;

impl Headers {
/// Create a new HEADERS frame
pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
pub fn new(
stream_id: StreamId,
pseudo: Pseudo,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
) -> Self {
Headers {
stream_id,
stream_dep: None,
header_block: HeaderBlock {
field_size: calculate_headermap_size(&fields),
fields,
field_order,
is_over_size: false,
pseudo,
},
flags: HeadersFlag::default(),
}
}

pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
pub fn trailers(
stream_id: StreamId,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
) -> Self {
let mut flags = HeadersFlag::default();
flags.set_end_stream();

Expand All @@ -169,6 +201,7 @@ impl Headers {
header_block: HeaderBlock {
field_size: calculate_headermap_size(&fields),
fields,
field_order,
is_over_size: false,
pseudo: Pseudo::default(),
},
Expand Down Expand Up @@ -234,6 +267,7 @@ impl Headers {
header_block: HeaderBlock {
fields: HeaderMap::new(),
field_size: 0,
field_order: OriginalHttp1Headers::new(),
is_over_size: false,
pseudo: Pseudo::default(),
},
Expand Down Expand Up @@ -276,8 +310,12 @@ impl Headers {
self.header_block.is_over_size
}

pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields)
pub fn into_parts(self) -> (Pseudo, HeaderMap, OriginalHttp1Headers) {
(
self.header_block.pseudo,
self.header_block.fields,
self.header_block.field_order,
)
}

#[cfg(feature = "unstable")]
Expand Down Expand Up @@ -388,12 +426,14 @@ impl PushPromise {
promised_id: StreamId,
pseudo: Pseudo,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
) -> Self {
PushPromise {
flags: PushPromiseFlag::default(),
header_block: HeaderBlock {
field_size: calculate_headermap_size(&fields),
fields,
field_order,
is_over_size: false,
pseudo,
},
Expand Down Expand Up @@ -434,8 +474,8 @@ impl PushPromise {
}

#[cfg(feature = "unstable")]
pub fn into_fields(self) -> HeaderMap {
self.header_block.fields
pub fn into_fields(self) -> (HeaderMap, OriginalHttp1Headers) {
(self.header_block.fields, self.header_block.field_order)
}

/// Loads the push promise frame but doesn't actually do HPACK decoding.
Expand Down Expand Up @@ -484,6 +524,7 @@ impl PushPromise {
header_block: HeaderBlock {
fields: HeaderMap::new(),
field_size: 0,
field_order: OriginalHttp1Headers::new(),
is_over_size: false,
pseudo: Pseudo::default(),
},
Expand Down Expand Up @@ -545,8 +586,12 @@ impl PushPromise {
}

/// Consume `self`, returning the parts of the frame
pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields)
pub fn into_parts(self) -> (Pseudo, HeaderMap, OriginalHttp1Headers) {
(
self.header_block.pseudo,
self.header_block.fields,
self.header_block.field_order,
)
}
}

Expand Down Expand Up @@ -796,7 +841,10 @@ impl Iterator for Iter {

self.fields
.next()
.map(|(name, value)| hpack::Header::Field { name, value })
.map(|(name, value)| hpack::Header::Field {
name: Some(name.into()),
value,
})
}
}

Expand Down Expand Up @@ -974,6 +1022,7 @@ impl HeaderBlock {
if headers_size < max_header_list_size {
self.field_size +=
decoded_header_size(name.as_str().len(), value.len());
self.field_order.push(name.clone().into());
self.fields.append(name, value);
} else if !self.is_over_size {
tracing::trace!("load_hpack; header list size over max");
Expand Down Expand Up @@ -1026,7 +1075,7 @@ impl HeaderBlock {
let headers = Iter {
pseudo_order: self.pseudo.order.iter(),
pseudo: Some(self.pseudo),
fields: self.fields.into_iter(),
fields: Http1HeaderMap::from_parts(self.fields, self.field_order).into_iter(),
};

encoder.encode(headers, &mut hpack);
Expand Down Expand Up @@ -1076,68 +1125,6 @@ fn decoded_header_size(name: usize, value: usize) -> usize {
#[cfg(test)]
mod test {
use super::*;
use crate::h2::frame;
use crate::h2::hpack::{huffman, Encoder};

#[test]
fn test_nameless_header_at_resume() {
let mut encoder = Encoder::default();
let mut dst = BytesMut::new();

let headers = Headers::new(
StreamId::ZERO,
Default::default(),
HeaderMap::from_iter(vec![
(
HeaderName::from_static("hello"),
HeaderValue::from_static("world"),
),
(
HeaderName::from_static("hello"),
HeaderValue::from_static("zomg"),
),
(
HeaderName::from_static("hello"),
HeaderValue::from_static("sup"),
),
]),
);

let continuation = headers
.encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
.unwrap();

assert_eq!(17, dst.len());
assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
assert_eq!("hello", huff_decode(&dst[11..15]));
assert_eq!(0x80 | 4, dst[15]);

let mut world = dst[16..17].to_owned();

dst.clear();

assert!(continuation
.encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
.is_none());

world.extend_from_slice(&dst[9..12]);
assert_eq!("world", huff_decode(&world));

assert_eq!(24, dst.len());
assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);

// // Next is not indexed
assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
assert_eq!("zomg", huff_decode(&dst[15..18]));
assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
assert_eq!("sup", huff_decode(&dst[21..]));
}

fn huff_decode(src: &[u8]) -> BytesMut {
let mut buf = BytesMut::new();
huffman::decode(src, &mut buf).unwrap()
}

#[test]
fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
Expand Down
7 changes: 5 additions & 2 deletions rama-http-core/src/h2/proto/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::h2::error::Reason;
use crate::h2::frame::{Pseudo, StreamId};
use crate::h2::proto::{Error, Open};

use rama_http_types::proto::h1::headers::original::OriginalHttp1Headers;
use rama_http_types::{HeaderMap, Request, Response};

use std::fmt;
Expand All @@ -19,6 +20,7 @@ pub(crate) trait Peer {
fn convert_poll_message(
pseudo: Pseudo,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
stream_id: StreamId,
) -> Result<Self::Poll, Error>;

Expand Down Expand Up @@ -61,13 +63,14 @@ impl Dyn {
&self,
pseudo: Pseudo,
fields: HeaderMap,
field_order: OriginalHttp1Headers,
stream_id: StreamId,
) -> Result<PollMessage, Error> {
if self.is_server() {
crate::h2::server::Peer::convert_poll_message(pseudo, fields, stream_id)
crate::h2::server::Peer::convert_poll_message(pseudo, fields, field_order, stream_id)
.map(PollMessage::Server)
} else {
crate::h2::client::Peer::convert_poll_message(pseudo, fields, stream_id)
crate::h2::client::Peer::convert_poll_message(pseudo, fields, field_order, stream_id)
.map(PollMessage::Client)
}
}
Expand Down
Loading

0 comments on commit 393ab89

Please sign in to comment.