diff --git a/crates/rattler_installs_packages/src/artifacts/snapshots/rattler_installs_packages__artifacts__sdist__tests__check_direct_url_json_with_commit_for_remote_git.snap.new b/crates/rattler_installs_packages/src/artifacts/snapshots/rattler_installs_packages__artifacts__sdist__tests__check_direct_url_json_with_commit_for_remote_git.snap.new deleted file mode 100644 index 6e8cd12d..00000000 --- a/crates/rattler_installs_packages/src/artifacts/snapshots/rattler_installs_packages__artifacts__sdist__tests__check_direct_url_json_with_commit_for_remote_git.snap.new +++ /dev/null @@ -1,31 +0,0 @@ ---- -source: crates/rattler_installs_packages/src/artifacts/sdist.rs -assertion_line: 1214 -expression: direct_url_json ---- -Some( - DirectUrlJson { - url: Url { - scheme: "https", - cannot_be_a_base: false, - username: "", - password: None, - host: Some( - Domain( - "github.com", - ), - ), - port: None, - path: "/mahmoud/boltons.git", - query: None, - fragment: None, - }, - source: Vcs { - vcs: Git, - requested_revision: Some( - "47c8046492d4db49f163bb977d20d5942e4ddb25", - ), - commit_id: "47c8046492d4db49f163bb977d20d5942e4ddb25", - }, - }, -) diff --git a/crates/rattler_installs_packages/src/index/file_store.rs b/crates/rattler_installs_packages/src/index/file_store.rs index 8ff8c5b0..63b098b0 100644 --- a/crates/rattler_installs_packages/src/index/file_store.rs +++ b/crates/rattler_installs_packages/src/index/file_store.rs @@ -148,6 +148,7 @@ impl FileStore { /// /// Internally the writer writes to a temporary file that is persisted to the final location to /// ensure that the final path is never corrupted. +#[derive(Debug)] pub struct LockedWriter<'a> { path: &'a Path, f: tempfile::NamedTempFile, diff --git a/crates/rattler_installs_packages/src/index/http.rs b/crates/rattler_installs_packages/src/index/http.rs index 80c37c2a..0f0ea053 100644 --- a/crates/rattler_installs_packages/src/index/http.rs +++ b/crates/rattler_installs_packages/src/index/http.rs @@ -11,6 +11,8 @@ use reqwest::{header::HeaderMap, Method}; use reqwest_middleware::ClientWithMiddleware; use serde::{Deserialize, Serialize}; use std::io; +use std::io::BufReader; +use std::io::BufWriter; use std::io::{Read, Seek, SeekFrom, Write}; use std::str::FromStr; use std::sync::Arc; @@ -19,6 +21,9 @@ use thiserror::Error; use tokio_util::compat::FuturesAsyncReadCompatExt; use url::Url; +const CURRENT_VERSION: u8 = 1; +const CACHE_BOM: &str = "RIP"; + // Attached to HTTP responses, to make testing easier #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum CacheStatus { @@ -104,10 +109,9 @@ impl Http { let key = key_for_request(&url, method, &headers); let lock = self.http_cache.lock(&key.as_slice()).await?; - if let Some((old_policy, final_url, old_body)) = lock - .reader() - .and_then(|reader| read_cache(reader.detach_unlocked()).ok()) - { + if let Some((old_policy, final_url, old_body)) = lock.reader().and_then(|reader| { + read_cache(reader.detach_unlocked(), CACHE_BOM, CURRENT_VERSION).ok() + }) { match old_policy.before_request(&request, SystemTime::now()) { BeforeRequest::Fresh(parts) => { tracing::debug!(url=%url, "is fresh"); @@ -138,12 +142,11 @@ impl Http { // Determine what to do based on the response headers. match old_policy.after_response(&request, &response, SystemTime::now()) { - AfterResponse::NotModified(new_policy, new_parts) => { + AfterResponse::NotModified(_, new_parts) => { tracing::debug!(url=%url, "stale, but not modified"); - let new_body = fill_cache(&new_policy, &final_url, old_body, lock)?; Ok(make_response( new_parts, - StreamingOrLocal::Local(Box::new(new_body)), + StreamingOrLocal::Local(Box::new(old_body)), CacheStatus::StaleButValidated, final_url, )) @@ -189,7 +192,6 @@ impl Http { let new_policy = CachePolicy::new(&request, &response); let (parts, body) = response.into_parts(); - let new_body = if new_policy.is_storable() { let new_body = fill_cache_async(&new_policy, &final_url, body, lock).await?; StreamingOrLocal::Local(Box::new(new_body)) @@ -248,17 +250,27 @@ fn key_for_request(url: &Url, method: Method, headers: &HeaderMap) -> Vec { } /// Read a HTTP cached value from a readable stream. -fn read_cache(mut f: R) -> std::io::Result<(CachePolicy, Url, impl ReadAndSeek)> +fn read_cache( + mut f: R, + bom_key: &str, + version: u8, +) -> std::io::Result<(CachePolicy, Url, impl ReadAndSeek)> where R: Read + Seek, { - let data: CacheData = ciborium::de::from_reader(&mut f) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let mut buff_reader = BufReader::new(&mut f); + verify_cache_bom_and_version(&mut buff_reader, bom_key, version)?; - let start = f.stream_position()?; + let mut struct_size_buffer = [0; 8]; + buff_reader.read_exact(&mut struct_size_buffer).unwrap(); + + let data: CacheData = ciborium::de::from_reader(buff_reader).unwrap(); + let start = u64::from_le_bytes(struct_size_buffer); let end = f.seek(SeekFrom::End(0))?; + let mut body = SeekSlice::new(f, start, end)?; body.rewind()?; + Ok((data.policy, data.url, body)) } @@ -268,28 +280,44 @@ struct CacheData { url: Url, } -/// Fill the cache with the -fn fill_cache( - policy: &CachePolicy, - url: &Url, - mut body: R, - handle: FileLock, -) -> Result { - let mut cache_writer = handle.begin()?; - ciborium::ser::into_writer( - &CacheData { - policy: policy.clone(), - url: url.clone(), - }, - &mut cache_writer, - ) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - let body_start = cache_writer.stream_position()?; - std::io::copy(&mut body, &mut cache_writer)?; - drop(body); - let body_end = cache_writer.stream_position()?; - let cache_entry = cache_writer.commit()?.detach_unlocked(); - SeekSlice::new(cache_entry, body_start, body_end) +/// Write cache BOM and metadata and return it's current position after writing +/// BOM and metadata of cache is represented by: +/// [BOM]--[VERSION]--[SIZE_OF_HEADERS_STRUCT] +fn write_cache_bom_and_metadata( + writer: &mut W, + bom_key: &str, + version: u8, +) -> Result { + writer.write_all(bom_key.as_bytes())?; + writer.write_all(&[version])?; + writer.stream_position() +} + +/// Verify that cache BOM and metadata is the same and up-to-date +fn verify_cache_bom_and_version( + reader: &mut R, + bom_key: &str, + version: u8, +) -> Result<(), std::io::Error> { + // Read and verify the byte order mark and version + let mut bom_and_version = [0u8; 4]; + reader.read_exact(&mut bom_and_version)?; + + if &bom_and_version[0..3] != bom_key.as_bytes() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid byte order mark", + )); + } + + if bom_and_version[3] != version { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Incombatible version", + )); + } + + Ok(()) } /// Fill the cache with the @@ -299,27 +327,58 @@ async fn fill_cache_async( mut body: impl Stream> + Send + Unpin, handle: FileLock, ) -> Result { - let mut cache_writer = handle.begin()?; + let cache_writer = handle.begin()?; + let mut buf_cache_writer = BufWriter::new(cache_writer); + + let bom_written_position = + write_cache_bom_and_metadata(&mut buf_cache_writer, CACHE_BOM, CURRENT_VERSION).unwrap(); + + // We need to save the struct size because we keep cache: + // headers_struct + body + // + // When reading using `BufReader` and serializing using `ciborium`, + // we don't know what was the final position of the struct and we + // can't slice and return only the body, because we do not know where to start. + // To overcome this, we record struct size at the start of cache, together with BOM + // which we later will use to seek at it and return the body. + // Example of stored cache: + // [BOM][VERSION][HEADERS_STRUCT_SIZE][HEADERS][BODY] + + let struct_size = [0; 8]; + buf_cache_writer.write_all(&struct_size).unwrap(); + ciborium::ser::into_writer( &CacheData { policy: policy.clone(), url: url.clone(), }, - &mut cache_writer, + &mut buf_cache_writer, ) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - let body_start = cache_writer.stream_position()?; + let body_start = buf_cache_writer.stream_position()?; + + buf_cache_writer + .seek(SeekFrom::Start(bom_written_position)) + .unwrap(); + + let body_le_bytes = body_start.to_le_bytes(); + buf_cache_writer + .write_all(body_le_bytes.as_slice()) + .unwrap(); + + buf_cache_writer.seek(SeekFrom::Start(body_start)).unwrap(); while let Some(bytes) = body.next().await { - cache_writer.write_all( + buf_cache_writer.write_all( bytes .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))? .as_ref(), )?; } - let body_end = cache_writer.stream_position()?; - let cache_entry = cache_writer.commit()?.detach_unlocked(); + let body_end = buf_cache_writer.stream_position()?; + let cache_entry = buf_cache_writer.into_inner()?.commit()?.detach_unlocked(); + SeekSlice::new(cache_entry, body_start, body_end) } @@ -370,3 +429,80 @@ fn body_to_streaming_or_local( .compat(), )) } + +#[cfg(test)] +mod tests { + use crate::index::{ + file_store::FileStore, + http::{write_cache_bom_and_metadata, CACHE_BOM, CURRENT_VERSION}, + }; + use http::{header::CACHE_CONTROL, HeaderMap, HeaderValue, Method}; + use reqwest::Client; + use reqwest_middleware::ClientWithMiddleware; + + use std::{fs, io::BufWriter, sync::Arc}; + use tempfile::TempDir; + + use super::{key_for_request, read_cache, CacheMode, Http}; + + fn get_http_client() -> (Arc, TempDir) { + let tempdir = tempfile::tempdir().unwrap(); + let client = ClientWithMiddleware::from(Client::new()); + + let tmp = tempdir.path().join("http"); + fs::create_dir_all(tmp).unwrap(); + + let http = Http::new( + client, + FileStore::new(&tempdir.path().join("http")).unwrap(), + ); + + (Arc::new(http), tempdir) + } + + #[tokio::test(flavor = "multi_thread")] + pub async fn test_cache_is_correct_written_and_read_when_requesting_pypi_boltons() { + let url = url::Url::parse("https://pypi.org/simple/boltons").unwrap(); + + let url_clone = url.clone(); + + let (client_arc, _tmpdir) = get_http_client(); + + let mut headers = HeaderMap::new(); + headers.insert(CACHE_CONTROL, HeaderValue::from_static("max-age=0")); + + // let's make a request and validate that the cache is saved + client_arc + .request(url, Method::GET, headers.clone(), CacheMode::Default) + .await + .unwrap(); + + let key = key_for_request(&url_clone, Method::GET, &headers); + { + let lock = client_arc.http_cache.lock(&key.as_slice()).await.unwrap(); + + let res = lock.reader().and_then(|reader| { + read_cache(reader.detach_unlocked(), CACHE_BOM, CURRENT_VERSION).ok() + }); + + assert!(res.is_some()); + } + + let lock = client_arc.http_cache.lock(&key.as_slice()).await.unwrap(); + + let mut buf_writer = BufWriter::new(lock.begin().unwrap()); + + // let's "corrupt" cache and change it's version metadata predenting that it's older or different cache + let stream_position = + write_cache_bom_and_metadata(&mut buf_writer, CACHE_BOM, CURRENT_VERSION + 1).unwrap(); + + assert!(stream_position > 0); + + let new_reader = buf_writer.into_inner().unwrap().commit().unwrap(); + + // read_cache should return Err because we expect that BOM differ + let read_again = read_cache(new_reader, CACHE_BOM, CURRENT_VERSION); + + assert!(read_again.is_err()); + } +}