diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7fa57a93..86f5d1f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,9 +15,9 @@ env: # Change to specific Rust release to pin rust_stable: stable rust_nightly: nightly-2024-07-07 - rust_clippy: '1.79' + rust_clippy: '1.81' # When updating this, also update relevant docs - rust_min: '1.79' + rust_min: '1.81' defaults: @@ -55,7 +55,7 @@ jobs: - docs - minrust steps: - - run: exit 0 + - run: exit 0 test-hll: name: Test S3 transfer manager HLL @@ -133,7 +133,7 @@ jobs: run: | cargo doc --lib --no-deps --all-features --document-private-items env: - RUSTFLAGS: --cfg docsrs + RUSTFLAGS: --cfg docsrs RUSTDOCFLAGS: --cfg docsrs minrust: @@ -246,7 +246,7 @@ jobs: - name: Install cargo-hack uses: taiki-e/install-action@cargo-hack - uses: Swatinem/rust-cache@v2 - - name: check --feature-powerset + - name: check --feature-powerset run: cargo hack check --all --feature-powerset # TODO - get cross check working diff --git a/aws-s3-transfer-manager/examples/cp.rs b/aws-s3-transfer-manager/examples/cp.rs index eb1a2412..838ad5cb 100644 --- a/aws-s3-transfer-manager/examples/cp.rs +++ b/aws-s3-transfer-manager/examples/cp.rs @@ -252,8 +252,7 @@ async fn do_upload(args: Args) -> Result<(), BoxError> { .bucket(bucket) .key(key) .body(stream) - .send() - .await?; + .initiate()?; let _resp = handle.join().await?; let elapsed = start.elapsed(); diff --git a/aws-s3-transfer-manager/src/client.rs b/aws-s3-transfer-manager/src/client.rs index 29e14885..a482e6e2 100644 --- a/aws-s3-transfer-manager/src/client.rs +++ b/aws-s3-transfer-manager/src/client.rs @@ -100,10 +100,9 @@ impl Client { /// .bucket("my-bucket") /// .key("my-key") /// .body(stream) - /// .send() - /// .await?; + /// .initiate()?; /// - /// // send() may return before the transfer is complete. + /// // initiate() will return before the transfer is complete. /// // Call the `join()` method on the returned handle to drive the transfer to completion. /// // The handle can also be used to get progress, pause, or cancel the transfer, etc. /// let response = handle.join().await?; diff --git a/aws-s3-transfer-manager/src/operation/download/builders.rs b/aws-s3-transfer-manager/src/operation/download/builders.rs index 8bd1b413..0347083b 100644 --- a/aws-s3-transfer-manager/src/operation/download/builders.rs +++ b/aws-s3-transfer-manager/src/operation/download/builders.rs @@ -524,7 +524,10 @@ impl DownloadFluentBuilder { impl crate::operation::download::input::DownloadInputBuilder { /// Initiate a download transfer for a single object with this input using the given client. - pub fn send_with(self, client: &crate::Client) -> Result { + pub fn initiate_with( + self, + client: &crate::Client, + ) -> Result { let mut fluent_builder = client.download(); fluent_builder.inner = self; fluent_builder.initiate() diff --git a/aws-s3-transfer-manager/src/operation/upload.rs b/aws-s3-transfer-manager/src/operation/upload.rs index ccc08a1f..d82292ff 100644 --- a/aws-s3-transfer-manager/src/operation/upload.rs +++ b/aws-s3-transfer-manager/src/operation/upload.rs @@ -14,9 +14,9 @@ mod service; use crate::error; use crate::io::InputStream; -use aws_smithy_types::byte_stream::ByteStream; use context::UploadContext; pub use handle::UploadHandle; +use handle::{MultipartUploadData, UploadType}; /// Request type for uploads to Amazon S3 pub use input::{UploadInput, UploadInputBuilder}; /// Response type for uploads to Amazon S3 @@ -36,57 +36,58 @@ pub(crate) struct Upload; impl Upload { /// Execute a single `Upload` transfer operation - pub(crate) async fn orchestrate( + pub(crate) fn orchestrate( handle: Arc, mut input: crate::operation::upload::UploadInput, ) -> Result { - let min_mpu_threshold = handle.mpu_threshold_bytes(); - let stream = input.take_body(); - let ctx = new_context(handle, input); - - // MPU has max of 10K parts which requires us to know the upper bound on the content length (today anyway) - // While true for file-based workloads, the upper `size_hint` might not be equal to the actual bytes transferred. - let content_length = stream - .size_hint() - .upper() - .ok_or_else(crate::io::error::Error::upper_bound_size_hint_required)?; + let ctx = new_context(handle.clone(), input); + Ok(UploadHandle::new( + ctx.clone(), + tokio::spawn(try_start_upload(handle.clone(), stream, ctx)), + )) + } +} +async fn try_start_upload( + handle: Arc, + stream: InputStream, + ctx: UploadContext, +) -> Result { + let min_mpu_threshold = handle.mpu_threshold_bytes(); + + // MPU has max of 10K parts which requires us to know the upper bound on the content length (today anyway) + // While true for file-based workloads, the upper `size_hint` might not be equal to the actual bytes transferred. + let content_length = stream + .size_hint() + .upper() + .ok_or_else(crate::io::error::Error::upper_bound_size_hint_required)?; + + let upload_type = if content_length < min_mpu_threshold && !stream.is_mpu_only() { + tracing::trace!("upload request content size hint ({content_length}) less than min part size threshold ({min_mpu_threshold}); sending as single PutObject request"); + UploadType::PutObject(tokio::spawn(put_object( + ctx.clone(), + stream, + content_length, + ))) + } else { + // TODO - to upload a 0 byte object via MPU you have to send [CreateMultipartUpload, UploadPart(part=1, 0 bytes), CompleteMultipartUpload] + // we should add tests for this and hide this edge case from the user (e.g. send an empty part when a custom PartStream returns `None` immediately) // FIXME - investigate what it would take to allow non mpu uploads for `PartStream` implementations - let handle = if content_length < min_mpu_threshold && !stream.is_mpu_only() { - tracing::trace!("upload request content size hint ({content_length}) less than min part size threshold ({min_mpu_threshold}); sending as single PutObject request"); - try_start_put_object(ctx, stream, content_length).await? - } else { - // TODO - to upload a 0 byte object via MPU you have to send [CreateMultipartUpload, UploadPart(part=1, 0 bytes), CompleteMultipartUpload] - // we should add tests for this and hide this edge case from the user (e.g. send an empty part when a custom PartStream returns `None` immediately) - try_start_mpu_upload(ctx, stream, content_length).await? - }; - - Ok(handle) - } + try_start_mpu_upload(ctx, stream, content_length).await? + }; + Ok(upload_type) } -async fn try_start_put_object( +async fn put_object( ctx: UploadContext, stream: InputStream, content_length: u64, -) -> Result { - let byte_stream = stream.into_byte_stream().await?; +) -> Result { + let body = stream.into_byte_stream().await?; let content_length: i64 = content_length.try_into().map_err(|_| { error::invalid_input(format!("content_length:{} is invalid.", content_length)) })?; - - Ok(UploadHandle::new_put_object( - ctx.clone(), - tokio::spawn(put_object(ctx.clone(), byte_stream, content_length)), - )) -} - -async fn put_object( - ctx: UploadContext, - body: ByteStream, - content_length: i64, -) -> Result { // FIXME - This affects performance in cases with a lot of small files workloads. We need a way to schedule // more work for a lot of small files. let _permit = ctx.handle.scheduler.acquire_permit().await?; @@ -147,7 +148,7 @@ async fn try_start_mpu_upload( ctx: UploadContext, stream: InputStream, content_length: u64, -) -> Result { +) -> Result { let part_size = cmp::max( ctx.handle.upload_part_size_bytes(), content_length / MAX_PARTS, @@ -159,18 +160,22 @@ async fn try_start_mpu_upload( "multipart upload started with upload id: {:?}", mpu.upload_id ); - - let mut handle = UploadHandle::new_multipart(ctx); - handle.set_response(mpu); - distribute_work(&mut handle, stream, part_size)?; - Ok(handle) + let upload_id = mpu.upload_id.clone().expect("upload_id is present"); + let mut mpu_data = MultipartUploadData { + upload_part_tasks: Default::default(), + read_body_tasks: Default::default(), + response: Some(mpu), + upload_id: upload_id.clone(), + }; + + distribute_work(&mut mpu_data, ctx, stream, part_size)?; + Ok(UploadType::MultipartUpload(mpu_data)) } fn new_context(handle: Arc, req: UploadInput) -> UploadContext { UploadContext { handle, request: Arc::new(req), - upload_id: None, } } @@ -223,6 +228,7 @@ mod test { use crate::io::InputStream; use crate::operation::upload::UploadInput; use crate::types::{ConcurrencySetting, PartSize}; + use aws_sdk_s3::operation::abort_multipart_upload::AbortMultipartUploadOutput; use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadOutput; use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::put_object::PutObjectOutput; @@ -231,6 +237,7 @@ mod test { use bytes::Bytes; use std::ops::Deref; use std::sync::Arc; + use std::sync::Barrier; use test_common::mock_client_with_stubbed_http_client; #[tokio::test] @@ -295,7 +302,7 @@ mod test { .key("test-key") .body(stream); - let handle = request.send_with(&tm).await.unwrap(); + let handle = request.initiate_with(&tm).unwrap(); let resp = handle.join().await.unwrap(); assert_eq!(expected_upload_id.deref(), resp.upload_id.unwrap().deref()); @@ -329,9 +336,70 @@ mod test { .bucket("test-bucket") .key("test-key") .body(stream); - let handle = request.send_with(&tm).await.unwrap(); + let handle = request.initiate_with(&tm).unwrap(); let resp = handle.join().await.unwrap(); assert_eq!(resp.upload_id(), None); assert_eq!(expected_e_tag.deref(), resp.e_tag().unwrap()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_abort_multipart_upload() { + let expected_upload_id = Arc::new("test-upload".to_owned()); + let body = Bytes::from_static(b"every adolescent dog goes bonkers early"); + let stream = InputStream::from(body); + let bucket = "test-bucket"; + let key = "test-key"; + let wait_till_create_mpu = Arc::new(Barrier::new(2)); + + let upload_id = expected_upload_id.clone(); + let create_mpu = + mock!(aws_sdk_s3::Client::create_multipart_upload).then_output(move || { + CreateMultipartUploadOutput::builder() + .upload_id(upload_id.as_ref().to_owned()) + .build() + }); + + let upload_part = mock!(aws_sdk_s3::Client::upload_part).then_output({ + let wait_till_create_mpu = wait_till_create_mpu.clone(); + move || { + wait_till_create_mpu.wait(); + UploadPartOutput::builder().build() + } + }); + + let abort_mpu = mock!(aws_sdk_s3::Client::abort_multipart_upload) + .match_requests({ + let upload_id: Arc = expected_upload_id.clone(); + move |input| { + input.upload_id.as_ref() == Some(&upload_id) + && input.bucket() == Some(bucket) + && input.key() == Some(key) + } + }) + .then_output(|| AbortMultipartUploadOutput::builder().build()); + + let client = mock_client_with_stubbed_http_client!( + aws_sdk_s3, + RuleMode::Sequential, + &[create_mpu, upload_part, abort_mpu] + ); + + let tm_config = crate::Config::builder() + .concurrency(ConcurrencySetting::Explicit(1)) + .set_multipart_threshold(PartSize::Target(10)) + .set_target_part_size(PartSize::Target(5 * 1024 * 1024)) + .client(client) + .build(); + + let tm = crate::Client::new(tm_config); + + let request = UploadInput::builder() + .bucket("test-bucket") + .key("test-key") + .body(stream); + let handle = request.initiate_with(&tm).unwrap(); + wait_till_create_mpu.wait(); + let abort = handle.abort().await.unwrap(); + assert_eq!(abort.upload_id().unwrap(), expected_upload_id.deref()); + } } diff --git a/aws-s3-transfer-manager/src/operation/upload/builders.rs b/aws-s3-transfer-manager/src/operation/upload/builders.rs index deebe71d..b65c821b 100644 --- a/aws-s3-transfer-manager/src/operation/upload/builders.rs +++ b/aws-s3-transfer-manager/src/operation/upload/builders.rs @@ -29,10 +29,10 @@ impl UploadFluentBuilder { bucket = self.inner.bucket.as_deref().unwrap_or_default(), key = self.inner.key.as_deref().unwrap_or_default(), ))] - // TODO: Make it consistent with download by renaming it to initiate and making it synchronous - pub async fn send(self) -> Result { + + pub fn initiate(self) -> Result { let input = self.inner.build()?; - crate::operation::upload::Upload::orchestrate(self.handle, input).await + crate::operation::upload::Upload::orchestrate(self.handle, input) } ///

