diff --git a/src/lib.rs b/src/lib.rs index 2b99aa0..ace5c68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,6 +131,15 @@ pub enum CheckSupportMethod { Head, } +/// Which URL should be used for subsequent range requests? +pub enum RangeRequestUrlSource { + /// Use the initial request URL + Request, + + /// Use the initial response URL + Response, +} + fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result { response .error_for_status() @@ -143,6 +152,7 @@ impl AsyncHttpRangeReader { client: impl Into, url: reqwest::Url, check_method: CheckSupportMethod, + range_request_url_source: RangeRequestUrlSource, extra_headers: HeaderMap, ) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> { let client = client.into(); @@ -156,7 +166,11 @@ impl AsyncHttpRangeReader { ) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_tail_response(client, response, extra_headers).await?; + let url = match range_request_url_source { + RangeRequestUrlSource::Request => url, + RangeRequestUrlSource::Response => response.url().clone(), + }; + let self_ = Self::from_tail_response(client, response, url, extra_headers).await?; Ok((self_, response_headers)) } CheckSupportMethod::Head => { @@ -164,7 +178,11 @@ impl AsyncHttpRangeReader { Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default()) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_head_response(client, response, extra_headers).await?; + let url = match range_request_url_source { + RangeRequestUrlSource::Request => url, + RangeRequestUrlSource::Response => response.url().clone(), + }; + let self_ = Self::from_head_response(client, response, url, extra_headers).await?; Ok((self_, response_headers)) } } @@ -200,6 +218,7 @@ impl AsyncHttpRangeReader { pub async fn from_tail_response( client: impl Into, tail_request_response: Response, + url: Url, extra_headers: HeaderMap, ) -> Result { let client = client.into(); @@ -245,7 +264,7 @@ impl AsyncHttpRangeReader { let (state_tx, state_rx) = watch::channel(StreamerState::default()); tokio::spawn(run_streamer( client, - tail_request_response.url().clone(), + url, extra_headers, Some((tail_request_response, start)), memory_map, @@ -300,6 +319,7 @@ impl AsyncHttpRangeReader { pub async fn from_head_response( client: impl Into, head_response: Response, + url: Url, extra_headers: HeaderMap, ) -> Result { let client = client.into(); @@ -345,7 +365,7 @@ impl AsyncHttpRangeReader { let (state_tx, state_rx) = watch::channel(StreamerState::default()); tokio::spawn(run_streamer( client, - head_response.url().clone(), + url, extra_headers, None, memory_map, @@ -688,6 +708,7 @@ mod test { Client::new(), server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), check_method, + RangeRequestUrlSource::Response, HeaderMap::default(), ) .await @@ -783,6 +804,57 @@ mod test { Client::new(), server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), check_method, + RangeRequestUrlSource::Response, + HeaderMap::default(), + ) + .await + .expect("bla"); + + // Also open a simple file reader + let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + .await + .unwrap(); + + // Read until the end and make sure that the contents matches + let mut range_read = vec![0; 64 * 1024]; + let mut file_read = vec![0; 64 * 1024]; + loop { + // Read with the async reader + let range_read_bytes = range.read(&mut range_read).await.unwrap(); + + // Read directly from the file + let file_read_bytes = file + .read_exact(&mut file_read[0..range_read_bytes]) + .await + .unwrap(); + + assert_eq!(range_read_bytes, file_read_bytes); + assert_eq!( + range_read[0..range_read_bytes], + file_read[0..file_read_bytes] + ); + + if file_read_bytes == 0 && range_read_bytes == 0 { + break; + } + } + } + + #[rstest] + #[case(RangeRequestUrlSource::Request)] + #[case(RangeRequestUrlSource::Response)] + #[tokio::test] + async fn async_range_reader_url_source(#[case] url_source: RangeRequestUrlSource) { + // Spawn a static file server + let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); + let server = StaticDirectoryServer::new(&path); + + // Construct an AsyncRangeReader + let (mut range, _) = AsyncHttpRangeReader::new( + Client::new(), + server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), + CheckSupportMethod::Head, + url_source, HeaderMap::default(), ) .await @@ -825,6 +897,7 @@ mod test { Client::new(), server.url().join("not-found").unwrap(), CheckSupportMethod::Head, + RangeRequestUrlSource::Response, HeaderMap::default(), ) .await