Skip to content

Commit

Permalink
feat: use ClientWithMiddleware instead of Client
Browse files Browse the repository at this point in the history
The work was mostly done by Vlad Ivanov in conda/rattler#488

Co-authored-by: Vlad Ivanov <[email protected]>
  • Loading branch information
baszalmstra and vlad-ivanov-name committed Feb 1, 2024
1 parent d9d7d07 commit d2e4937
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 67 deletions.
19 changes: 1 addition & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ itertools = "0.12.1"
bisection = "0.1.0"
memmap2 = "0.9.0"
reqwest = { version = "0.11.22", default-features = false, features = ["stream"] }
reqwest-middleware = "0.2.4"
tokio = { version = "1.33.0", default-features = false }
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio-util = "0.7.9"
Expand All @@ -29,21 +30,3 @@ async_zip = { version = "0.0.15", default-features = false, features = ["tokio"]
assert_matches = "1.5.0"
rstest = { version = "0.18.2" }
url = { version = "2.4.1" }

# The profile that 'cargo dist' will build with
[profile.dist]
inherits = "release"
lto = "thin"

# Config for 'cargo dist'
[workspace.metadata.dist]
# The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax)
cargo-dist-version = "0.3.1"
# CI backends to support
ci = ["github"]
# The installers to generate for each app
installers = []
# Target platforms to build apps for (Rust target-triple syntax)
targets = ["x86_64-unknown-linux-gnu", "aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-pc-windows-msvc"]
# Publish jobs to run in CI
pr-run-mode = "plan"
20 changes: 17 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
use std::sync::Arc;

/// Error type used for [`crate::AsyncHttpRangeReader`]
#[derive(Clone, Debug, thiserror::Error)]
pub enum AsyncHttpRangeReaderError {
/// The server does not support range requests
#[error("range requests are not supported")]
HttpRangeRequestUnsupported,

/// Other HTTP error
#[error(transparent)]
HttpError(#[from] Arc<reqwest::Error>),
HttpError(#[from] Arc<reqwest_middleware::Error>),

/// An error occurred during transport
#[error("an error occurred during transport: {0}")]
TransportError(#[source] Arc<reqwest::Error>),
TransportError(#[source] Arc<reqwest_middleware::Error>),

/// An IO error occurred
#[error("io error occurred: {0}")]
IoError(#[source] Arc<std::io::Error>),

/// Content-Range header is missing from response
#[error("content-range header is missing from response")]
ContentRangeMissing,

/// Content-Length header is missing from response
#[error("content-length header is missing from response")]
ContentLengthMissing,

/// Memory mapping the file failed
#[error("memory mapping the file failed")]
MemoryMapError(#[source] Arc<std::io::Error>),
}
Expand All @@ -30,8 +38,14 @@ impl From<std::io::Error> for AsyncHttpRangeReaderError {
}
}

impl From<reqwest_middleware::Error> for AsyncHttpRangeReaderError {
fn from(err: reqwest_middleware::Error) -> Self {
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
}
}

impl From<reqwest::Error> for AsyncHttpRangeReaderError {
fn from(err: reqwest::Error) -> Self {
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
AsyncHttpRangeReaderError::TransportError(Arc::new(err.into()))
}
}
80 changes: 47 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@
//! The primary use-case for this library is to be able to sparsely stream a zip archive over HTTP
//! but its designed in a generic fashion.
mod sparse_range;

mod error;
#[cfg(test)]
mod static_directory_server;
mod sparse_range;

use futures::{FutureExt, Stream, StreamExt};
use http_content_range::{ContentRange, ContentRangeBytes};
use memmap2::MmapMut;
use reqwest::header::HeaderMap;
use reqwest::{Client, Response, Url};
use reqwest::{Response, Url};
use sparse_range::SparseRange;
use std::{
io::{self, ErrorKind, SeekFrom},
Expand Down Expand Up @@ -74,7 +71,7 @@ pub use error::AsyncHttpRangeReaderError;
/// if response.status() == reqwest::StatusCode::NOT_MODIFIED {
/// Ok(None)
/// } else {
/// let reader = AsyncHttpRangeReader::from_head_response(client, response).await?;
/// let reader = AsyncHttpRangeReader::from_head_response(client.into(), response).await?;
/// Ok(Some(reader))
/// }
/// }
Expand Down Expand Up @@ -133,13 +130,20 @@ pub enum CheckSupportMethod {
Head,
}

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}

impl AsyncHttpRangeReader {
/// Construct a new `AsyncHttpRangeReader`.
pub async fn new(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
check_method: CheckSupportMethod,
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
let client = client.into();
match check_method {
CheckSupportMethod::NegativeRangeRequest(initial_chunk_size) => {
let response = Self::initial_tail_request(
Expand All @@ -148,7 +152,7 @@ impl AsyncHttpRangeReader {
initial_chunk_size,

Check warning on line 152 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
HeaderMap::default(),
)
.await?;
.await?;
let response_headers = response.headers().clone();
let self_ = Self::from_tail_response(client, response).await?;
Ok((self_, response_headers))
Expand All @@ -168,11 +172,12 @@ impl AsyncHttpRangeReader {
/// requests. This will return a number of bytes from the end of the stream. Use the
/// `initial_chunk_size` parameter to define how many bytes should be requested from the end.
pub async fn initial_tail_request(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
initial_chunk_size: u64,
extra_headers: HeaderMap,
) -> Result<Response, AsyncHttpRangeReaderError> {
let client = client.into();
let tail_response = client
.get(url)
.header(
Expand All @@ -182,7 +187,7 @@ impl AsyncHttpRangeReader {
.headers(extra_headers)
.send()
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(Arc::new)
.map_err(AsyncHttpRangeReaderError::HttpError)?;
Ok(tail_response)
Expand All @@ -191,24 +196,26 @@ impl AsyncHttpRangeReader {
/// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user
/// provided response that also has a range of bytes from the end as body)
pub async fn from_tail_response(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
tail_request_response: Response,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();

// Get the size of the file from this initial request
let content_range = ContentRange::parse(
tail_request_response
.headers()
.get(reqwest::header::CONTENT_RANGE)
.ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
.to_str()
.map_err(|_| AsyncHttpRangeReaderError::ContentRangeMissing)?,
.map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?,
);

Check warning on line 212 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
let (start, finish, complete_length) = match content_range {
ContentRange::Bytes(ContentRangeBytes {
first_byte,
last_byte,
complete_length,
}) => (first_byte, last_byte, complete_length),
first_byte,
last_byte,
complete_length,
}) => (first_byte, last_byte, complete_length),
_ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
};

Expand Down Expand Up @@ -266,17 +273,19 @@ impl AsyncHttpRangeReader {
/// Send an initial range request to determine if the remote accepts range
/// requests and get the content length
pub async fn initial_head_request(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
extra_headers: HeaderMap,
) -> Result<Response, AsyncHttpRangeReaderError> {
let client = client.into();

// Perform a HEAD request to get the content-length.
let head_response = client
.head(url.clone())
.headers(extra_headers)
.send()
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(Arc::new)
.map_err(AsyncHttpRangeReaderError::HttpError)?;
Ok(head_response)
Expand All @@ -285,9 +294,11 @@ impl AsyncHttpRangeReader {
/// Initialize the reader from [`AsyncHttpRangeReader::initial_head_request`] (or a user
/// provided response the)
pub async fn from_head_response(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
head_response: Response,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();

// Are range requests supported?
if head_response
.headers()
Expand All @@ -303,9 +314,9 @@ impl AsyncHttpRangeReader {
.get(reqwest::header::CONTENT_LENGTH)
.ok_or(AsyncHttpRangeReaderError::ContentLengthMissing)?
.to_str()
.map_err(|_| AsyncHttpRangeReaderError::ContentLengthMissing)?
.map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?
.parse()
.map_err(|_| AsyncHttpRangeReaderError::ContentLengthMissing)?;
.map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?;

// Allocate a memory map to hold the data
let memory_map = memmap2::MmapOptions::new()
Expand Down Expand Up @@ -363,8 +374,8 @@ impl AsyncHttpRangeReader {
inner.streamer_state.requested_ranges.clone()
}

// Prefetches a range of bytes from the remote. When specifying a large range this can
// drastically reduce the number of requests required to the server.
/// Prefetches a range of bytes from the remote. When specifying a large range this can
/// drastically reduce the number of requests required to the server.
pub async fn prefetch(&mut self, bytes: Range<u64>) {
let inner = self.inner.get_mut();

Expand Down Expand Up @@ -393,7 +404,7 @@ impl AsyncHttpRangeReader {
/// become available.
#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
async fn run_streamer(
client: Client,
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
initial_tail_response: Option<(Response, u64)>,
mut memory_map: MmapMut,
Expand All @@ -416,7 +427,7 @@ async fn run_streamer(
&mut state_tx,

Check warning on line 427 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
&mut state,
)
.await
.await
{
return;
}
Expand Down Expand Up @@ -453,7 +464,7 @@ async fn run_streamer(
.send()
.instrument(span)
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(|e| std::io::Error::new(ErrorKind::Other, e))
{
Err(e) => {
Expand All @@ -471,7 +482,7 @@ async fn run_streamer(
&mut state_tx,

Check warning on line 482 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
&mut state,
)
.await
.await
{
break 'outer;
}
Expand Down Expand Up @@ -619,6 +630,9 @@ impl AsyncRead for AsyncHttpRangeReader {
}
}

#[cfg(test)]
mod static_directory_server;

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -659,8 +673,8 @@ mod test {
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),

Check warning on line 673 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
check_method,
)
.await
.expect("Could not download range - did you run `git lfs pull`?");
.await
.expect("Could not download range - did you run `git lfs pull`?");

// Make sure we have read the last couple of bytes
range.prefetch(range.len() - 8192..range.len()).await;
Expand Down Expand Up @@ -753,8 +767,8 @@ mod test {
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),

Check warning on line 767 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
check_method,
)
.await
.expect("bla");
.await
.expect("bla");

// Also open a simple file reader
let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
Expand Down Expand Up @@ -794,8 +808,8 @@ mod test {
server.url().join("not-found").unwrap(),

Check warning on line 808 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Format and Lint

Diff in /home/runner/work/async_http_range_reader/async_http_range_reader/src/lib.rs
CheckSupportMethod::Head,
)
.await
.expect_err("expected an error");
.await
.expect_err("expected an error");

assert_matches!(
err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND)
Expand Down
Loading

0 comments on commit d2e4937

Please sign in to comment.