The canned ACL to apply to the object. For more information, see Canned ACL in the Amazon S3 User Guide.

@@ -911,12 +911,12 @@ impl UploadFluentBuilder { impl crate::operation::upload::input::UploadInputBuilder { /// Initiate an upload transfer for a single object with this input using the given client. - pub async fn send_with( + pub fn initiate_with( self, client: &crate::Client, ) -> Result { let mut fluent_builder = client.upload(); fluent_builder.inner = self; - fluent_builder.send().await + fluent_builder.initiate() } } diff --git a/aws-s3-transfer-manager/src/operation/upload/context.rs b/aws-s3-transfer-manager/src/operation/upload/context.rs index 4c3c4c9a..1dc0cab1 100644 --- a/aws-s3-transfer-manager/src/operation/upload/context.rs +++ b/aws-s3-transfer-manager/src/operation/upload/context.rs @@ -12,8 +12,6 @@ use std::sync::Arc; pub(crate) struct UploadContext { /// reference to client handle used to do actual work pub(crate) handle: Arc, - /// the multipart upload ID - pub(crate) upload_id: Option, /// the original request (NOTE: the body will have been taken for processing, only the other fields remain) pub(crate) request: Arc, } @@ -28,14 +26,4 @@ impl UploadContext { pub(crate) fn request(&self) -> &UploadInput { self.request.deref() } - - /// Set the upload ID if the transfer will be done using a multipart upload - pub(crate) fn set_upload_id(&mut self, upload_id: String) { - self.upload_id = Some(upload_id) - } - - /// Check if this transfer is using multipart upload - pub(crate) fn is_multipart_upload(&self) -> bool { - self.upload_id.is_some() - } } diff --git a/aws-s3-transfer-manager/src/operation/upload/handle.rs b/aws-s3-transfer-manager/src/operation/upload/handle.rs index 1bda28ed..878f19ec 100644 --- a/aws-s3-transfer-manager/src/operation/upload/handle.rs +++ b/aws-s3-transfer-manager/src/operation/upload/handle.rs @@ -16,15 +16,21 @@ use tracing::Instrument; #[derive(Debug)] pub(crate) enum UploadType { - MultipartUpload { - /// All child multipart upload tasks spawned for this upload - upload_part_tasks: Arc>>>, - /// All child read body tasks spawned for this upload - read_body_tasks: task::JoinSet>, - }, - PutObject { - put_object_task: JoinHandle>, - }, + MultipartUpload(MultipartUploadData), + PutObject(JoinHandle>), +} + +#[derive(Debug)] +pub(crate) struct MultipartUploadData { + /// All child multipart upload tasks spawned for this upload + pub(crate) upload_part_tasks: + Arc>>>, + /// All child read body tasks spawned for this upload + pub(crate) read_body_tasks: task::JoinSet>, + /// The response that will eventually be yielded to the caller. + pub(crate) response: Option, + /// the multipart upload ID + pub(crate) upload_id: String, } /// Response type for a single upload object request. @@ -50,137 +56,111 @@ pub(crate) enum UploadType { #[derive(Debug)] #[non_exhaustive] pub struct UploadHandle { - pub(crate) upload_type: UploadType, + /// Initial task which determines the upload type + initiate_task: JoinHandle>, /// The context used to drive an upload to completion pub(crate) ctx: UploadContext, - /// The response that will eventually be yielded to the caller. - response: Option, } impl UploadHandle { - /// Create a new multipart upload handle with the given request context - pub(crate) fn new_multipart(ctx: UploadContext) -> Self { - Self { - upload_type: UploadType::MultipartUpload { - upload_part_tasks: Arc::new(Mutex::new(task::JoinSet::new())), - read_body_tasks: task::JoinSet::new(), - }, - ctx, - response: None, - } - } - - /// Create a new put_object upload handle with the given request context - pub(crate) fn new_put_object( + pub(crate) fn new( ctx: UploadContext, - put_object_task: JoinHandle>, + initiate_task: JoinHandle>, ) -> Self { - Self { - upload_type: UploadType::PutObject { put_object_task }, - ctx, - response: None, - } - } - - /// Set the initial response builder once available - /// - /// This is usually after `CreateMultipartUpload` is initiated (or - /// `PutObject` is invoked for uploads less than the required MPU threshold). - pub(crate) fn set_response(&mut self, builder: UploadOutputBuilder) { - if builder.upload_id.is_some() { - let upload_id = builder.upload_id.clone().expect("upload ID present"); - self.ctx.set_upload_id(upload_id); - } - - self.response = Some(builder); + Self { initiate_task, ctx } } /// Consume the handle and wait for upload to complete #[tracing::instrument(skip_all, level = "debug", name = "join-upload")] pub async fn join(self) -> Result { + // TODO: We won't send completeMPU until customers join the future. This can create a + // bottleneck where we have many uploads not making the completeMPU call, waiting for the join + // to happen, and then everyone tries to do completeMPU at the same time. We should investigate doing + // this without waiting for join to happen. complete_upload(self).await } /// Abort the upload and cancel any in-progress part uploads. #[tracing::instrument(skip_all, level = "debug", name = "abort-upload")] - pub async fn abort(&mut self) -> Result { + pub async fn abort(self) -> Result { // TODO(aws-sdk-rust#1159) - handle already completed upload - match &mut self.upload_type { - UploadType::PutObject { put_object_task } => { - put_object_task.abort(); - let _ = put_object_task.await?; - } - UploadType::MultipartUpload { - upload_part_tasks, - read_body_tasks, - } => { - // cancel in-progress read_body tasks - read_body_tasks.abort_all(); - while (read_body_tasks.join_next().await).is_some() {} - - // cancel in-progress upload tasks - let mut tasks = upload_part_tasks.lock().await; - tasks.abort_all(); - - // join all tasks - while (tasks.join_next().await).is_some() {} + self.initiate_task.abort(); + if let Ok(Ok(upload_type)) = self.initiate_task.await { + match upload_type { + UploadType::PutObject(put_object_task) => { + put_object_task.abort(); + let _ = put_object_task.await?; + Ok(AbortedUpload::default()) + } + UploadType::MultipartUpload(mpu_ctx) => { + abort_multipart_upload(self.ctx.clone(), mpu_ctx).await + } } - }; - - if !self.ctx.is_multipart_upload() { - return Ok(AbortedUpload::default()); - } - - let abort_policy = self - .ctx - .request - .failed_multipart_upload_policy - .clone() - .unwrap_or_default(); - - match abort_policy { - FailedMultipartUploadPolicy::AbortUpload => abort_upload(self).await, - FailedMultipartUploadPolicy::Retain => Ok(AbortedUpload::default()), + } else { + // Nothing to abort since initiate task was not successful. + Ok(AbortedUpload::default()) } } } -async fn abort_upload(handle: &UploadHandle) -> Result { - let abort_mpu_resp = handle - .ctx - .client() - .abort_multipart_upload() - .set_bucket(handle.ctx.request.bucket.clone()) - .set_key(handle.ctx.request.key.clone()) - .set_upload_id(handle.ctx.upload_id.clone()) - .set_request_payer(handle.ctx.request.request_payer.clone()) - .set_expected_bucket_owner(handle.ctx.request.expected_bucket_owner.clone()) - .send() - .instrument(tracing::debug_span!("send-abort-multipart-upload")) - .await?; +/// Abort the multipart upload and cancel any in-progress part uploads. +async fn abort_multipart_upload( + ctx: UploadContext, + mut mpu_data: MultipartUploadData, +) -> Result { + // cancel in-progress read_body tasks + mpu_data.read_body_tasks.abort_all(); + while (mpu_data.read_body_tasks.join_next().await).is_some() {} + + // cancel in-progress upload tasks + let mut tasks = mpu_data.upload_part_tasks.lock().await; + tasks.abort_all(); + + // join all tasks + while (tasks.join_next().await).is_some() {} + + let abort_policy = ctx + .request + .failed_multipart_upload_policy + .clone() + .unwrap_or_default(); + match abort_policy { + FailedMultipartUploadPolicy::Retain => Ok(AbortedUpload::default()), + FailedMultipartUploadPolicy::AbortUpload => { + let abort_mpu_resp = ctx + .client() + .abort_multipart_upload() + .set_bucket(ctx.request.bucket.clone()) + .set_key(ctx.request.key.clone()) + .set_upload_id(Some(mpu_data.upload_id.clone())) + .set_request_payer(ctx.request.request_payer.clone()) + .set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone()) + .send() + .instrument(tracing::debug_span!("send-abort-multipart-upload")) + .await?; - let aborted_upload = AbortedUpload { - upload_id: handle.ctx.upload_id.clone(), - request_charged: abort_mpu_resp.request_charged, - }; + let aborted_upload = AbortedUpload { + upload_id: Some(mpu_data.upload_id), + request_charged: abort_mpu_resp.request_charged, + }; - Ok(aborted_upload) + Ok(aborted_upload) + } + } } -async fn complete_upload(mut handle: UploadHandle) -> Result { - match &mut handle.upload_type { - UploadType::PutObject { put_object_task } => put_object_task.await?, - UploadType::MultipartUpload { - upload_part_tasks, - read_body_tasks, - } => { - while let Some(join_result) = read_body_tasks.join_next().await { +async fn complete_upload(handle: UploadHandle) -> Result { + let upload_type = handle.initiate_task.await??; + match upload_type { + UploadType::PutObject(put_object_task) => put_object_task.await?, + UploadType::MultipartUpload(mut mpu_data) => { + while let Some(join_result) = mpu_data.read_body_tasks.join_next().await { if let Err(err) = join_result.expect("task completed") { tracing::error!( "multipart upload failed while trying to read the body, aborting" ); // TODO(aws-sdk-rust#1159) - if cancelling causes an error we want to propagate that in the returned error somehow? - if let Err(err) = handle.abort().await { + if let Err(err) = abort_multipart_upload(handle.ctx, mpu_data).await { tracing::error!("failed to abort upload: {}", DisplayErrorContext(err)) }; return Err(err); @@ -189,7 +169,7 @@ async fn complete_upload(mut handle: UploadHandle) -> Result Result Result Result Result Result<(), error::Error> { @@ -106,50 +108,43 @@ pub(super) fn distribute_work( .part_size(part_size.try_into().expect("valid part size")) .build(), ); - match &mut handle.upload_type { - UploadType::PutObject { .. } => { - unreachable!("distribute_work must not be called for PutObject.") - } - UploadType::MultipartUpload { - upload_part_tasks, - read_body_tasks, - } => { - // group all spawned tasks together - let parent_span_for_all_tasks = tracing::debug_span!( - parent: None, "upload-tasks", // TODO: for upload_objects, parent should be upload-objects-tasks - bucket = handle.ctx.request.bucket().unwrap_or_default(), - key = handle.ctx.request.key().unwrap_or_default(), - ); - parent_span_for_all_tasks.follows_from(tracing::Span::current()); + // group all spawned tasks together + let parent_span_for_all_tasks = tracing::debug_span!( + parent: None, "upload-tasks", // TODO: for upload_objects, parent should be upload-objects-tasks + bucket = ctx.request.bucket().unwrap_or_default(), + key = ctx.request.key().unwrap_or_default(), + ); + parent_span_for_all_tasks.follows_from(tracing::Span::current()); - // it looks nice to group all read-workers under single span - let parent_span_for_read_tasks = tracing::debug_span!( - parent: parent_span_for_all_tasks.clone(), - "upload-read-tasks" - ); + // it looks nice to group all read-workers under single span + let parent_span_for_read_tasks = tracing::debug_span!( + parent: parent_span_for_all_tasks.clone(), + "upload-read-tasks" + ); - // it looks nice to group all upload tasks together under single span - let parent_span_for_upload_tasks = tracing::debug_span!( - parent: parent_span_for_all_tasks, - "upload-net-tasks" - ); + // it looks nice to group all upload tasks together under single span + let parent_span_for_upload_tasks = tracing::debug_span!( + parent: parent_span_for_all_tasks, + "upload-net-tasks" + ); - let svc = upload_part_service(&handle.ctx); - let n_workers = handle.ctx.handle.num_workers(); - for _ in 0..n_workers { - let worker = read_body( - part_reader.clone(), - handle.ctx.clone(), - svc.clone(), - upload_part_tasks.clone(), - parent_span_for_upload_tasks.clone(), - ); - read_body_tasks.spawn(worker.instrument(parent_span_for_read_tasks.clone())); - } - tracing::trace!("work distributed for uploading parts"); - Ok(()) - } + let svc = upload_part_service(&ctx); + let n_workers = ctx.handle.num_workers(); + for _ in 0..n_workers { + let worker = read_body( + part_reader.clone(), + ctx.clone(), + mpu_data.upload_id.clone(), + svc.clone(), + mpu_data.upload_part_tasks.clone(), + parent_span_for_upload_tasks.clone(), + ); + mpu_data + .read_body_tasks + .spawn(worker.instrument(parent_span_for_read_tasks.clone())); } + tracing::trace!("work distributed for uploading parts"); + Ok(()) } /// Worker function that pulls part data from the `part_reader` and spawns tasks to upload each part until the reader @@ -157,6 +152,7 @@ pub(super) fn distribute_work( pub(super) async fn read_body( part_reader: Arc, ctx: UploadContext, + upload_id: String, svc: impl Service + Clone + Send @@ -172,6 +168,7 @@ pub(super) async fn read_body( let req = UploadPartRequest { ctx: ctx.clone(), part_data, + upload_id: upload_id.clone(), }; let svc = svc.clone(); let task = svc.oneshot(req); diff --git a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs index 10d75bc7..2e33b6de 100644 --- a/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/upload_objects/worker.rs @@ -253,8 +253,7 @@ async fn upload_single_obj( .build() .expect("valid input"); - let mut handle = - crate::operation::upload::Upload::orchestrate(ctx.handle.clone(), input).await?; + let handle = crate::operation::upload::Upload::orchestrate(ctx.handle.clone(), input)?; // The cancellation process would work fine without this if statement. // It's here so we can save a single upload operation that would otherwise @@ -323,27 +322,19 @@ fn handle_failed_upload( #[cfg(test)] mod tests { - use std::sync::{Arc, Barrier}; - - use aws_sdk_s3::operation::{ - abort_multipart_upload::AbortMultipartUploadOutput, - create_multipart_upload::CreateMultipartUploadOutput, put_object::PutObjectOutput, - upload_part::UploadPartOutput, - }; + use aws_sdk_s3::operation::put_object::PutObjectOutput; use aws_smithy_mocks_experimental::{mock, RuleMode}; use bytes::Bytes; use test_common::mock_client_with_stubbed_http_client; use crate::{ client::Handle, - config::MIN_MULTIPART_PART_SIZE_BYTES, io::InputStream, operation::upload_objects::{ worker::{upload_single_obj, UploadObjectJob}, UploadObjectsContext, UploadObjectsInputBuilder, }, runtime::scheduler::Scheduler, - types::PartSize, DEFAULT_CONCURRENCY, }; @@ -738,90 +729,4 @@ mod tests { assert_eq!(&crate::error::ErrorKind::OperationCancelled, err.kind()); } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_cancel_single_upload_via_multipart_upload() { - let bucket = "test-bucket"; - let key = "test-key"; - let upload_id: String = "test-upload-id".to_owned(); - - let wait_till_create_mpu = Arc::new(Barrier::new(2)); - let (resume_upload_single_obj_tx, resume_upload_single_obj_rx) = - tokio::sync::watch::channel(()); - let resume_upload_single_obj_tx = Arc::new(resume_upload_single_obj_tx); - - let create_mpu = mock!(aws_sdk_s3::Client::create_multipart_upload).then_output({ - let wait_till_create_mpu = wait_till_create_mpu.clone(); - let upload_id = upload_id.clone(); - move || { - // This ensures that a cancellation signal won't be sent until `create_multipart_upload`. - wait_till_create_mpu.wait(); - - // This increases the reliability of the test, ensuring that the cancellation signal has been sent - // and that `upload_single_obj` can now resume. - while !resume_upload_single_obj_rx.has_changed().unwrap() { - std::thread::sleep(std::time::Duration::from_millis(100)); - } - - CreateMultipartUploadOutput::builder() - .upload_id(upload_id.clone()) - .build() - } - }); - let upload_part = mock!(aws_sdk_s3::Client::upload_part) - .then_output(|| UploadPartOutput::builder().build()); - let abort_mpu = mock!(aws_sdk_s3::Client::abort_multipart_upload) - .match_requests({ - let upload_id = upload_id.clone(); - move |input| { - input.upload_id.as_ref() == Some(&upload_id) - && input.bucket() == Some(bucket) - && input.key() == Some(key) - } - }) - .then_output(|| AbortMultipartUploadOutput::builder().build()); - let s3_client = mock_client_with_stubbed_http_client!( - aws_sdk_s3, - RuleMode::MatchAny, - &[create_mpu, upload_part, abort_mpu] - ); - let config = crate::Config::builder() - .set_multipart_threshold(PartSize::Target(MIN_MULTIPART_PART_SIZE_BYTES)) - .client(s3_client) - .build(); - - let scheduler = Scheduler::new(DEFAULT_CONCURRENCY); - - let handle = std::sync::Arc::new(Handle { config, scheduler }); - let input = UploadObjectsInputBuilder::default() - .source("doesnotmatter") - .bucket(bucket) - .build() - .unwrap(); - - // specify the size of the contents so it triggers multipart upload - let contents = vec![0; MIN_MULTIPART_PART_SIZE_BYTES as usize]; - let ctx = UploadObjectsContext::new(handle, input); - let job = UploadObjectJob { - object: InputStream::from(Bytes::copy_from_slice(contents.as_slice())), - key: key.to_owned(), - }; - - tokio::task::spawn({ - let ctx = ctx.clone(); - let resume_upload_single_obj_tx = resume_upload_single_obj_tx.clone(); - async move { - wait_till_create_mpu.wait(); - // The upload operation has reached a point where a `CreateMultipartUploadOutput` is being prepared, - // which means that cancellation can now be triggered. - ctx.state.cancel_tx.send(true).unwrap(); - // Tell `upload_single_obj` that it can now proceed. - resume_upload_single_obj_tx.send(()).unwrap(); - } - }); - - let err = upload_single_obj(&ctx, job).await.unwrap_err(); - - assert_eq!(&crate::error::ErrorKind::OperationCancelled, err.kind()); - } } diff --git a/aws-s3-transfer-manager/tests/upload_test.rs b/aws-s3-transfer-manager/tests/upload_test.rs index ae65797c..ad47aec3 100644 --- a/aws-s3-transfer-manager/tests/upload_test.rs +++ b/aws-s3-transfer-manager/tests/upload_test.rs @@ -131,8 +131,7 @@ async fn test_many_uploads_no_deadlock() { .bucket("test-bucket") .key(format!("many-async-uploads-{}.txt", i)) .body(InputStream::from_part_stream(stream)) - .send() - .await + .initiate() .unwrap(); transfers.push((handle, tx));