Skip to content

Commit

Permalink
Allow the request URL to be used for subsequent responses
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb committed Apr 8, 2024
1 parent da28771 commit 3b72d58
Showing 1 changed file with 77 additions and 4 deletions.
81 changes: 77 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ pub enum CheckSupportMethod {
Head,
}

/// Which URL should be used for subsequent range requests?
pub enum RangeRequestUrlSource {
/// Use the initial request URL
Request,

/// Use the initial response URL
Response,
}

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
.error_for_status()
Expand All @@ -143,6 +152,7 @@ impl AsyncHttpRangeReader {
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
check_method: CheckSupportMethod,
range_request_url_source: RangeRequestUrlSource,
extra_headers: HeaderMap,
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
let client = client.into();
Expand All @@ -156,15 +166,23 @@ impl AsyncHttpRangeReader {
)
.await?;
let response_headers = response.headers().clone();
let self_ = Self::from_tail_response(client, response, extra_headers).await?;
let url = match range_request_url_source {
RangeRequestUrlSource::Request => url,
RangeRequestUrlSource::Response => response.url().clone(),
};
let self_ = Self::from_tail_response(client, response, url, extra_headers).await?;
Ok((self_, response_headers))
}
CheckSupportMethod::Head => {
let response =
Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
.await?;
let response_headers = response.headers().clone();
let self_ = Self::from_head_response(client, response, extra_headers).await?;
let url = match range_request_url_source {
RangeRequestUrlSource::Request => url,
RangeRequestUrlSource::Response => response.url().clone(),
};
let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
Ok((self_, response_headers))
}
}
Expand Down Expand Up @@ -200,6 +218,7 @@ impl AsyncHttpRangeReader {
pub async fn from_tail_response(
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
tail_request_response: Response,
url: Url,
extra_headers: HeaderMap,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();
Expand Down Expand Up @@ -245,7 +264,7 @@ impl AsyncHttpRangeReader {
let (state_tx, state_rx) = watch::channel(StreamerState::default());
tokio::spawn(run_streamer(
client,
tail_request_response.url().clone(),
url,
extra_headers,
Some((tail_request_response, start)),
memory_map,
Expand Down Expand Up @@ -300,6 +319,7 @@ impl AsyncHttpRangeReader {
pub async fn from_head_response(
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
head_response: Response,
url: Url,
extra_headers: HeaderMap,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();
Expand Down Expand Up @@ -345,7 +365,7 @@ impl AsyncHttpRangeReader {
let (state_tx, state_rx) = watch::channel(StreamerState::default());
tokio::spawn(run_streamer(
client,
head_response.url().clone(),
url,
extra_headers,
None,
memory_map,
Expand Down Expand Up @@ -688,6 +708,7 @@ mod test {
Client::new(),
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
check_method,
RangeRequestUrlSource::Response,
HeaderMap::default(),
)
.await
Expand Down Expand Up @@ -783,6 +804,57 @@ mod test {
Client::new(),
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
check_method,
RangeRequestUrlSource::Response,
HeaderMap::default(),
)
.await
.expect("bla");

// Also open a simple file reader
let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
.await
.unwrap();

// Read until the end and make sure that the contents matches
let mut range_read = vec![0; 64 * 1024];
let mut file_read = vec![0; 64 * 1024];
loop {
// Read with the async reader
let range_read_bytes = range.read(&mut range_read).await.unwrap();

// Read directly from the file
let file_read_bytes = file
.read_exact(&mut file_read[0..range_read_bytes])
.await
.unwrap();

assert_eq!(range_read_bytes, file_read_bytes);
assert_eq!(
range_read[0..range_read_bytes],
file_read[0..file_read_bytes]
);

if file_read_bytes == 0 && range_read_bytes == 0 {
break;
}
}
}

#[rstest]
#[case(RangeRequestUrlSource::Request)]
#[case(RangeRequestUrlSource::Response)]
#[tokio::test]
async fn async_range_reader_url_source(#[case] url_source: RangeRequestUrlSource) {
// Spawn a static file server
let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
let server = StaticDirectoryServer::new(&path);

// Construct an AsyncRangeReader
let (mut range, _) = AsyncHttpRangeReader::new(
Client::new(),
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
CheckSupportMethod::Head,
url_source,
HeaderMap::default(),
)
.await
Expand Down Expand Up @@ -825,6 +897,7 @@ mod test {
Client::new(),
server.url().join("not-found").unwrap(),
CheckSupportMethod::Head,
RangeRequestUrlSource::Response,
HeaderMap::default(),
)
.await
Expand Down

0 comments on commit 3b72d58

Please sign in to comment.