Skip to content

Commit

Permalink
feat: use ClientWithMiddleware instead of Client (#4)
Browse files Browse the repository at this point in the history
The work was mostly done by Vlad Ivanov in conda/rattler#488

Co-authored-by: Vlad Ivanov <[email protected]>
  • Loading branch information
baszalmstra and vlad-ivanov-name authored Feb 1, 2024
1 parent d9d7d07 commit 30045ff
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 53 deletions.
19 changes: 1 addition & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ itertools = "0.12.1"
bisection = "0.1.0"
memmap2 = "0.9.0"
reqwest = { version = "0.11.22", default-features = false, features = ["stream"] }
reqwest-middleware = "0.2.4"
tokio = { version = "1.33.0", default-features = false }
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio-util = "0.7.9"
Expand All @@ -29,21 +30,3 @@ async_zip = { version = "0.0.15", default-features = false, features = ["tokio"]
assert_matches = "1.5.0"
rstest = { version = "0.18.2" }
url = { version = "2.4.1" }

# The profile that 'cargo dist' will build with
[profile.dist]
inherits = "release"
lto = "thin"

# Config for 'cargo dist'
[workspace.metadata.dist]
# The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax)
cargo-dist-version = "0.3.1"
# CI backends to support
ci = ["github"]
# The installers to generate for each app
installers = []
# Target platforms to build apps for (Rust target-triple syntax)
targets = ["x86_64-unknown-linux-gnu", "aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-pc-windows-msvc"]
# Publish jobs to run in CI
pr-run-mode = "plan"
20 changes: 17 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
use std::sync::Arc;

/// Error type used for [`crate::AsyncHttpRangeReader`]
#[derive(Clone, Debug, thiserror::Error)]
pub enum AsyncHttpRangeReaderError {
/// The server does not support range requests
#[error("range requests are not supported")]
HttpRangeRequestUnsupported,

/// Other HTTP error
#[error(transparent)]
HttpError(#[from] Arc<reqwest::Error>),
HttpError(#[from] Arc<reqwest_middleware::Error>),

/// An error occurred during transport
#[error("an error occurred during transport: {0}")]
TransportError(#[source] Arc<reqwest::Error>),
TransportError(#[source] Arc<reqwest_middleware::Error>),

/// An IO error occurred
#[error("io error occurred: {0}")]
IoError(#[source] Arc<std::io::Error>),

/// Content-Range header is missing from response
#[error("content-range header is missing from response")]
ContentRangeMissing,

/// Content-Length header is missing from response
#[error("content-length header is missing from response")]
ContentLengthMissing,

/// Memory mapping the file failed
#[error("memory mapping the file failed")]
MemoryMapError(#[source] Arc<std::io::Error>),
}
Expand All @@ -30,8 +38,14 @@ impl From<std::io::Error> for AsyncHttpRangeReaderError {
}
}

impl From<reqwest_middleware::Error> for AsyncHttpRangeReaderError {
fn from(err: reqwest_middleware::Error) -> Self {
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
}
}

impl From<reqwest::Error> for AsyncHttpRangeReaderError {
fn from(err: reqwest::Error) -> Self {
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
AsyncHttpRangeReaderError::TransportError(Arc::new(err.into()))
}
}
52 changes: 33 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@
//! The primary use-case for this library is to be able to sparsely stream a zip archive over HTTP
//! but its designed in a generic fashion.
mod sparse_range;

mod error;
#[cfg(test)]
mod static_directory_server;
mod sparse_range;

use futures::{FutureExt, Stream, StreamExt};
use http_content_range::{ContentRange, ContentRangeBytes};
use memmap2::MmapMut;
use reqwest::header::HeaderMap;
use reqwest::{Client, Response, Url};
use reqwest::{Response, Url};
use sparse_range::SparseRange;
use std::{
io::{self, ErrorKind, SeekFrom},
Expand Down Expand Up @@ -133,13 +130,20 @@ pub enum CheckSupportMethod {
Head,
}

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}

impl AsyncHttpRangeReader {
/// Construct a new `AsyncHttpRangeReader`.
pub async fn new(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
check_method: CheckSupportMethod,
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
let client = client.into();
match check_method {
CheckSupportMethod::NegativeRangeRequest(initial_chunk_size) => {
let response = Self::initial_tail_request(
Expand Down Expand Up @@ -168,11 +172,12 @@ impl AsyncHttpRangeReader {
/// requests. This will return a number of bytes from the end of the stream. Use the
/// `initial_chunk_size` parameter to define how many bytes should be requested from the end.
pub async fn initial_tail_request(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
initial_chunk_size: u64,
extra_headers: HeaderMap,
) -> Result<Response, AsyncHttpRangeReaderError> {
let client = client.into();
let tail_response = client
.get(url)
.header(
Expand All @@ -182,7 +187,7 @@ impl AsyncHttpRangeReader {
.headers(extra_headers)
.send()
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(Arc::new)
.map_err(AsyncHttpRangeReaderError::HttpError)?;
Ok(tail_response)
Expand All @@ -191,17 +196,19 @@ impl AsyncHttpRangeReader {
/// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user
/// provided response that also has a range of bytes from the end as body)
pub async fn from_tail_response(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
tail_request_response: Response,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();

// Get the size of the file from this initial request
let content_range = ContentRange::parse(
tail_request_response
.headers()
.get(reqwest::header::CONTENT_RANGE)
.ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
.to_str()
.map_err(|_| AsyncHttpRangeReaderError::ContentRangeMissing)?,
.map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?,
);
let (start, finish, complete_length) = match content_range {
ContentRange::Bytes(ContentRangeBytes {
Expand Down Expand Up @@ -266,17 +273,19 @@ impl AsyncHttpRangeReader {
/// Send an initial range request to determine if the remote accepts range
/// requests and get the content length
pub async fn initial_head_request(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
url: reqwest::Url,
extra_headers: HeaderMap,
) -> Result<Response, AsyncHttpRangeReaderError> {
let client = client.into();

// Perform a HEAD request to get the content-length.
let head_response = client
.head(url.clone())
.headers(extra_headers)
.send()
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(Arc::new)
.map_err(AsyncHttpRangeReaderError::HttpError)?;
Ok(head_response)
Expand All @@ -285,9 +294,11 @@ impl AsyncHttpRangeReader {
/// Initialize the reader from [`AsyncHttpRangeReader::initial_head_request`] (or a user
/// provided response the)
pub async fn from_head_response(
client: reqwest::Client,
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
head_response: Response,
) -> Result<Self, AsyncHttpRangeReaderError> {
let client = client.into();

// Are range requests supported?
if head_response
.headers()
Expand All @@ -303,9 +314,9 @@ impl AsyncHttpRangeReader {
.get(reqwest::header::CONTENT_LENGTH)
.ok_or(AsyncHttpRangeReaderError::ContentLengthMissing)?
.to_str()
.map_err(|_| AsyncHttpRangeReaderError::ContentLengthMissing)?
.map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?
.parse()
.map_err(|_| AsyncHttpRangeReaderError::ContentLengthMissing)?;
.map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?;

// Allocate a memory map to hold the data
let memory_map = memmap2::MmapOptions::new()
Expand Down Expand Up @@ -363,8 +374,8 @@ impl AsyncHttpRangeReader {
inner.streamer_state.requested_ranges.clone()
}

// Prefetches a range of bytes from the remote. When specifying a large range this can
// drastically reduce the number of requests required to the server.
/// Prefetches a range of bytes from the remote. When specifying a large range this can
/// drastically reduce the number of requests required to the server.
pub async fn prefetch(&mut self, bytes: Range<u64>) {
let inner = self.inner.get_mut();

Expand Down Expand Up @@ -393,7 +404,7 @@ impl AsyncHttpRangeReader {
/// become available.
#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
async fn run_streamer(
client: Client,
client: reqwest_middleware::ClientWithMiddleware,
url: Url,
initial_tail_response: Option<(Response, u64)>,
mut memory_map: MmapMut,
Expand Down Expand Up @@ -453,7 +464,7 @@ async fn run_streamer(
.send()
.instrument(span)
.await
.and_then(Response::error_for_status)
.and_then(error_for_status)
.map_err(|e| std::io::Error::new(ErrorKind::Other, e))
{
Err(e) => {
Expand Down Expand Up @@ -619,6 +630,9 @@ impl AsyncRead for AsyncHttpRangeReader {
}
}

#[cfg(test)]
mod static_directory_server;

#[cfg(test)]
mod test {
use super::*;
Expand Down
22 changes: 9 additions & 13 deletions src/sparse_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ impl SparseRange {
// Compute the bounds of covered range taking into account existing covered ranges.
let start = left_slice
.first()
.map(|&left_bound| left_bound.min(range_start))
.unwrap_or(range_start);
.map_or(range_start, |&left_bound| left_bound.min(range_start));

// Get the ranges that are missing
let mut bound = start;
Expand All @@ -79,8 +78,7 @@ impl SparseRange {

let end = right_slice
.last()
.map(|&right_bound| right_bound.max(range_end))
.unwrap_or(range_end);
.map_or(range_end, |&right_bound| right_bound.max(range_end));

bound > end
}
Expand All @@ -93,7 +91,7 @@ impl SparseRange {
}

/// Find the ranges that are uncovered for the specified range together with what the
/// SparseRange would look like if we covered that range.
/// [`SparseRange`] would look like if we covered that range.
pub fn cover(&self, range: Range<u64>) -> Option<(SparseRange, Vec<RangeInclusive<u64>>)> {
let range_start = range.start;
let range_end = range.end - 1;
Expand All @@ -109,12 +107,10 @@ impl SparseRange {
// Compute the bounds of covered range taking into account existing covered ranges.
let start = left_slice
.first()
.map(|&left_bound| left_bound.min(range_start))
.unwrap_or(range_start);
.map_or(range_start, |&left_bound| left_bound.min(range_start));
let end = right_slice
.last()
.map(|&right_bound| right_bound.max(range_end))
.unwrap_or(range_end);
.map_or(range_end, |&right_bound| right_bound.max(range_end));

// Get the ranges that are missing
let mut ranges = Vec::new();
Expand All @@ -126,10 +122,12 @@ impl SparseRange {
bound = right_bound + 1;
}
if bound <= end {
ranges.push(bound..=end)
ranges.push(bound..=end);
}

if !ranges.is_empty() {
if ranges.is_empty() {
None
} else {
let mut new_left = self.left.clone();
new_left.splice(left_index..right_index, [start]);
let mut new_right = self.right.clone();
Expand All @@ -141,8 +139,6 @@ impl SparseRange {
},
ranges,
))
} else {
None
}
}
}
Expand Down

0 comments on commit 30045ff

Please sign in to comment.