diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index b568b5645..f8d10babd 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -1,7 +1,3 @@ -use rattler::{ - repo_data::fetch::{terminal_progress, MultiRequestRepoDataBuilder}, - solver::SolverProblem, -}; use rattler_conda_types::{Channel, ChannelConfig, MatchSpec}; #[derive(Debug, clap::Parser)] @@ -17,7 +13,7 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { let channel_config = ChannelConfig::default(); // Parse the match specs - let specs = opt + let _specs = opt .specs .iter() .map(|spec| MatchSpec::from_str(spec, &channel_config)) @@ -31,38 +27,38 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { .map_err(|e| anyhow::anyhow!("could not create cache directory: {}", e))?; // Get the channels to download - let channels = opt + let _channels = opt .channels .unwrap_or_else(|| vec![String::from("conda-forge")]) .into_iter() .map(|channel_str| Channel::from_str(&channel_str, &channel_config)) .collect::, _>>()?; - // Download all repo data from the channels and create an index - let repo_data_per_source = MultiRequestRepoDataBuilder::default() - .set_cache_dir(&cache_dir) - .set_listener(terminal_progress()) - .set_fail_fast(false) - .add_channels(channels) - .request() - .await; - - // Error out if fetching one of the sources resulted in an error. - let repo_data = repo_data_per_source - .into_iter() - .map(|(channel, _, result)| result.map(|data| (channel, data))) - .collect::, _>>()?; - - let solver_problem = SolverProblem { - channels: repo_data - .iter() - .map(|(channel, repodata)| (channel.base_url().to_string(), repodata)) - .collect(), - specs, - }; - - let result = solver_problem.solve()?; - println!("{:#?}", result); + // // Download all repo data from the channels and create an index + // let repo_data_per_source = MultiRequestRepoDataBuilder::default() + // .set_cache_dir(&cache_dir) + // .set_listener(terminal_progress()) + // .set_fail_fast(false) + // .add_channels(channels) + // .request() + // .await; + // + // // Error out if fetching one of the sources resulted in an error. + // let repo_data = repo_data_per_source + // .into_iter() + // .map(|(channel, _, result)| result.map(|data| (channel, data))) + // .collect::, _>>()?; + // + // let solver_problem = SolverProblem { + // channels: repo_data + // .iter() + // .map(|(channel, repodata)| (channel.base_url().to_string(), repodata)) + // .collect(), + // specs, + // }; + // + // let result = solver_problem.solve()?; + // println!("{:#?}", result); Ok(()) } diff --git a/crates/rattler/Cargo.toml b/crates/rattler/Cargo.toml index d64b3140b..c89d097bc 100644 --- a/crates/rattler/Cargo.toml +++ b/crates/rattler/Cargo.toml @@ -13,15 +13,14 @@ rustls-tls = ['reqwest/rustls-tls'] [dependencies] anyhow = "1.0.44" apple-codesign = "0.22.0" -async-compression = { version = "0.3.12", features = ["gzip", "futures-bufread", "tokio", "bzip2"] } +async-compression = { version = "0.3.12", features = ["gzip", "tokio", "bzip2", "zstd"] } bytes = "1.1.0" +chrono = { version = "0.4.23", default-features = false, features = ["std", "serde", "alloc"] } digest = "0.10.6" dirs = "4.0.0" -extendhash = "1.0.9" futures = "0.3.17" fxhash = "0.2.1" hex = "0.4.3" -indicatif = { version = "0.17.1", features = ["improved_unicode"] } itertools = "0.10.3" libc = "0.2" libz-sys = { version = "1.1.0", default-features = false, features = ["static"] } @@ -31,6 +30,7 @@ nom = "7.1.0" once_cell = "1.8.0" pin-project-lite = "0.2.9" rattler_conda_types = { version = "0.1.0", path = "../rattler_conda_types" } +rattler_digest = { version = "0.1.0", path = "../rattler_digest" } rattler_package_streaming = { version = "0.1.0", path = "../rattler_package_streaming", features = ["reqwest", "tokio"] } regex = "1.5.4" reqwest = { version = "0.11.6", default-features = false, features = ["stream", "json", "gzip"] } @@ -50,14 +50,9 @@ uuid = { version = "1.3.0", features = ["v4", "fast-rng"] } [dev-dependencies] assert_matches = "1.5.0" -axum = "0.6.2" -insta = { version = "1.16.0", features = ["yaml"] } -proptest = "1.0.0" rand = "0.8.4" rstest = "0.16.0" -tokio-test = "0.4.2" -tower-http = { version = "0.3.5", features = ["fs", "compression-gzip"] } -tracing-test = "0.2.4" +tracing-test = { version = "0.2.4" } [build-dependencies] cc = "1" diff --git a/crates/rattler/src/install/link.rs b/crates/rattler/src/install/link.rs index e2703dbd0..85623b8c4 100644 --- a/crates/rattler/src/install/link.rs +++ b/crates/rattler/src/install/link.rs @@ -1,7 +1,7 @@ -use crate::utils::{parse_sha256_from_hex, Sha256HashingWriter}; use apple_codesign::{SigningSettings, UnifiedSigner}; use rattler_conda_types::package::{FileMode, PathType, PathsEntry}; use rattler_conda_types::Platform; +use rattler_digest::{parse_digest_from_hex, HashingWriter}; use std::fs::Permissions; use std::io::Write; use std::path::Path; @@ -81,7 +81,7 @@ pub fn link_file( // Open the destination file let destination = std::fs::File::create(&destination_path) .map_err(LinkFileError::FailedToOpenDestinationFile)?; - let mut destination_writer = Sha256HashingWriter::new(destination); + let mut destination_writer = HashingWriter::<_, sha2::Sha256>::new(destination); // Replace the prefix placeholder in the file with the new placeholder copy_and_replace_placholders( @@ -111,7 +111,7 @@ pub fn link_file( let original_hash = path_json_entry .sha256 .as_deref() - .and_then(parse_sha256_from_hex); + .and_then(parse_digest_from_hex::); let content_changed = original_hash != Some(current_hash); // If the binary changed it requires resigning. diff --git a/crates/rattler/src/lib.rs b/crates/rattler/src/lib.rs index aab63b414..f8e0baff7 100644 --- a/crates/rattler/src/lib.rs +++ b/crates/rattler/src/lib.rs @@ -9,14 +9,13 @@ //! interfacing with many other languages (WASM, Javascript, Python, C, etc) and is therefor a good //! candidate for a reimplementation. +use std::path::PathBuf; + pub mod install; pub mod package_cache; -pub mod repo_data; pub mod solver; pub mod validation; -pub(crate) mod utils; - /// A helper function that returns a [`Channel`] instance that points to an empty channel on disk /// that is bundled with this repository. #[cfg(any(doctest, test))] @@ -30,10 +29,14 @@ pub fn empty_channel() -> rattler_conda_types::Channel { .unwrap() } -#[cfg(test)] -use std::path::{Path, PathBuf}; - #[cfg(test)] pub(crate) fn get_test_data_dir() -> PathBuf { - Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data") + std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data") +} + +/// Returns the default cache directory used by rattler. +pub fn default_cache_dir() -> anyhow::Result { + Ok(dirs::cache_dir() + .ok_or_else(|| anyhow::anyhow!("could not determine cache directory for current platform"))? + .join("rattler/cache")) } diff --git a/crates/rattler/src/repo_data/fetch/mod.rs b/crates/rattler/src/repo_data/fetch/mod.rs deleted file mode 100644 index 53f688289..000000000 --- a/crates/rattler/src/repo_data/fetch/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -//! The modules defines functionality to download channel [`rattler_conda_types::RepoData`] from -//! several different type of sources, cache the results, do this for several sources in parallel, -//! and provide adequate progress information to a user. - -mod multi_request; -mod progress; -mod request; - -pub use multi_request::{MultiRequestRepoDataBuilder, MultiRequestRepoDataListener}; -pub use progress::terminal_progress; -pub use request::{ - DoneState, DownloadingState, RepoDataRequestState, RequestRepoDataBuilder, - RequestRepoDataError, RequestRepoDataListener, -}; diff --git a/crates/rattler/src/repo_data/fetch/multi_request.rs b/crates/rattler/src/repo_data/fetch/multi_request.rs deleted file mode 100644 index ee7073368..000000000 --- a/crates/rattler/src/repo_data/fetch/multi_request.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! Defines the [`MultiRequestRepoDataBuilder`] struct. This struct enables async fetching channel -//! repodata from multiple source in parallel. - -use crate::repo_data::fetch::request::{ - RepoDataRequestState, RequestRepoDataBuilder, RequestRepoDataError, RequestRepoDataListener, -}; -use crate::utils::default_cache_dir; -use futures::{stream::FuturesUnordered, StreamExt}; -use rattler_conda_types::{Channel, Platform, RepoData}; -use std::path::PathBuf; - -/// The `MultiRequestRepoDataBuilder` handles fetching data from multiple conda channels and -/// for multiple platforms. Internally it dispatches all requests to [`RequestRepoDataBuilder`]s -/// which ensure that only the latest changes are fetched. -/// -/// A `MultiRequestRepoDataBuilder` also provides very explicit user feedback through the -/// [`MultiRequestRepoDataBuilder::set_listener`] method. An example of its usage can be found in -/// the [`super::terminal_progress`] which disables multiple CLI progress bars while the requests -/// are being performed. -/// -/// ```rust,no_run -/// # use std::path::PathBuf; -/// # use rattler::{repo_data::fetch::MultiRequestRepoDataBuilder}; -/// # use rattler_conda_types::{Channel, Platform, ChannelConfig}; -/// # tokio_test::block_on(async { -/// let _repo_data = MultiRequestRepoDataBuilder::default() -/// .add_channel(Channel::from_str("conda-forge", &ChannelConfig::default()).unwrap()) -/// .request() -/// .await; -/// # }) -/// ``` -pub struct MultiRequestRepoDataBuilder { - /// All the source to fetch - sources: Vec<(Channel, Platform)>, - - /// The directory to store the cache - cache_dir: Option, - - /// An optional [`reqwest::Client`] that is used to perform the request. When performing - /// multiple requests its useful to reuse a single client. - http_client: Option, - - /// True to fail as soon as one of the queries fails. If this is set to false the other queries - /// continue. Defaults to `true`. - fail_fast: bool, - - /// An optional listener - listener: Option, -} - -impl Default for MultiRequestRepoDataBuilder { - fn default() -> Self { - Self { - sources: vec![], - cache_dir: None, - http_client: None, - fail_fast: true, - listener: None, - } - } -} - -/// A listener function that is called for a request source ([`Channel`] and [`Platform`]) when a -/// state change of the request occurred. -pub type MultiRequestRepoDataListener = - Box; - -impl MultiRequestRepoDataBuilder { - /// Adds the specific platform of the given channel to the list of sources to fetch. - pub fn add_channel_and_platform(mut self, channel: Channel, platform: Platform) -> Self { - self.sources.push((channel, platform)); - self - } - - /// Adds the specified channel to the list of source to fetch. The platforms specified in the - /// channel or the default platforms are added as defined by [`Channel::platforms_or_default`]. - pub fn add_channel(mut self, channel: Channel) -> Self { - for platform in channel.platforms_or_default() { - self.sources.push((channel.clone(), *platform)); - } - self - } - - /// Adds multiple channels to the request builder. For each channel the platforms or the default - /// set of platforms are added (see: [`Channel::platforms_or_default`]). - pub fn add_channels(mut self, channels: impl IntoIterator) -> Self { - for channel in channels.into_iter() { - for platform in channel.platforms_or_default() { - self.sources.push((channel.clone(), *platform)); - } - } - self - } - - /// Sets a default cache directory that will be used for caching requests. - pub fn set_default_cache_dir(self) -> anyhow::Result { - let cache_dir = default_cache_dir()?; - std::fs::create_dir_all(&cache_dir)?; - Ok(self.set_cache_dir(cache_dir)) - } - - /// Sets the directory that will be used for caching requests. - pub fn set_cache_dir(mut self, cache_dir: impl Into) -> Self { - self.cache_dir = Some(cache_dir.into()); - self - } - - /// Sets the [`reqwest::Client`] that is used to perform HTTP requests. If this is not called - /// a new client is created for this entire instance. The created client is shared for all - /// requests. When performing multiple requests its more efficient to reuse a single client - /// across multiple requests. - pub fn set_http_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self - } - - /// Adds a state listener to the builder. This is invoked every time the state of the request - /// changes. See the [`RepoDataRequestState`] for more information. - pub fn set_listener(mut self, listener: MultiRequestRepoDataListener) -> Self { - self.listener = Some(listener); - self - } - - /// Sets a boolean indicating whether or not to stop processing the rest of the queries if one - /// of them fails. By default this is `true`. - pub fn set_fail_fast(mut self, fail_fast: bool) -> Self { - self.fail_fast = fail_fast; - self - } - - /// Asynchronously fetches repodata information from the sources added to this instance. A - /// vector is returned that contains the state of each source. The returned state is returned in - /// the same order as they were added. - pub async fn request( - mut self, - ) -> Vec<(Channel, Platform, Result)> { - // Construct an http client for the requests if none has been specified by the user. - let http_client = self.http_client.unwrap_or_else(reqwest::Client::new); - - // Channel that will receive events from the different sources. Each source spawns a new - // future, this channel ensures that all events arrive in the same place so we can handle - // them. - let (state_sender, mut state_receiver) = tokio::sync::mpsc::unbounded_channel(); - - // Construct a query for every source - let mut futures = FuturesUnordered::new(); - let mut results = Vec::with_capacity(self.sources.len()); - for (idx, (channel, platform)) in self.sources.into_iter().enumerate() { - // Create a result for this source that is initially cancelled. If we return from this - // function before the a result is computed this is the correct response. - results.push(( - channel.clone(), - platform, - Err(RequestRepoDataError::Cancelled), - )); - - // If there is a listener active for this instance, construct a listener for this - // specific source request that funnels all state changes to a channel. - let listener = if self.listener.is_some() { - // Construct a closure that captures the index of the current source. State changes - // for the current source request are added to an unbounded channel which is - // processed on the main task. - let sender = state_sender.clone(); - let mut request_listener: RequestRepoDataListener = - Box::new(move |request_state| { - // Silently ignore send errors. It probably means the receiving end was - // dropped, which is perfectly fine. - let _ = sender.send((idx, request_state)); - }); - - // Notify the listener immediately about a pending state. This is done on the main - // task to ensure that the listener is notified about all the sources in the correct - // order. Since the source requests are spawned they may run on a background thread - // where potentially the order of the source is lost. Firing an initial state change - // here ensures that the listener is notified of all the sources in the same order - // they were added to this instance. - request_listener(RepoDataRequestState::Pending); - - Some(request_listener) - } else { - None - }; - - // Construct a `RequestRepoDataBuilder` for this source that will perform the actual - // request. - let source_request = RequestRepoDataBuilder { - channel, - platform, - cache_dir: self.cache_dir.clone(), - http_client: Some(http_client.clone()), - listener, - }; - - // Spawn a future that will await the request. This future is "spawned" which means - // it is executed on a different thread. The JoinHandle is pushed to the `futures` - // collection which allows us the asynchronously wait for all results in parallel. - let request_future = tokio::spawn(async move { (idx, source_request.request().await) }); - futures.push(request_future); - } - - // Drop the event sender, this will ensure that only RequestRepoDataBuilder listeners could - // have a sender. Once all requests have finished they will drop their sender handle, which - // will eventually close all senders and therefor the receiver. If this wouldn't be the case - // the select below would wait indefinitely until it received an event. - drop(state_sender); - - // Loop over two streams until they both complete. The `select!` macro selects the first - // future that becomes ready from the two sources. - // - // 1. The `state_receiver` is a channel that contains `RepoDataRequestState`s from each - // source as it executes. This only contains data if this instance has a listener. - // 2. The `futures` is an `UnorderedFutures` collection that yields results from individual - // source requests as they become available. - loop { - tokio::select! { - Some((idx, state_change)) = state_receiver.recv() => { - let listener = self - .listener - .as_mut() - .expect("there must be a listener at this point"); - let channel = results[idx].0.clone(); - let platform = results[idx].1; - listener(channel, platform, state_change); - }, - Some(result) = futures.next() => match result { - Ok((idx, result)) => { - // Store the result in the results container. This overwrites the value that - // is currently already there. The initial value is a Cancelled result. - results[idx].2 = result; - - // If the result contains an error and we want to fail fast, break right - // away, this will drop the rest of the futures, cancelling them. - if results[idx].2.is_err() && self.fail_fast { - break; - } - }, - Err(err) => { - // If a panic occurred in the source request we want to propagate it here. - if let Ok(reason) = err.try_into_panic() { - std::panic::resume_unwind(reason); - } - break; - } - }, - else => break, - } - } - - results - } -} diff --git a/crates/rattler/src/repo_data/fetch/progress.rs b/crates/rattler/src/repo_data/fetch/progress.rs deleted file mode 100644 index 0e4c5862c..000000000 --- a/crates/rattler/src/repo_data/fetch/progress.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Defines some useful [`MultiRequestRepoDataListener`]s. - -use std::{collections::HashMap, time::Duration}; - -use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle}; - -use super::{DoneState, DownloadingState, MultiRequestRepoDataListener, RepoDataRequestState}; -use rattler_conda_types::{Channel, Platform}; - -/// Returns a listener to use with the [`super::MultiRequestRepoDataBuilder`] that will show the -/// progress as several progress bars. -/// -/// ```rust,no_run -/// # use rattler::{repo_data::fetch::{ terminal_progress, MultiRequestRepoDataBuilder}}; -/// # use rattler_conda_types::{Channel, ChannelConfig}; -/// # tokio_test::block_on(async { -/// let _ = MultiRequestRepoDataBuilder::default() -/// .add_channel(Channel::from_str("conda-forge", &ChannelConfig::default()).unwrap()) -/// .set_listener(terminal_progress()) -/// .request() -/// .await; -/// # }); -/// ``` -pub fn terminal_progress() -> MultiRequestRepoDataListener { - let multi_progress = MultiProgress::with_draw_target(ProgressDrawTarget::stderr_with_hz(10)); - let mut progress_bars = HashMap::<(Channel, Platform), ProgressBar>::new(); - - // Construct a closure that captures the above variables. This closure will be called multiple - // times during the lifetime of the request to notify of any state changes. The code below - // will update a progressbar to reflect the state changes. - Box::new(move |channel, platform, state| { - // Find the progress bar that is associates with the given channel and platform. Or if no - // such progress bar exists yet, create it. - let progress_bar = - progress_bars - .entry((channel, platform)) - .or_insert_with_key(|(channel, platform)| { - let progress_bar = multi_progress.add( - ProgressBar::new(1) - .with_finish(ProgressFinish::AndLeave) - .with_prefix(format!( - "{}/{}", - channel - .name - .as_ref() - .map(String::from) - .unwrap_or_else(|| channel.canonical_name()), - platform - )) - .with_style(default_progress_style()), - ); - progress_bar.enable_steady_tick(Duration::from_millis(100)); - progress_bar - }); - - match state { - RepoDataRequestState::Pending => {} - RepoDataRequestState::Downloading(DownloadingState { bytes, total }) => { - progress_bar.set_length(total.unwrap_or(bytes) as u64); - progress_bar.set_position(bytes as u64); - } - RepoDataRequestState::Deserializing => { - progress_bar.set_style(deserializing_progress_style()); - progress_bar.set_message("Deserializing..") - } - RepoDataRequestState::Done(DoneState { - cache_miss: changed, - }) => { - progress_bar.set_style(finished_progress_style()); - if changed { - progress_bar.set_message("Done!"); - } else { - progress_bar.set_message("No changes!"); - } - progress_bar.finish() - } - RepoDataRequestState::Error(_) => { - progress_bar.set_style(errored_progress_style()); - progress_bar.finish_with_message("Error"); - } - } - }) -} - -/// Returns the style to use for a progressbar that is currently in progress. -fn default_progress_style() -> ProgressStyle { - ProgressStyle::default_bar() - .template("{spinner:.green} {prefix:20!} [{elapsed_precise}] [{bar:.bright.yellow/dim.white}] {bytes:>8} @ {bytes_per_sec:8}").unwrap() - .progress_chars("━━╾─") -} - -/// Returns the style to use for a progressbar that is in Deserializing state. -fn deserializing_progress_style() -> ProgressStyle { - ProgressStyle::default_bar() - .template("{spinner:.green} {prefix:20!} [{elapsed_precise}] [{bar:.bright.green/dim.white}] {wide_msg}").unwrap() - .progress_chars("━━╾─") -} - -/// Returns the style to use for a progressbar that is finished. -fn finished_progress_style() -> ProgressStyle { - ProgressStyle::default_bar() - .template(" {prefix:20!} [{elapsed_precise}] {msg:.bold}") - .unwrap() - .progress_chars("━━╾─") -} - -/// Returns the style to use for a progressbar that is in error state. -fn errored_progress_style() -> ProgressStyle { - ProgressStyle::default_bar() - .template(" {prefix:20!} [{elapsed_precise}] {msg:.bold.red}") - .unwrap() - .progress_chars("━━╾─") -} diff --git a/crates/rattler/src/repo_data/fetch/request/file.rs b/crates/rattler/src/repo_data/fetch/request/file.rs deleted file mode 100644 index 553c4c771..000000000 --- a/crates/rattler/src/repo_data/fetch/request/file.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! Defines the [`fetch_repodata`] function which reads repodata information from disk. - -use crate::repo_data::fetch::{DoneState, RepoDataRequestState, RequestRepoDataError}; -use rattler_conda_types::RepoData; -use std::{ - fs::OpenOptions, - io::{BufReader, Read}, - path::Path, -}; - -/// Read [`RepoData`] from disk. No caching is performed since the data already resides on disk -/// anyway. -/// -/// The `listener` parameter allows following the progress of the request through its various -/// stages. See [`RepoDataRequestState`] for the various stages a request can go through. As reading -/// repodata can take several seconds the `listener` can be used to show some visual feedback to the -/// user. -pub async fn fetch_repodata( - path: &Path, - listener: &mut impl FnMut(RepoDataRequestState), -) -> Result<(RepoData, DoneState), RequestRepoDataError> { - // Read the entire file to memory. This does probably cost a lot more memory, but - // deserialization is much (~10x) faster. Since this might take some time, we run this in a - // in a separate background task to ensure we don't unnecessarily block the current thread. - let path = path.to_owned(); - let bytes = tokio::task::spawn_blocking(move || -> Result, RequestRepoDataError> { - let file = OpenOptions::new().read(true).write(true).open(&path)?; - let mut bytes = Vec::new(); - BufReader::new(file).read_to_end(&mut bytes)?; - Ok(bytes) - }) - .await??; - - // Now that we have all the data in memory we can deserialize the content using `serde`. Since - // repodata information can be quite huge we run the deserialization in a separate background - // task to ensure we don't block the current thread. - listener(RepoDataRequestState::Deserializing); - let repodata = tokio::task::spawn_blocking(move || serde_json::from_slice(&bytes)).await??; - - // No cache is used, so there was definitely a cache miss. - Ok((repodata, DoneState { cache_miss: true })) -} - -#[cfg(test)] -mod test { - use super::fetch_repodata; - use std::path::PathBuf; - - #[tokio::test] - async fn test_fetch_file() { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let subdir_path = manifest_dir.join("resources/channels/empty/noarch/repodata.json"); - let _ = fetch_repodata(&subdir_path, &mut |_| {}).await.unwrap(); - } -} diff --git a/crates/rattler/src/repo_data/fetch/request/http.rs b/crates/rattler/src/repo_data/fetch/request/http.rs deleted file mode 100644 index 85882a6c7..000000000 --- a/crates/rattler/src/repo_data/fetch/request/http.rs +++ /dev/null @@ -1,437 +0,0 @@ -//! Defines the [`fetch_repodata`] function which downloads and caches repodata requests over http. - -use std::{ - fs::File, - io::{self, BufReader, BufWriter, ErrorKind, Read, Write}, - path::Path, -}; - -use bytes::Bytes; -use futures::{Stream, TryFutureExt, TryStreamExt}; -use reqwest::header::{ - HeaderMap, HeaderValue, ETAG, IF_MODIFIED_SINCE, IF_NONE_MATCH, LAST_MODIFIED, -}; -use reqwest::StatusCode; -use serde_with::{serde_as, DisplayFromStr}; -use tempfile::NamedTempFile; -use tokio::io::AsyncReadExt; -use tokio_util::io::StreamReader; -use url::Url; - -use crate::{ - repo_data::fetch::{DoneState, DownloadingState, RepoDataRequestState, RequestRepoDataError}, - utils::{url_to_cache_filename, AsyncEncoding, Encoding}, -}; -use rattler_conda_types::RepoData; - -/// Information stored along the repodata json that defines some caching properties. -#[serde_as] -#[derive(serde::Serialize, serde::Deserialize, Debug, Eq, PartialEq, Clone)] -struct RepoDataMetadata { - #[serde(rename = "_url")] - #[serde_as(as = "DisplayFromStr")] - url: Url, - - #[serde(rename = "_etag")] - #[serde(skip_serializing_if = "Option::is_none")] - etag: Option, - - #[serde(rename = "_last_modified")] - #[serde(skip_serializing_if = "Option::is_none")] - last_modified: Option, -} - -/// Downloads the repodata from the specified Url. The Url must point to a "repodata.json" file. -/// -/// Requests can be cached by specifying a `cache_dir`. If the cache_dir is specified it will be -/// searched for a valid cache entry. If there is a cache hit, information from it will be send to -/// the remote. Only when there is new information on the server the repodata is downloaded, -/// otherwise it is fetched from the local cache. If no `cache_dir` is specified the repodata is -/// always completely downloaded from the server. -/// -/// The `listener` parameter allows following the progress of the request through its various -/// stages. See [`RepoDataRequestState`] for the various stages a request can go through. As a -/// downloading repodata can take several seconds the `listener` can be used to show some visual -/// feedback to the user. -pub async fn fetch_repodata( - url: Url, - client: reqwest::Client, - cache_dir: Option<&Path>, - listener: &mut impl FnMut(RepoDataRequestState), -) -> Result<(RepoData, DoneState), RequestRepoDataError> { - // If a cache directory has been set for this this request try looking up a cached entry and - // read the metadata from it. If any error occurs during the loading of the cache we simply - // ignore it and continue without a cache. - let (metadata, cache_data) = if let Some(cache_dir) = cache_dir { - let cache_path = cache_dir - .join(url_to_cache_filename(&url)) - .with_extension("json"); - match read_cache_file(&cache_path) { - Ok((metadata, cache_data)) => (Some(metadata), Some(cache_data)), - _ => (None, None), - } - } else { - (None, None) - }; - - let mut headers = HeaderMap::default(); - - // We can handle g-zip encoding which is often used. We could also set this option on the - // client, but that will disable all download progress messages by `reqwest` because the - // gzipped data is decoded on the fly and the size of the decompressed body is unknown. - // However, we don't really care about the decompressed size but rather we'd like to know - // the number of raw bytes that are actually downloaded. - // - // To do this we manually set the request header to accept gzip encoding and we use the - // [`AsyncEncoding`] trait to perform the decoding on the fly. - headers.insert( - reqwest::header::ACCEPT_ENCODING, - HeaderValue::from_static("gzip"), - ); - - // Add headers that provide our caching behavior. We record the ETag that was previously send by - // the server as well as the last-modified header. - if let Some(metadata) = metadata { - if metadata.url == url { - if let Some(etag) = metadata - .etag - .and_then(|etag| HeaderValue::from_str(&etag).ok()) - { - headers.insert(IF_NONE_MATCH, etag); - } - if let Some(last_modified) = metadata - .last_modified - .and_then(|etag| HeaderValue::from_str(&etag).ok()) - { - headers.insert(IF_MODIFIED_SINCE, last_modified); - } - } - } - - // Construct a request to the server and dispatch it. - let response = client - .get(url.clone()) - .headers(headers) - .send() - .await? - .error_for_status()?; - - // If the server replied with a NOT_MODIFIED status it means that the ETag or the last modified - // date we send along actually matches whats already on the server or the contents didnt change - // since the last time we queried the data. This means we can use the cached data. - if response.status() == StatusCode::NOT_MODIFIED { - // Now that we have all the data in memory we can deserialize the content using `serde`. - // Since repodata information can be quite huge we run the deserialization in a separate - // background task to ensure we don't block the current thread. - listener(RepoDataRequestState::Deserializing); - let repodata = tokio::task::spawn_blocking(move || { - serde_json::from_slice(cache_data.unwrap().as_slice()) - }) - .await??; - return Ok((repodata, DoneState { cache_miss: false })); - } - - // Determine the length of the response in bytes and notify the listener that a download is - // starting. The response may be compressed. Decompression happens below. - let content_size = response.content_length().map(|len| len as usize); - listener( - DownloadingState { - bytes: 0, - total: content_size, - } - .into(), - ); - - // Get the ETag from the response (if any). This can be used to cache the result during a next - // request. - let etag = response - .headers() - .get(ETAG) - .and_then(|header| header.to_str().ok()) - .map(ToOwned::to_owned); - - // Get the last modified time. This can also be used to cache the result during a next request. - let last_modified = response - .headers() - .get(LAST_MODIFIED) - .and_then(|header| header.to_str().ok()) - .map(ToOwned::to_owned); - - // Get the request as a stream of bytes. Download progress is added through the - // [`add_download_progress_listener`] function, and decompression happens through the - // [`AsyncEncoding::decode`] function. - let encoding = Encoding::from(&response); - let bytes_stream = - add_download_progress_listener(response.bytes_stream(), listener, content_size); - let mut decoded_byte_stream = - StreamReader::new(bytes_stream.map_err(|e| io::Error::new(ErrorKind::Other, e))) - .decode(encoding); - - // The above code didn't actually perform any downloading. This code allocates memory to read - // the downloaded information to. The [`AsyncReadExt::read_to_end`] function than actually - // downloads all the bytes. The bytes are decompressed on the fly. - // - // By now, we know that the data we read from cache is out of date, so we can reuse the memory - // allocated for it, although we do clear it out first. If we dont have any pre-allocated cache - // data, we allocate a new block of memory. - // - // We don't know what the decompressed size of the bytes will be but a good guess is simply the - // size of the response body. If we don't know the size of the body we start with 1MB. - let mut data = cache_data - .map(|mut data| { - data.clear(); - data - }) - .unwrap_or_else(|| Vec::with_capacity(content_size.unwrap_or(1_073_741_824) as usize)); - decoded_byte_stream.read_to_end(&mut data).await?; - let bytes = Bytes::from(data); - - // Explicitly drop the byte stream, this is required to ensure that we can safely use the - // mutable listener that was captured by the download progress. - drop(decoded_byte_stream); - - // If there is a cache directory write to the cache - let caching_future = cache_repodata_response( - cache_dir, - RepoDataMetadata { - url, - etag, - last_modified, - }, - bytes.clone(), - ); - - // Now that we have all the data in memory we can deserialize the content using `serde`. Since - // repodata information can be quite huge we run the deserialization in a separate background - // task to ensure we don't block the current thread. - listener(RepoDataRequestState::Deserializing); - let deserializing_future = tokio::task::spawn_blocking(move || serde_json::from_slice(&bytes)) - .map_err(RequestRepoDataError::from) - .and_then(|serde_result| async { serde_result.map_err(RequestRepoDataError::from) }); - - // Await the result of caching and deserializing. This either returns immediately if any error - // occurs or until both futures complete successfully. - let (_, repodata) = tokio::try_join!(caching_future, deserializing_future)?; - - // If we get here, we have successfully downloaded (and potentially cached) the complete - // repodata from the server. - Ok((repodata, DoneState { cache_miss: true })) -} - -/// Called to asynchronously cache the response from a HTTP request to the specified cache -/// directory. If the cache directory is `None` nothing happens. -async fn cache_repodata_response( - cache_dir: Option<&Path>, - metadata: RepoDataMetadata, - bytes: Bytes, -) -> Result<(), RequestRepoDataError> { - // Early out if the cache directory is empty - let cache_dir = if let Some(cache_dir) = cache_dir { - cache_dir.to_owned() - } else { - return Ok(()); - }; - - // File system operations can be blocking which is why we do it on a separate thread through the - // call to `spawn_blocking`. This ensures that any blocking operations are not run on the main - // task. - tokio::task::spawn_blocking(move || { - std::fs::create_dir_all(&cache_dir)?; - let cache_path = cache_dir - .join(url_to_cache_filename(&metadata.url)) - .with_extension("json"); - let cache_file = create_cache_file(metadata, &bytes)?; - cache_file.persist(&cache_path)?; - Ok(()) - }) - .await? -} - -/// Writes the bytes encoded as JSON object to a file together with the specified metadata. -/// -/// This function concatenates the metadata json and the `raw_bytes` together by removing the -/// trailing `}` of the metadata json, adding a `,` and removing the preceding `{` from the raw -/// bytes. If any of these characters cannot be located in the data the function panics. -/// -/// On success a [`NamedTempFile`] is returned which contains the resulting json. Its up to the -/// caller to either persist this file. See [`NamedTempFile::persist`] for more information. -fn create_cache_file(metadata: RepoDataMetadata, raw_bytes: &[u8]) -> io::Result { - // Convert the metadata to json - let metadata_json = - serde_json::to_string(&metadata).expect("converting metadata to json shouldn't fail"); - - // Open the cache file. - let mut temp_file = NamedTempFile::new()?; - let mut writer = BufWriter::new(temp_file.as_file_mut()); - - // Strip the trailing closing '}' so we can append the rest of the json. - let stripped_metadata_json = metadata_json - .strip_suffix('}') - .expect("expected metadata to end with a '}'"); - - // Strip the preceding opening '{' from the raw data. - let stripped_raw_bytes = raw_bytes - .strip_prefix(b"{") - .expect("expected the repodata to be preceded by an opening '{'"); - - // Write the contents of the metadata, followed by the contents of the raw bytes. - writer.write_all(stripped_metadata_json.as_bytes())?; - writer.write_all(",".as_bytes())?; - writer.write_all(stripped_raw_bytes)?; - - // Drop the writer so we can return the temp file - drop(writer); - - Ok(temp_file) -} - -/// Reads a cache file and return the contents of it as well as the metadata read from the bytes. -/// -/// A repodata cache file contains the original json read from the remote as well as extra -/// information called metadata (see [`RepoDatametadata`]) which is injected after the data is -/// received from the remote. The metadata can be used to determine if the data stored in the cache -/// is actually current and doesnt need to be updated. -fn read_cache_file(cache_path: &Path) -> anyhow::Result<(RepoDataMetadata, Vec)> { - // Read the contents of the entire cache file to memory - let mut reader = BufReader::new(File::open(cache_path)?); - let mut cache_data = Vec::new(); - reader.read_to_end(&mut cache_data)?; - - // Parse the metadata from the data - let metadata: RepoDataMetadata = serde_json::from_slice(&cache_data)?; - - Ok((metadata, cache_data)) -} - -/// Modifies the input stream to emit download information to the specified listener on the fly. -fn add_download_progress_listener<'s, E>( - stream: impl Stream> + 's, - listener: &'s mut impl FnMut(RepoDataRequestState), - content_length: Option, -) -> impl Stream> + 's { - let mut bytes_downloaded = 0; - stream.inspect_ok(move |bytes| { - bytes_downloaded += bytes.len(); - listener( - DownloadingState { - bytes: bytes_downloaded, - total: content_length, - } - .into(), - ); - }) -} - -#[cfg(test)] -mod test { - use std::fs::File; - use std::io::BufReader; - use std::path::PathBuf; - use std::str::FromStr; - use tempfile::TempDir; - use url::Url; - - use super::{create_cache_file, fetch_repodata, read_cache_file, RepoDataMetadata}; - use crate::repo_data::fetch::request::REPODATA_CHANNEL_PATH; - use crate::utils::simple_channel_server::SimpleChannelServer; - use rattler_conda_types::{Channel, ChannelConfig, Platform}; - - #[tokio::test] - async fn test_fetch_http() { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let channel_path = manifest_dir.join("resources/channels/empty"); - - let server = SimpleChannelServer::new(channel_path); - let url = server.url().to_string(); - let channel = Channel::from_str(url, &ChannelConfig::default()).unwrap(); - - let _result = fetch_repodata( - channel - .platform_url(Platform::NoArch) - .join(REPODATA_CHANNEL_PATH) - .unwrap(), - reqwest::Client::default(), - None, - &mut |_| {}, - ) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_http_fetch_cache() { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let channel_path = manifest_dir.join("resources/channels/empty"); - - let server = SimpleChannelServer::new(channel_path); - let url = server.url().to_string(); - let channel = Channel::from_str(url, &ChannelConfig::default()).unwrap(); - - // Create a temporary directory to store the cache in - let cache_dir = TempDir::new().unwrap(); - - // Fetch the repodata from the server - let (repodata, done_state) = fetch_repodata( - channel - .platform_url(Platform::NoArch) - .join(REPODATA_CHANNEL_PATH) - .unwrap(), - reqwest::Client::default(), - Some(cache_dir.path()), - &mut |_| {}, - ) - .await - .unwrap(); - assert!(done_state.cache_miss); - - // Fetch the repodata again, and check that the result has been cached - let (repodata_with_cache, cached_done_state) = fetch_repodata( - channel - .platform_url(Platform::NoArch) - .join(REPODATA_CHANNEL_PATH) - .unwrap(), - reqwest::Client::default(), - Some(cache_dir.path()), - &mut |_| {}, - ) - .await - .unwrap(); - - assert!(!cached_done_state.cache_miss); - assert_eq!(repodata, repodata_with_cache); - } - - #[test] - fn test_cache_in_cache_out() { - #[derive(Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)] - struct Foo { - data: String, - } - - let data = Foo { - data: String::from("Hello, world!"), - }; - let data_bytes = serde_json::to_vec(&data).unwrap(); - - let metadata = RepoDataMetadata { - url: Url::from_str("https://google.com").unwrap(), - etag: Some(String::from("THIS IS NOT REALLY AN ETAG")), - last_modified: Some(String::from("this is a last modified data or something")), - }; - - // Create a cached file - let cache_file = create_cache_file(metadata.clone(), &data_bytes).unwrap(); - - // The cache file still contains valid json - let _: serde_json::Value = - serde_json::from_reader(BufReader::new(File::open(cache_file.path()).unwrap())) - .expect("cache file doesnt contain valid json"); - - // Read the cached file again - let (result_metadata, result_bytes) = read_cache_file(cache_file.path()).unwrap(); - - // See if the data from the cache matches that what we wrote to it. - assert_eq!(data, serde_json::from_slice::(&result_bytes).unwrap()); - assert_eq!(metadata, result_metadata); - } -} diff --git a/crates/rattler/src/repo_data/fetch/request/mod.rs b/crates/rattler/src/repo_data/fetch/request/mod.rs deleted file mode 100644 index b7111a0ad..000000000 --- a/crates/rattler/src/repo_data/fetch/request/mod.rs +++ /dev/null @@ -1,252 +0,0 @@ -//! Defines a builder struct ([`RequestRepoDataBuilder`]) to construct a request to download channel -//! [`RepoData`]. The request allows for all known types of source and provides adequate local -//! caching. -//! -//! The `RequestRepoDataBuilder` only fetches a single repodata source see -//! [`super::MultiRequestRepoDataBuilder`] for the ability to download from multiple sources in -//! parallel. - -mod file; -mod http; - -use crate::utils::default_cache_dir; -use rattler_conda_types::{Channel, Platform, RepoData}; -use std::{io, path::PathBuf}; -use tempfile::PersistError; -use tokio::task::JoinError; - -const REPODATA_CHANNEL_PATH: &str = "repodata.json"; - -/// An error that may occur when trying the fetch repository data. -#[derive(Debug, thiserror::Error)] -pub enum RequestRepoDataError { - #[error("error deserializing repository data: {0}")] - DeserializeError(#[from] serde_json::Error), - - #[error("error downloading data: {0}")] - TransportError(#[from] reqwest::Error), - - #[error("{0}")] - IoError(#[from] io::Error), - - #[error("unsupported scheme'")] - UnsupportedScheme, - - #[error("unable to persist temporary file: {0}")] - PersistError(#[from] PersistError), - - #[error("invalid path")] - InvalidPath, - - #[error("the operation was cancelled")] - Cancelled, -} - -impl From for RequestRepoDataError { - fn from(err: JoinError) -> Self { - match err.try_into_panic() { - Ok(panic) => std::panic::resume_unwind(panic), - Err(_) => RequestRepoDataError::Cancelled, - } - } -} - -/// When a request is processed it goes through several stages, this enum list those stages in -/// order. -#[derive(Debug, Clone)] -pub enum RepoDataRequestState { - /// The initial state - Pending, - - /// The request is downloading from a remote server - Downloading(DownloadingState), - - /// The request is being deserialized - Deserializing, - - /// The request has finished processing - Done(DoneState), - - /// An error has occurred during downloading - Error(String), -} - -/// State information of a request when the information is being downloaded. -#[derive(Debug, Clone)] -pub struct DownloadingState { - /// The number of bytes downloaded - pub bytes: usize, - - /// The total number of bytes to download. `None` if the total size is unknown. This can happen - /// if the server does not supply a `Content-Length` header. - pub total: Option, -} - -impl From for RepoDataRequestState { - fn from(state: DownloadingState) -> Self { - RepoDataRequestState::Downloading(state) - } -} - -/// State information of a request when the request has finished. -#[derive(Debug, Clone)] -pub struct DoneState { - /// True if the data was fetched straight from the source and didn't come a cache. - pub cache_miss: bool, -} - -impl From for RepoDataRequestState { - fn from(state: DoneState) -> Self { - RepoDataRequestState::Done(state) - } -} - -/// The `RequestRepoDataBuilder` struct allows downloading of repodata from various sources and with -/// proper caching. Repodata fetch can become complex due to the size of some of the repodata. -/// Especially downloading only changes required to update a cached version can become quite -/// complex. This struct handles all the intricacies required to efficiently fetch up-to-date -/// repodata. -/// -/// This struct uses a builder pattern which allows a user to setup certain settings other than the -/// default before actually performing the fetch. -/// -/// In its simplest form you simply construct a `RequestRepoDataBuilder` through the -/// [`RequestRepoDataBuilder::new`] function, and asynchronously fetch the repodata with -/// [`RequestRepoDataBuilder::request`]. -/// -/// ```rust,no_run -/// # use std::path::PathBuf; -/// # use rattler::{repo_data::fetch::RequestRepoDataBuilder}; -/// # use rattler_conda_types::{Channel, Platform, ChannelConfig}; -/// # tokio_test::block_on(async { -/// let channel = Channel::from_str("conda-forge", &ChannelConfig::default()).unwrap(); -/// let _repo_data = RequestRepoDataBuilder::new(channel, Platform::NoArch) -/// .request() -/// .await -/// .unwrap(); -/// # }) -/// ``` -/// -/// The `RequestRepoDataBuilder` only fetches a single repodata source, see -/// [`super::MultiRequestRepoDataBuilder`] for the ability to download from multiple sources in -/// parallel. -pub struct RequestRepoDataBuilder { - /// The channel to download from - pub(super) channel: Channel, - - /// The platform within the channel (also sometimes called the subdir) - pub(super) platform: Platform, - - /// The directory to store the cache - pub(super) cache_dir: Option, - - /// An optional [`reqwest::Client`] that is used to perform the request. When performing - /// multiple requests its useful to reuse a single client. - pub(super) http_client: Option, - - /// An optional listener - pub(super) listener: Option, -} - -/// A listener function that is called when a state change of the request occurred. -pub type RequestRepoDataListener = Box; - -impl RequestRepoDataBuilder { - /// Constructs a new builder to request repodata for the given channel and platform. - pub fn new(channel: Channel, platform: Platform) -> Self { - Self { - channel, - platform, - cache_dir: None, - http_client: None, - listener: None, - } - } - - /// Sets a default cache directory that will be used for caching requests. - pub fn set_default_cache_dir(self) -> anyhow::Result { - let cache_dir = default_cache_dir()?; - std::fs::create_dir_all(&cache_dir)?; - Ok(self.set_cache_dir(cache_dir)) - } - - /// Sets the directory that will be used for caching requests. - pub fn set_cache_dir(mut self, cache_dir: impl Into) -> Self { - self.cache_dir = Some(cache_dir.into()); - self - } - - /// Sets the [`reqwest::Client`] that is used to perform HTTP requests. If this is not called - /// a new client is created for each request. When performing multiple requests its more - /// efficient to reuse a single client across multiple requests. - pub fn set_http_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self - } - - /// Adds a state listener to the builder. This is invoked every time the state of the request - /// changes. See the [`RepoDataRequestState`] for more information. - pub fn set_listener(mut self, listener: RequestRepoDataListener) -> Self { - self.listener = Some(listener); - self - } - - /// Consumes self and starts an async request to fetch the repodata. - pub async fn request(self) -> Result { - // Get the url to the subdirectory index. Note that the subdirectory is the platform name. - let platform_url = self - .channel - .platform_url(self.platform) - .join(REPODATA_CHANNEL_PATH) - .expect("repodata.json is a valid json path"); - - // Construct a new listener function that wraps the optional listener. This allows us to - // call the listener from anywhere without having to check if there actually is a listener. - let mut listener = self.listener; - let mut listener = move |state| { - if let Some(listener) = listener.as_deref_mut() { - listener(state) - } - }; - - // Perform the actual request. This is done in an anonymous function to ensure that any - // try's do not propagate straight to the outer function. We catch any errors and notify - // the listener. - let borrowed_listener = &mut listener; - let result = (move || async move { - match platform_url.scheme() { - "https" | "http" => { - // Download the repodata from the subdirectory url - let http_client = self.http_client.unwrap_or_else(reqwest::Client::new); - http::fetch_repodata( - platform_url, - http_client, - self.cache_dir.as_deref(), - borrowed_listener, - ) - .await - } - "file" => { - let path = platform_url - .to_file_path() - .map_err(|_| RequestRepoDataError::InvalidPath)?; - file::fetch_repodata(&path, borrowed_listener).await - } - _ => Err(RequestRepoDataError::UnsupportedScheme), - } - })() - .await; - - // Update the listener accordingly - match result { - Ok((repodata, done_state)) => { - listener(done_state.into()); - Ok(repodata) - } - Err(e) => { - listener(RepoDataRequestState::Error(format!("{}", &e))); - Err(e) - } - } - } -} diff --git a/crates/rattler/src/repo_data/mod.rs b/crates/rattler/src/repo_data/mod.rs deleted file mode 100644 index 0e40308b0..000000000 --- a/crates/rattler/src/repo_data/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod fetch; diff --git a/crates/rattler/src/solver/libsolv/mod.rs b/crates/rattler/src/solver/libsolv/mod.rs index c64c3d5ed..5d0edbba4 100644 --- a/crates/rattler/src/solver/libsolv/mod.rs +++ b/crates/rattler/src/solver/libsolv/mod.rs @@ -34,7 +34,7 @@ mod test { format!( "{}/{}", env!("CARGO_MANIFEST_DIR"), - "resources/channels/conda-forge/linux-64/repodata.json" + "../../test-data/channels/conda-forge/linux-64/repodata.json" ) } @@ -42,7 +42,7 @@ mod test { format!( "{}/{}", env!("CARGO_MANIFEST_DIR"), - "resources/channels/conda-forge/noarch/repodata.json" + "../../test-data/channels/conda-forge/noarch/repodata.json" ) } diff --git a/crates/rattler/src/utils/hash.rs b/crates/rattler/src/utils/hash.rs deleted file mode 100644 index 191900494..000000000 --- a/crates/rattler/src/utils/hash.rs +++ /dev/null @@ -1,89 +0,0 @@ -use digest::{Digest, Output}; -use sha2::Sha256; -use std::fs::File; -use std::io::Write; -use std::path::Path; - -/// Compute the SHA256 hash of the file at the specified location. -pub fn compute_file_sha256(path: &Path) -> Result, std::io::Error> { - // Open the file for reading - let mut file = File::open(path)?; - - // Determine the hash of the file on disk - let mut hasher = Sha256::new(); - std::io::copy(&mut file, &mut hasher)?; - - Ok(hasher.finalize()) -} - -/// Parses a SHA256 hex string to a digest. -pub fn parse_sha256_from_hex(str: &str) -> Option> { - let mut sha256 = >::default(); - match hex::decode_to_slice(str, &mut sha256) { - Ok(_) => Some(sha256), - Err(_) => None, - } -} - -/// A simple object that provides a [`Write`] implementation that also immediately hashes the bytes -/// written to it. -pub struct HashingWriter { - writer: W, - hasher: D, -} - -pub type Sha256HashingWriter = HashingWriter; - -impl HashingWriter { - /// Constructs a new instance from a writer and a new (empty) hasher. - pub fn new(writer: W) -> Self { - Self { - writer, - hasher: Default::default(), - } - } -} - -impl HashingWriter { - pub fn finalize(self) -> (W, Output) { - (self.writer, self.hasher.finalize()) - } -} - -impl Write for HashingWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let bytes = self.writer.write(buf)?; - self.hasher.update(&buf[..bytes]); - Ok(bytes) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.writer.flush() - } -} - -#[cfg(test)] -mod test { - use rstest::rstest; - - #[rstest] - #[case( - "1234567890", - "c775e7b757ede630cd0aa1113bd102661ab38829ca52a6422ab782862f268646" - )] - #[case( - "Hello, world!", - "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3" - )] - fn test_compute_file_sha256(#[case] input: &str, #[case] expected_hash: &str) { - // Write a known value to a temporary file and verify that the compute hash matches what we would - // expect. - - let temp_dir = tempfile::tempdir().unwrap(); - let file_path = temp_dir.path().join("test"); - std::fs::write(&file_path, input).unwrap(); - let hash = super::compute_file_sha256(&file_path).unwrap(); - - assert_eq!(format!("{hash:x}"), expected_hash) - } -} diff --git a/crates/rattler/src/validation.rs b/crates/rattler/src/validation.rs index 85fd768f4..b3d76137d 100644 --- a/crates/rattler/src/validation.rs +++ b/crates/rattler/src/validation.rs @@ -10,8 +10,8 @@ //! `paths.json` file is missing these deprecated files are used instead to reconstruct a //! [`PathsJson`] object. See [`PathsJson::from_deprecated_package_directory`] for more information. -use crate::{utils, utils::parse_sha256_from_hex}; use rattler_conda_types::package::{PackageFile, PathType, PathsEntry, PathsJson}; +use rattler_digest::{compute_file_digest, parse_digest_from_hex}; use std::{ fs::Metadata, io::ErrorKind, @@ -152,10 +152,10 @@ fn validate_package_hard_link_entry( // Check the SHA256 hash of the file if let Some(hash_str) = entry.sha256.as_deref() { // Determine the hash of the file on disk - let hash = utils::compute_file_sha256(&path)?; + let hash = compute_file_digest::(&path)?; // Convert the hash to bytes. - let expected_hash = parse_sha256_from_hex(hash_str).ok_or_else(|| { + let expected_hash = parse_digest_from_hex::(hash_str).ok_or_else(|| { PackageEntryValidationError::HashMismatch(hash_str.to_owned(), format!("{:x}", hash)) })?; diff --git a/crates/rattler_digest/Cargo.toml b/crates/rattler_digest/Cargo.toml new file mode 100644 index 000000000..c6df3f78d --- /dev/null +++ b/crates/rattler_digest/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "rattler_digest" +version = "0.1.0" +edition = "2021" +authors = ["Bas Zalmstra "] +description = "An simple crate used by rattler crates to compute different hashes from different sources" +categories = ["conda"] +homepage = "https://github.com/mamba-org/rattler" +repository = "https://github.com/mamba-org/rattler" +license = "BSD-3-Clause" + +[dependencies] +digest = "0.10.6" +tokio = { version = "1.12.0", features = ["io-util"], optional = true } +hex = "0.4.3" + +[features] +tokio = ["dep:tokio"] + +[dev-dependencies] +sha2 = "0.10.6" +rstest = "0.16.0" +tempfile = "3.3.0" +md-5 = "0.10.5" diff --git a/crates/rattler_digest/src/lib.rs b/crates/rattler_digest/src/lib.rs new file mode 100644 index 000000000..607181790 --- /dev/null +++ b/crates/rattler_digest/src/lib.rs @@ -0,0 +1,144 @@ +#![deny(missing_docs)] + +//! A module that provides utility functions for computing hashes using the +//! [RustCrypto/hashes](https://github.com/RustCrypto/hashes) library. +//! +//! This module provides several functions that wrap around the hashing algorithms provided by the +//! RustCrypto library. These functions allow you to easily compute the hash of a file, or a stream +//! of bytes using a variety of hashing algorithms. +//! +//! By utilizing the [`Digest`] trait, any hashing algorithm that implements that trait can be used +//! with the functions provided in this crate. +//! +//! # Examples +//! +//! ```no_run +//! use rattler_digest::{compute_bytes_digest, compute_file_digest}; +//! use sha2::Sha256; +//! use md5::Md5; +//! +//! // Compute the MD5 hash of a string +//! let md5_result = compute_bytes_digest::("Hello, world!"); +//! println!("MD5 hash: {:x}", md5_result); +//! +//! // Compute the SHA256 hash of a file +//! let sha256_result = compute_file_digest::("somefile.txt").unwrap(); +//! println!("SHA256 hash: {:x}", sha256_result); +//! ``` +//! +//! # Available functions +//! +//! - [`compute_file_digest`]: Computes the hash of a file on disk. +//! - [`parse_digest_from_hex`]: Given a hex representation of a digest, parses it to bytes. +//! - [`HashingWriter`]: An object that wraps a writable object and implements [`Write`] and +//! [`::tokio::io::AsyncWrite`]. It forwards the data to the wrapped object but also computes the hash of the +//! content on the fly. +//! +//! For more information on the hashing algorithms provided by the +//! [RustCrypto/hashes](https://github.com/RustCrypto/hashes) library, see the documentation for +//! that library. + +#[cfg(feature = "tokio")] +mod tokio; + +pub use digest; + +use digest::{Digest, Output}; +use std::{fs::File, io::Write, path::Path}; + +/// Compute a hash of the file at the specified location. +pub fn compute_file_digest( + path: impl AsRef, +) -> Result, std::io::Error> { + // Open the file for reading + let mut file = File::open(path)?; + + // Determine the hash of the file on disk + let mut hasher = D::default(); + std::io::copy(&mut file, &mut hasher)?; + + Ok(hasher.finalize()) +} + +/// Compute a hash of the specified bytes. +pub fn compute_bytes_digest(path: impl AsRef<[u8]>) -> Output { + let mut hasher = D::default(); + hasher.update(path); + hasher.finalize() +} + +/// Parses a hash hex string to a digest. +pub fn parse_digest_from_hex(str: &str) -> Option> { + let mut hash = >::default(); + match hex::decode_to_slice(str, &mut hash) { + Ok(_) => Some(hash), + Err(_) => None, + } +} + +/// A simple object that provides a [`Write`] implementation that also immediately hashes the bytes +/// written to it. Call [`HashingWriter::finalize`] to retrieve both the original `impl Write` +/// object as well as the hash. +/// +/// If the `tokio` feature is enabled this object also implements [`::tokio::io::AsyncWrite`] which +/// allows you to use it in an async context as well. +pub struct HashingWriter { + writer: W, + hasher: D, +} + +impl HashingWriter { + /// Constructs a new instance from a writer and a new (empty) hasher. + pub fn new(writer: W) -> Self { + Self { + writer, + hasher: Default::default(), + } + } +} + +impl HashingWriter { + /// Consumes this instance and returns the original writer and the hash of all bytes written to + /// this instance. + pub fn finalize(self) -> (W, Output) { + (self.writer, self.hasher.finalize()) + } +} + +impl Write for HashingWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let bytes = self.writer.write(buf)?; + self.hasher.update(&buf[..bytes]); + Ok(bytes) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.writer.flush() + } +} + +#[cfg(test)] +mod test { + use rstest::rstest; + + #[rstest] + #[case( + "1234567890", + "c775e7b757ede630cd0aa1113bd102661ab38829ca52a6422ab782862f268646" + )] + #[case( + "Hello, world!", + "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3" + )] + fn test_compute_file_sha256(#[case] input: &str, #[case] expected_hash: &str) { + // Write a known value to a temporary file and verify that the compute hash matches what we would + // expect. + + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test"); + std::fs::write(&file_path, input).unwrap(); + let hash = super::compute_file_digest::(&file_path).unwrap(); + + assert_eq!(format!("{hash:x}"), expected_hash) + } +} diff --git a/crates/rattler_digest/src/tokio.rs b/crates/rattler_digest/src/tokio.rs new file mode 100644 index 000000000..76d0fc82d --- /dev/null +++ b/crates/rattler_digest/src/tokio.rs @@ -0,0 +1,42 @@ +use super::HashingWriter; +use digest::Digest; +use std::{ + io::Error, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::AsyncWrite; + +impl AsyncWrite for HashingWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // pin-project the writer + let (writer, hasher) = unsafe { + let this = self.get_unchecked_mut(); + (Pin::new_unchecked(&mut this.writer), &mut this.hasher) + }; + + match writer.poll_write(cx, buf) { + Poll::Ready(Ok(bytes)) => { + hasher.update(&buf[..bytes]); + Poll::Ready(Ok(bytes)) + } + other => other, + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // This is okay because `writer` is pinned when `self` is. + let writer = unsafe { self.map_unchecked_mut(|s| &mut s.writer) }; + writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // This is okay because `writer` is pinned when `self` is. + let writer = unsafe { self.map_unchecked_mut(|s| &mut s.writer) }; + writer.poll_flush(cx) + } +} diff --git a/crates/rattler_package_streaming/Cargo.toml b/crates/rattler_package_streaming/Cargo.toml index 462cb526f..d13cec0f9 100644 --- a/crates/rattler_package_streaming/Cargo.toml +++ b/crates/rattler_package_streaming/Cargo.toml @@ -10,7 +10,6 @@ repository = "https://github.com/mamba-org/rattler" license = "BSD-3-Clause" [dependencies] -async-trait = "0.1.59" thiserror = "1.0.37" tar = { version = "0.4.38" } bzip2 = { version = "0.4" } diff --git a/crates/rattler_repodata_gateway/Cargo.toml b/crates/rattler_repodata_gateway/Cargo.toml new file mode 100644 index 000000000..fd2be22e9 --- /dev/null +++ b/crates/rattler_repodata_gateway/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "rattler_repodata_gateway" +version = "0.1.0" +edition = "2021" +authors = ["Bas Zalmstra "] +description = "A crate to interact with Conda repodata" +categories = ["conda"] +homepage = "https://github.com/mamba-org/rattler" +repository = "https://github.com/mamba-org/rattler" +license = "BSD-3-Clause" + +[dependencies] +async-compression = { version = "0.3.12", features = ["gzip", "tokio", "bzip2", "zstd"] } +blake2 = "0.10.6" +cache_control = "0.2.0" +chrono = { version = "0.4.23", default-features = false, features = ["std", "serde", "alloc", "clock"] } +humansize = "2.1.3" +futures = "0.3.17" +reqwest = { version = "0.11.6", default-features = false, features = ["stream"] } +tokio-util = { version = "0.7.3", features = ["codec", "io"] } +tempfile = "3.3.0" +tracing = "0.1.29" +thiserror = "1.0.30" +url = { version = "2.2.2", features = ["serde"] } +tokio = { version = "1.12.0", features = ["rt", "io-util"] } +anyhow = "1.0.44" +serde = { version = "1.0.130", features = ["derive"] } +serde_json = { version = "1.0.68" } +pin-project-lite = "0.2.9" +md-5 = "0.10.5" +rattler_digest = { version = "0.1.0", path = "../rattler_digest", features = ["tokio"] } + +[target.'cfg(unix)'.dependencies] +libc = "0.2" + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.45.0", features = ["Win32_Storage_FileSystem", "Win32_Foundation", "Win32_System_IO"] } + +[dev-dependencies] +hex-literal = "0.3.4" +tower-http = { version = "0.3.5", features = ["fs", "compression-gzip", "trace"] } +tracing-test = { version = "0.2.4" } +insta = { version = "1.16.0", features = ["yaml"] } +axum = "0.6.2" +assert_matches = "1.5.0" +tokio = { version = "1.12.0", features = ["macros"] } diff --git a/crates/rattler_repodata_gateway/src/fetch/cache/cache_headers.rs b/crates/rattler_repodata_gateway/src/fetch/cache/cache_headers.rs new file mode 100644 index 000000000..9e33f3f17 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/cache/cache_headers.rs @@ -0,0 +1,80 @@ +use reqwest::{ + header, + header::{HeaderMap, HeaderValue}, + Response, +}; +use serde::{Deserialize, Serialize}; + +/// Extracted HTTP response headers that enable caching the repodata.json files. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CacheHeaders { + /// The ETag HTTP cache header + #[serde(default, skip_serializing_if = "Option::is_none")] + pub etag: Option, + + /// The Last-Modified HTTP cache header + #[serde(default, skip_serializing_if = "Option::is_none", rename = "mod")] + pub last_modified: Option, + + /// The cache control configuration + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +impl From<&Response> for CacheHeaders { + fn from(response: &Response) -> Self { + // Get the ETag from the response (if any). This can be used to cache the result during a + // next request. + let etag = response + .headers() + .get(header::ETAG) + .and_then(|header| header.to_str().ok()) + .map(ToOwned::to_owned); + + // Get the last modified time. This can also be used to cache the result during a next + // request. + let last_modified = response + .headers() + .get(header::LAST_MODIFIED) + .and_then(|header| header.to_str().ok()) + .map(ToOwned::to_owned); + + // Get the cache-control headers so we possibly perform local caching. + let cache_control = response + .headers() + .get(header::CACHE_CONTROL) + .and_then(|header| header.to_str().ok()) + .map(ToOwned::to_owned); + + Self { + etag, + last_modified, + cache_control, + } + } +} + +impl CacheHeaders { + /// Adds the headers to the specified request to short-circuit if the content is still up to + /// date. + pub fn add_to_request(&self, headers: &mut HeaderMap) { + // If previously there was an etag header, add the If-None-Match header so the server only sends + // us new data if the etag is not longer valid. + if let Some(etag) = self + .etag + .as_deref() + .and_then(|etag| HeaderValue::from_str(etag).ok()) + { + headers.insert(header::IF_NONE_MATCH, etag); + } + // If a previous request contains a Last-Modified header, add the If-Modified-Since header to let + // the server send us new data if the contents has been modified since that date. + if let Some(last_modified) = self + .last_modified + .as_deref() + .and_then(|last_modifed| HeaderValue::from_str(last_modifed).ok()) + { + headers.insert(header::IF_MODIFIED_SINCE, last_modified); + } + } +} diff --git a/crates/rattler_repodata_gateway/src/fetch/cache/mod.rs b/crates/rattler_repodata_gateway/src/fetch/cache/mod.rs new file mode 100644 index 000000000..23d77816b --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/cache/mod.rs @@ -0,0 +1,169 @@ +mod cache_headers; + +pub use cache_headers::CacheHeaders; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{fs::File, io::Read, path::Path, str::FromStr, time::SystemTime}; +use url::Url; + +/// Representation of the `.state.json` file alongside a `repodata.json` file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RepoDataState { + /// The URL from where the repodata was downloaded. This is the URL of the `repodata.json`, + /// `repodata.json.zst`, or another variant. This is different from the subdir url which does + /// NOT include the final filename. + pub url: Url, + + /// The HTTP cache headers send along with the last response. + #[serde(flatten)] + pub cache_headers: CacheHeaders, + + /// The timestamp of the repodata.json on disk + #[serde( + deserialize_with = "duration_from_nanos", + serialize_with = "duration_to_nanos", + rename = "mtime_ns" + )] + pub cache_last_modified: SystemTime, + + /// The size of the repodata.json file on disk. + #[serde(rename = "size")] + pub cache_size: u64, + + /// The blake2 hash of the file + #[serde( + default, + skip_serializing_if = "Option::is_none", + deserialize_with = "deserialize_blake2_hash", + serialize_with = "serialize_blake2_hash" + )] + pub blake2_hash: Option>, + + /// Whether or not zst is available for the subdirectory + pub has_zst: Option>, + + /// Whether a bz2 compressed version is available for the subdirectory + pub has_bz2: Option>, + + /// Whether or not JLAP is available for the subdirectory + pub has_jlap: Option>, +} + +impl RepoDataState { + /// Reads and parses a file from disk. + pub fn from_path(path: &Path) -> Result { + let content = { + let mut file = File::open(path)?; + let mut content = Default::default(); + file.read_to_string(&mut content)?; + content + }; + Ok(Self::from_str(&content)?) + } + + /// Save the cache state to the specified file. + pub fn to_path(&self, path: &Path) -> Result<(), std::io::Error> { + let file = File::create(path)?; + Ok(serde_json::to_writer_pretty(file, self)?) + } +} + +impl FromStr for RepoDataState { + type Err = serde_json::Error; + + fn from_str(s: &str) -> Result { + serde_json::from_str(s) + } +} + +/// Represents a value and when the value was last checked. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Expiring { + pub value: T, + + // #[serde(with = "chrono::serde::ts_seconds")] + pub last_checked: chrono::DateTime, +} + +impl Expiring { + pub fn value(&self, expiration: chrono::Duration) -> Option<&T> { + if chrono::Utc::now().signed_duration_since(self.last_checked) >= expiration { + None + } else { + Some(&self.value) + } + } +} + +/// Deserializes a [`SystemTime`] by parsing an integer and converting that as a nanosecond based unix +/// epoch timestamp to a [`SystemTime`]. +fn duration_from_nanos<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + use serde::de::Error; + SystemTime::UNIX_EPOCH + .checked_add(std::time::Duration::from_nanos(Deserialize::deserialize( + deserializer, + )?)) + .ok_or_else(|| D::Error::custom("the time cannot be represented internally")) +} + +/// Serializes a [`SystemTime`] by converting it to a nanosecond based unix epoch timestamp. +fn duration_to_nanos(time: &SystemTime, s: S) -> Result { + use serde::ser::Error; + time.duration_since(SystemTime::UNIX_EPOCH) + .map_err(|_| S::Error::custom("duration cannot be computed for file time"))? + .as_nanos() + .serialize(s) +} + +fn deserialize_blake2_hash<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + use serde::de::Error; + match Option::<&'de str>::deserialize(deserializer)? { + Some(str) => Ok(Some( + rattler_digest::parse_digest_from_hex::(str) + .ok_or_else(|| D::Error::custom("failed to parse blake2 hash"))?, + )), + None => Ok(None), + } +} + +fn serialize_blake2_hash( + time: &Option>, + s: S, +) -> Result { + match time.as_ref() { + None => s.serialize_none(), + Some(hash) => format!("{:x}", hash).serialize(s), + } +} + +#[cfg(test)] +mod test { + use super::RepoDataState; + use std::str::FromStr; + + #[test] + pub fn test_parse_repo_data_state() { + insta::assert_yaml_snapshot!(RepoDataState::from_str( + r#"{ + "cache_control": "public, max-age=1200", + "etag": "\"bec332621e00fc4ad87ba185171bcf46\"", + "has_zst": { + "last_checked": "2023-02-13T14:08:50Z", + "value": true + }, + "mod": "Mon, 13 Feb 2023 13:49:56 GMT", + "mtime_ns": 1676297333020928000, + "size": 156627374, + "url": "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst" + }"#, + ) + .unwrap()); + } +} diff --git a/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler__repo_data__cache__test__parse_repo_data_state.snap b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler__repo_data__cache__test__parse_repo_data_state.snap new file mode 100644 index 000000000..cd1ee9164 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler__repo_data__cache__test__parse_repo_data_state.snap @@ -0,0 +1,16 @@ +--- +source: crates/rattler/src/repo_data/cache/mod.rs +expression: "RepoDataState::from_str(r#\"{\n \"cache_control\": \"public, max-age=1200\",\n \"etag\": \"\\\"bec332621e00fc4ad87ba185171bcf46\\\"\",\n \"has_zst\": {\n \"last_checked\": \"2023-02-13T14:08:50Z\",\n \"value\": true\n },\n \"mod\": \"Mon, 13 Feb 2023 13:49:56 GMT\",\n \"mtime_ns\": 1676297333020928000,\n \"size\": 156627374,\n \"url\": \"https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst\"\n }\"#).unwrap()" +--- +url: "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst" +etag: "\"bec332621e00fc4ad87ba185171bcf46\"" +mod: "Mon, 13 Feb 2023 13:49:56 GMT" +cache_control: "public, max-age=1200" +mtime_ns: 1676297333020928000 +size: 156627374 +has_zst: + value: true + last_checked: "2023-02-13T14:08:50Z" +has_bz2: ~ +has_jlap: ~ + diff --git a/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_cache__cache__test__parse_repo_data_state.snap b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_cache__cache__test__parse_repo_data_state.snap new file mode 100644 index 000000000..3baf52d5e --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_cache__cache__test__parse_repo_data_state.snap @@ -0,0 +1,16 @@ +--- +source: crates/rattler_repodata_cache/src/cache/mod.rs +expression: "RepoDataState::from_str(r#\"{\n \"cache_control\": \"public, max-age=1200\",\n \"etag\": \"\\\"bec332621e00fc4ad87ba185171bcf46\\\"\",\n \"has_zst\": {\n \"last_checked\": \"2023-02-13T14:08:50Z\",\n \"value\": true\n },\n \"mod\": \"Mon, 13 Feb 2023 13:49:56 GMT\",\n \"mtime_ns\": 1676297333020928000,\n \"size\": 156627374,\n \"url\": \"https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst\"\n }\"#).unwrap()" +--- +url: "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst" +etag: "\"bec332621e00fc4ad87ba185171bcf46\"" +mod: "Mon, 13 Feb 2023 13:49:56 GMT" +cache_control: "public, max-age=1200" +mtime_ns: 1676297333020928000 +size: 156627374 +has_zst: + value: true + last_checked: "2023-02-13T14:08:50Z" +has_bz2: ~ +has_jlap: ~ + diff --git a/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_gateway__fetch__cache__test__parse_repo_data_state.snap b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_gateway__fetch__cache__test__parse_repo_data_state.snap new file mode 100644 index 000000000..12c6ecb23 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/cache/snapshots/rattler_repodata_gateway__fetch__cache__test__parse_repo_data_state.snap @@ -0,0 +1,16 @@ +--- +source: crates/rattler_repodata_gateway/src/fetch/cache/mod.rs +expression: "RepoDataState::from_str(r#\"{\n \"cache_control\": \"public, max-age=1200\",\n \"etag\": \"\\\"bec332621e00fc4ad87ba185171bcf46\\\"\",\n \"has_zst\": {\n \"last_checked\": \"2023-02-13T14:08:50Z\",\n \"value\": true\n },\n \"mod\": \"Mon, 13 Feb 2023 13:49:56 GMT\",\n \"mtime_ns\": 1676297333020928000,\n \"size\": 156627374,\n \"url\": \"https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst\"\n }\"#).unwrap()" +--- +url: "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst" +etag: "\"bec332621e00fc4ad87ba185171bcf46\"" +mod: "Mon, 13 Feb 2023 13:49:56 GMT" +cache_control: "public, max-age=1200" +mtime_ns: 1676297333020928000 +size: 156627374 +has_zst: + value: true + last_checked: "2023-02-13T14:08:50Z" +has_bz2: ~ +has_jlap: ~ + diff --git a/crates/rattler_repodata_gateway/src/fetch/mod.rs b/crates/rattler_repodata_gateway/src/fetch/mod.rs new file mode 100644 index 000000000..69ab8f9a6 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/mod.rs @@ -0,0 +1,1124 @@ +//! This module provides functionality to download and cache `repodata.json` from a remote location. + +use crate::utils::{AsyncEncoding, Encoding, LockedFile}; +use cache::{CacheHeaders, Expiring, RepoDataState}; +use cache_control::{Cachability, CacheControl}; +use futures::{future::ready, FutureExt, TryStreamExt}; +use humansize::{SizeFormatter, DECIMAL}; +use rattler_digest::{compute_file_digest, HashingWriter}; +use reqwest::{ + header::{HeaderMap, HeaderValue}, + Client, Response, StatusCode, +}; +use std::{ + io::ErrorKind, + path::{Path, PathBuf}, + time::SystemTime, +}; +use tempfile::NamedTempFile; +use tokio_util::io::StreamReader; +use tracing::instrument; +use url::Url; + +mod cache; + +#[allow(missing_docs)] +#[derive(Debug, thiserror::Error)] +pub enum FetchRepoDataError { + #[error("failed to acquire a lock on the repodata cache")] + FailedToAcquireLock(#[source] anyhow::Error), + + #[error(transparent)] + HttpError(#[from] reqwest::Error), + + #[error(transparent)] + FailedToDownloadRepoData(std::io::Error), + + #[error("failed to create temporary file for repodata.json")] + FailedToCreateTemporaryFile(#[source] std::io::Error), + + #[error("failed to persist temporary repodata.json file")] + FailedToPersistTemporaryFile(#[from] tempfile::PersistError), + + #[error("failed to get metadata from repodata.json file")] + FailedToGetMetadata(#[source] std::io::Error), + + #[error("failed to write cache state")] + FailedToWriteCacheState(#[source] std::io::Error), + + #[error("there is no cache available")] + NoCacheAvailable, + + #[error("the operation was cancelled")] + Cancelled, +} + +impl From for FetchRepoDataError { + fn from(err: tokio::task::JoinError) -> Self { + // Rethrow any panic + if let Ok(panic) = err.try_into_panic() { + std::panic::resume_unwind(panic); + } + + // Otherwise it the operation has been cancelled + FetchRepoDataError::Cancelled + } +} + +/// Defines how to use the repodata cache. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CacheAction { + /// Use the cache if its up to date or fetch from the URL if there is no valid cached value. + CacheOrFetch, + + /// Only use the cache, but error out if the cache is not up to date + UseCacheOnly, + + /// Only use the cache, ignore whether or not it is up to date. + ForceCacheOnly, + + /// Do not use the cache even if there is an up to date entry. + NoCache, +} + +impl Default for CacheAction { + fn default() -> Self { + CacheAction::CacheOrFetch + } +} + +/// Additional knobs that allow you to tweak the behavior of [`fetch_repo_data`]. +#[derive(Default)] +pub struct FetchRepoDataOptions { + /// How to use the cache. By default it will cache and reuse downloaded repodata.json (if the + /// server allows it). + pub cache_action: CacheAction, + + /// A function that is called during downloading of the repodata.json to report progress. + pub download_progress: Option>, +} + +/// A struct that provides information about download progress. +#[derive(Debug, Clone)] +pub struct DownloadProgress { + /// The number of bytes already downloaded + pub bytes: u64, + + /// The total number of bytes to download. Or `None` if this is not known. This can happen + /// if the server does not supply a `Content-Length` header. + pub total: Option, +} + +/// The result of [`fetch_repo_data`]. +#[derive(Debug)] +pub struct CachedRepoData { + /// A lockfile that guards access to any of the repodata.json file or its cache. + pub lock_file: LockedFile, + + /// The path to the uncompressed repodata.json file. + pub repo_data_json_path: PathBuf, + + /// The cache data. + pub cache_state: RepoDataState, + + /// How the cache was used for this request. + pub cache_result: CacheResult, +} + +/// Indicates whether or not the repodata.json cache was up-to-date or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheResult { + /// The cache was hit, the data on disk was already valid. + CacheHit, + + /// The cache was hit, we did have to check with the server, but no data was downloaded. + CacheHitAfterFetch, + + /// The cache was present but it was outdated. + CacheOutdated, + + /// There was no cache available + CacheNotPresent, +} + +/// Fetch the repodata.json file for the given subdirectory. The result is cached on disk using the +/// HTTP cache headers returned from the server. +/// +/// The successful result of this function also returns a lockfile which ensures that both the state +/// and the repodata that is pointed to remain in sync. However, not releasing the lockfile (by +/// dropping it) could block other threads and processes, it is therefor advices to release it as +/// quickly as possible. +/// +/// This method implements several different methods to download the repodata.json file from the +/// remote: +/// +/// * If a `repodata.json.zst` file is available in the same directory that file is downloaded +/// and decompressed. +/// * If a `repodata.json.bz2` file is available in the same directory that file is downloaded +/// and decompressed. +/// * Otherwise the regular `repodata.json` file is downloaded. +/// +/// The checks to see if a `.zst` and/or `.bz2` file exist are performed by doing a HEAD request to +/// the respective URLs. The result of these are cached. +#[instrument(err, skip_all, fields(subdir_url, cache_path = %cache_path.display()))] +pub async fn fetch_repo_data( + subdir_url: Url, + client: Client, + cache_path: &Path, + options: FetchRepoDataOptions, +) -> Result { + let subdir_url = normalize_subdir_url(subdir_url); + + // Compute the cache key from the url + let cache_key = crate::utils::url_to_cache_filename(&subdir_url); + let repo_data_json_path = cache_path.join(format!("{}.json", cache_key)); + let cache_state_path = cache_path.join(format!("{}.state.json", cache_key)); + + // Lock all files that have to do with that cache key + let lock_file_path = cache_path.join(format!("{}.lock", &cache_key)); + let lock_file = + tokio::task::spawn_blocking(move || LockedFile::open_rw(lock_file_path, "repodata cache")) + .await? + .map_err(FetchRepoDataError::FailedToAcquireLock)?; + + // Validate the current state of the cache + let cache_state = if options.cache_action != CacheAction::NoCache { + let owned_subdir_url = subdir_url.clone(); + let owned_cache_path = cache_path.to_owned(); + let cache_state = tokio::task::spawn_blocking(move || { + validate_cached_state(&owned_cache_path, &owned_subdir_url) + }) + .await?; + match (cache_state, options.cache_action) { + (ValidatedCacheState::UpToDate(cache_state), _) + | (ValidatedCacheState::OutOfDate(cache_state), CacheAction::ForceCacheOnly) => { + // Cache is up to date or we dont care about whether or not its up to date, + // so just immediately return what we have. + return Ok(CachedRepoData { + lock_file, + repo_data_json_path, + cache_state, + cache_result: CacheResult::CacheHit, + }); + } + (ValidatedCacheState::OutOfDate(_), CacheAction::UseCacheOnly) => { + // The cache is out of date but we also cant fetch new data + return Err(FetchRepoDataError::NoCacheAvailable); + } + (ValidatedCacheState::OutOfDate(cache_state), _) => { + // The cache is out of date but we can still refresh the data + Some(cache_state) + } + ( + ValidatedCacheState::Mismatched(_), + CacheAction::UseCacheOnly | CacheAction::ForceCacheOnly, + ) => { + // The cache doesnt match the repodata.json that is on disk. This means the cache is + // not usable. + return Err(FetchRepoDataError::NoCacheAvailable); + } + (ValidatedCacheState::Mismatched(cache_state), _) => { + // The cache doesnt match the data that is on disk. but it might contain some other + // interesting cached data as well... + Some(cache_state) + } + ( + ValidatedCacheState::InvalidOrMissing, + CacheAction::UseCacheOnly | CacheAction::ForceCacheOnly, + ) => { + // No cache available at all, and we cant refresh the data. + return Err(FetchRepoDataError::NoCacheAvailable); + } + (ValidatedCacheState::InvalidOrMissing, _) => { + // No cache available but we can update it! + None + } + } + } else { + None + }; + + // Determine the availability of variants based on the cache or by querying the remote. + let VariantAvailability { + has_zst: cached_zst_available, + has_bz2: cached_bz2_available, + } = check_variant_availability(&client, &subdir_url, cache_state.as_ref()).await; + + // Now that the caches have been refreshed determine whether or not we can use one of the + // variants. We dont check the expiration here since we just refreshed it. + let has_zst = cached_zst_available + .as_ref() + .map(|state| state.value) + .unwrap_or(false); + let has_bz2 = cached_bz2_available + .as_ref() + .map(|state| state.value) + .unwrap_or(false); + + // Determine which variant to download + let repo_data_url = if has_zst { + subdir_url.join("repodata.json.zst").unwrap() + } else if has_bz2 { + subdir_url.join("repodata.json.bz2").unwrap() + } else { + subdir_url.join("repodata.json").unwrap() + }; + + // Construct the HTTP request + tracing::debug!("fetching '{}'", &repo_data_url); + let request_builder = client.get(repo_data_url.clone()); + + let mut headers = HeaderMap::default(); + + // We can handle g-zip encoding which is often used. We could also set this option on the + // client, but that will disable all download progress messages by `reqwest` because the + // gzipped data is decoded on the fly and the size of the decompressed body is unknown. + // However, we don't really care about the decompressed size but rather we'd like to know + // the number of raw bytes that are actually downloaded. + // + // To do this we manually set the request header to accept gzip encoding and we use the + // [`AsyncEncoding`] trait to perform the decoding on the fly. + headers.insert( + reqwest::header::ACCEPT_ENCODING, + HeaderValue::from_static("gzip"), + ); + + // Add previous cache headers if we have them + if let Some(cache_headers) = cache_state.as_ref().map(|state| &state.cache_headers) { + cache_headers.add_to_request(&mut headers) + } + + // Send the request and wait for a reply + let response = request_builder + .headers(headers) + .send() + .await? + .error_for_status()?; + + // If the content didn't change, simply return whatever we have on disk. + if response.status() == StatusCode::NOT_MODIFIED { + tracing::debug!("repodata was unmodified"); + + // Update the cache on disk with any new findings. + let cache_state = RepoDataState { + url: repo_data_url, + has_zst: cached_zst_available, + has_bz2: cached_bz2_available, + .. cache_state.expect("we must have had a cache, otherwise we wouldn't know the previous state of the cache") + }; + + let cache_state = tokio::task::spawn_blocking(move || { + cache_state + .to_path(&cache_state_path) + .map(|_| cache_state) + .map_err(FetchRepoDataError::FailedToWriteCacheState) + }) + .await??; + + return Ok(CachedRepoData { + lock_file, + repo_data_json_path, + cache_state, + cache_result: CacheResult::CacheHitAfterFetch, + }); + } + + // Get cache headers from the response + let cache_headers = CacheHeaders::from(&response); + + // Stream the content to a temporary file + let (temp_file, blake2_hash) = stream_and_decode_to_file( + response, + if has_zst { + Encoding::Zst + } else if has_bz2 { + Encoding::Bz2 + } else { + Encoding::Passthrough + }, + cache_path, + options.download_progress, + ) + .await?; + + // Persist the file to its final destination + let repo_data_destination_path = repo_data_json_path.clone(); + let repo_data_json_metadata = tokio::task::spawn_blocking(move || { + let file = temp_file + .persist(repo_data_destination_path) + .map_err(FetchRepoDataError::FailedToPersistTemporaryFile)?; + + // Determine the last modified date and size of the repodata.json file. We store these values in + // the cache to link the cache to the corresponding repodata.json file. + file.metadata() + .map_err(FetchRepoDataError::FailedToGetMetadata) + }) + .await??; + + // Update the cache on disk. + let had_cache = cache_state.is_some(); + let new_cache_state = RepoDataState { + url: repo_data_url, + cache_headers, + cache_last_modified: repo_data_json_metadata + .modified() + .map_err(FetchRepoDataError::FailedToGetMetadata)?, + cache_size: repo_data_json_metadata.len(), + blake2_hash: Some(blake2_hash), + has_zst: cached_zst_available, + has_bz2: cached_bz2_available, + // We dont do anything with JLAP so just copy over the value. + has_jlap: cache_state.and_then(|state| state.has_jlap), + }; + + let new_cache_state = tokio::task::spawn_blocking(move || { + new_cache_state + .to_path(&cache_state_path) + .map(|_| new_cache_state) + .map_err(FetchRepoDataError::FailedToWriteCacheState) + }) + .await??; + + Ok(CachedRepoData { + lock_file, + repo_data_json_path, + cache_state: new_cache_state, + cache_result: if had_cache { + CacheResult::CacheOutdated + } else { + CacheResult::CacheNotPresent + }, + }) +} + +/// Streams and decodes the response to a new temporary file in the given directory. While writing +/// to disk it also computes the BLAKE2 hash of the file. +#[instrument(skip_all)] +async fn stream_and_decode_to_file( + response: Response, + content_encoding: Encoding, + temp_dir: &Path, + mut progress: Option>, +) -> Result<(NamedTempFile, blake2::digest::Output), FetchRepoDataError> { + // Determine the length of the response in bytes and notify the listener that a download is + // starting. The response may be compressed. Decompression happens below. + let content_size = response.content_length(); + if let Some(progress) = progress.as_mut() { + progress(DownloadProgress { + bytes: 0, + total: content_size, + }) + } + + // Determine the encoding of the response + let transfer_encoding = Encoding::from(&response); + + // Convert the response into a byte stream + let bytes_stream = response + .bytes_stream() + .map_err(|e| std::io::Error::new(ErrorKind::Other, e)); + + // Listen in on the bytes as they come from the response. Progress is tracked here instead of + // after decoding because that doesnt properly represent the number of bytes that are being + // transferred over the network. + let mut total_bytes = 0; + let total_bytes_mut = &mut total_bytes; + let bytes_stream = bytes_stream.inspect_ok(move |bytes| { + *total_bytes_mut += bytes.len() as u64; + if let Some(progress) = progress.as_mut() { + progress(DownloadProgress { + bytes: *total_bytes_mut, + total: content_size, + }) + } + }); + + // Create a new stream from the byte stream that decodes the bytes using the transfer encoding + // on the fly. + let decoded_byte_stream = StreamReader::new(bytes_stream).decode(transfer_encoding); + + // Create yet another stream that decodes the bytes yet again but this time using the content + // encoding. + let mut decoded_repo_data_json_bytes = + tokio::io::BufReader::new(decoded_byte_stream).decode(content_encoding); + + tracing::trace!( + "decoding repodata (content: {:?}, transfer: {:?})", + content_encoding, + transfer_encoding + ); + + // Construct a temporary file + let temp_file = + NamedTempFile::new_in(temp_dir).map_err(FetchRepoDataError::FailedToCreateTemporaryFile)?; + + // Clone the file handle and create a hashing writer so we can compute a hash while the content + // is being written to disk. + let file = tokio::fs::File::from_std(temp_file.as_file().try_clone().unwrap()); + let mut hashing_file_writer = HashingWriter::<_, blake2::Blake2s256>::new(file); + + // Decode, hash and write the data to the file. + let bytes = tokio::io::copy(&mut decoded_repo_data_json_bytes, &mut hashing_file_writer) + .await + .map_err(FetchRepoDataError::FailedToDownloadRepoData)?; + + // Finalize the hash + let (_, hash) = hashing_file_writer.finalize(); + + tracing::debug!( + "downloaded {}, decoded that into {}, BLAKE2 hash: {:x}", + SizeFormatter::new(total_bytes, DECIMAL), + SizeFormatter::new(bytes, DECIMAL), + hash + ); + + Ok((temp_file, hash)) +} + +/// Describes the availability of certain `repodata.json`. +#[derive(Debug)] +struct VariantAvailability { + has_zst: Option>, + has_bz2: Option>, +} + +/// Determine the availability of `repodata.json` variants (like a `.zst` or `.bz2`) by checking +/// a cache or the internet. +async fn check_variant_availability( + client: &Client, + subdir_url: &Url, + cache_state: Option<&RepoDataState>, +) -> VariantAvailability { + // Determine from the cache which variant are available. This is currently cached for a maximum + // of 14 days. + let expiration_duration = chrono::Duration::days(14); + let has_zst = cache_state + .and_then(|state| state.has_zst.as_ref()) + .and_then(|value| value.value(expiration_duration)) + .copied(); + let has_bz2 = cache_state + .and_then(|state| state.has_bz2.as_ref()) + .and_then(|value| value.value(expiration_duration)) + .copied(); + + // Create a future to possibly refresh the zst state. + let zst_repodata_url = subdir_url.join("repodata.json.zst").unwrap(); + let bz2_repodata_url = subdir_url.join("repodata.json.bz2").unwrap(); + let zst_future = match has_zst { + Some(_) => { + // The last cached value was value so we simply copy that + ready(cache_state.and_then(|state| state.has_zst.clone())).left_future() + } + None => async { + Some(Expiring { + value: check_valid_download_target(&zst_repodata_url, client).await, + last_checked: chrono::Utc::now(), + }) + } + .right_future(), + }; + + // Create a future to determine if bz2 is available. We only check this if we dont already know that + // zst is available because if thats available we're going to use that anyway. + let bz2_future = if has_zst != Some(true) { + // If the zst variant might not be available we need to check whether bz2 is available. + async { + match has_bz2 { + Some(_) => { + // The last cached value was value so we simply copy that. + cache_state.and_then(|state| state.has_bz2.clone()) + } + None => Some(Expiring { + value: check_valid_download_target(&bz2_repodata_url, client).await, + last_checked: chrono::Utc::now(), + }), + } + } + .left_future() + } else { + // If we already know that zst is available we simply copy the availability value from the last + // time we checked. + ready(cache_state.and_then(|state| state.has_zst.clone())).right_future() + }; + + // TODO: Implement JLAP + + // Await both futures so they happen concurrently. Note that a request might not actually happen if + // the cache is still valid. + let (has_zst, has_bz2) = futures::join!(zst_future, bz2_future); + + VariantAvailability { has_zst, has_bz2 } +} + +/// Performs a HEAD request on the given URL to see if it is available. +async fn check_valid_download_target(url: &Url, client: &Client) -> bool { + tracing::debug!("checking availability of '{url}'"); + + // Otherwise, perform a HEAD request to determine whether the url seems valid. + match client.head(url.clone()).send().await { + Ok(response) => { + if response.status().is_success() { + tracing::debug!("'{url}' seems to be available"); + true + } else { + tracing::debug!("'{url}' seems to be unavailable"); + false + } + } + Err(e) => { + tracing::warn!( + "failed to perform HEAD request on '{url}': {e}. Assuming its unavailable.." + ); + false + } + } +} + +// Ensures that the URL contains a trailing slash. This is important for the [`Url::join`] function. +fn normalize_subdir_url(url: Url) -> Url { + let mut path = url.path(); + path = path.trim_end_matches('/'); + let mut url = url.clone(); + url.set_path(&format!("{path}/")); + url +} + +/// A value returned from [`validate_cached_state`] which indicates the state of a repodata.json cache. +enum ValidatedCacheState { + /// There is no cache, the cache could not be parsed, or the cache does not reference the same + /// request. We can completely ignore any cached data. + InvalidOrMissing, + + /// The cache does not match the repodata.json file that is on disk. This usually indicates that the + /// repodata.json was modified without updating the cache. + Mismatched(RepoDataState), + + /// The cache could be read and corresponds to the repodata.json file that is on disk but the cached + /// data is (partially) out of date. + OutOfDate(RepoDataState), + + /// The cache is up to date. + UpToDate(RepoDataState), +} + +/// Tries to determine if the cache state for the repodata.json for the given `subdir_url` is +/// considered to be up-to-date. +/// +/// This functions reads multiple files from the `cache_path`, it is left up to the user to ensure +/// that these files stay synchronized during the execution of this function. +fn validate_cached_state(cache_path: &Path, subdir_url: &Url) -> ValidatedCacheState { + let cache_key = crate::utils::url_to_cache_filename(subdir_url); + let repo_data_json_path = cache_path.join(format!("{}.json", cache_key)); + let cache_state_path = cache_path.join(format!("{}.state.json", cache_key)); + + // Check if we have cached repodata.json file + let json_metadata = match std::fs::metadata(&repo_data_json_path) { + Err(e) if e.kind() == ErrorKind::NotFound => return ValidatedCacheState::InvalidOrMissing, + Err(e) => { + tracing::warn!( + "failed to get metadata of repodata.json file '{}': {e}. Ignoring cached files...", + repo_data_json_path.display() + ); + return ValidatedCacheState::InvalidOrMissing; + } + Ok(metadata) => metadata, + }; + + // Try to read the repodata state cache + let cache_state = match RepoDataState::from_path(&cache_state_path) { + Err(e) if e.kind() == ErrorKind::NotFound => { + // Ignore, the cache just doesnt exist + tracing::info!("repodata cache state is missing. Ignoring cached files..."); + return ValidatedCacheState::InvalidOrMissing; + } + Err(e) => { + // An error occured while reading the cached state. + tracing::warn!( + "invalid repodata cache state '{}': {e}. Ignoring cached files...", + cache_state_path.display() + ); + return ValidatedCacheState::InvalidOrMissing; + } + Ok(state) => state, + }; + + // Do the URLs match? + let cached_subdir_url = if cache_state.url.path().ends_with('/') { + cache_state.url.clone() + } else { + let path = cache_state.url.path(); + let (subdir_path, _) = path.rsplit_once('/').unwrap_or(("", path)); + let mut url = cache_state.url.clone(); + url.set_path(&format!("{subdir_path}/")); + url + }; + if &cached_subdir_url != subdir_url { + tracing::warn!( + "cache state refers to a different repodata.json url. Ignoring cached files..." + ); + return ValidatedCacheState::InvalidOrMissing; + } + + // Determine last modified date of the repodata.json file. + let cache_last_modified = match json_metadata.modified() { + Err(_) => { + tracing::warn!("could not determine last modified date of repodata.json file. Ignoring cached files..."); + return ValidatedCacheState::Mismatched(cache_state); + } + Ok(last_modified) => last_modified, + }; + + // Make sure that the repodata state cache refers to the repodata that exists on disk. + // + // Check the blake hash of the repodata.json file if we have a similar hash in the state. + if let Some(cached_hash) = cache_state.blake2_hash.as_ref() { + match compute_file_digest::(&repo_data_json_path) { + Err(e) => { + tracing::warn!( + "could not compute BLAKE2 hash of repodata.json file: {e}. Ignoring cached files..." + ); + return ValidatedCacheState::Mismatched(cache_state); + } + Ok(hash) => { + if &hash != cached_hash { + tracing::warn!( + "BLAKE2 hash of repodata.json does not match cache state. Ignoring cached files..." + ); + return ValidatedCacheState::Mismatched(cache_state); + } + } + } + } else { + // The state cache records the size and last modified date of the original file. If those do + // not match, the repodata.json file has been modified. + if json_metadata.len() != cache_state.cache_size + || Some(cache_last_modified) != json_metadata.modified().ok() + { + tracing::warn!("repodata cache state mismatches the existing repodatajson file. Ignoring cached files..."); + return ValidatedCacheState::Mismatched(cache_state); + } + } + + // Determine the age of the cache + let cache_age = match SystemTime::now().duration_since(cache_last_modified) { + Ok(duration) => duration, + Err(e) => { + tracing::warn!("failed to determine cache age: {e}. Ignoring cached files..."); + return ValidatedCacheState::Mismatched(cache_state); + } + }; + + // Parse the cache control header, and determine if the cache is out of date or not. + match cache_state.cache_headers.cache_control.as_deref() { + Some(cache_control) => match CacheControl::from_value(cache_control) { + None => { + tracing::warn!( + "could not parse cache_control from repodata cache state. Ignoring cached files..." + ); + return ValidatedCacheState::Mismatched(cache_state); + } + Some(CacheControl { + cachability: Some(Cachability::Public), + max_age: Some(duration), + .. + }) => { + if duration > cache_age { + tracing::info!("Cache is out of date. Assuming out of date..."); + return ValidatedCacheState::OutOfDate(cache_state); + } + } + Some(_) => { + tracing::info!( + "Unsupported cache-control value '{}'. Assuming out of date...", + cache_control + ); + return ValidatedCacheState::OutOfDate(cache_state); + } + }, + None => { + tracing::warn!("previous cache state does not contain cache_control header. Assuming out of date..."); + return ValidatedCacheState::OutOfDate(cache_state); + } + } + + // Well then! If we get here, it means the cache must be up to date! + ValidatedCacheState::UpToDate(cache_state) +} + +#[cfg(test)] +mod test { + use super::{ + fetch_repo_data, CacheResult, CachedRepoData, DownloadProgress, FetchRepoDataOptions, + }; + use crate::utils::simple_channel_server::SimpleChannelServer; + use crate::utils::Encoding; + use assert_matches::assert_matches; + use hex_literal::hex; + use reqwest::Client; + use std::path::Path; + use tempfile::TempDir; + use tokio::io::AsyncWriteExt; + use url::Url; + + async fn write_encoded( + mut input: &[u8], + destination: &Path, + encoding: Encoding, + ) -> Result<(), std::io::Error> { + // Open the file for writing + let mut file = tokio::fs::File::create(destination).await.unwrap(); + + match encoding { + Encoding::Passthrough => { + tokio::io::copy(&mut input, &mut file).await?; + } + Encoding::GZip => { + let mut encoder = async_compression::tokio::write::GzipEncoder::new(file); + tokio::io::copy(&mut input, &mut encoder).await?; + encoder.shutdown().await?; + } + Encoding::Bz2 => { + let mut encoder = async_compression::tokio::write::BzEncoder::new(file); + tokio::io::copy(&mut input, &mut encoder).await?; + encoder.shutdown().await?; + } + Encoding::Zst => { + let mut encoder = async_compression::tokio::write::ZstdEncoder::new(file); + tokio::io::copy(&mut input, &mut encoder).await?; + encoder.shutdown().await?; + } + } + + Ok(()) + } + + #[test] + pub fn test_normalize_url() { + assert_eq!( + super::normalize_subdir_url(Url::parse("http://localhost/channels/empty").unwrap()), + Url::parse("http://localhost/channels/empty/").unwrap(), + ); + assert_eq!( + super::normalize_subdir_url(Url::parse("http://localhost/channels/empty/").unwrap()), + Url::parse("http://localhost/channels/empty/").unwrap(), + ); + } + + const FAKE_REPO_DATA: &str = r#"{ + "packages.conda": { + "asttokens-2.2.1-pyhd8ed1ab_0.conda": { + "arch": null, + "build": "pyhd8ed1ab_0", + "build_number": 0, + "build_string": "pyhd8ed1ab_0", + "constrains": [], + "depends": [ + "python >=3.5", + "six" + ], + "fn": "asttokens-2.2.1-pyhd8ed1ab_0.conda", + "license": "Apache-2.0", + "license_family": "Apache", + "md5": "bf7f54dd0f25c3f06ecb82a07341841a", + "name": "asttokens", + "noarch": "python", + "platform": null, + "sha256": "7ed530efddd47a96c11197906b4008405b90e3bc2f4e0df722a36e0e6103fd9c", + "size": 27831, + "subdir": "noarch", + "timestamp": 1670264089059, + "track_features": "", + "url": "https://conda.anaconda.org/conda-forge/noarch/asttokens-2.2.1-pyhd8ed1ab_0.conda", + "version": "2.2.1" + } + } + } + "#; + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_fetch_repo_data() { + // Create a directory with some repodata. + let subdir_path = TempDir::new().unwrap(); + std::fs::write(subdir_path.path().join("repodata.json"), FAKE_REPO_DATA).unwrap(); + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let result = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_eq!( + result.cache_state.blake2_hash.unwrap()[..], + hex!("791749939c9d6e26801bbcd525b908da15d42d3249f01efaca1ed1133f38bb87")[..] + ); + assert_eq!( + std::fs::read_to_string(result.repo_data_json_path).unwrap(), + FAKE_REPO_DATA + ); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_cache_works() { + // Create a directory with some repodata. + let subdir_path = TempDir::new().unwrap(); + std::fs::write(subdir_path.path().join("repodata.json"), FAKE_REPO_DATA).unwrap(); + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let CachedRepoData { cache_result, .. } = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_matches!(cache_result, CacheResult::CacheNotPresent); + + // Download the data from the channel with a filled cache. + let CachedRepoData { cache_result, .. } = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_matches!( + cache_result, + CacheResult::CacheHit | CacheResult::CacheHitAfterFetch + ); + + // I know this is terrible but without the sleep rust is too blazingly fast and the server + // doesnt think the file was actually updated.. This is because the time send by the server + // has seconds precision. + tokio::time::sleep(std::time::Duration::from_millis(1500)).await; + + // Update the original repodata.json file + std::fs::write(subdir_path.path().join("repodata.json"), FAKE_REPO_DATA).unwrap(); + + // Download the data from the channel with a filled cache. + let CachedRepoData { cache_result, .. } = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_matches!(cache_result, CacheResult::CacheOutdated); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_zst_works() { + let subdir_path = TempDir::new().unwrap(); + write_encoded( + FAKE_REPO_DATA.as_bytes(), + &subdir_path.path().join("repodata.json.zst"), + Encoding::Zst, + ) + .await + .unwrap(); + + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let result = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_eq!( + std::fs::read_to_string(result.repo_data_json_path).unwrap(), + FAKE_REPO_DATA + ); + assert_matches!( + result.cache_state.has_zst, Some(super::Expiring { + value, .. + }) if value + ); + assert_matches!( + result.cache_state.has_bz2, Some(super::Expiring { + value, .. + }) if !value + ); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_bz2_works() { + let subdir_path = TempDir::new().unwrap(); + write_encoded( + FAKE_REPO_DATA.as_bytes(), + &subdir_path.path().join("repodata.json.bz2"), + Encoding::Bz2, + ) + .await + .unwrap(); + + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let result = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_eq!( + std::fs::read_to_string(result.repo_data_json_path).unwrap(), + FAKE_REPO_DATA + ); + assert_matches!( + result.cache_state.has_zst, Some(super::Expiring { + value, .. + }) if !value + ); + assert_matches!( + result.cache_state.has_bz2, Some(super::Expiring { + value, .. + }) if value + ); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_zst_is_preferred() { + let subdir_path = TempDir::new().unwrap(); + write_encoded( + FAKE_REPO_DATA.as_bytes(), + &subdir_path.path().join("repodata.json.bz2"), + Encoding::Bz2, + ) + .await + .unwrap(); + write_encoded( + FAKE_REPO_DATA.as_bytes(), + &subdir_path.path().join("repodata.json.zst"), + Encoding::Zst, + ) + .await + .unwrap(); + + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let result = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_eq!( + std::fs::read_to_string(result.repo_data_json_path).unwrap(), + FAKE_REPO_DATA + ); + assert!(result.cache_state.url.path().ends_with("repodata.json.zst")); + assert_matches!( + result.cache_state.has_zst, Some(super::Expiring { + value, .. + }) if value + ); + assert_matches!( + result.cache_state.has_bz2, Some(super::Expiring { + value, .. + }) if value + ); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_gzip_transfer_encoding() { + // Create a directory with some repodata. + let subdir_path = TempDir::new().unwrap(); + write_encoded( + FAKE_REPO_DATA.as_ref(), + &subdir_path.path().join("repodata.json.gz"), + Encoding::GZip, + ) + .await + .unwrap(); + + // The server is configured in such a way that if file `a` is requested but a file called + // `a.gz` is available it will stream the `a.gz` file and report that its a `gzip` encoded + // stream. + let server = SimpleChannelServer::new(subdir_path.path()); + + // Download the data from the channel + let cache_dir = TempDir::new().unwrap(); + let result = fetch_repo_data( + server.url(), + Client::builder().no_gzip().build().unwrap(), + cache_dir.path(), + Default::default(), + ) + .await + .unwrap(); + + assert_eq!( + std::fs::read_to_string(result.repo_data_json_path).unwrap(), + FAKE_REPO_DATA + ); + } + + #[tracing_test::traced_test] + #[tokio::test] + pub async fn test_progress() { + use std::cell::Cell; + use std::sync::Arc; + + // Create a directory with some repodata. + let subdir_path = TempDir::new().unwrap(); + std::fs::write(subdir_path.path().join("repodata.json"), FAKE_REPO_DATA).unwrap(); + let server = SimpleChannelServer::new(subdir_path.path()); + + let last_download_progress = Arc::new(Cell::new(0)); + let last_download_progress_captured = last_download_progress.clone(); + let download_progress = move |progress: DownloadProgress| { + last_download_progress_captured.set(progress.bytes); + assert_eq!(progress.total, Some(1110)); + }; + + // Download the data from the channel with an empty cache. + let cache_dir = TempDir::new().unwrap(); + let _result = fetch_repo_data( + server.url(), + Client::default(), + cache_dir.path(), + FetchRepoDataOptions { + download_progress: Some(Box::new(download_progress)), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(last_download_progress.get(), 1110); + } +} diff --git a/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__fetch_repo_data.snap b/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__fetch_repo_data.snap new file mode 100644 index 000000000..3543fa6a7 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__fetch_repo_data.snap @@ -0,0 +1,16 @@ +--- +source: crates/rattler/src/repo_data/mod.rs +expression: result.cache_state +--- +url: "http://localhost:63643/noarch/" +mod: "Wed, 21 Sep 2022 19:03:45 GMT" +mtime_ns: 1676557004445475300 +cache_size: 209 +has_zst: + value: false + last_checked: "2023-02-16T14:16:44.441641600Z" +has_bz2: + value: false + last_checked: "2023-02-16T14:16:44.440955Z" +has_jlap: ~ + diff --git a/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__parse_repo_data_state.snap b/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__parse_repo_data_state.snap new file mode 100644 index 000000000..53a4cad93 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/fetch/snapshots/rattler__repo_data__test__parse_repo_data_state.snap @@ -0,0 +1,16 @@ +--- +source: crates/rattler/src/repo_data/mod.rs +expression: "RepoDataState::from_str(r#\"{\n \"cache_control\": \"public, max-age=1200\",\n \"etag\": \"\\\"bec332621e00fc4ad87ba185171bcf46\\\"\",\n \"has_zst\": {\n \"last_checked\": \"2023-02-13T14:08:50Z\",\n \"value\": true\n },\n \"mod\": \"Mon, 13 Feb 2023 13:49:56 GMT\",\n \"mtime_ns\": 1676297333020928000,\n \"size\": 156627374,\n \"url\": \"https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst\"\n }\"#).unwrap()" +--- +url: "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst" +etag: "\"bec332621e00fc4ad87ba185171bcf46\"" +mod: "Mon, 13 Feb 2023 13:49:56 GMT" +cache_control: "public, max-age=1200" +mtime_ns: 1676297333020928000 +size: 156627374 +has_zst: + value: true + last_checked: "2023-02-13T14:08:50Z" +has_bz2: ~ +has_jlap: ~ + diff --git a/crates/rattler_repodata_gateway/src/lib.rs b/crates/rattler_repodata_gateway/src/lib.rs new file mode 100644 index 000000000..3e170d6f1 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/lib.rs @@ -0,0 +1,11 @@ +#![deny(missing_docs)] + +//! `rattler_repodata_gateway` is a crate that provides functionality to interact with Conda +//! repodata. It currently provides functionality to download and cache `repodata.json` files +//! through the [`fetch::fetch_repo_data`] function. +//! +//! In the future this crate will also provide more high-level functionality to query information +//! about specific packages from different sources. + +pub mod fetch; +mod utils; diff --git a/crates/rattler/src/utils/encoding.rs b/crates/rattler_repodata_gateway/src/utils/encoding.rs similarity index 78% rename from crates/rattler/src/utils/encoding.rs rename to crates/rattler_repodata_gateway/src/utils/encoding.rs index 6b366055e..4156bacd1 100644 --- a/crates/rattler/src/utils/encoding.rs +++ b/crates/rattler_repodata_gateway/src/utils/encoding.rs @@ -8,6 +8,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; pub enum Encoding { Passthrough, GZip, + Bz2, + Zst, } impl<'a> From<&'a reqwest::Response> for Encoding { @@ -25,6 +27,8 @@ pin_project! { pub enum Decoder { Passthrough { #[pin] inner: T }, GZip { #[pin] inner: async_compression::tokio::bufread::GzipDecoder }, + Bz2 { #[pin] inner: async_compression::tokio::bufread::BzDecoder }, + Zst { #[pin] inner: async_compression::tokio::bufread::ZstdDecoder }, } } @@ -37,6 +41,8 @@ impl AsyncRead for Decoder { match self.project() { DecoderProj::Passthrough { inner } => inner.poll_read(cx, buf), DecoderProj::GZip { inner } => inner.poll_read(cx, buf), + DecoderProj::Bz2 { inner } => inner.poll_read(cx, buf), + DecoderProj::Zst { inner } => inner.poll_read(cx, buf), } } } @@ -53,6 +59,12 @@ impl AsyncEncoding for T { Encoding::GZip => Decoder::GZip { inner: async_compression::tokio::bufread::GzipDecoder::new(self), }, + Encoding::Bz2 => Decoder::Bz2 { + inner: async_compression::tokio::bufread::BzDecoder::new(self), + }, + Encoding::Zst => Decoder::Zst { + inner: async_compression::tokio::bufread::ZstdDecoder::new(self), + }, } } } diff --git a/crates/rattler_repodata_gateway/src/utils/flock.rs b/crates/rattler_repodata_gateway/src/utils/flock.rs new file mode 100644 index 000000000..211233ad2 --- /dev/null +++ b/crates/rattler_repodata_gateway/src/utils/flock.rs @@ -0,0 +1,424 @@ +/// Implementation of file locks taken from: +/// https://github.com/rust-lang/cargo/blob/39c13e67a5962466cc7253d41bc1099bbcb224c3/src/cargo/util/flock.rs +/// +/// Under MIT license: +/// +/// Permission is hereby granted, free of charge, to any +/// person obtaining a copy of this software and associated +/// documentation files (the "Software"), to deal in the +/// Software without restriction, including without +/// limitation the rights to use, copy, modify, merge, +/// publish, distribute, sublicense, and/or sell copies of +/// the Software, and to permit persons to whom the Software +/// is furnished to do so, subject to the following +/// conditions: +/// +/// The above copyright notice and this permission notice +/// shall be included in all copies or substantial portions +/// of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +/// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +/// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +/// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +/// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +/// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +/// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +/// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +/// DEALINGS IN THE SOFTWARE. +use std::fs::{File, OpenOptions}; +use std::io; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use anyhow::Context as _; +use sys::*; + +#[derive(Debug)] +pub struct LockedFile { + f: Option, + path: PathBuf, + state: State, +} + +#[derive(PartialEq, Debug)] +enum State { + Unlocked, + Shared, + Exclusive, +} + +impl LockedFile { + /// Returns the underlying file handle of this lock. + pub fn file(&self) -> &File { + self.f.as_ref().unwrap() + } + + /// Returns the underlying path that this lock points to. + /// + /// Note that special care must be taken to ensure that the path is not + /// referenced outside the lifetime of this lock. + pub fn path(&self) -> &Path { + assert_ne!(self.state, State::Unlocked); + &self.path + } + + /// Returns the parent path containing this file + pub fn parent(&self) -> &Path { + assert_ne!(self.state, State::Unlocked); + self.path.parent().unwrap() + } + + /// Opens exclusive access to a file, returning the locked version of a + /// file. + /// + /// This function will create a file at `path` if it doesn't already exist + /// (including intermediate directories), and then it will acquire an + /// exclusive lock on `path`. If the process must block waiting for the + /// lock, the `msg` is printed to `config`. + /// + /// The returned file can be accessed to look at the path and also has + /// read/write access to the underlying file. + pub fn open_rw

