Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

download objects operation #45

Merged
merged 16 commits into from
Aug 26, 2024
3 changes: 3 additions & 0 deletions aws-s3-transfer-manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ aws-types = "1.3.3"
bytes = "1"
# FIXME - upgrade to hyper 1.x
hyper = { version = "0.14.29", features = ["client"] }
path-clean = "1.0.1"
tokio = { version = "1.38.0", features = ["rt-multi-thread", "io-util", "sync", "fs", "macros"] }
tracing = "0.1"

Expand All @@ -29,6 +30,8 @@ clap = { version = "4.5.7", default-features = false, features = ["derive", "std
console-subscriber = "0.3.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tempfile = "3.10.1"
fastrand = "2.1.0"
walkdir = "2"

[target.'cfg(not(target_env = "msvc"))'.dev-dependencies]
jemallocator = "0.5.4"
62 changes: 50 additions & 12 deletions aws-s3-transfer-manager/examples/cp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use std::{mem, time};
use aws_s3_transfer_manager::io::InputStream;
use aws_s3_transfer_manager::operation::download::body::Body;
use aws_s3_transfer_manager::types::{ConcurrencySetting, PartSize};
use aws_sdk_s3::config::StalledStreamProtectionConfig;
use aws_sdk_s3::error::DisplayErrorContext;
use aws_types::SdkConfig;
use bytes::Buf;
use clap::{CommandFactory, Parser};
Expand Down Expand Up @@ -102,11 +104,8 @@ struct S3Uri(String);
impl S3Uri {
/// Split the URI into it's component parts '(bucket, key)'
fn parts(&self) -> (&str, &str) {
self.0
.strip_prefix("s3://")
.expect("valid s3 uri prefix")
.split_once('/')
.expect("invalid s3 uri, missing '/' between bucket and key")
let bucket = self.0.strip_prefix("s3://").expect("valid s3 uri prefix");
bucket.split_once('/').unwrap_or((bucket, ""))
}
}

Expand All @@ -116,8 +115,44 @@ fn invalid_arg(message: &str) -> ! {
.exit()
}

async fn do_recursive_download(
args: Args,
tm: aws_s3_transfer_manager::Client,
) -> Result<(), BoxError> {
let (bucket, key_prefix) = args.source.expect_s3().parts();
let dest = args.dest.expect_local();
fs::create_dir_all(dest).await?;

let start = time::Instant::now();
let handle = tm
.download_objects()
.bucket(bucket)
.key_prefix(key_prefix)
.destination(dest)
.send()
.await?;

let output = handle.join().await?;
tracing::info!("download output: {output:?}");

let elapsed = start.elapsed();
let transfer_size_bytes = output.total_bytes_transferred();
let transfer_size_megabytes = transfer_size_bytes as f64 / ONE_MEGABYTE as f64;
let transfer_size_megabits = transfer_size_megabytes * 8f64;

println!(
"downloaded {} objects totalling {transfer_size_bytes} bytes ({transfer_size_megabytes} MB) in {elapsed:?}; Mb/s: {}",
output.objects_downloaded(),
transfer_size_megabits / elapsed.as_secs_f64(),
);
Ok(())
}

async fn do_download(args: Args) -> Result<(), BoxError> {
let config = aws_config::from_env().load().await;
let config = aws_config::from_env()
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.load()
.await;
Velfi marked this conversation as resolved.
Show resolved Hide resolved
warmup(&config).await?;

let s3_client = aws_sdk_s3::Client::new(&config);
Expand All @@ -130,12 +165,11 @@ async fn do_download(args: Args) -> Result<(), BoxError> {

let tm = aws_s3_transfer_manager::Client::new(tm_config);

let (bucket, key) = args.source.expect_s3().parts();

if args.recursive {
todo!("implement recursive download")
return do_recursive_download(args, tm).await;
}

let (bucket, key) = args.source.expect_s3().parts();
let dest = fs::File::create(args.dest.expect_local()).await?;
println!("dest file opened, starting download");

Expand Down Expand Up @@ -228,11 +262,15 @@ async fn main() -> Result<(), BoxError> {
}

use TransferUri::*;
match (&args.source, &args.dest) {
(Local(_), S3(_)) => do_upload(args).await?,
let result = match (&args.source, &args.dest) {
(Local(_), S3(_)) => do_upload(args).await,
(Local(_), Local(_)) => invalid_arg("local to local transfer not supported"),
(S3(_), Local(_)) => do_download(args).await?,
(S3(_), Local(_)) => do_download(args).await,
(S3(_), S3(_)) => invalid_arg("s3 to s3 transfer not supported"),
};
Velfi marked this conversation as resolved.
Show resolved Hide resolved

if let Err(ref err) = result {
tracing::error!("transfer failed: {}", DisplayErrorContext(err.as_ref()));
}

Ok(())
Expand Down
27 changes: 27 additions & 0 deletions aws-s3-transfer-manager/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::Arc;

/// Types for single object upload operation
pub mod upload;

Expand All @@ -11,3 +13,28 @@ pub mod download;

/// Types for multiple object download operation
pub mod download_objects;

/// Container for maintaining context required to carry out a single operation/transfer.
///
/// `State` is whatever additional operation specific state is required for the operation.
#[derive(Debug)]
pub(crate) struct TransferContext<State> {
handle: Arc<crate::client::Handle>,
state: Arc<State>,
}

impl<State> TransferContext<State> {
/// The S3 client to use for SDK operations
pub(crate) fn client(&self) -> &aws_sdk_s3::Client {
self.handle.config.client()
}
}

impl<State> Clone for TransferContext<State> {
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
state: self.state.clone(),
}
}
}
2 changes: 1 addition & 1 deletion aws-s3-transfer-manager/src/operation/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Download {
// have the correct metadata w.r.t. content-length and maybe others for the whole object.
object_meta: discovery.meta,
body: Body::new(comp_rx),
_tasks: tasks,
tasks,
};

Ok(handle)
Expand Down
10 changes: 9 additions & 1 deletion aws-s3-transfer-manager/src/operation/download/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct DownloadHandle {
pub body: Body,

/// All child tasks spawned for this download
pub(crate) _tasks: task::JoinSet<()>,
pub(crate) tasks: task::JoinSet<()>,
}

impl DownloadHandle {
Expand All @@ -30,4 +30,12 @@ impl DownloadHandle {
pub fn body(&self) -> &Body {
&self.body
}

/// Consume the handle and wait for download transfer to complete
pub async fn join(mut self) -> Result<(), crate::error::Error> {
while let Some(join_result) = self.tasks.join_next().await {
join_result?;
}
Ok(())
}
}
1 change: 1 addition & 0 deletions aws-s3-transfer-manager/src/operation/download/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pub(super) async fn distribute_work(
chunk_req
);
let chunk_size = chunk_req.size();
// FIXME - restructure like download_objects for tasks to yield errors that can be propagated on join()
tx.send(chunk_req).await.expect("channel open");

seq += 1;
Expand Down
76 changes: 72 additions & 4 deletions aws-s3-transfer-manager/src/operation/download_objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ pub use output::{DownloadObjectsOutput, DownloadObjectsOutputBuilder};
mod handle;
pub use handle::DownloadObjectsHandle;

use std::sync::Arc;
mod list_objects;
mod worker;

use std::path::Path;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use tokio::{fs, task::JoinSet};
use tracing::Instrument;

use crate::{error, types::FailedDownloadTransfer};

use super::TransferContext;

/// Operation struct for downloading multiple objects from Amazon S3
#[derive(Clone, Default, Debug)]
Expand All @@ -25,9 +36,66 @@ pub(crate) struct DownloadObjects;
impl DownloadObjects {
/// Execute a single `DownloadObjects` transfer operation
pub(crate) async fn orchestrate(
_handle: Arc<crate::client::Handle>,
_input: crate::operation::download_objects::DownloadObjectsInput,
handle: Arc<crate::client::Handle>,
input: crate::operation::download_objects::DownloadObjectsInput,
) -> Result<DownloadObjectsHandle, crate::error::Error> {
unimplemented!()
// validate existence of source directory and return error if it's not a directory
aajtodd marked this conversation as resolved.
Show resolved Hide resolved
let destination = input.destination().expect("destination set");
validate_destination(destination).await?;

let concurrency = handle.num_workers();
let ctx = DownloadObjectsContext::new(handle.clone(), input);

// spawn all work into the same JoinSet such that when the set is dropped all tasks are cancelled.
let mut tasks = JoinSet::new();
let (work_tx, work_rx) = async_channel::bounded(concurrency);

// spawn worker to discover/distribute work
tasks.spawn(worker::discover_objects(ctx.clone(), work_tx));

for i in 0..concurrency {
let worker = worker::download_objects(ctx.clone(), work_rx.clone())
.instrument(tracing::debug_span!("object-downloader", worker = i));
Velfi marked this conversation as resolved.
Show resolved Hide resolved
tasks.spawn(worker);
}

let handle = DownloadObjectsHandle { tasks, ctx };
Ok(handle)
}
}

async fn validate_destination(path: &Path) -> Result<(), error::Error> {
let file = fs::File::open(path).await?;
let meta = file.metadata().await?;

if !meta.is_dir() {
aajtodd marked this conversation as resolved.
Show resolved Hide resolved
return Err(error::invalid_input(format!(
"destination is not a directory: {path:?}"
)));
}

Ok(())
}

/// DownloadObjects operation specific state
#[derive(Debug)]
pub(crate) struct DownloadObjectsState {
input: DownloadObjectsInput,
failed_downloads: Mutex<Option<Vec<FailedDownloadTransfer>>>,
Velfi marked this conversation as resolved.
Show resolved Hide resolved
successful_downloads: AtomicU64,
total_bytes_transferred: AtomicU64,
}

type DownloadObjectsContext = TransferContext<DownloadObjectsState>;

impl DownloadObjectsContext {
fn new(handle: Arc<crate::client::Handle>, input: DownloadObjectsInput) -> Self {
let state = Arc::new(DownloadObjectsState {
input,
failed_downloads: Mutex::new(None),
successful_downloads: AtomicU64::default(),
total_bytes_transferred: AtomicU64::default(),
});
TransferContext { handle, state }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ impl DownloadObjectsFluentBuilder {
}

/// Filter unwanted S3 objects from being downloaded as part of the transfer.
pub fn filter(mut self, input: impl Fn(&aws_sdk_s3::types::Object) -> bool + 'static) -> Self {
pub fn filter(
mut self,
input: impl Fn(&aws_sdk_s3::types::Object) -> bool + Send + Sync + 'static,
Velfi marked this conversation as resolved.
Show resolved Hide resolved
) -> Self {
self.inner = self.inner.filter(input);
self
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,44 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::atomic::Ordering;

use tokio::task;

use super::DownloadObjectsOutput;
use super::{DownloadObjectsContext, DownloadObjectsOutput};

/// Handle for `DownloadObjects` transfer operation
#[derive(Debug)]
#[non_exhaustive]
pub struct DownloadObjectsHandle {
/// All child tasks spawned for this download
pub(crate) _tasks: task::JoinSet<()>,
pub(crate) tasks: task::JoinSet<Result<(), crate::error::Error>>,
/// The context used to drive an upload to completion
pub(crate) ctx: DownloadObjectsContext,
}

impl DownloadObjectsHandle {
/// Consume the handle and wait for download transfer to complete
pub async fn join(self) -> Result<DownloadObjectsOutput, crate::error::Error> {
unimplemented!()
pub async fn join(mut self) -> Result<DownloadObjectsOutput, crate::error::Error> {
// join all tasks
while let Some(join_result) = self.tasks.join_next().await {
join_result??;
graebm marked this conversation as resolved.
Show resolved Hide resolved
}

let failed_downloads = self.ctx.state.failed_downloads.lock().unwrap().take();
let successful_downloads = self.ctx.state.successful_downloads.load(Ordering::SeqCst);
let total_bytes_transferred = self
.ctx
.state
.total_bytes_transferred
.load(Ordering::SeqCst);
Comment on lines +32 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider making an accessor for this.


let output = DownloadObjectsOutput::builder()
.objects_downloaded(successful_downloads)
.set_failed_transfers(failed_downloads)
.total_bytes_transferred(total_bytes_transferred)
.build();

Ok(output)
}
}
Loading