From 061f9745b4e2432489e9de1481f786b128dd2648 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Thu, 5 Dec 2024 13:18:49 -0600 Subject: [PATCH 1/4] Follow up on PR#75 (#78) * Incorporate post-merge review feedback from PR#75 This commit addresses https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/75#discussion_r1853090148 https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/75#discussion_r1856992994 * Add test for canceling upload object via MPU This commit addresses https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/75#discussion_r1857001515 * Fix memory leak in test detected by `LeakSanitizer` * Simulate the flow of CreateMPU -> upload cancellation -> AbortMPU This commit addresses https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/78#discussion_r1863896824 * Avoid a single letter variable name --- aws-s3-transfer-manager/src/config.rs | 2 +- .../src/operation/upload/handle.rs | 3 + .../src/operation/upload_objects/handle.rs | 4 +- .../src/operation/upload_objects/worker.rs | 98 ++++++++++++++++++- aws-s3-transfer-manager/tests/upload_test.rs | 10 -- 5 files changed, 102 insertions(+), 15 deletions(-) diff --git a/aws-s3-transfer-manager/src/config.rs b/aws-s3-transfer-manager/src/config.rs index a728959..66b9cd8 100644 --- a/aws-s3-transfer-manager/src/config.rs +++ b/aws-s3-transfer-manager/src/config.rs @@ -10,7 +10,7 @@ use std::cmp; pub(crate) mod loader; /// Minimum upload part size in bytes -const MIN_MULTIPART_PART_SIZE_BYTES: u64 = 5 * ByteUnit::Mebibyte.as_bytes_u64(); +pub(crate) const MIN_MULTIPART_PART_SIZE_BYTES: u64 = 5 * ByteUnit::Mebibyte.as_bytes_u64(); /// Configuration for a [`Client`](crate::client::Client) #[derive(Debug, Clone)] diff --git a/aws-s3-transfer-manager/src/operation/upload/handle.rs b/aws-s3-transfer-manager/src/operation/upload/handle.rs index 74d2d02..1bda28e 100644 --- a/aws-s3-transfer-manager/src/operation/upload/handle.rs +++ b/aws-s3-transfer-manager/src/operation/upload/handle.rs @@ -44,6 +44,9 @@ pub(crate) enum UploadType { /// It first calls `.abort_all` on the tasks it owns, and then invokes `AbortMultipartUpload` /// to abort any in-progress multipart uploads. Errors encountered during `AbortMultipartUpload` /// are logged, but do not affect the overall cancellation flow. +/// +/// In either case, if the upload operation has already been completed before the handle is dropped +/// or aborted, the uploaded object will not be deleted from S3. #[derive(Debug)] #[non_exhaustive] pub struct UploadHandle { diff --git a/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs b/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs index 05a0f6d..319f9f5 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects/handle.rs @@ -44,8 +44,8 @@ impl UploadObjectsHandle { /// Consume the handle and wait for the upload to complete /// /// When the `FailedTransferPolicy` is set to [`FailedTransferPolicy::Abort`], this method - /// will return an error if any of the spawned tasks encounter one. The other tasks will - /// be canceled, but their cancellations will not be reported as errors by this method; + /// will return the first error if any of the spawned tasks encounter one. The other tasks + /// will be canceled, but their cancellations will not be reported as errors by this method; /// they will be logged as errors, instead. /// /// If the `FailedTransferPolicy` is set to [`FailedTransferPolicy::Continue`], the diff --git a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs index ee2a448..10d75bc 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs @@ -323,19 +323,27 @@ fn handle_failed_upload( #[cfg(test)] mod tests { - use aws_sdk_s3::operation::put_object::PutObjectOutput; + use std::sync::{Arc, Barrier}; + + use aws_sdk_s3::operation::{ + abort_multipart_upload::AbortMultipartUploadOutput, + create_multipart_upload::CreateMultipartUploadOutput, put_object::PutObjectOutput, + upload_part::UploadPartOutput, + }; use aws_smithy_mocks_experimental::{mock, RuleMode}; use bytes::Bytes; use test_common::mock_client_with_stubbed_http_client; use crate::{ client::Handle, + config::MIN_MULTIPART_PART_SIZE_BYTES, io::InputStream, operation::upload_objects::{ worker::{upload_single_obj, UploadObjectJob}, UploadObjectsContext, UploadObjectsInputBuilder, }, runtime::scheduler::Scheduler, + types::PartSize, DEFAULT_CONCURRENCY, }; @@ -700,7 +708,7 @@ mod tests { } #[tokio::test] - async fn test_cancel_single_upload() { + async fn test_cancel_single_upload_via_put_object() { let bucket = "doesnotmatter"; let put_object = mock!(aws_sdk_s3::Client::put_object) .match_requests(move |input| input.bucket() == Some(bucket)) @@ -730,4 +738,90 @@ mod tests { assert_eq!(&crate::error::ErrorKind::OperationCancelled, err.kind()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_cancel_single_upload_via_multipart_upload() { + let bucket = "test-bucket"; + let key = "test-key"; + let upload_id: String = "test-upload-id".to_owned(); + + let wait_till_create_mpu = Arc::new(Barrier::new(2)); + let (resume_upload_single_obj_tx, resume_upload_single_obj_rx) = + tokio::sync::watch::channel(()); + let resume_upload_single_obj_tx = Arc::new(resume_upload_single_obj_tx); + + let create_mpu = mock!(aws_sdk_s3::Client::create_multipart_upload).then_output({ + let wait_till_create_mpu = wait_till_create_mpu.clone(); + let upload_id = upload_id.clone(); + move || { + // This ensures that a cancellation signal won't be sent until `create_multipart_upload`. + wait_till_create_mpu.wait(); + + // This increases the reliability of the test, ensuring that the cancellation signal has been sent + // and that `upload_single_obj` can now resume. + while !resume_upload_single_obj_rx.has_changed().unwrap() { + std::thread::sleep(std::time::Duration::from_millis(100)); + } + + CreateMultipartUploadOutput::builder() + .upload_id(upload_id.clone()) + .build() + } + }); + let upload_part = mock!(aws_sdk_s3::Client::upload_part) + .then_output(|| UploadPartOutput::builder().build()); + let abort_mpu = mock!(aws_sdk_s3::Client::abort_multipart_upload) + .match_requests({ + let upload_id = upload_id.clone(); + move |input| { + input.upload_id.as_ref() == Some(&upload_id) + && input.bucket() == Some(bucket) + && input.key() == Some(key) + } + }) + .then_output(|| AbortMultipartUploadOutput::builder().build()); + let s3_client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::MatchAny, + &[create_mpu, upload_part, abort_mpu] + ); + let config = crate::Config::builder() + .set_multipart_threshold(PartSize::Target(MIN_MULTIPART_PART_SIZE_BYTES)) + .client(s3_client) + .build(); + + let scheduler = Scheduler::new(DEFAULT_CONCURRENCY); + + let handle = std::sync::Arc::new(Handle { config, scheduler }); + let input = UploadObjectsInputBuilder::default() + .source("doesnotmatter") + .bucket(bucket) + .build() + .unwrap(); + + // specify the size of the contents so it triggers multipart upload + let contents = vec![0; MIN_MULTIPART_PART_SIZE_BYTES as usize]; + let ctx = UploadObjectsContext::new(handle, input); + let job = UploadObjectJob { + object: InputStream::from(Bytes::copy_from_slice(contents.as_slice())), + key: key.to_owned(), + }; + + tokio::task::spawn({ + let ctx = ctx.clone(); + let resume_upload_single_obj_tx = resume_upload_single_obj_tx.clone(); + async move { + wait_till_create_mpu.wait(); + // The upload operation has reached a point where a `CreateMultipartUploadOutput` is being prepared, + // which means that cancellation can now be triggered. + ctx.state.cancel_tx.send(true).unwrap(); + // Tell `upload_single_obj` that it can now proceed. + resume_upload_single_obj_tx.send(()).unwrap(); + } + }); + + let err = upload_single_obj(&ctx, job).await.unwrap_err(); + + assert_eq!(&crate::error::ErrorKind::OperationCancelled, err.kind()); + } } diff --git a/aws-s3-transfer-manager/tests/upload_test.rs b/aws-s3-transfer-manager/tests/upload_test.rs index 2450ed8..ae65797 100644 --- a/aws-s3-transfer-manager/tests/upload_test.rs +++ b/aws-s3-transfer-manager/tests/upload_test.rs @@ -12,7 +12,6 @@ use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadOut use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::upload_part::UploadPartOutput; use aws_smithy_mocks_experimental::{mock, RuleMode}; -use aws_smithy_runtime::client::http::test_util::infallible_client_fn; use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; use bytes::Bytes; use pin_project_lite::pin_project; @@ -116,15 +115,6 @@ fn mock_s3_client_for_multipart_upload() -> aws_sdk_s3::Client { async fn test_many_uploads_no_deadlock() { let (_guard, _rx) = capture_test_logs(); let client = mock_s3_client_for_multipart_upload(); - let client = aws_sdk_s3::Client::from_conf( - client - .config() - .to_builder() - .http_client(infallible_client_fn(|_req| { - http_02x::Response::builder().status(200).body("").unwrap() - })) - .build(), - ); let config = aws_s3_transfer_manager::Config::builder() .client(client) .build(); From 540285ab90cab70039502c88ddda83f21a7af68f Mon Sep 17 00:00:00 2001 From: Waqar Ahmed Khan Date: Mon, 9 Dec 2024 09:48:37 -0800 Subject: [PATCH 2/4] Abort Download Object (#80) --- aws-s3-transfer-manager/examples/cp.rs | 2 +- .../src/operation/download.rs | 53 +++++++++++---- .../src/operation/download/body.rs | 7 +- .../src/operation/download/handle.rs | 21 +++--- .../src/operation/download/service.rs | 67 +++++++++++++------ .../src/operation/download_objects/worker.rs | 13 +--- .../tests/download_test.rs | 48 +++++++++---- 7 files changed, 141 insertions(+), 70 deletions(-) diff --git a/aws-s3-transfer-manager/examples/cp.rs b/aws-s3-transfer-manager/examples/cp.rs index 3652ebb..eb1a241 100644 --- a/aws-s3-transfer-manager/examples/cp.rs +++ b/aws-s3-transfer-manager/examples/cp.rs @@ -10,7 +10,7 @@ use std::time; use aws_s3_transfer_manager::io::InputStream; use aws_s3_transfer_manager::metrics::unit::ByteUnit; use aws_s3_transfer_manager::metrics::Throughput; -use aws_s3_transfer_manager::operation::download::body::Body; +use aws_s3_transfer_manager::operation::download::Body; use aws_s3_transfer_manager::types::{ConcurrencySetting, PartSize}; use aws_sdk_s3::error::DisplayErrorContext; use bytes::Buf; diff --git a/aws-s3-transfer-manager/src/operation/download.rs b/aws-s3-transfer-manager/src/operation/download.rs index df03fc3..13ed087 100644 --- a/aws-s3-transfer-manager/src/operation/download.rs +++ b/aws-s3-transfer-manager/src/operation/download.rs @@ -9,11 +9,13 @@ use aws_sdk_s3::error::DisplayErrorContext; /// Request type for dowloading a single object from Amazon S3 pub use input::{DownloadInput, DownloadInputBuilder}; -/// Abstractions for response bodies and consuming data streams. -pub mod body; /// Operation builders pub mod builders; +/// Abstractions for responses and consuming data streams. +mod body; +pub use body::{Body, ChunkOutput}; + mod discovery; mod handle; @@ -33,15 +35,14 @@ use crate::error; use crate::io::AggregatedBytes; use crate::runtime::scheduler::OwnedWorkPermit; use aws_smithy_types::byte_stream::ByteStream; -use body::{Body, ChunkOutput}; use discovery::discover_obj; use service::distribute_work; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tokio::sync::{mpsc, oneshot, Mutex, OnceCell}; +use tokio::sync::{mpsc, oneshot, watch, Mutex, OnceCell}; use tokio::task::{self, JoinSet}; -use super::TransferContext; +use super::{CancelNotificationReceiver, CancelNotificationSender, TransferContext}; /// Operation struct for single object download #[derive(Clone, Default, Debug)] @@ -101,7 +102,7 @@ async fn send_discovery( object_meta_tx: oneshot::Sender, input: DownloadInput, use_current_span_as_parent_for_tasks: bool, -) -> Result<(), crate::error::Error> { +) { // create span to serve as parent of spawned child tasks. let parent_span_for_tasks = tracing::debug_span!( parent: if use_current_span_as_parent_for_tasks { tracing::Span::current().id() } else { None } , @@ -115,13 +116,37 @@ async fn send_discovery( } // acquire a permit for discovery - let permit = ctx.handle.scheduler.acquire_permit().await?; + let permit = ctx.handle.scheduler.acquire_permit().await; + let permit = match permit { + Ok(permit) => permit, + Err(err) => { + if comp_tx.send(Err(err)).await.is_err() { + tracing::debug!("Download handle for key({:?}) has been dropped, aborting during the discovery phase", input.key); + } + return; + } + }; // make initial discovery about the object size, metadata, possibly first chunk - let mut discovery = discover_obj(&ctx, &input).await?; - // FIXME - This will fail if the handle is dropped at this point. We should handle - // the cancellation gracefully here. - let _ = object_meta_tx.send(discovery.object_meta); + let discovery = discover_obj(&ctx, &input).await; + let mut discovery = match discovery { + Ok(discovery) => discovery, + Err(err) => { + if comp_tx.send(Err(err)).await.is_err() { + tracing::debug!("Download handle for key({:?}) has been dropped, aborting during the discovery phase", input.key); + } + return; + } + }; + + if object_meta_tx.send(discovery.object_meta).is_err() { + tracing::debug!( + "Download handle for key({:?}) has been dropped, aborting during the discovery phase", + input.key + ); + return; + } + let initial_chunk = discovery.initial_chunk.take(); let mut tasks = tasks.lock().await; @@ -148,7 +173,6 @@ async fn send_discovery( parent_span_for_tasks, ); } - Ok(()) } /// Handle possibly sending the first chunk of data received through discovery. Returns @@ -200,14 +224,19 @@ fn handle_discovery_chunk( #[derive(Debug)] pub(crate) struct DownloadState { current_seq: AtomicU64, + cancel_tx: CancelNotificationSender, + cancel_rx: CancelNotificationReceiver, } type DownloadContext = TransferContext; impl DownloadContext { fn new(handle: Arc) -> Self { + let (cancel_tx, cancel_rx) = watch::channel(false); let state = Arc::new(DownloadState { current_seq: AtomicU64::new(0), + cancel_tx, + cancel_rx, }); TransferContext { handle, state } } diff --git a/aws-s3-transfer-manager/src/operation/download/body.rs b/aws-s3-transfer-manager/src/operation/download/body.rs index e5348c9..aa416a9 100644 --- a/aws-s3-transfer-manager/src/operation/download/body.rs +++ b/aws-s3-transfer-manager/src/operation/download/body.rs @@ -12,7 +12,7 @@ use crate::io::AggregatedBytes; use super::chunk_meta::ChunkMetadata; -/// Stream of binary data representing an Amazon S3 Object's contents. +/// Stream of [ChunkOutput] representing an Amazon S3 Object's contents and metadata. /// /// Wraps potentially multiple streams of binary data into a single coherent stream. /// The data on this stream is sequenced into the correct order. @@ -81,7 +81,10 @@ impl Body { match self.inner.next().await { None => break, Some(Ok(chunk)) => self.sequencer.push(chunk), - Some(Err(err)) => return Some(Err(err)), + Some(Err(err)) => { + self.close(); + return Some(Err(err)); + } } } diff --git a/aws-s3-transfer-manager/src/operation/download/handle.rs b/aws-s3-transfer-manager/src/operation/download/handle.rs index 63689f3..1d1f98a 100644 --- a/aws-s3-transfer-manager/src/operation/download/handle.rs +++ b/aws-s3-transfer-manager/src/operation/download/handle.rs @@ -23,11 +23,11 @@ pub struct DownloadHandle { /// Object metadata. pub(crate) object_meta: OnceCell, - /// The object content + /// The object content, in chunks, and the metadata for each chunk pub(crate) body: Body, /// Discovery task - pub(crate) discovery: task::JoinHandle>, + pub(crate) discovery: task::JoinHandle<()>, /// All child tasks (ranged GetObject) spawned for this download pub(crate) tasks: Arc>>, @@ -53,7 +53,7 @@ impl DownloadHandle { Ok(meta) } - /// Object content + /// The object content, in chunks, and the metadata for each chunk pub fn body(&self) -> &Body { &self.body } @@ -63,19 +63,16 @@ impl DownloadHandle { &mut self.body } - /// Consume the handle and wait for download transfer to complete - #[tracing::instrument(skip_all, level = "debug", name = "join-download")] - pub async fn join(mut self) -> Result<(), crate::error::Error> { + /// Abort the download and cancel any in-progress work. + pub async fn abort(mut self) { self.body.close(); - - self.discovery.await??; + self.discovery.abort(); + let _ = self.discovery.await; // It's safe to grab the lock here because discovery is already complete, and we will never // lock tasks again after discovery to spawn more tasks. let mut tasks = self.tasks.lock().await; - while let Some(join_result) = tasks.join_next().await { - join_result?; - } - Ok(()) + tasks.abort_all(); + while (tasks.join_next().await).is_some() {} } } diff --git a/aws-s3-transfer-manager/src/operation/download/service.rs b/aws-s3-transfer-manager/src/operation/download/service.rs index ef20736..d3ec31f 100644 --- a/aws-s3-transfer-manager/src/operation/download/service.rs +++ b/aws-s3-transfer-manager/src/operation/download/service.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ use crate::error; +use crate::error::ErrorKind; use crate::http::header; use crate::io::AggregatedBytes; use crate::middleware::limit::concurrency::ConcurrencyLimitLayer; @@ -69,25 +70,33 @@ async fn download_specific_chunk( ); let op = input.into_sdk_operation(ctx.client()); - let mut resp = op - .send() - // no instrument() here because parent span shows duration of send + collect - .await - .map_err(error::from_kind(error::ErrorKind::ChunkFailed))?; - - let body = mem::replace(&mut resp.body, ByteStream::new(SdkBody::taken())); - let body = AggregatedBytes::from_byte_stream(body) - .instrument(tracing::debug_span!( - "collect-body-from-download-chunk", - seq - )) - .await?; - - Ok(ChunkOutput { - seq, - data: body, - metadata: resp.into(), - }) + let mut cancel_rx = ctx.state.cancel_rx.clone(); + tokio::select! { + _ = cancel_rx.changed() => { + tracing::debug!("Received cancellating signal, exiting and not downloading chunk#{seq}"); + Err(error::operation_cancelled()) + }, + resp = op.send() => { + match resp { + Err(err) => Err(error::from_kind(error::ErrorKind::ChunkFailed)(err)), + Ok(mut resp) => { + let body = mem::replace(&mut resp.body, ByteStream::new(SdkBody::taken())); + let body = AggregatedBytes::from_byte_stream(body) + .instrument(tracing::debug_span!( + "collect-body-from-download-chunk", + seq + )) + .await?; + + Ok(ChunkOutput { + seq, + data: body, + metadata: resp.into(), + }) + }, + } + } + } } /// Create a new tower::Service for downloading individual chunks of an object from S3 @@ -139,12 +148,30 @@ pub(super) fn distribute_work( let svc = svc.clone(); let comp_tx = comp_tx.clone(); + let cancel_tx = ctx.state.cancel_tx.clone(); let task = async move { - // TODO: If downloading a chunk fails, do we want to abort the download? let resp = svc.oneshot(req).await; + // If any chunk fails, send cancel notification, to kill any other in-flight chunks + if let Err(err) = &resp { + if *err.kind() == ErrorKind::OperationCancelled { + // Ignore any OperationCancelled errors. + return; + } + if cancel_tx.send(true).is_err() { + tracing::debug!( + "all receiver ends have dropped, unable to send a cancellation signal" + ); + } + } + if let Err(err) = comp_tx.send(resp).await { tracing::debug!(error = ?err, "chunk send failed, channel closed"); + if cancel_tx.send(true).is_err() { + tracing::debug!( + "all receiver ends have dropped, unable to send a cancellation signal" + ); + } } }; tasks.spawn(task.instrument(parent_span_for_tasks.clone())); diff --git a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs index 35b938e..fa670ff 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs @@ -2,6 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ +use crate::operation::download::Body; use async_channel::{Receiver, Sender}; use path_clean::PathClean; use std::borrow::Cow; @@ -12,7 +13,6 @@ use tokio::fs; use tokio::io::AsyncWriteExt; use crate::error::{self, ErrorKind}; -use crate::operation::download::body::Body; use crate::operation::download::{DownloadInput, DownloadInputBuilder}; use crate::operation::DEFAULT_DELIMITER; use crate::types::{DownloadFilter, FailedDownload, FailedTransferPolicy}; @@ -191,13 +191,8 @@ async fn download_single_obj( .has_changed() .expect("the channel should be open as it is owned by `DownloadObjectsState`") { - /* - * TODO(single download cleanup): Comment in the following lines of code once single download has been cleaned up. - * Note that it may not be called `.abort()` depending on the outcome of the cleanup. - * - * handle.abort().await; - * return Err(error::operation_cancelled()); - */ + handle.abort().await; + return Err(error::operation_cancelled()); } let _ = handle.object_meta().await?; @@ -214,8 +209,6 @@ async fn download_single_obj( } } - handle.join().await?; - Ok(()) } diff --git a/aws-s3-transfer-manager/tests/download_test.rs b/aws-s3-transfer-manager/tests/download_test.rs index db69353..cc57d96 100644 --- a/aws-s3-transfer-manager/tests/download_test.rs +++ b/aws-s3-transfer-manager/tests/download_test.rs @@ -5,7 +5,7 @@ use aws_config::Region; use aws_s3_transfer_manager::{ - error::BoxError, + error::{BoxError, Error}, operation::download::DownloadHandle, types::{ConcurrencySetting, PartSize}, }; @@ -45,14 +45,22 @@ fn dummy_expected_request() -> http_02x::Request { } /// drain/consume the body -async fn drain(handle: &mut DownloadHandle) -> Result { +async fn drain(handle: &mut DownloadHandle) -> Result { let body = handle.body_mut(); let mut data = BytesMut::new(); + let mut error: Option = None; while let Some(chunk) = body.next().await { - let chunk = chunk?.data.into_bytes(); - data.put(chunk); + match chunk { + Ok(chunk) => data.put(chunk.data.into_bytes()), + Err(err) => { + error.get_or_insert(err); + } + } } + if let Some(error) = error { + return Err(error); + } Ok(data.into()) } @@ -148,11 +156,9 @@ async fn test_download_ranges() { requests[2].headers().get("Range"), Some("bytes=10485760-12582911") ); - - handle.join().await.unwrap(); } -/// Test body not consumed which should not prevent the handle from being joined +/// Test body not consumed which should not prevent the handle from being dropped #[tokio::test] async fn test_body_not_consumed() { let data = rand_data(12 * MEBIBYTE); @@ -160,14 +166,33 @@ async fn test_body_not_consumed() { let (tm, _) = simple_test_tm(&data, part_size); - let handle = tm + let mut handle = tm .download() .bucket("test-bucket") .key("test-object") .initiate() .unwrap(); - handle.join().await.unwrap(); + let _ = handle.body_mut().next().await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_abort_download() { + let data = rand_data(25 * MEBIBYTE); + let part_size = MEBIBYTE; + + let (tm, http_client) = simple_test_tm(&data, part_size); + + let handle = tm + .download() + .bucket("test-bucket") + .key("test-object") + .initiate() + .unwrap(); + let _ = handle.object_meta().await; + handle.abort().await; + let requests = http_client.actual_requests().collect::>(); + assert!(requests.len() < data.len() / part_size); } pin_project! { @@ -282,7 +307,6 @@ async fn test_retry_failed_chunk() { assert_eq!(data.len(), body.len()); let requests = http_client.actual_requests().collect::>(); assert_eq!(3, requests.len()); - handle.join().await.unwrap(); } const ERROR_RESPONSE: &str = r#" @@ -297,7 +321,7 @@ const ERROR_RESPONSE: &str = r#" /// Test non retryable SdkError #[tokio::test] async fn test_non_retryable_error() { - let data = rand_data(12 * MEBIBYTE); + let data = rand_data(20 * MEBIBYTE); let part_size = 8 * MEBIBYTE; let http_client = StaticReplayClient::new(vec![ @@ -334,7 +358,6 @@ async fn test_non_retryable_error() { let _ = drain(&mut handle).await.unwrap_err(); - handle.join().await.unwrap(); let requests = http_client.actual_requests().collect::>(); assert_eq!(2, requests.len()); } @@ -394,7 +417,6 @@ async fn test_retry_max_attempts() { .unwrap(); let _ = drain(&mut handle).await.unwrap_err(); - handle.join().await.unwrap(); let requests = http_client.actual_requests().collect::>(); assert_eq!(4, requests.len()); } From d18bc7a1ce817d78adca8c490e4b29668ebd82ad Mon Sep 17 00:00:00 2001 From: Dengke Tang <815825145@qq.com> Date: Tue, 10 Dec 2024 13:37:22 -0800 Subject: [PATCH 3/4] Add user agent (#79) - Add an interceptor when we load the config to add the transfer manager metric for user agent - Add a framework metadata for the transfer manager config and pass it down to the sdk s3 client as the framework name in the user agent. - Only added for the `ConfigLoader`, we can decide what to do when user pass in their own client via config directly later. --- aws-s3-transfer-manager/Cargo.toml | 5 +- aws-s3-transfer-manager/external-types.toml | 2 +- aws-s3-transfer-manager/src/config.rs | 24 ++++ aws-s3-transfer-manager/src/config/loader.rs | 131 ++++++++++++++++++- 4 files changed, 156 insertions(+), 6 deletions(-) diff --git a/aws-s3-transfer-manager/Cargo.toml b/aws-s3-transfer-manager/Cargo.toml index b2d2da4..a0a074a 100644 --- a/aws-s3-transfer-manager/Cargo.toml +++ b/aws-s3-transfer-manager/Cargo.toml @@ -15,7 +15,8 @@ aws-config = { version = "1.5.6", features = ["behavior-version-latest"] } aws-sdk-s3 = { version = "1.51.0", features = ["behavior-version-latest"] } aws-smithy-async = "1.2.1" aws-smithy-experimental = { version = "0.1.3", features = ["crypto-aws-lc"] } -aws-smithy-runtime-api = "1.7.1" +aws-smithy-runtime-api = "1.7.3" +aws-runtime = "1.4.4" aws-smithy-types = "1.2.6" aws-types = "1.3.3" blocking = "1.6.0" @@ -32,7 +33,7 @@ walkdir = "2" [dev-dependencies] aws-sdk-s3 = { version = "1.51.0", features = ["behavior-version-latest", "test-util"] } aws-smithy-mocks-experimental = "0.2.1" -aws-smithy-runtime = { version = "1.7.1", features = ["client", "connector-hyper-0-14-x", "test-util", "wire-mock"] } +aws-smithy-runtime = { version = "1.7.4", features = ["client", "connector-hyper-0-14-x", "test-util", "wire-mock"] } clap = { version = "4.5.7", default-features = false, features = ["derive", "std", "help"] } console-subscriber = "0.4.0" http-02x = { package = "http", version = "0.2.9" } diff --git a/aws-s3-transfer-manager/external-types.toml b/aws-s3-transfer-manager/external-types.toml index e372912..aa3c1d1 100644 --- a/aws-s3-transfer-manager/external-types.toml +++ b/aws-s3-transfer-manager/external-types.toml @@ -8,5 +8,5 @@ allowed_external_types = [ "bytes::bytes::Bytes", "bytes::buf::buf_impl::Buf", "aws_types::request_id::RequestId", - "aws_types::request_id::RequestIdExt" + "aws_types::request_id::RequestIdExt", ] diff --git a/aws-s3-transfer-manager/src/config.rs b/aws-s3-transfer-manager/src/config.rs index 66b9cd8..2197c32 100644 --- a/aws-s3-transfer-manager/src/config.rs +++ b/aws-s3-transfer-manager/src/config.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_runtime::user_agent::FrameworkMetadata; + use crate::metrics::unit::ByteUnit; use crate::types::{ConcurrencySetting, PartSize}; use std::cmp; @@ -18,6 +20,7 @@ pub struct Config { multipart_threshold: PartSize, target_part_size: PartSize, concurrency: ConcurrencySetting, + framework_metadata: Option, client: aws_sdk_s3::client::Client, } @@ -43,6 +46,12 @@ impl Config { &self.concurrency } + /// Returns the framework metadata setting when using transfer manager. + #[doc(hidden)] + pub fn framework_metadata(&self) -> Option<&FrameworkMetadata> { + self.framework_metadata.as_ref() + } + /// The Amazon S3 client instance that will be used to send requests to S3. pub fn client(&self) -> &aws_sdk_s3::Client { &self.client @@ -55,6 +64,7 @@ pub struct Builder { multipart_threshold_part_size: PartSize, target_part_size: PartSize, concurrency: ConcurrencySetting, + framework_metadata: Option, client: Option, } @@ -122,8 +132,21 @@ impl Builder { self } + /// Sets the framework metadata for the transfer manager. + /// + /// This _optional_ name is used to identify the framework using transfer manager in the user agent that + /// gets sent along with requests. + #[doc(hidden)] + pub fn framework_metadata(mut self, framework_metadata: Option) -> Self { + self.framework_metadata = framework_metadata; + self + } + /// Set an explicit S3 client to use. pub fn client(mut self, client: aws_sdk_s3::Client) -> Self { + // TODO - decide the approach here: + // - Convert the client to build to modify it based on other configs for transfer manager + // - Instead of taking the client, take sdk-config/s3-config/builder? self.client = Some(client); self } @@ -134,6 +157,7 @@ impl Builder { multipart_threshold: self.multipart_threshold_part_size, target_part_size: self.target_part_size, concurrency: self.concurrency, + framework_metadata: self.framework_metadata, client: self.client.expect("client set"), } } diff --git a/aws-s3-transfer-manager/src/config/loader.rs b/aws-s3-transfer-manager/src/config/loader.rs index 56b76c3..02212f9 100644 --- a/aws-s3-transfer-manager/src/config/loader.rs +++ b/aws-s3-transfer-manager/src/config/loader.rs @@ -3,6 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_config::BehaviorVersion; +use aws_runtime::sdk_feature::AwsSdkFeature; +use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, FrameworkMetadata}; +use aws_sdk_s3::config::{Intercept, IntoShared}; +use aws_types::os_shim_internal::Env; + use crate::config::Builder; use crate::{ http, @@ -10,6 +16,39 @@ use crate::{ Config, }; +#[derive(Debug)] +struct S3TransferManagerInterceptor { + frame_work_meta_data: Option, +} + +impl Intercept for S3TransferManagerInterceptor { + fn name(&self) -> &'static str { + "S3TransferManager" + } + + fn read_before_execution( + &self, + _ctx: &aws_sdk_s3::config::interceptors::BeforeSerializationInterceptorContextRef<'_>, + cfg: &mut aws_sdk_s3::config::ConfigBag, + ) -> Result<(), aws_sdk_s3::error::BoxError> { + // Assume the interceptor only be added to the client constructed by the loader. + // In this case, there should not be any user agent was sent before this interceptor starts. + // Create our own user agent with S3Transfer feature and user passed-in framework_meta_data if any. + cfg.interceptor_state() + .store_append(AwsSdkFeature::S3Transfer); + let api_metadata = cfg.load::().unwrap(); + // TODO: maybe APP Name someday + let mut ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata.clone()); + if let Some(framework_metadata) = self.frame_work_meta_data.clone() { + ua = ua.with_framework_metadata(framework_metadata); + } + + cfg.interceptor_state().store_put(ua); + + Ok(()) + } +} + /// Load transfer manager [`Config`] from the environment. #[derive(Default, Debug)] pub struct ConfigLoader { @@ -52,17 +91,103 @@ impl ConfigLoader { self } + /// Sets the framework metadata for the transfer manager. + /// + /// This _optional_ name is used to identify the framework using transfer manager in the user agent that + /// gets sent along with requests. + #[doc(hidden)] + pub fn framework_metadata(mut self, framework_metadata: Option) -> Self { + self.builder = self.builder.framework_metadata(framework_metadata); + self + } + /// Load the default configuration /// /// If fields have been overridden during builder construction, the override values will be /// used. Otherwise, the default values for each field will be provided. pub async fn load(self) -> Config { - let shared_config = aws_config::from_env() + let shared_config = aws_config::defaults(BehaviorVersion::latest()) .http_client(http::default_client()) .load() .await; - let s3_client = aws_sdk_s3::Client::new(&shared_config); - let builder = self.builder.client(s3_client); + + let mut sdk_client_builder = aws_sdk_s3::config::Builder::from(&shared_config); + + let interceptor = S3TransferManagerInterceptor { + frame_work_meta_data: self.builder.framework_metadata.clone(), + }; + sdk_client_builder.push_interceptor(S3TransferManagerInterceptor::into_shared(interceptor)); + let builder = self + .builder + .client(aws_sdk_s3::Client::from_conf(sdk_client_builder.build())); builder.build() } } + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use crate::types::{ConcurrencySetting, PartSize}; + use aws_config::Region; + use aws_runtime::user_agent::FrameworkMetadata; + use aws_sdk_s3::config::Intercept; + use aws_smithy_runtime::client::http::test_util::capture_request; + + #[tokio::test] + async fn load_with_interceptor() { + let config = crate::from_env() + .concurrency(ConcurrencySetting::Explicit(123)) + .part_size(PartSize::Target(8)) + .load() + .await; + let sdk_s3_config = config.client().config(); + let tm_interceptor_exists = sdk_s3_config + .interceptors() + .any(|item| item.name() == "S3TransferManager"); + assert!(tm_interceptor_exists); + } + + #[tokio::test] + async fn load_with_interceptor_and_framework_metadata() { + let (http_client, captured_request) = capture_request(None); + let config = crate::from_env() + .concurrency(ConcurrencySetting::Explicit(123)) + .part_size(PartSize::Target(8)) + .framework_metadata(Some( + FrameworkMetadata::new("some-framework", Some(Cow::Borrowed("1.3"))).unwrap(), + )) + .load() + .await; + // Inject the captured request to the http client to capture the request made from transfer manager. + let sdk_s3_config = config + .client() + .config() + .to_builder() + .http_client(http_client) + .region(Region::from_static("us-west-2")) + .with_test_defaults() + .build(); + + let capture_request_config = crate::Config::builder() + .client(aws_sdk_s3::Client::from_conf(sdk_s3_config)) + .concurrency(ConcurrencySetting::Explicit(123)) + .part_size(PartSize::Target(8)) + .build(); + + let transfer_manager = crate::Client::new(capture_request_config); + + let mut handle = transfer_manager + .download() + .bucket("foo") + .key("bar") + .initiate() + .unwrap(); + // Expect to fail + let _ = handle.body_mut().next().await; + // Check the request made contains the expected framework meta data in user agent. + let expected_req = captured_request.expect_request(); + let user_agent = expected_req.headers().get("x-amz-user-agent").unwrap(); + assert!(user_agent.contains("lib/some-framework/1.3")); + } +} From 6e646645e8cc79291b637d71b49b2f45d4ad1025 Mon Sep 17 00:00:00 2001 From: Dengke Tang <815825145@qq.com> Date: Fri, 13 Dec 2024 15:53:28 -0800 Subject: [PATCH 4/4] support if-match for get objects (#82) support detecting object changed during download --- .../src/operation/download.rs | 7 +- .../tests/download_test.rs | 93 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/aws-s3-transfer-manager/src/operation/download.rs b/aws-s3-transfer-manager/src/operation/download.rs index 13ed087..24bb978 100644 --- a/aws-s3-transfer-manager/src/operation/download.rs +++ b/aws-s3-transfer-manager/src/operation/download.rs @@ -100,7 +100,7 @@ async fn send_discovery( ctx: DownloadContext, comp_tx: mpsc::Sender>, object_meta_tx: oneshot::Sender, - input: DownloadInput, + mut input: DownloadInput, use_current_span_as_parent_for_tasks: bool, ) { // create span to serve as parent of spawned child tasks. @@ -139,6 +139,11 @@ async fn send_discovery( } }; + // Add if_match to the rest of the requests using the etag + // we got from discovery to ensure the object stays the same + // during the download process. + input.if_match.clone_from(&discovery.object_meta.e_tag); + if object_meta_tx.send(discovery.object_meta).is_err() { tracing::debug!( "Download handle for key({:?}) has been dropped, aborting during the discovery phase", diff --git a/aws-s3-transfer-manager/tests/download_test.rs b/aws-s3-transfer-manager/tests/download_test.rs index cc57d96..d69c136 100644 --- a/aws-s3-transfer-manager/tests/download_test.rs +++ b/aws-s3-transfer-manager/tests/download_test.rs @@ -90,6 +90,7 @@ fn simple_object_connector(data: &Bytes, part_size: usize) -> StaticReplayClient "Content-Range", format!("bytes {start}-{end}/{}", data.len()), ) + .header("ETag", "my-etag") .body(SdkBody::from(chunk)) .unwrap(), ) @@ -420,3 +421,95 @@ async fn test_retry_max_attempts() { let requests = http_client.actual_requests().collect::>(); assert_eq!(4, requests.len()); } + +/// Test the if_match header was added correctly based on the response from server. +#[tokio::test] +async fn test_download_if_match() { + let data = rand_data(12 * MEBIBYTE); + let part_size = 5 * MEBIBYTE; + + let (tm, http_client) = simple_test_tm(&data, part_size); + + let mut handle = tm + .download() + .bucket("test-bucket") + .key("test-object") + .initiate() + .unwrap(); + + let _ = drain(&mut handle).await.unwrap(); + + let requests = http_client.actual_requests().collect::>(); + assert_eq!(3, requests.len()); + + // The first request is to discover the object meta data and should not have any if-match + assert_eq!(requests[0].headers().get("If-Match"), None); + // All the following requests should have the if-match header + assert_eq!(requests[1].headers().get("If-Match"), Some("my-etag")); + assert_eq!(requests[2].headers().get("If-Match"), Some("my-etag")); +} + +const OBJECT_MODIFIED_RESPONSE: &str = r#" + + PreconditionFailed + At least one of the pre-conditions you specified did not hold + If-Match + +"#; + +/// Test that if the object modified during download. +#[tokio::test] +async fn test_download_object_modified() { + let data = rand_data(12 * MEBIBYTE); + let part_size = 5 * MEBIBYTE; + + // Create a static replay client (http connector) to mock the S3 response when object modified during download. + // + // Assumptions: + // 1. First request for discovery, succeed with etag + // 2. Followed requests fail to mock the object changed during download. + let events = data + .chunks(part_size) + .enumerate() + .map(|(idx, chunk)| { + let start = idx * part_size; + let end = std::cmp::min(start + part_size, data.len()) - 1; + let mut response = http_02x::Response::builder() + .status(206) + .header("Content-Length", format!("{}", end - start + 1)) + .header( + "Content-Range", + format!("bytes {start}-{end}/{}", data.len()), + ) + .header("ETag", "my-etag") + .body(SdkBody::from(chunk)) + .unwrap(); + if idx > 0 { + response = http_02x::Response::builder() + .status(412) + .header("Date", "Thu, 12 Jan 2023 00:04:21 GMT") + .body(SdkBody::from(OBJECT_MODIFIED_RESPONSE)) + .unwrap(); + } + ReplayEvent::new( + // NOTE: Rather than try to recreate all the expected requests we just put in placeholders and + // make our own assertions against the captured requests. + dummy_expected_request(), + response, + ) + }) + .collect(); + + let http_client = StaticReplayClient::new(events); + let tm = test_tm(http_client.clone(), part_size); + + let mut handle = tm + .download() + .bucket("test-bucket") + .key("test-object") + .initiate() + .unwrap(); + + let error = drain(&mut handle).await.unwrap_err(); + assert!(format!("{:?}", error).contains("PreconditionFailed")); +}