Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add optimisation for http cache #216

Merged
merged 22 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

1 change: 1 addition & 0 deletions crates/rattler_installs_packages/src/index/file_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl FileStore {
///
/// Internally the writer writes to a temporary file that is persisted to the final location to
/// ensure that the final path is never corrupted.
#[derive(Debug)]
pub struct LockedWriter<'a> {
path: &'a Path,
f: tempfile::NamedTempFile,
Expand Down
216 changes: 176 additions & 40 deletions crates/rattler_installs_packages/src/index/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use reqwest::{header::HeaderMap, Method};
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use std::io;
use std::io::BufReader;
use std::io::BufWriter;
use std::io::{Read, Seek, SeekFrom, Write};
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -19,6 +21,9 @@ use thiserror::Error;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url;

const CURRENT_VERSION: u8 = 1;
const CACHE_BOM: &str = "RIP";

// Attached to HTTP responses, to make testing easier
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum CacheStatus {
Expand Down Expand Up @@ -104,10 +109,9 @@ impl Http {
let key = key_for_request(&url, method, &headers);
let lock = self.http_cache.lock(&key.as_slice()).await?;

if let Some((old_policy, final_url, old_body)) = lock
.reader()
.and_then(|reader| read_cache(reader.detach_unlocked()).ok())
{
if let Some((old_policy, final_url, old_body)) = lock.reader().and_then(|reader| {
read_cache(reader.detach_unlocked(), CACHE_BOM, CURRENT_VERSION).ok()
}) {
match old_policy.before_request(&request, SystemTime::now()) {
BeforeRequest::Fresh(parts) => {
tracing::debug!(url=%url, "is fresh");
Expand Down Expand Up @@ -138,12 +142,11 @@ impl Http {

// Determine what to do based on the response headers.
match old_policy.after_response(&request, &response, SystemTime::now()) {
AfterResponse::NotModified(new_policy, new_parts) => {
AfterResponse::NotModified(_, new_parts) => {
tracing::debug!(url=%url, "stale, but not modified");
let new_body = fill_cache(&new_policy, &final_url, old_body, lock)?;
Ok(make_response(
new_parts,
StreamingOrLocal::Local(Box::new(new_body)),
StreamingOrLocal::Local(Box::new(old_body)),
CacheStatus::StaleButValidated,
final_url,
))
Expand Down Expand Up @@ -189,7 +192,6 @@ impl Http {

let new_policy = CachePolicy::new(&request, &response);
let (parts, body) = response.into_parts();

let new_body = if new_policy.is_storable() {
let new_body = fill_cache_async(&new_policy, &final_url, body, lock).await?;
StreamingOrLocal::Local(Box::new(new_body))
Expand Down Expand Up @@ -248,17 +250,27 @@ fn key_for_request(url: &Url, method: Method, headers: &HeaderMap) -> Vec<u8> {
}

/// Read a HTTP cached value from a readable stream.
fn read_cache<R>(mut f: R) -> std::io::Result<(CachePolicy, Url, impl ReadAndSeek)>
fn read_cache<R>(
mut f: R,
bom_key: &str,
version: u8,
) -> std::io::Result<(CachePolicy, Url, impl ReadAndSeek)>
where
R: Read + Seek,
{
let data: CacheData = ciborium::de::from_reader(&mut f)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let mut buff_reader = BufReader::new(&mut f);
verify_cache_bom_and_version(&mut buff_reader, bom_key, version)?;

let start = f.stream_position()?;
let mut struct_size_buffer = [0; 8];
buff_reader.read_exact(&mut struct_size_buffer).unwrap();

let data: CacheData = ciborium::de::from_reader(buff_reader).unwrap();
let start = u64::from_le_bytes(struct_size_buffer);
let end = f.seek(SeekFrom::End(0))?;

let mut body = SeekSlice::new(f, start, end)?;
body.rewind()?;

Ok((data.policy, data.url, body))
}

Expand All @@ -268,28 +280,44 @@ struct CacheData {
url: Url,
}

/// Fill the cache with the
fn fill_cache<R: Read>(
policy: &CachePolicy,
url: &Url,
mut body: R,
handle: FileLock,
) -> Result<impl Read + Seek, std::io::Error> {
let mut cache_writer = handle.begin()?;
ciborium::ser::into_writer(
&CacheData {
policy: policy.clone(),
url: url.clone(),
},
&mut cache_writer,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let body_start = cache_writer.stream_position()?;
std::io::copy(&mut body, &mut cache_writer)?;
drop(body);
let body_end = cache_writer.stream_position()?;
let cache_entry = cache_writer.commit()?.detach_unlocked();
SeekSlice::new(cache_entry, body_start, body_end)
/// Write cache BOM and metadata and return it's current position after writing
/// BOM and metadata of cache is represented by:
/// [BOM]--[VERSION]--[SIZE_OF_HEADERS_STRUCT]
fn write_cache_bom_and_metadata<W: Write + Seek>(
writer: &mut W,
bom_key: &str,
version: u8,
) -> Result<u64, std::io::Error> {
writer.write_all(bom_key.as_bytes())?;
writer.write_all(&[version])?;
writer.stream_position()
}

/// Verify that cache BOM and metadata is the same and up-to-date
fn verify_cache_bom_and_version<R: Read + Seek>(
reader: &mut R,
bom_key: &str,
version: u8,
) -> Result<(), std::io::Error> {
// Read and verify the byte order mark and version
let mut bom_and_version = [0u8; 4];
reader.read_exact(&mut bom_and_version)?;

if &bom_and_version[0..3] != bom_key.as_bytes() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid byte order mark",
));
}

if bom_and_version[3] != version {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Incombatible version",
));
}

Ok(())
}

/// Fill the cache with the
Expand All @@ -299,27 +327,58 @@ async fn fill_cache_async(
mut body: impl Stream<Item = reqwest::Result<Bytes>> + Send + Unpin,
handle: FileLock,
) -> Result<impl Read + Seek, std::io::Error> {
let mut cache_writer = handle.begin()?;
let cache_writer = handle.begin()?;
let mut buf_cache_writer = BufWriter::new(cache_writer);

let bom_written_position =
write_cache_bom_and_metadata(&mut buf_cache_writer, CACHE_BOM, CURRENT_VERSION).unwrap();

// We need to save the struct size because we keep cache:
// headers_struct + body
//
// When reading using `BufReader` and serializing using `ciborium`,
// we don't know what was the final position of the struct and we
// can't slice and return only the body, because we do not know where to start.
// To overcome this, we record struct size at the start of cache, together with BOM
// which we later will use to seek at it and return the body.
// Example of stored cache:
// [BOM][VERSION][HEADERS_STRUCT_SIZE][HEADERS][BODY]

let struct_size = [0; 8];
buf_cache_writer.write_all(&struct_size).unwrap();

ciborium::ser::into_writer(
&CacheData {
policy: policy.clone(),
url: url.clone(),
},
&mut cache_writer,
&mut buf_cache_writer,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let body_start = cache_writer.stream_position()?;
let body_start = buf_cache_writer.stream_position()?;

buf_cache_writer
.seek(SeekFrom::Start(bom_written_position))
.unwrap();

let body_le_bytes = body_start.to_le_bytes();
buf_cache_writer
.write_all(body_le_bytes.as_slice())
.unwrap();

buf_cache_writer.seek(SeekFrom::Start(body_start)).unwrap();

while let Some(bytes) = body.next().await {
cache_writer.write_all(
buf_cache_writer.write_all(
bytes
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
.as_ref(),
)?;
}

let body_end = cache_writer.stream_position()?;
let cache_entry = cache_writer.commit()?.detach_unlocked();
let body_end = buf_cache_writer.stream_position()?;
let cache_entry = buf_cache_writer.into_inner()?.commit()?.detach_unlocked();

SeekSlice::new(cache_entry, body_start, body_end)
}

Expand Down Expand Up @@ -370,3 +429,80 @@ fn body_to_streaming_or_local(
.compat(),
))
}

#[cfg(test)]
mod tests {
use crate::index::{
file_store::FileStore,
http::{write_cache_bom_and_metadata, CACHE_BOM, CURRENT_VERSION},
};
use http::{header::CACHE_CONTROL, HeaderMap, HeaderValue, Method};
use reqwest::Client;
use reqwest_middleware::ClientWithMiddleware;

use std::{fs, io::BufWriter, sync::Arc};
use tempfile::TempDir;

use super::{key_for_request, read_cache, CacheMode, Http};

fn get_http_client() -> (Arc<Http>, TempDir) {
let tempdir = tempfile::tempdir().unwrap();
let client = ClientWithMiddleware::from(Client::new());

let tmp = tempdir.path().join("http");
fs::create_dir_all(tmp).unwrap();

let http = Http::new(
client,
FileStore::new(&tempdir.path().join("http")).unwrap(),
);

(Arc::new(http), tempdir)
}

#[tokio::test(flavor = "multi_thread")]
pub async fn test_cache_is_correct_written_and_read_when_requesting_pypi_boltons() {
let url = url::Url::parse("https://pypi.org/simple/boltons").unwrap();

let url_clone = url.clone();

let (client_arc, _tmpdir) = get_http_client();

let mut headers = HeaderMap::new();
headers.insert(CACHE_CONTROL, HeaderValue::from_static("max-age=0"));

// let's make a request and validate that the cache is saved
client_arc
.request(url, Method::GET, headers.clone(), CacheMode::Default)
.await
.unwrap();

let key = key_for_request(&url_clone, Method::GET, &headers);
{
let lock = client_arc.http_cache.lock(&key.as_slice()).await.unwrap();

let res = lock.reader().and_then(|reader| {
read_cache(reader.detach_unlocked(), CACHE_BOM, CURRENT_VERSION).ok()
});

assert!(res.is_some());
}

let lock = client_arc.http_cache.lock(&key.as_slice()).await.unwrap();

let mut buf_writer = BufWriter::new(lock.begin().unwrap());

// let's "corrupt" cache and change it's version metadata predenting that it's older or different cache
let stream_position =
write_cache_bom_and_metadata(&mut buf_writer, CACHE_BOM, CURRENT_VERSION + 1).unwrap();

assert!(stream_position > 0);

let new_reader = buf_writer.into_inner().unwrap().commit().unwrap();

// read_cache should return Err because we expect that BOM differ
let read_again = read_cache(new_reader, CACHE_BOM, CURRENT_VERSION);

assert!(read_again.is_err());
}
}