Skip to content

Commit

Permalink
download objects operation (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd authored Aug 26, 2024
1 parent d74f1dd commit 3684958
Show file tree
Hide file tree
Showing 15 changed files with 1,618 additions and 53 deletions.
3 changes: 3 additions & 0 deletions aws-s3-transfer-manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ aws-types = "1.3.3"
bytes = "1"
# FIXME - upgrade to hyper 1.x
hyper = { version = "0.14.29", features = ["client"] }
path-clean = "1.0.1"
tokio = { version = "1.38.0", features = ["rt-multi-thread", "io-util", "sync", "fs", "macros"] }
tracing = "0.1"

Expand All @@ -29,6 +30,8 @@ clap = { version = "4.5.7", default-features = false, features = ["derive", "std
console-subscriber = "0.3.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tempfile = "3.10.1"
fastrand = "2.1.0"
walkdir = "2"

[target.'cfg(not(target_env = "msvc"))'.dev-dependencies]
jemallocator = "0.5.4"
62 changes: 50 additions & 12 deletions aws-s3-transfer-manager/examples/cp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use std::{mem, time};
use aws_s3_transfer_manager::io::InputStream;
use aws_s3_transfer_manager::operation::download::body::Body;
use aws_s3_transfer_manager::types::{ConcurrencySetting, PartSize};
use aws_sdk_s3::config::StalledStreamProtectionConfig;
use aws_sdk_s3::error::DisplayErrorContext;
use aws_types::SdkConfig;
use bytes::Buf;
use clap::{CommandFactory, Parser};
Expand Down Expand Up @@ -102,11 +104,8 @@ struct S3Uri(String);
impl S3Uri {
/// Split the URI into it's component parts '(bucket, key)'
fn parts(&self) -> (&str, &str) {
self.0
.strip_prefix("s3://")
.expect("valid s3 uri prefix")
.split_once('/')
.expect("invalid s3 uri, missing '/' between bucket and key")
let bucket = self.0.strip_prefix("s3://").expect("valid s3 uri prefix");
bucket.split_once('/').unwrap_or((bucket, ""))
}
}

Expand All @@ -116,8 +115,44 @@ fn invalid_arg(message: &str) -> ! {
.exit()
}

async fn do_recursive_download(
args: Args,
tm: aws_s3_transfer_manager::Client,
) -> Result<(), BoxError> {
let (bucket, key_prefix) = args.source.expect_s3().parts();
let dest = args.dest.expect_local();
fs::create_dir_all(dest).await?;

let start = time::Instant::now();
let handle = tm
.download_objects()
.bucket(bucket)
.key_prefix(key_prefix)
.destination(dest)
.send()
.await?;

let output = handle.join().await?;
tracing::info!("download output: {output:?}");

let elapsed = start.elapsed();
let transfer_size_bytes = output.total_bytes_transferred();
let transfer_size_megabytes = transfer_size_bytes as f64 / ONE_MEGABYTE as f64;
let transfer_size_megabits = transfer_size_megabytes * 8f64;

println!(
"downloaded {} objects totalling {transfer_size_bytes} bytes ({transfer_size_megabytes} MB) in {elapsed:?}; Mb/s: {}",
output.objects_downloaded(),
transfer_size_megabits / elapsed.as_secs_f64(),
);
Ok(())
}

async fn do_download(args: Args) -> Result<(), BoxError> {
let config = aws_config::from_env().load().await;
let config = aws_config::from_env()
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.load()
.await;
warmup(&config).await?;

let s3_client = aws_sdk_s3::Client::new(&config);
Expand All @@ -130,12 +165,11 @@ async fn do_download(args: Args) -> Result<(), BoxError> {

let tm = aws_s3_transfer_manager::Client::new(tm_config);

let (bucket, key) = args.source.expect_s3().parts();

if args.recursive {
todo!("implement recursive download")
return do_recursive_download(args, tm).await;
}

let (bucket, key) = args.source.expect_s3().parts();
let dest = fs::File::create(args.dest.expect_local()).await?;
println!("dest file opened, starting download");

Expand Down Expand Up @@ -228,11 +262,15 @@ async fn main() -> Result<(), BoxError> {
}

use TransferUri::*;
match (&args.source, &args.dest) {
(Local(_), S3(_)) => do_upload(args).await?,
let result = match (&args.source, &args.dest) {
(Local(_), S3(_)) => do_upload(args).await,
(Local(_), Local(_)) => invalid_arg("local to local transfer not supported"),
(S3(_), Local(_)) => do_download(args).await?,
(S3(_), Local(_)) => do_download(args).await,
(S3(_), S3(_)) => invalid_arg("s3 to s3 transfer not supported"),
};

if let Err(ref err) = result {
tracing::error!("transfer failed: {}", DisplayErrorContext(err.as_ref()));
}

Ok(())
Expand Down
27 changes: 27 additions & 0 deletions aws-s3-transfer-manager/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::Arc;

/// Types for single object upload operation
pub mod upload;

Expand All @@ -14,3 +16,28 @@ pub mod download_objects;

/// Types for multiple object upload operation
pub mod upload_objects;

/// Container for maintaining context required to carry out a single operation/transfer.
///
/// `State` is whatever additional operation specific state is required for the operation.
#[derive(Debug)]
pub(crate) struct TransferContext<State> {
handle: Arc<crate::client::Handle>,
state: Arc<State>,
}

impl<State> TransferContext<State> {
/// The S3 client to use for SDK operations
pub(crate) fn client(&self) -> &aws_sdk_s3::Client {
self.handle.config.client()
}
}

impl<State> Clone for TransferContext<State> {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
state: self.state.clone(),
}
}
}
2 changes: 1 addition & 1 deletion aws-s3-transfer-manager/src/operation/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Download {
// have the correct metadata w.r.t. content-length and maybe others for the whole object.
object_meta: discovery.meta,
body: Body::new(comp_rx),
_tasks: tasks,
tasks,
};

Ok(handle)
Expand Down
10 changes: 9 additions & 1 deletion aws-s3-transfer-manager/src/operation/download/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct DownloadHandle {
pub body: Body,

/// All child tasks spawned for this download
pub(crate) _tasks: task::JoinSet<()>,
pub(crate) tasks: task::JoinSet<()>,
}

impl DownloadHandle {
Expand All @@ -30,4 +30,12 @@ impl DownloadHandle {
pub fn body(&self) -> &Body {
&self.body
}

/// Consume the handle and wait for download transfer to complete
pub async fn join(mut self) -> Result<(), crate::error::Error> {
while let Some(join_result) = self.tasks.join_next().await {
join_result?;
}
Ok(())
}
}
1 change: 1 addition & 0 deletions aws-s3-transfer-manager/src/operation/download/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pub(super) async fn distribute_work(
chunk_req
);
let chunk_size = chunk_req.size();
// FIXME - restructure like download_objects for tasks to yield errors that can be propagated on join()
tx.send(chunk_req).await.expect("channel open");

seq += 1;
Expand Down
75 changes: 71 additions & 4 deletions aws-s3-transfer-manager/src/operation/download_objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ pub use output::{DownloadObjectsOutput, DownloadObjectsOutputBuilder};
mod handle;
pub use handle::DownloadObjectsHandle;

use std::sync::Arc;
mod list_objects;
mod worker;

use std::path::Path;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use tokio::{fs, task::JoinSet};
use tracing::Instrument;

use crate::{error, types::FailedDownloadTransfer};

use super::TransferContext;

/// Operation struct for downloading multiple objects from Amazon S3
#[derive(Clone, Default, Debug)]
Expand All @@ -25,9 +36,65 @@ pub(crate) struct DownloadObjects;
impl DownloadObjects {
/// Execute a single `DownloadObjects` transfer operation
pub(crate) async fn orchestrate(
_handle: Arc<crate::client::Handle>,
_input: crate::operation::download_objects::DownloadObjectsInput,
handle: Arc<crate::client::Handle>,
input: crate::operation::download_objects::DownloadObjectsInput,
) -> Result<DownloadObjectsHandle, crate::error::Error> {
unimplemented!()
// validate existence of destination and return error if it's not a directory
let destination = input.destination().expect("destination set");
validate_destination(destination).await?;

let concurrency = handle.num_workers();
let ctx = DownloadObjectsContext::new(handle.clone(), input);

// spawn all work into the same JoinSet such that when the set is dropped all tasks are cancelled.
let mut tasks = JoinSet::new();
let (work_tx, work_rx) = async_channel::bounded(concurrency);

// spawn worker to discover/distribute work
tasks.spawn(worker::discover_objects(ctx.clone(), work_tx));

for i in 0..concurrency {
let worker = worker::download_objects(ctx.clone(), work_rx.clone())
.instrument(tracing::debug_span!("object-downloader", worker = i));
tasks.spawn(worker);
}

let handle = DownloadObjectsHandle { tasks, ctx };
Ok(handle)
}
}

async fn validate_destination(path: &Path) -> Result<(), error::Error> {
let meta = fs::metadata(path).await?;

if !meta.is_dir() {
return Err(error::invalid_input(format!(
"destination is not a directory: {path:?}"
)));
}

Ok(())
}

/// DownloadObjects operation specific state
#[derive(Debug)]
pub(crate) struct DownloadObjectsState {
input: DownloadObjectsInput,
failed_downloads: Mutex<Option<Vec<FailedDownloadTransfer>>>,
successful_downloads: AtomicU64,
total_bytes_transferred: AtomicU64,
}

type DownloadObjectsContext = TransferContext<DownloadObjectsState>;

impl DownloadObjectsContext {
fn new(handle: Arc<crate::client::Handle>, input: DownloadObjectsInput) -> Self {
let state = Arc::new(DownloadObjectsState {
input,
failed_downloads: Mutex::new(None),
successful_downloads: AtomicU64::default(),
total_bytes_transferred: AtomicU64::default(),
});
TransferContext { handle, state }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ impl DownloadObjectsFluentBuilder {
}

/// Filter unwanted S3 objects from being downloaded as part of the transfer.
pub fn filter(mut self, input: impl Fn(&aws_sdk_s3::types::Object) -> bool + 'static) -> Self {
pub fn filter(
mut self,
input: impl Fn(&aws_sdk_s3::types::Object) -> bool + Send + Sync + 'static,
) -> Self {
self.inner = self.inner.filter(input);
self
}
Expand Down
31 changes: 27 additions & 4 deletions aws-s3-transfer-manager/src/operation/download_objects/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,44 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::atomic::Ordering;

use tokio::task;

use super::DownloadObjectsOutput;
use super::{DownloadObjectsContext, DownloadObjectsOutput};

/// Handle for `DownloadObjects` transfer operation
#[derive(Debug)]
#[non_exhaustive]
pub struct DownloadObjectsHandle {
/// All child tasks spawned for this download
pub(crate) _tasks: task::JoinSet<()>,
pub(crate) tasks: task::JoinSet<Result<(), crate::error::Error>>,
/// The context used to drive an upload to completion
pub(crate) ctx: DownloadObjectsContext,
}

impl DownloadObjectsHandle {
/// Consume the handle and wait for download transfer to complete
pub async fn join(self) -> Result<DownloadObjectsOutput, crate::error::Error> {
unimplemented!()
pub async fn join(mut self) -> Result<DownloadObjectsOutput, crate::error::Error> {
// join all tasks
while let Some(join_result) = self.tasks.join_next().await {
join_result??;
}

let failed_downloads = self.ctx.state.failed_downloads.lock().unwrap().take();
let successful_downloads = self.ctx.state.successful_downloads.load(Ordering::SeqCst);
let total_bytes_transferred = self
.ctx
.state
.total_bytes_transferred
.load(Ordering::SeqCst);

let output = DownloadObjectsOutput::builder()
.objects_downloaded(successful_downloads)
.set_failed_transfers(failed_downloads)
.total_bytes_transferred(total_bytes_transferred)
.build();

Ok(output)
}
}
Loading

0 comments on commit 3684958

Please sign in to comment.