Skip to content

Commit

Permalink
Add BufWriter for Adapative Put / Multipart Upload (apache#5431)
Browse files Browse the repository at this point in the history
* Add BufWriter

* Review feedback
  • Loading branch information
tustvold authored Feb 27, 2024
1 parent 37cf8a6 commit ef5c45c
Showing 1 changed file with 161 additions and 2 deletions.
163 changes: 161 additions & 2 deletions object_store/src/buffered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Utilities for performing tokio-style buffered IO
use crate::path::Path;
use crate::{ObjectMeta, ObjectStore};
use crate::{MultipartId, ObjectMeta, ObjectStore};
use bytes::Bytes;
use futures::future::{BoxFuture, FutureExt};
use futures::ready;
Expand All @@ -27,7 +27,7 @@ use std::io::{Error, ErrorKind, SeekFrom};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, ReadBuf};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, AsyncWriteExt, ReadBuf};

/// The default buffer size used by [`BufReader`]
pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
Expand Down Expand Up @@ -205,6 +205,138 @@ impl AsyncBufRead for BufReader {
}
}

/// An async buffered writer compatible with the tokio IO traits
///
/// Up to `capacity` bytes will be buffered in memory, and flushed on shutdown
/// using [`ObjectStore::put`]. If `capacity` is exceeded, data will instead be
/// streamed using [`ObjectStore::put_multipart`]
pub struct BufWriter {
capacity: usize,
state: BufWriterState,
multipart_id: Option<MultipartId>,
store: Arc<dyn ObjectStore>,
}

impl std::fmt::Debug for BufWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BufWriter")
.field("capacity", &self.capacity)
.field("multipart_id", &self.multipart_id)
.finish()
}
}

type MultipartResult = (MultipartId, Box<dyn AsyncWrite + Send + Unpin>);

enum BufWriterState {
/// Buffer up to capacity bytes
Buffer(Path, Vec<u8>),
/// [`ObjectStore::put_multipart`]
Prepare(BoxFuture<'static, std::io::Result<MultipartResult>>),
/// Write to a multipart upload
Write(Box<dyn AsyncWrite + Send + Unpin>),
/// [`ObjectStore::put`]
Put(BoxFuture<'static, std::io::Result<()>>),
}

impl BufWriter {
/// Create a new [`BufWriter`] from the provided [`ObjectStore`] and [`Path`]
pub fn new(store: Arc<dyn ObjectStore>, path: Path) -> Self {
Self::with_capacity(store, path, 10 * 1024 * 1024)
}

/// Create a new [`BufWriter`] from the provided [`ObjectStore`], [`Path`] and `capacity`
pub fn with_capacity(store: Arc<dyn ObjectStore>, path: Path, capacity: usize) -> Self {
Self {
capacity,
store,
state: BufWriterState::Buffer(path, Vec::new()),
multipart_id: None,
}
}

/// Returns the [`MultipartId`] if multipart upload
pub fn multipart_id(&self) -> Option<&MultipartId> {
self.multipart_id.as_ref()
}
}

impl AsyncWrite for BufWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
let cap = self.capacity;
loop {
return match &mut self.state {
BufWriterState::Write(write) => Pin::new(write).poll_write(cx, buf),
BufWriterState::Put(_) => panic!("Already shut down"),
BufWriterState::Prepare(f) => {
let (id, w) = ready!(f.poll_unpin(cx)?);
self.state = BufWriterState::Write(w);
self.multipart_id = Some(id);
continue;
}
BufWriterState::Buffer(path, b) => {
if b.len().saturating_add(buf.len()) >= cap {
let buffer = std::mem::take(b);
let path = std::mem::take(path);
let store = Arc::clone(&self.store);
self.state = BufWriterState::Prepare(Box::pin(async move {
let (id, mut writer) = store.put_multipart(&path).await?;
writer.write_all(&buffer).await?;
Ok((id, writer))
}));
continue;
}
b.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
};
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
loop {
return match &mut self.state {
BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())),
BufWriterState::Write(write) => Pin::new(write).poll_flush(cx),
BufWriterState::Put(_) => panic!("Already shut down"),
BufWriterState::Prepare(f) => {
let (id, w) = ready!(f.poll_unpin(cx)?);
self.state = BufWriterState::Write(w);
self.multipart_id = Some(id);
continue;
}
};
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
loop {
match &mut self.state {
BufWriterState::Prepare(f) => {
let (id, w) = ready!(f.poll_unpin(cx)?);
self.state = BufWriterState::Write(w);
self.multipart_id = Some(id);
}
BufWriterState::Buffer(p, b) => {
let buf = std::mem::take(b);
let path = std::mem::take(p);
let store = Arc::clone(&self.store);
self.state = BufWriterState::Put(Box::pin(async move {
store.put(&path, buf.into()).await?;
Ok(())
}));
}
BufWriterState::Put(f) => return f.poll_unpin(cx),
BufWriterState::Write(w) => return Pin::new(w).poll_shutdown(cx),
}
}
}
}

/// Port of standardised function as requires Rust 1.66
///
/// <https://github.com/rust-lang/rust/pull/87601/files#diff-b9390ee807a1dae3c3128dce36df56748ad8d23c6e361c0ebba4d744bf6efdb9R1533>
Expand Down Expand Up @@ -300,4 +432,31 @@ mod tests {
assert!(buffer.is_empty());
}
}

#[tokio::test]
async fn test_buf_writer() {
let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
let path = Path::from("file.txt");

// Test put
let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30);
writer.write_all(&[0; 20]).await.unwrap();
writer.flush().await.unwrap();
writer.write_all(&[0; 5]).await.unwrap();
assert!(writer.multipart_id().is_none());
writer.shutdown().await.unwrap();
assert!(writer.multipart_id().is_none());
assert_eq!(store.head(&path).await.unwrap().size, 25);

// Test multipart
let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30);
writer.write_all(&[0; 20]).await.unwrap();
writer.flush().await.unwrap();
writer.write_all(&[0; 20]).await.unwrap();
assert!(writer.multipart_id().is_some());
writer.shutdown().await.unwrap();
assert!(writer.multipart_id().is_some());

assert_eq!(store.head(&path).await.unwrap().size, 40);
}
}

0 comments on commit ef5c45c

Please sign in to comment.