Skip to content

Commit

Permalink
make CLI http client preserve header order/casing
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Jan 5, 2025
1 parent eea4e0b commit cbf911f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 36 deletions.
38 changes: 19 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 14 additions & 2 deletions rama-http/src/io/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,20 @@ where
parts.headers = header_map.clone().consume(&mut parts.extensions);

for (name, value) in header_map {
w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
.await?;
match parts.version {
http::Version::HTTP_2 | http::Version::HTTP_3 => {
// write lower-case for H2/H3
w.write_all(
format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
.as_bytes(),
)
.await?;
}
_ => {
w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
.await?;
}
}
}
}

Expand Down
16 changes: 14 additions & 2 deletions rama-http/src/io/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,20 @@ where
parts.headers = header_map.clone().consume(&mut parts.extensions);

for (name, value) in header_map {
w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
.await?;
match parts.version {
http::Version::HTTP_2 | http::Version::HTTP_3 => {
// write lower-case for H2/H3
w.write_all(
format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
.as_bytes(),
)
.await?;
}
_ => {
w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
.await?;
}
}
}
}

Expand Down
54 changes: 41 additions & 13 deletions src/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
Body, Method, Request, Uri,
},
};
use rama_http::proto::h1::{headers::original::OriginalHttp1Headers, Http1HeaderName};
use rama_utils::macros::match_ignore_ascii_case_str;
use serde_json::Value;
use std::collections::HashMap;
Expand Down Expand Up @@ -65,7 +66,7 @@ impl RequestArgsBuilder {
method: None,
url: arg,
query: HashMap::new(),
headers: HashMap::new(),
headers: Vec::new(),
body: HashMap::new(),
})
}
Expand All @@ -78,7 +79,7 @@ impl RequestArgsBuilder {
method: method.clone(),
url: arg,
query: HashMap::new(),
headers: HashMap::new(),
headers: Vec::new(),
body: HashMap::new(),
}),
BuilderState::Data {
Expand Down Expand Up @@ -191,15 +192,24 @@ impl RequestArgsBuilder {
}
}
}

let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
for (name, value) in headers {
req = req.header(name, value);
let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
.context("convert string into Http1HeaderName")?;
req = req.header(header_name.clone(), value);
header_order.push(header_name);
}

if body.is_empty() {
return req
let mut req = req
.body(Body::empty())
.map_err(OpaqueError::from_std)
.context("create request without body");
.context("create request without body")?;

req.extensions_mut().insert(header_order);

return Ok(req);
}

let ct = content_type.unwrap_or_else(|| {
Expand All @@ -217,7 +227,9 @@ impl RequestArgsBuilder {

let req = if req.headers_ref().is_none() {
let req = req.header(CONTENT_TYPE, ct.header_value());
header_order.push(CONTENT_TYPE.into());
if ct == ContentType::Json {
header_order.push(ACCEPT.into());
req.header(ACCEPT, ct.header_value())
} else {
req
Expand All @@ -226,36 +238,44 @@ impl RequestArgsBuilder {
let headers = req.headers_mut().unwrap();

if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
header_order.push(CONTENT_TYPE.into());
entry.insert(ct.header_value());
}

if ct == ContentType::Json {
if let Entry::Vacant(entry) = headers.entry(ACCEPT) {
header_order.push(ACCEPT.into());
entry.insert(ct.header_value());
}
}

req
};

match ct {
let mut req = match ct {
ContentType::Json => {
let body = serde_json::to_string(&body)
.map_err(OpaqueError::from_std)
.context("serialize form body")?;
header_order.push(CONTENT_LENGTH.into());
req.header(CONTENT_LENGTH, body.len().to_string())
.body(Body::from(body))
}
ContentType::Form => {
let body = serde_html_form::to_string(&body)
.map_err(OpaqueError::from_std)
.context("serialize json body")?;
header_order.push(CONTENT_LENGTH.into());
req.header(CONTENT_LENGTH, body.len().to_string())
.body(Body::from(body))
}
}
.map_err(OpaqueError::from_std)
.context("create request with body")
.context("create request with body")?;

req.extensions_mut().insert(header_order);

Ok(req)
}
}
}
Expand All @@ -264,7 +284,7 @@ impl RequestArgsBuilder {
fn parse_arg_as_data(
arg: String,
query: &mut HashMap<String, Vec<String>>,
headers: &mut HashMap<String, String>,
headers: &mut Vec<(String, String)>,
body: &mut HashMap<String, Value>,
) -> Result<(), String> {
let mut state = DataParseArgState::None;
Expand Down Expand Up @@ -307,7 +327,7 @@ fn parse_arg_as_data(
} else {
// :
let value = &value[1..];
headers.insert(name.to_owned(), value.to_owned());
headers.push((name.to_owned(), value.to_owned()));
}
break;
}
Expand Down Expand Up @@ -395,7 +415,7 @@ enum BuilderState {
method: Option<Method>,
url: String,
query: HashMap<String, Vec<String>>,
headers: HashMap<String, String>,
headers: Vec<(String, String)>,
body: HashMap<String, Value>,
},
Error {
Expand Down Expand Up @@ -457,21 +477,29 @@ mod tests {
for (args, expected_request_str) in [
(vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
(vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
(
vec![
"example.com/bar",
"FOO:bar",
"AnSweR:42",
],
"GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
),
(
vec![
"example.com/foo",
"c=d",
"Content-Type:application/x-www-form-urlencoded",
],
"POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
"POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
),
(
vec![
"example.com/foo",
"a=b",
"Content-Type:application/json",
],
"POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
"POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
),
(
vec![
Expand Down Expand Up @@ -503,7 +531,7 @@ mod tests {
":3000",
"Cookie:foo=bar",
],
"GET / HTTP/1.1\r\ncookie: foo=bar\r\n\r\n",
"GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
),
(
vec![
Expand Down

0 comments on commit cbf911f

Please sign in to comment.