diff --git a/aws-s3-transfer-manager/src/error.rs b/aws-s3-transfer-manager/src/error.rs index 7931655..7b30347 100644 --- a/aws-s3-transfer-manager/src/error.rs +++ b/aws-s3-transfer-manager/src/error.rs @@ -160,3 +160,10 @@ where Error::new(kind, value) } } + +static CANCELLATION_ERROR: &str = + "at least one operation has been aborted, cancelling all ongoing requests"; + +pub(crate) fn operation_cancelled() -> Error { + Error::new(ErrorKind::OperationCancelled, CANCELLATION_ERROR) +} diff --git a/aws-s3-transfer-manager/src/operation.rs b/aws-s3-transfer-manager/src/operation.rs index ecf0a9b..8a8a656 100644 --- a/aws-s3-transfer-manager/src/operation.rs +++ b/aws-s3-transfer-manager/src/operation.rs @@ -22,6 +22,10 @@ pub mod upload_objects; // The default delimiter of the S3 object key pub(crate) const DEFAULT_DELIMITER: &str = "/"; +// Type aliases to channel ends to send/receive cancel notification +pub(crate) type CancelNotificationSender = tokio::sync::watch::Sender; +pub(crate) type CancelNotificationReceiver = tokio::sync::watch::Receiver; + /// Container for maintaining context required to carry out a single operation/transfer. /// /// `State` is whatever additional operation specific state is required for the operation. diff --git a/aws-s3-transfer-manager/src/operation/download/discovery.rs b/aws-s3-transfer-manager/src/operation/download/discovery.rs index ed72136..893bfba 100644 --- a/aws-s3-transfer-manager/src/operation/download/discovery.rs +++ b/aws-s3-transfer-manager/src/operation/download/discovery.rs @@ -196,9 +196,10 @@ mod tests { use aws_sdk_s3::operation::get_object::GetObjectOutput; use aws_sdk_s3::operation::head_object::HeadObjectOutput; use aws_sdk_s3::Client; - use aws_smithy_mocks_experimental::{mock, mock_client}; + use aws_smithy_mocks_experimental::mock; use aws_smithy_types::byte_stream::ByteStream; use bytes::Buf; + use test_common::mock_client_with_stubbed_http_client; use super::ObjectDiscovery; @@ -247,7 +248,7 @@ mod tests { async fn get_discovery_from_head(range: Option) -> ObjectDiscovery { let head_obj_rule = mock!(Client::head_object) .then_output(|| HeadObjectOutput::builder().content_length(500).build()); - let client = mock_client!(aws_sdk_s3, &[&head_obj_rule]); + let client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&head_obj_rule]); let ctx = DownloadContext::new(test_handle(client, 5 * ByteUnit::Mebibyte.as_bytes_u64())); @@ -296,7 +297,7 @@ mod tests { .body(ByteStream::from_static(bytes)) .build() }); - let client = mock_client!(aws_sdk_s3, &[&get_obj_rule]); + let client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&get_obj_rule]); let ctx = DownloadContext::new(test_handle(client, target_part_size)); @@ -332,7 +333,7 @@ mod tests { .body(ByteStream::from_static(bytes)) .build() }); - let client = mock_client!(aws_sdk_s3, &[&get_obj_rule]); + let client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&get_obj_rule]); let ctx = DownloadContext::new(test_handle(client, target_part_size)); @@ -367,7 +368,7 @@ mod tests { .body(ByteStream::from_static(bytes)) .build() }); - let client = mock_client!(aws_sdk_s3, &[&get_obj_rule]); + let client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&get_obj_rule]); let ctx = DownloadContext::new(test_handle(client, target_part_size)); diff --git a/aws-s3-transfer-manager/src/operation/download_objects.rs b/aws-s3-transfer-manager/src/operation/download_objects.rs index 1c0e332..c7439a7 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects.rs @@ -15,7 +15,7 @@ pub use output::{DownloadObjectsOutput, DownloadObjectsOutputBuilder}; mod handle; pub use handle::DownloadObjectsHandle; -use tokio::fs; +use tokio::{fs, sync::watch}; mod list_objects; mod worker; @@ -27,7 +27,9 @@ use tracing::Instrument; use crate::types::FailedDownload; -use super::{validate_target_is_dir, TransferContext}; +use super::{ + validate_target_is_dir, CancelNotificationReceiver, CancelNotificationSender, TransferContext, +}; /// Operation struct for downloading multiple objects from Amazon S3 #[derive(Clone, Default, Debug)] @@ -83,6 +85,8 @@ pub(crate) struct DownloadObjectsState { // TODO - Determine if `input` should be separated from this struct // https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/67#discussion_r1821661603 input: DownloadObjectsInput, + cancel_tx: CancelNotificationSender, + cancel_rx: CancelNotificationReceiver, failed_downloads: Mutex>, successful_downloads: AtomicU64, total_bytes_transferred: AtomicU64, @@ -92,8 +96,11 @@ type DownloadObjectsContext = TransferContext; impl DownloadObjectsContext { fn new(handle: Arc, input: DownloadObjectsInput) -> Self { + let (cancel_tx, cancel_rx) = watch::channel(false); let state = Arc::new(DownloadObjectsState { input, + cancel_tx, + cancel_rx, failed_downloads: Mutex::new(Vec::new()), successful_downloads: AtomicU64::default(), total_bytes_transferred: AtomicU64::default(), diff --git a/aws-s3-transfer-manager/src/operation/download_objects/handle.rs b/aws-s3-transfer-manager/src/operation/download_objects/handle.rs index 966f5c3..b7783ec 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects/handle.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects/handle.rs @@ -5,6 +5,8 @@ use tokio::task; +use crate::{error::ErrorKind, types::FailedTransferPolicy}; + use super::{DownloadObjectsContext, DownloadObjectsOutput}; /// Handle for `DownloadObjects` transfer operation @@ -19,14 +21,57 @@ pub struct DownloadObjectsHandle { impl DownloadObjectsHandle { /// Consume the handle and wait for download transfer to complete + /// + /// When the `FailedTransferPolicy` is set to [`FailedTransferPolicy::Abort`], this method + /// will return the first error if any of the spawned tasks encounter one. The other tasks + /// will be canceled, but their cancellations will not be reported as errors by this method; + /// they will be logged as errors, instead. + /// + /// If the `FailedTransferPolicy` is set to [`FailedTransferPolicy::Continue`], the + /// [`DownloadObjectsOutput`] will include a detailed breakdown, including the number of + /// successful downloads and the number of failed ones. + /// + // TODO(aws-sdk-rust#1159) - Consider if we want to return other all errors encountered during cancellation. #[tracing::instrument(skip_all, level = "debug", name = "join-download-objects")] pub async fn join(mut self) -> Result { - // TODO - Consider implementing more sophisticated error handling such as canceling in-progress transfers + let mut first_error_to_report = None; // join all tasks while let Some(join_result) = self.tasks.join_next().await { - join_result??; + let result = join_result.expect("task completed"); + if let Err(e) = result { + match self.ctx.state.input.failure_policy() { + FailedTransferPolicy::Abort + if first_error_to_report.is_none() + && e.kind() != &ErrorKind::OperationCancelled => + { + first_error_to_report = Some(e); + } + FailedTransferPolicy::Continue => { + tracing::warn!("encountered but dismissed error when the failure policy is `Continue`: {e}") + } + _ => {} + } + } + } + + if let Some(e) = first_error_to_report { + Err(e) + } else { + Ok(DownloadObjectsOutput::from(self.ctx.state.as_ref())) + } + } + + /// Aborts all tasks owned by the handle. + pub async fn abort(&mut self) -> Result<(), crate::error::Error> { + if self.ctx.state.input.failure_policy() == &FailedTransferPolicy::Abort { + if self.ctx.state.cancel_tx.send(true).is_err() { + tracing::debug!( + "all receiver ends have been dropped, unable to send a cancellation signal" + ); + } + while (self.tasks.join_next().await).is_some() {} } - Ok(DownloadObjectsOutput::from(self.ctx.state.as_ref())) + Ok(()) } } diff --git a/aws-s3-transfer-manager/src/operation/download_objects/list_objects.rs b/aws-s3-transfer-manager/src/operation/download_objects/list_objects.rs index ca5a5c7..b43792a 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects/list_objects.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects/list_objects.rs @@ -205,7 +205,8 @@ mod tests { operation::list_objects_v2::ListObjectsV2Output, types::{CommonPrefix, Object}, }; - use aws_smithy_mocks_experimental::{mock, mock_client}; + use aws_smithy_mocks_experimental::mock; + use test_common::mock_client_with_stubbed_http_client; use crate::operation::download_objects::{DownloadObjectsContext, DownloadObjectsInput}; @@ -332,7 +333,10 @@ mod tests { .then_output(|| list_resp(None, "pre1", None, vec!["pre1/k7", "pre1/k8"])); let resp5 = mock!(aws_sdk_s3::Client::list_objects_v2) .then_output(|| list_resp(None, "pre2", None, vec!["pre2/k9", "pre2/k10"])); - let client = mock_client!(aws_sdk_s3, &[&resp1, &resp2, &resp3, &resp4, &resp5]); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + &[&resp1, &resp2, &resp3, &resp4, &resp5] + ); let config = crate::Config::builder().client(client).build(); let client = crate::Client::new(config); diff --git a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs index ec3882d..35b938e 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs @@ -11,7 +11,7 @@ use std::sync::atomic::Ordering; use tokio::fs; use tokio::io::AsyncWriteExt; -use crate::error; +use crate::error::{self, ErrorKind}; use crate::operation::download::body::Body; use crate::operation::download::{DownloadInput, DownloadInputBuilder}; use crate::operation::DEFAULT_DELIMITER; @@ -49,18 +49,37 @@ pub(super) async fn discover_objects( let default_filter = &DownloadFilter::default(); let filter = ctx.state.input.filter().unwrap_or(default_filter); - while let Some(obj_result) = stream.next().await { - let object = obj_result?; - if !(filter.predicate)(&object) { - // TODO(SEP) - The S3 Transfer Manager MAY add validation to handle the case for the objects whose - // keys differ only by case in case-insensitive filesystems such as Windows. For example, throw - // validation exception if a user attempts to download a bucket that contains "foobar" and "FOOBAR" in Windows. - tracing::debug!("skipping object due to filter: {:?}", object); - continue; - } + let mut cancel_rx = ctx.state.cancel_rx.clone(); - let job = DownloadObjectJob { object }; - work_tx.send(job).await.expect("channel valid"); + loop { + tokio::select! { + _ = cancel_rx.changed() => { + tracing::error!("received cancellation signal, exiting and not listing new objects"); + return Err(error::operation_cancelled()); + } + obj_result = stream.next() => { + match obj_result { + None => break, + Some(obj_result) => { + let object = obj_result?; + + if !(filter.predicate)(&object) { + // TODO(SEP) - The S3 Transfer Manager MAY add validation to handle the case for the objects whose + // keys differ only by case in case-insensitive filesystems such as Windows. For example, throw + // validation exception if a user attempts to download a bucket that contains "foobar" and "FOOBAR" in Windows. + tracing::debug!("skipping object due to filter: {:?}", object); + continue; + } + + let job = DownloadObjectJob { object }; + if work_tx.send(job).await.is_err() { + tracing::error!("all receiver ends have been dropped, unable to send a job!"); + break; + } + } + } + } + } } Ok(()) @@ -71,54 +90,74 @@ pub(super) async fn download_objects( ctx: DownloadObjectsContext, work_rx: Receiver, ) -> Result<(), error::Error> { - while let Ok(job) = work_rx.recv().await { - tracing::debug!( - "worker recv'd request for key {:?} ({:?} bytes)", - job.object.key, - job.object.size() - ); - - let dl_result = download_single_obj(&ctx, &job).await; - match dl_result { - Ok(_) => { - ctx.state - .successful_downloads - .fetch_add(1, Ordering::SeqCst); - - let bytes_transferred: u64 = job - .object - .size() - .unwrap_or_default() - .try_into() - .unwrap_or_default(); - - ctx.state - .total_bytes_transferred - .fetch_add(bytes_transferred, Ordering::SeqCst); - - tracing::debug!("worker finished downloading key {:?}", job.object.key); + let mut cancel_rx = ctx.state.cancel_rx.clone(); + loop { + tokio::select! { + _ = cancel_rx.changed() => { + tracing::error!("received cancellation signal, exiting and not downloading a new object"); + return Err(error::operation_cancelled()); } - Err(err) => { - tracing::debug!( - "worker failed to download key {:?}: {}", - job.object.key, - err - ); - match ctx.state.input.failure_policy() { - // TODO - this will abort this worker, the rest of the workers will be aborted - // when the handle is joined and the error is propagated and the task set is - // dropped. This _may_ be later/too passive and we might consider aborting all - // the tasks on error rather than relying on join and then drop. - FailedTransferPolicy::Abort => return Err(err), - FailedTransferPolicy::Continue => { - let mut failures = ctx.state.failed_downloads.lock().unwrap(); - - let failed_transfer = FailedDownload { - input: job.input(&ctx), - error: err, - }; - - failures.push(failed_transfer); + job = work_rx.recv() => { + match job { + Err(_) => break, + Ok(job) => { + tracing::debug!( + "worker recv'd request for key {:?} ({:?} bytes)", + job.object.key, + job.object.size() + ); + + let dl_result = download_single_obj(&ctx, &job).await; + match dl_result { + Ok(_) => { + ctx.state + .successful_downloads + .fetch_add(1, Ordering::SeqCst); + + let bytes_transferred: u64 = job + .object + .size() + .unwrap_or_default() + .try_into() + .unwrap_or_default(); + + ctx.state + .total_bytes_transferred + .fetch_add(bytes_transferred, Ordering::SeqCst); + + tracing::debug!("worker finished downloading key {:?}", job.object.key); + } + Err(err) => { + tracing::debug!( + "worker failed to download key {:?}: {}", + job.object.key, + err + ); + match ctx.state.input.failure_policy() { + FailedTransferPolicy::Abort => { + // Sending a cancellation signal during graceful shutdown would be redundant. + if err.kind() != &ErrorKind::OperationCancelled + && ctx.state.cancel_tx.send(true).is_err() + { + tracing::warn!( + "all receiver ends have been dropped, unable to send a cancellation signal" + ); + } + return Err(err); + } + FailedTransferPolicy::Continue => { + let mut failures = ctx.state.failed_downloads.lock().unwrap(); + + let failed_transfer = FailedDownload { + input: job.input(&ctx), + error: err, + }; + + failures.push(failed_transfer); + } + } + } + } } } } @@ -142,6 +181,25 @@ async fn download_single_obj( let key_path = local_key_path(root_dir, key.as_str(), prefix, delim)?; let mut handle = crate::operation::download::Download::orchestrate(ctx.handle.clone(), input, true)?; + + // The cancellation process would work fine without this if statement. + // It's here so we can save a single download operation that would otherwise + // be wasted if the system is already in graceful shutdown mode. + if ctx + .state + .cancel_rx + .has_changed() + .expect("the channel should be open as it is owned by `DownloadObjectsState`") + { + /* + * TODO(single download cleanup): Comment in the following lines of code once single download has been cleaned up. + * Note that it may not be called `.abort()` depending on the outcome of the cleanup. + * + * handle.abort().await; + * return Err(error::operation_cancelled()); + */ + } + let _ = handle.object_meta().await?; let mut body = mem::replace(&mut handle.body, Body::empty()); @@ -233,14 +291,14 @@ fn validate_path(root_dir: &Path, local_path: &Path, key: &str) -> Result<(), er #[cfg(test)] mod tests { - use aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output; - use aws_smithy_mocks_experimental::{mock, mock_client}; + use super::*; + + use crate::operation::download_objects::{DownloadObjectsContext, DownloadObjectsInput}; - use crate::operation::download_objects::{ - worker::discover_objects, DownloadObjectsContext, DownloadObjectsInput, - }; + use aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output; + use aws_smithy_mocks_experimental::mock; + use test_common::mock_client_with_stubbed_http_client; - use super::{local_key_path, strip_key_prefix}; use std::path::PathBuf; struct ObjectKeyPathTest { @@ -448,7 +506,7 @@ mod tests { .build() }); - let s3_client = mock_client!(aws_sdk_s3, &[&list_objects_rule]); + let s3_client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&list_objects_rule]); let config = crate::Config::builder().client(s3_client).build(); let client = crate::Client::new(config); let input = DownloadObjectsInput::builder() @@ -505,7 +563,7 @@ mod tests { .build() }); - let s3_client = mock_client!(aws_sdk_s3, &[&list_objects_rule]); + let s3_client = mock_client_with_stubbed_http_client!(aws_sdk_s3, &[&list_objects_rule]); let config = crate::Config::builder().client(s3_client).build(); let client = crate::Client::new(config); let input = DownloadObjectsInput::builder() diff --git a/aws-s3-transfer-manager/src/operation/upload.rs b/aws-s3-transfer-manager/src/operation/upload.rs index b57f778..ccc08a1 100644 --- a/aws-s3-transfer-manager/src/operation/upload.rs +++ b/aws-s3-transfer-manager/src/operation/upload.rs @@ -227,10 +227,11 @@ mod test { use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::put_object::PutObjectOutput; use aws_sdk_s3::operation::upload_part::UploadPartOutput; - use aws_smithy_mocks_experimental::{mock, mock_client, RuleMode}; + use aws_smithy_mocks_experimental::{mock, RuleMode}; use bytes::Bytes; use std::ops::Deref; use std::sync::Arc; + use test_common::mock_client_with_stubbed_http_client; #[tokio::test] async fn test_basic_mpu() { @@ -274,7 +275,7 @@ mod test { .build() }); - let client = mock_client!( + let client = mock_client_with_stubbed_http_client!( aws_sdk_s3, RuleMode::Sequential, &[&create_mpu, &upload_1, &upload_2, &complete_mpu] @@ -314,7 +315,8 @@ mod test { .build() }); - let client = mock_client!(aws_sdk_s3, RuleMode::Sequential, &[&put_object]); + let client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::Sequential, &[&put_object]); let tm_config = crate::Config::builder() .concurrency(ConcurrencySetting::Explicit(1)) diff --git a/aws-s3-transfer-manager/src/operation/upload_objects.rs b/aws-s3-transfer-manager/src/operation/upload_objects.rs index 60fb3ff..13ae61c 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects.rs @@ -16,17 +16,16 @@ pub use handle::UploadObjectsHandle; mod output; pub use output::{UploadObjectsOutput, UploadObjectsOutputBuilder}; -use tokio::{ - sync::watch::{self, Receiver, Sender}, - task::JoinSet, -}; +use tokio::{sync::watch, task::JoinSet}; use tracing::Instrument; mod worker; use crate::{error, types::FailedUpload}; -use super::{validate_target_is_dir, TransferContext}; +use super::{ + validate_target_is_dir, CancelNotificationReceiver, CancelNotificationSender, TransferContext, +}; /// Operation struct for uploading multiple objects to Amazon S3 #[derive(Clone, Default, Debug)] @@ -85,8 +84,8 @@ pub(crate) struct UploadObjectsState { // TODO - Determine if `input` should be separated from this struct // https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/67#discussion_r1821661603 input: UploadObjectsInput, - cancel_tx: Sender, - cancel_rx: Receiver, + cancel_tx: CancelNotificationSender, + cancel_rx: CancelNotificationReceiver, failed_uploads: Mutex>, successful_uploads: AtomicU64, total_bytes_transferred: AtomicU64, @@ -95,8 +94,8 @@ pub(crate) struct UploadObjectsState { impl UploadObjectsState { pub(crate) fn new( input: UploadObjectsInput, - cancel_tx: Sender, - cancel_rx: Receiver, + cancel_tx: CancelNotificationSender, + cancel_rx: CancelNotificationReceiver, ) -> Self { Self { input, diff --git a/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs b/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs index 3c563cd..319f9f5 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs @@ -49,7 +49,7 @@ impl UploadObjectsHandle { /// they will be logged as errors, instead. /// /// If the `FailedTransferPolicy` is set to [`FailedTransferPolicy::Continue`], the - /// [`UploadObjectsOutput`] will include a detailed breakdown, such as the number of + /// [`UploadObjectsOutput`] will include a detailed breakdown, including the number of /// successful uploads and the number of failed ones. /// // TODO(aws-sdk-rust#1159) - Consider if we want to return failed `AbortMultipartUpload` during cancellation. @@ -90,7 +90,7 @@ impl UploadObjectsHandle { pub async fn abort(&mut self) -> Result<(), crate::error::Error> { if self.ctx.state.input.failure_policy() == &FailedTransferPolicy::Abort { if self.ctx.state.cancel_tx.send(true).is_err() { - tracing::warn!( + tracing::debug!( "all receiver ends have been dropped, unable to send a cancellation signal" ); } diff --git a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs index 2c8d083..977da82 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs @@ -22,9 +22,6 @@ use crate::operation::DEFAULT_DELIMITER; use crate::types::{FailedTransferPolicy, FailedUpload, UploadFilter}; use crate::{error, types::UploadFilterItem}; -const CANCELLATION_ERROR: &str = - "at least one operation has been aborted, cancelling all ongoing requests"; - #[derive(Debug)] pub(super) struct UploadObjectJob { key: String, @@ -61,7 +58,7 @@ pub(super) async fn list_directory_contents( tokio::select! { _ = cancel_rx.changed() => { tracing::error!("received cancellation signal, exiting and not yielding new directory contents"); - return Err(crate::error::Error::new(ErrorKind::OperationCancelled, CANCELLATION_ERROR.to_owned())); + return Err(error::operation_cancelled()); } entry = walker.next() => { match entry { @@ -194,7 +191,7 @@ pub(super) async fn upload_objects( tokio::select! { _ = cancel_rx.changed() => { tracing::error!("received cancellation signal, exiting and ignoring any future work"); - return Err(crate::error::Error::new(ErrorKind::OperationCancelled, CANCELLATION_ERROR.to_owned())); + return Err(error::operation_cancelled()); } job = list_directory_rx.recv() => { match job { @@ -274,10 +271,7 @@ async fn upload_single_obj( DisplayErrorContext(&e) ); } - Err(crate::error::Error::new( - ErrorKind::OperationCancelled, - CANCELLATION_ERROR.to_owned(), - )) + Err(error::operation_cancelled()) } else { handle.join().await?; Ok(bytes_transferred) @@ -333,8 +327,9 @@ mod tests { abort_multipart_upload::AbortMultipartUploadOutput, create_multipart_upload::CreateMultipartUploadOutput, put_object::PutObjectOutput, }; - use aws_smithy_mocks_experimental::{mock, mock_client, RuleMode}; + use aws_smithy_mocks_experimental::{mock, RuleMode}; use bytes::Bytes; + use test_common::mock_client_with_stubbed_http_client; use crate::{ client::Handle, @@ -716,7 +711,8 @@ mod tests { .match_requests(move |input| input.bucket() == Some(bucket)) .then_output(|| PutObjectOutput::builder().build()); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); + let s3_client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); let config = crate::Config::builder().client(s3_client).build(); let scheduler = Scheduler::new(DEFAULT_CONCURRENCY); @@ -764,7 +760,11 @@ mod tests { } }) .then_output(|| AbortMultipartUploadOutput::builder().build()); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::Sequential, &[create_mpu, abort_mpu]); + let s3_client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::Sequential, + &[create_mpu, abort_mpu] + ); let config = crate::Config::builder() .set_multipart_threshold(PartSize::Target(MIN_MULTIPART_PART_SIZE_BYTES)) .client(s3_client) diff --git a/aws-s3-transfer-manager/test-common/Cargo.toml b/aws-s3-transfer-manager/test-common/Cargo.toml index 0510045..7b93b32 100644 --- a/aws-s3-transfer-manager/test-common/Cargo.toml +++ b/aws-s3-transfer-manager/test-common/Cargo.toml @@ -6,4 +6,7 @@ license = "Apache-2.0" publish = false [dependencies] +aws-sdk-s3 = { version = "1.51.0", features = ["behavior-version-latest", "test-util"] } +aws-smithy-mocks-experimental = "0.2.1" +http-02x = { package = "http", version = "0.2.9" } tempfile = "3.12.0" \ No newline at end of file diff --git a/aws-s3-transfer-manager/test-common/src/lib.rs b/aws-s3-transfer-manager/test-common/src/lib.rs index 4c79d7b..61f66cc 100644 --- a/aws-s3-transfer-manager/test-common/src/lib.rs +++ b/aws-s3-transfer-manager/test-common/src/lib.rs @@ -3,25 +3,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -#![cfg(target_family = "unix")] - -use std::{fs, io::Write}; -use tempfile::{tempdir, TempDir}; - /// Create a directory structure rooted at `recursion_root`, containing files with sizes /// specified in `files` /// /// For testing purposes, certain directories (and all files within them) can be made /// inaccessible by providing `inaccessible_dir_relative_paths`, which should be relative /// to `recursion_root`. +#[cfg(target_family = "unix")] pub fn create_test_dir( recursion_root: Option<&str>, files: Vec<(&str, usize)>, inaccessible_dir_relative_paths: &[&str], -) -> TempDir { +) -> tempfile::TempDir { let temp_dir = match recursion_root { - Some(root) => TempDir::with_prefix(root).unwrap(), - None => tempdir().unwrap(), + Some(root) => tempfile::TempDir::with_prefix(root).unwrap(), + None => tempfile::tempdir().unwrap(), }; // Create the directory structure and files @@ -30,21 +26,51 @@ pub fn create_test_dir( let parent = full_path.parent().unwrap(); // Create the parent directories if they don't exist - fs::create_dir_all(parent).unwrap(); + std::fs::create_dir_all(parent).unwrap(); // Create the file with the specified size - let mut file = fs::File::create(&full_path).unwrap(); - file.write_all(&vec![0; size]).unwrap(); // Writing `size` byte + let mut file = std::fs::File::create(&full_path).unwrap(); + std::io::Write::write_all(&mut file, &vec![0; size]).unwrap(); // Writing `size` byte } // Set the directories in `inaccessible_dir_relative_paths` to be inaccessible, // which will in turn render the files within those directories inaccessible for dir_relative_path in inaccessible_dir_relative_paths { let dir_path = temp_dir.path().join(*dir_relative_path); - let mut permissions = fs::metadata(&dir_path).unwrap().permissions(); + let mut permissions = std::fs::metadata(&dir_path).unwrap().permissions(); std::os::unix::fs::PermissionsExt::set_mode(&mut permissions, 0o000); // No permissions for anyone - fs::set_permissions(dir_path, permissions).unwrap(); + std::fs::set_permissions(dir_path, permissions).unwrap(); } temp_dir } + +/// A macro to generate a mock S3 client with the underlying HTTP client stubbed out +/// +/// This macro wraps [`mock_client`](aws_smithy_mocks_experimental::mock_client) to work around the issue +/// where the inner macro, when used alone, does not stub the HTTP client, causing real HTTP requests to be sent. +// TODO(https://github.com/smithy-lang/smithy-rs/issues/3926): Once resolved, remove this macro and have the callers use the upstream version instead. +#[macro_export] +macro_rules! mock_client_with_stubbed_http_client { + ($aws_crate: ident, $rules: expr) => { + mock_client_with_stubbed_http_client!( + $aws_crate, + aws_smithy_mocks_experimental::RuleMode::Sequential, + $rules + ) + }; + ($aws_crate: ident, $rule_mode: expr, $rules: expr) => {{ + let client = aws_smithy_mocks_experimental::mock_client!($aws_crate, $rule_mode, $rules); + $aws_crate::client::Client::from_conf( + client + .config() + .to_builder() + .http_client( + aws_smithy_runtime::client::http::test_util::infallible_client_fn(|_req| { + http_02x::Response::builder().status(200).body("").unwrap() + }), + ) + .build(), + ) + }}; +} diff --git a/aws-s3-transfer-manager/tests/download_objects_test.rs b/aws-s3-transfer-manager/tests/download_objects_test.rs index ba0508f..3e5e80c 100644 --- a/aws-s3-transfer-manager/tests/download_objects_test.rs +++ b/aws-s3-transfer-manager/tests/download_objects_test.rs @@ -4,16 +4,25 @@ */ #![cfg(target_family = "unix")] -use aws_s3_transfer_manager::types::FailedTransferPolicy; +use aws_s3_transfer_manager::{error::ErrorKind, types::FailedTransferPolicy}; use aws_sdk_s3::{ - error::DisplayErrorContext, - operation::{get_object::GetObjectOutput, list_objects_v2::ListObjectsV2Output}, + error::{DisplayErrorContext, SdkError}, + operation::{ + get_object::GetObjectOutput, + list_objects_v2::{ListObjectsV2Error, ListObjectsV2Output}, + }, primitives::ByteStream, }; -use aws_smithy_mocks_experimental::{mock, mock_client, Rule, RuleMode}; -use aws_smithy_runtime_api::{client::orchestrator::HttpResponse, http::StatusCode}; +use aws_smithy_mocks_experimental::{mock, Rule, RuleMode}; +use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; +use aws_smithy_runtime_api::{ + client::orchestrator::HttpResponse, + http::{Response, StatusCode}, +}; use bytes::Bytes; -use std::{io, iter, path::Path, sync::Arc}; +use std::{error::Error as _, io, iter, path::Path, sync::Arc}; +use test_common::mock_client_with_stubbed_http_client; +use tokio::sync::watch; use walkdir::WalkDir; #[derive(Debug, Clone)] @@ -63,23 +72,28 @@ impl MockObject { } } -fn get_object_error_http_resp() -> HttpResponse { +fn error_http_resp() -> HttpResponse { HttpResponse::new(StatusCode::try_from(500).unwrap(), Bytes::new().into()) } /// Get the mock rule for this object when `get_object` API is invoked for the corresponding key fn get_object_rule(mobj: &MockObject) -> Rule { - let share1 = Arc::new(mobj.clone()); - let share2 = share1.clone(); + let mock_obj = Arc::new(mobj.clone()); if mobj.error_on_get { mock!(aws_sdk_s3::Client::get_object) - .match_requests(move |r| r.key() == share1.object.key()) - .then_http_response(get_object_error_http_resp) + .match_requests({ + let mock_obj = mock_obj.clone(); + move |r| r.key() == mock_obj.object.key() + }) + .then_http_response(error_http_resp) } else { mock!(aws_sdk_s3::Client::get_object) - .match_requests(move |r| r.key() == share1.object.key()) - .then_output(move || share2.get_object_output()) + .match_requests({ + let mock_obj = mock_obj.clone(); + move |r| r.key() == mock_obj.object.key() + }) + .then_output(move || mock_obj.get_object_output()) } } @@ -98,21 +112,27 @@ impl MockBucket { MockBucketBuilder::default() } - /// Return the mock rules representing this bucket. This includes - /// the `ListObjectsV2` call as well as all of the `GetObject` calls. - fn rules(&self) -> Vec { + /// Configure the mock behavior listing `objects` stored in this `MockBucket`. + fn list_objects_rule(&self) -> Rule { let contents = self.objects.iter().map(|m| m.object.clone()).collect(); let list_output = ListObjectsV2Output::builder() .set_contents(Some(contents)) .build(); - let list_rule = - mock!(aws_sdk_s3::Client::list_objects_v2).then_output(move || list_output.clone()); + mock!(aws_sdk_s3::Client::list_objects_v2).then_output(move || list_output.clone()) + } - let mut rules: Vec = self.objects.iter().map(get_object_rule).collect(); + /// Configure the mock behavior of `GetObject` for `objects` stored in this `MockBucket`. + fn get_object_rules(&self) -> Vec { + self.objects.iter().map(get_object_rule).collect() + } - rules.push(list_rule); + /// Return the mock rules representing this bucket. This includes + /// the `ListObjectsV2` call as well as all of the `GetObject` calls. + fn rules(&self) -> Vec { + let mut rules = self.get_object_rules(); + rules.push(self.list_objects_rule()); rules } } @@ -173,7 +193,11 @@ async fn test_strip_prefix_in_destination_path() { .key_with_size("abc/def/ghi/xyz.txt", 5) .build(); - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -212,7 +236,11 @@ async fn test_object_with_prefix_included() { .key_with_size("abcd", 5) .build(); - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -251,7 +279,11 @@ async fn test_failed_download_policy_continue() { .key_with_error("key3") .build(); - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -309,7 +341,11 @@ async fn test_recursively_downloads() { builder.build() }; - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -345,7 +381,11 @@ async fn test_delimiter() { .key_with_size("2023|1|1.png", 5) .build(); - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -383,7 +423,11 @@ async fn test_delimiter() { async fn test_destination_dir_not_valid() { let bucket = MockBucket::builder().key_with_size("image.png", 12).build(); - let client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, bucket.rules().as_slice()); + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); let config = aws_s3_transfer_manager::Config::builder() .client(client) @@ -403,3 +447,177 @@ async fn test_destination_dir_not_valid() { let err_str = format!("{}", DisplayErrorContext(err)); assert!(err_str.contains("target is not a directory")); } + +#[tokio::test] +async fn test_abort_on_handle_should_terminate_tasks_gracefully() { + let (_guard, rx) = capture_test_logs(); + + let bucket = MockBucket::builder() + .key_with_size("key1", 12) + .key_with_error("key2") + .key_with_size("key3", 7) + .build(); + + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); + + let config = aws_s3_transfer_manager::Config::builder() + .client(client) + .build(); + let tm = aws_s3_transfer_manager::Client::new(config); + + let dest = tempfile::tempdir().unwrap(); + + let mut handle = tm + .download_objects() + .bucket("test-bucket") + .destination(dest.path()) + .send() + .await + .unwrap(); + + handle.abort().await.unwrap(); + + assert!(rx.contents().contains("received cancellation signal")); +} + +#[tokio::test] +async fn test_failed_list_objects_should_cancel_the_operation() { + let (_guard, rx) = capture_test_logs(); + + let bucket = MockBucket::builder() + .key_with_size("key1", 12) + .key_with_error("key2") + .key_with_size("key3", 7) + .build(); + + let mut rules = bucket.get_object_rules(); + rules.push(mock!(aws_sdk_s3::Client::list_objects_v2).then_http_response(error_http_resp)); + let client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, rules.as_slice()); + + let config = aws_s3_transfer_manager::Config::builder() + .client(client) + .build(); + let tm = aws_s3_transfer_manager::Client::new(config); + + let dest = tempfile::tempdir().unwrap(); + + let handle = tm + .download_objects() + .bucket("test-bucket") + .destination(dest.path()) + .send() + .await + .unwrap(); + + let err = handle.join().await.unwrap_err(); + assert_eq!(&ErrorKind::ChildOperationFailed, err.kind()); + let service_error = err + .source() + .unwrap() + .downcast_ref::>() + .expect("should downcast to `SdkError`"); + assert!(service_error + .raw_response() + .unwrap() + .status() + .is_server_error()); + + // `ListObjectsV2` didn't list a single object and existed, so no one received a cancellation signal. + // Configuring the mock behavior of `ListObjectsV2` so it falis to list halfway through is more interesting + // for testing, but can make the test more complex. + assert!(!rx.contents().contains("received cancellation signal")); +} + +#[tokio::test] +async fn test_failed_get_object_should_cancel_the_operation() { + let (_guard, rx) = capture_test_logs(); + + let bucket = MockBucket::builder() + .key_with_size("key1", 12) + .key_with_error("key2") + .key_with_size("key3", 7) + .build(); + + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + bucket.rules().as_slice() + ); + + let config = aws_s3_transfer_manager::Config::builder() + .client(client) + .build(); + let tm = aws_s3_transfer_manager::Client::new(config); + + let dest = tempfile::tempdir().unwrap(); + + let handle = tm + .download_objects() + .bucket("test-bucket") + .destination(dest.path()) + .send() + .await + .unwrap(); + + let err = handle.join().await.unwrap_err(); + assert_eq!(&ErrorKind::ObjectNotDiscoverable, err.kind()); + + let logs = rx.contents(); + assert!( + logs.contains("received cancellation signal") + || logs.contains("req channel closed, worker finished") + ); +} + +#[tokio::test] +async fn test_drop_download_objects_handle() { + let bucket = MockBucket::builder() + .key_with_size("key1", 12) + .key_with_error("key2") + .key_with_size("key3", 7) + .build(); + + let (watch_tx, watch_rx) = watch::channel(()); + + let rule = mock!(aws_sdk_s3::Client::get_object).then_output({ + watch_tx.send(()).unwrap(); + move || GetObjectOutput::builder().build() + }); + + let s3_client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + vec![rule, bucket.list_objects_rule()].as_slice() + ); + let config = aws_s3_transfer_manager::Config::builder() + .client(s3_client) + .build(); + let tm = aws_s3_transfer_manager::Client::new(config); + + let dest = tempfile::tempdir().unwrap(); + + let handle = tm + .download_objects() + .bucket("test-bucket") + .destination(dest.path()) + .send() + .await + .unwrap(); + + // Wait until execution reaches the point just before returning `GetObjectOutput`, + // as dropping `handle` immediately after creation may not be interesting for testing. + while !watch_rx.has_changed().unwrap() { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Give some time so spawned tasks might be able to proceed with their tasks a bit. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // should not panic + drop(handle) +} diff --git a/aws-s3-transfer-manager/tests/upload_objects_test.rs b/aws-s3-transfer-manager/tests/upload_objects_test.rs index dea17d5..de91305 100644 --- a/aws-s3-transfer-manager/tests/upload_objects_test.rs +++ b/aws-s3-transfer-manager/tests/upload_objects_test.rs @@ -25,11 +25,11 @@ use aws_sdk_s3::{ }, Client, }; -use aws_smithy_mocks_experimental::{mock, mock_client, RuleMode}; +use aws_smithy_mocks_experimental::{mock, RuleMode}; use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; use aws_smithy_runtime_api::http::StatusCode; use aws_smithy_types::body::SdkBody; -use test_common::create_test_dir; +use test_common::{create_test_dir, mock_client_with_stubbed_http_client}; use tokio::{fs::symlink, sync::watch}; // Create an S3 client with mock behavior configured for `PutObject` @@ -38,7 +38,7 @@ fn mock_s3_client_for_put_object(bucket_name: String) -> Client { .match_requests(move |input| input.bucket() == Some(&bucket_name)) .then_output(|| PutObjectOutput::builder().build()); - mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]) + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]) } // Create an S3 client with mock behavior configured for `MultipartUpload` @@ -74,7 +74,7 @@ fn mock_s3_client_for_multipart_upload(bucket_name: String) -> Client { }) .then_output(|| CompleteMultipartUploadOutput::builder().build()); - mock_client!( + mock_client_with_stubbed_http_client!( aws_sdk_s3, RuleMode::MatchAny, &[create_mpu, upload_part, complete_mpu] @@ -323,7 +323,8 @@ async fn test_server_error_should_be_recorded_as_such_in_failed_transfers() { .then_http_response(|| { HttpResponse::new(StatusCode::try_from(500).unwrap(), SdkBody::empty()) }); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); + let s3_client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); let config = aws_s3_transfer_manager::Config::builder() .client(s3_client) .build(); @@ -435,7 +436,8 @@ async fn test_abort_on_handle_should_terminate_tasks_gracefully() { } }); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); + let s3_client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); let config = aws_s3_transfer_manager::Config::builder() .client(s3_client) .build(); @@ -481,12 +483,13 @@ async fn test_failed_child_operation_should_cause_ongoing_requests_to_be_cancell if already_accessed { HttpResponse::new(StatusCode::try_from(200).unwrap(), SdkBody::empty()) } else { - // Force the first call to PubObject to fail, triggering operation cancellation for all subsequent PubObject calls. + // Force the first call to PubObject to fail, triggering operation cancellation for all subsequent PutObject calls. HttpResponse::new(StatusCode::try_from(500).unwrap(), SdkBody::empty()) } }); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); + let s3_client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); let config = aws_s3_transfer_manager::Config::builder() .client(s3_client) .build(); @@ -536,12 +539,11 @@ async fn test_drop_upload_objects_handle() { .then_output({ move || { watch_tx.send(()).unwrap(); - // sleep for some time so that the main thread proceeds with `drop(handle)` - std::thread::sleep(std::time::Duration::from_millis(100)); PutObjectOutput::builder().build() } }); - let s3_client = mock_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); + let s3_client = + mock_client_with_stubbed_http_client!(aws_sdk_s3, RuleMode::MatchAny, &[put_object]); let config = aws_s3_transfer_manager::Config::builder() .client(s3_client) .build(); @@ -562,6 +564,9 @@ async fn test_drop_upload_objects_handle() { tokio::time::sleep(std::time::Duration::from_millis(100)).await; } + // Give some time so spawned tasks might be able to proceed with their tasks a bit. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + // should not panic drop(handle) } diff --git a/aws-s3-transfer-manager/tests/upload_test.rs b/aws-s3-transfer-manager/tests/upload_test.rs index acde4eb..ae65797 100644 --- a/aws-s3-transfer-manager/tests/upload_test.rs +++ b/aws-s3-transfer-manager/tests/upload_test.rs @@ -11,11 +11,11 @@ use aws_s3_transfer_manager::io::{InputStream, PartData, PartStream, SizeHint, S use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadOutput; use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::upload_part::UploadPartOutput; -use aws_smithy_mocks_experimental::{mock, mock_client, RuleMode}; -use aws_smithy_runtime::client::http::test_util::infallible_client_fn; +use aws_smithy_mocks_experimental::{mock, RuleMode}; use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; use bytes::Bytes; use pin_project_lite::pin_project; +use test_common::mock_client_with_stubbed_http_client; use tokio::sync::mpsc; /// number of simultaneous uploads to create @@ -93,7 +93,7 @@ fn mock_s3_client_for_multipart_upload() -> aws_sdk_s3::Client { }) .then_output(|| CompleteMultipartUploadOutput::builder().build()); - mock_client!( + mock_client_with_stubbed_http_client!( aws_sdk_s3, RuleMode::MatchAny, &[create_mpu, upload_part, complete_mpu] @@ -115,15 +115,6 @@ fn mock_s3_client_for_multipart_upload() -> aws_sdk_s3::Client { async fn test_many_uploads_no_deadlock() { let (_guard, _rx) = capture_test_logs(); let client = mock_s3_client_for_multipart_upload(); - let client = aws_sdk_s3::Client::from_conf( - client - .config() - .to_builder() - .http_client(infallible_client_fn(|_req| { - http_02x::Response::builder().status(200).body("").unwrap() - })) - .build(), - ); let config = aws_s3_transfer_manager::Config::builder() .client(client) .build();