(path: P, msg: &str) -> anyhow::Result + where + P: AsRef, + { + Self::open( + path.as_ref(), + OpenOptions::new().read(true).write(true).create(true), + State::Exclusive, + msg, + ) + } + + /// Opens shared access to a file, returning the locked version of a file. + /// + /// This function will fail if `path` doesn't already exist, but if it does + /// then it will acquire a shared lock on `path`. If the process must block + /// waiting for the lock, the `msg` is printed to tracing. + /// + /// The returned file can be accessed to look at the path and also has read + /// access to the underlying file. Any writes to the file will return an + /// error. + pub fn open_ro

(path: P, msg: &str) -> anyhow::Result + where + P: AsRef, + { + Self::open( + path.as_ref(), + OpenOptions::new().read(true), + State::Shared, + msg, + ) + } + + fn open( + path: &Path, + opts: &OpenOptions, + state: State, + msg: &str, + ) -> anyhow::Result { + // If we want an exclusive lock then if we fail because of NotFound it's + // likely because an intermediate directory didn't exist, so try to + // create the directory and then continue. + let f = opts + .open(path) + .or_else(|e| { + if e.kind() == io::ErrorKind::NotFound && state == State::Exclusive { + std::fs::create_dir_all(path.parent().unwrap())?; + Ok(opts.open(path)?) + } else { + Err(anyhow::Error::from(e)) + } + }) + .with_context(|| format!("failed to open: {}", path.display()))?; + match state { + State::Exclusive => { + acquire(msg, path, &|| try_lock_exclusive(&f), &|| { + lock_exclusive(&f) + })?; + } + State::Shared => { + acquire(msg, path, &|| try_lock_shared(&f), &|| lock_shared(&f))?; + } + State::Unlocked => {} + } + Ok(LockedFile { + f: Some(f), + path: path.to_owned(), + state, + }) + } +} + +impl Read for LockedFile { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.file().read(buf) + } +} + +impl Seek for LockedFile { + fn seek(&mut self, to: SeekFrom) -> io::Result { + self.file().seek(to) + } +} + +impl Write for LockedFile { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.file().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.file().flush() + } +} + +impl Drop for LockedFile { + fn drop(&mut self) { + if self.state != State::Unlocked { + if let Some(f) = self.f.take() { + let _ = unlock(&f); + } + } + } +} + +/// Acquires a lock on a file in a "nice" manner. +/// +/// Almost all long-running blocking actions in Cargo have a status message +/// associated with them as we're not sure how long they'll take. Whenever a +/// conflicted file lock happens, this is the case (we're not sure when the lock +/// will be released). +/// +/// This function will acquire the lock on a `path`, printing out a nice message +/// to the console if we have to wait for it. It will first attempt to use `try` +/// to acquire a lock on the crate, and in the case of contention it will emit a +/// status message based on `msg` to tracing, and then use `block` to +/// block waiting to acquire a lock. +/// +/// Returns an error if the lock could not be acquired or if any error other +/// than a contention error happens. +fn acquire( + msg: &str, + path: &Path, + lock_try: &dyn Fn() -> io::Result<()>, + lock_block: &dyn Fn() -> io::Result<()>, +) -> anyhow::Result<()> { + // File locking on Unix is currently implemented via `flock`, which is known + // to be broken on NFS. We could in theory just ignore errors that happen on + // NFS, but apparently the failure mode [1] for `flock` on NFS is **blocking + // forever**, even if the "non-blocking" flag is passed! + // + // As a result, we just skip all file locks entirely on NFS mounts. That + // should avoid calling any `flock` functions at all, and it wouldn't work + // there anyway. + // + // [1]: https://github.com/rust-lang/cargo/issues/2615 + if is_on_nfs_mount(path) { + return Ok(()); + } + + match lock_try() { + Ok(()) => return Ok(()), + + // In addition to ignoring NFS which is commonly not working we also + // just ignore locking on filesystems that look like they don't + // implement file locking. + Err(e) if error_unsupported(&e) => return Ok(()), + + Err(e) => { + if !error_contended(&e) { + let e = anyhow::Error::from(e); + let cx = format!("failed to lock file: {}", path.display()); + return Err(e.context(cx)); + } + } + } + + tracing::info!("waiting for file lock on {}", msg); + + lock_block().with_context(|| format!("failed to lock file: {}", path.display()))?; + return Ok(()); + + #[cfg(all(target_os = "linux", not(target_env = "musl")))] + fn is_on_nfs_mount(path: &Path) -> bool { + use std::ffi::CString; + use std::mem; + use std::os::unix::prelude::*; + + let path = match CString::new(path.as_os_str().as_bytes()) { + Ok(path) => path, + Err(_) => return false, + }; + + unsafe { + let mut buf: libc::statfs = mem::zeroed(); + let r = libc::statfs(path.as_ptr(), &mut buf); + + r == 0 && buf.f_type as u32 == libc::NFS_SUPER_MAGIC as u32 + } + } + + #[cfg(any(not(target_os = "linux"), target_env = "musl"))] + fn is_on_nfs_mount(_path: &Path) -> bool { + false + } +} + +#[cfg(unix)] +mod sys { + use std::fs::File; + use std::io::{Error, Result}; + use std::os::unix::io::AsRawFd; + + pub(super) fn lock_shared(file: &File) -> Result<()> { + flock(file, libc::LOCK_SH) + } + + pub(super) fn lock_exclusive(file: &File) -> Result<()> { + flock(file, libc::LOCK_EX) + } + + pub(super) fn try_lock_shared(file: &File) -> Result<()> { + flock(file, libc::LOCK_SH | libc::LOCK_NB) + } + + pub(super) fn try_lock_exclusive(file: &File) -> Result<()> { + flock(file, libc::LOCK_EX | libc::LOCK_NB) + } + + pub(super) fn unlock(file: &File) -> Result<()> { + flock(file, libc::LOCK_UN) + } + + pub(super) fn error_contended(err: &Error) -> bool { + err.raw_os_error().map_or(false, |x| x == libc::EWOULDBLOCK) + } + + pub(super) fn error_unsupported(err: &Error) -> bool { + match err.raw_os_error() { + // Unfortunately, depending on the target, these may or may not be the same. + // For targets in which they are the same, the duplicate pattern causes a warning. + #[allow(unreachable_patterns)] + Some(libc::ENOTSUP | libc::EOPNOTSUPP) => true, + Some(libc::ENOSYS) => true, + _ => false, + } + } + + #[cfg(not(target_os = "solaris"))] + fn flock(file: &File, flag: libc::c_int) -> Result<()> { + let ret = unsafe { libc::flock(file.as_raw_fd(), flag) }; + if ret < 0 { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } + + #[cfg(target_os = "solaris")] + fn flock(file: &File, flag: libc::c_int) -> Result<()> { + // Solaris lacks flock(), so try to emulate using fcntl() + let mut flock = libc::flock { + l_type: 0, + l_whence: 0, + l_start: 0, + l_len: 0, + l_sysid: 0, + l_pid: 0, + l_pad: [0, 0, 0, 0], + }; + flock.l_type = if flag & libc::LOCK_UN != 0 { + libc::F_UNLCK + } else if flag & libc::LOCK_EX != 0 { + libc::F_WRLCK + } else if flag & libc::LOCK_SH != 0 { + libc::F_RDLCK + } else { + panic!("unexpected flock() operation") + }; + + let mut cmd = libc::F_SETLKW; + if (flag & libc::LOCK_NB) != 0 { + cmd = libc::F_SETLK; + } + + let ret = unsafe { libc::fcntl(file.as_raw_fd(), cmd, &flock) }; + + if ret < 0 { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } +} + +#[cfg(windows)] +mod sys { + use std::fs::File; + use std::io::{Error, Result}; + use std::mem; + use std::os::windows::io::AsRawHandle; + + use windows_sys::Win32::Foundation::HANDLE; + use windows_sys::Win32::Foundation::{ERROR_INVALID_FUNCTION, ERROR_LOCK_VIOLATION}; + use windows_sys::Win32::Storage::FileSystem::{ + LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY, + }; + + pub(super) fn lock_shared(file: &File) -> Result<()> { + lock_file(file, 0) + } + + pub(super) fn lock_exclusive(file: &File) -> Result<()> { + lock_file(file, LOCKFILE_EXCLUSIVE_LOCK) + } + + pub(super) fn try_lock_shared(file: &File) -> Result<()> { + lock_file(file, LOCKFILE_FAIL_IMMEDIATELY) + } + + pub(super) fn try_lock_exclusive(file: &File) -> Result<()> { + lock_file(file, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY) + } + + pub(super) fn error_contended(err: &Error) -> bool { + err.raw_os_error() + .map_or(false, |x| x == ERROR_LOCK_VIOLATION as i32) + } + + pub(super) fn error_unsupported(err: &Error) -> bool { + err.raw_os_error() + .map_or(false, |x| x == ERROR_INVALID_FUNCTION as i32) + } + + pub(super) fn unlock(file: &File) -> Result<()> { + unsafe { + let ret = UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0); + if ret == 0 { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } + } + + fn lock_file(file: &File, flags: u32) -> Result<()> { + unsafe { + let mut overlapped = mem::zeroed(); + let ret = LockFileEx( + file.as_raw_handle() as HANDLE, + flags, + 0, + !0, + !0, + &mut overlapped, + ); + if ret == 0 { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } + } +} diff --git a/crates/rattler/src/utils/mod.rs b/crates/rattler_repodata_gateway/src/utils/mod.rs similarity index 73% rename from crates/rattler/src/utils/mod.rs rename to crates/rattler_repodata_gateway/src/utils/mod.rs index 53843d1c6..543dedc1b 100644 --- a/crates/rattler/src/utils/mod.rs +++ b/crates/rattler_repodata_gateway/src/utils/mod.rs @@ -1,23 +1,17 @@ pub use encoding::{AsyncEncoding, Encoding}; -pub use hash::{compute_file_sha256, parse_sha256_from_hex, HashingWriter, Sha256HashingWriter}; -use std::{fmt::Write, path::PathBuf}; +pub use flock::LockedFile; +use std::fmt::Write; use url::Url; mod encoding; -mod hash; #[cfg(test)] pub(crate) mod simple_channel_server; -/// Returns the default cache directory used by rattler. -pub fn default_cache_dir() -> anyhow::Result { - Ok(dirs::cache_dir() - .ok_or_else(|| anyhow::anyhow!("could not determine cache directory for current platform"))? - .join("rattler/cache")) -} +mod flock; /// Convert a URL to a cache filename -pub fn url_to_cache_filename(url: &Url) -> String { +pub(crate) fn url_to_cache_filename(url: &Url) -> String { // Start Rant: // This function mimics behavior from Mamba which itself mimics this behavior from Conda. // However, I find this function absolutely ridiculous, it contains all sort of weird edge @@ -35,7 +29,7 @@ pub fn url_to_cache_filename(url: &Url) -> String { let url_str = url_str.strip_suffix("/repodata.json").unwrap_or(&url_str); // Compute the MD5 hash of the resulting URL string - let hash = extendhash::md5::compute_hash(url_str.as_bytes()); + let hash = rattler_digest::compute_bytes_digest::(url_str); // Convert the hash to an MD5 hash. let mut result = String::with_capacity(8); diff --git a/crates/rattler/src/utils/simple_channel_server.rs b/crates/rattler_repodata_gateway/src/utils/simple_channel_server.rs similarity index 100% rename from crates/rattler/src/utils/simple_channel_server.rs rename to crates/rattler_repodata_gateway/src/utils/simple_channel_server.rs diff --git a/crates/rattler/resources/channels/conda-forge/linux-64/repodata.json b/test-data/channels/conda-forge/linux-64/repodata.json similarity index 100% rename from crates/rattler/resources/channels/conda-forge/linux-64/repodata.json rename to test-data/channels/conda-forge/linux-64/repodata.json diff --git a/crates/rattler/resources/channels/conda-forge/noarch/repodata.json b/test-data/channels/conda-forge/noarch/repodata.json similarity index 100% rename from crates/rattler/resources/channels/conda-forge/noarch/repodata.json rename to test-data/channels/conda-forge/noarch/repodata.json diff --git a/crates/rattler/resources/channels/conda-forge/noarch/repodata.json.gz b/test-data/channels/conda-forge/noarch/repodata.json.gz similarity index 100% rename from crates/rattler/resources/channels/conda-forge/noarch/repodata.json.gz rename to test-data/channels/conda-forge/noarch/repodata.json.gz diff --git a/test-data/channels/empty/noarch/repodata.json b/test-data/channels/empty/noarch/repodata.json new file mode 100644 index 000000000..9325baf77 --- /dev/null +++ b/test-data/channels/empty/noarch/repodata.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6914ea088069b467fb0f1b70c41fbba31bdcdb442471d3d9f927bf90181ef653 +size 198 diff --git a/crates/rattler/resources/channels/empty/noarch/repodata.json.gz b/test-data/channels/empty/noarch/repodata.json.gz similarity index 100% rename from crates/rattler/resources/channels/empty/noarch/repodata.json.gz rename to test-data/channels/empty/noarch/repodata.json.gz