diff --git a/aws-s3-transfer-manager/examples/cp.rs b/aws-s3-transfer-manager/examples/cp.rs index 3652ebb..eb1a241 100644 --- a/aws-s3-transfer-manager/examples/cp.rs +++ b/aws-s3-transfer-manager/examples/cp.rs @@ -10,7 +10,7 @@ use std::time; use aws_s3_transfer_manager::io::InputStream; use aws_s3_transfer_manager::metrics::unit::ByteUnit; use aws_s3_transfer_manager::metrics::Throughput; -use aws_s3_transfer_manager::operation::download::body::Body; +use aws_s3_transfer_manager::operation::download::Body; use aws_s3_transfer_manager::types::{ConcurrencySetting, PartSize}; use aws_sdk_s3::error::DisplayErrorContext; use bytes::Buf; diff --git a/aws-s3-transfer-manager/src/operation/download.rs b/aws-s3-transfer-manager/src/operation/download.rs index df03fc3..13ed087 100644 --- a/aws-s3-transfer-manager/src/operation/download.rs +++ b/aws-s3-transfer-manager/src/operation/download.rs @@ -9,11 +9,13 @@ use aws_sdk_s3::error::DisplayErrorContext; /// Request type for dowloading a single object from Amazon S3 pub use input::{DownloadInput, DownloadInputBuilder}; -/// Abstractions for response bodies and consuming data streams. -pub mod body; /// Operation builders pub mod builders; +/// Abstractions for responses and consuming data streams. +mod body; +pub use body::{Body, ChunkOutput}; + mod discovery; mod handle; @@ -33,15 +35,14 @@ use crate::error; use crate::io::AggregatedBytes; use crate::runtime::scheduler::OwnedWorkPermit; use aws_smithy_types::byte_stream::ByteStream; -use body::{Body, ChunkOutput}; use discovery::discover_obj; use service::distribute_work; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tokio::sync::{mpsc, oneshot, Mutex, OnceCell}; +use tokio::sync::{mpsc, oneshot, watch, Mutex, OnceCell}; use tokio::task::{self, JoinSet}; -use super::TransferContext; +use super::{CancelNotificationReceiver, CancelNotificationSender, TransferContext}; /// Operation struct for single object download #[derive(Clone, Default, Debug)] @@ -101,7 +102,7 @@ async fn send_discovery( object_meta_tx: oneshot::Sender, input: DownloadInput, use_current_span_as_parent_for_tasks: bool, -) -> Result<(), crate::error::Error> { +) { // create span to serve as parent of spawned child tasks. let parent_span_for_tasks = tracing::debug_span!( parent: if use_current_span_as_parent_for_tasks { tracing::Span::current().id() } else { None } , @@ -115,13 +116,37 @@ async fn send_discovery( } // acquire a permit for discovery - let permit = ctx.handle.scheduler.acquire_permit().await?; + let permit = ctx.handle.scheduler.acquire_permit().await; + let permit = match permit { + Ok(permit) => permit, + Err(err) => { + if comp_tx.send(Err(err)).await.is_err() { + tracing::debug!("Download handle for key({:?}) has been dropped, aborting during the discovery phase", input.key); + } + return; + } + }; // make initial discovery about the object size, metadata, possibly first chunk - let mut discovery = discover_obj(&ctx, &input).await?; - // FIXME - This will fail if the handle is dropped at this point. We should handle - // the cancellation gracefully here. - let _ = object_meta_tx.send(discovery.object_meta); + let discovery = discover_obj(&ctx, &input).await; + let mut discovery = match discovery { + Ok(discovery) => discovery, + Err(err) => { + if comp_tx.send(Err(err)).await.is_err() { + tracing::debug!("Download handle for key({:?}) has been dropped, aborting during the discovery phase", input.key); + } + return; + } + }; + + if object_meta_tx.send(discovery.object_meta).is_err() { + tracing::debug!( + "Download handle for key({:?}) has been dropped, aborting during the discovery phase", + input.key + ); + return; + } + let initial_chunk = discovery.initial_chunk.take(); let mut tasks = tasks.lock().await; @@ -148,7 +173,6 @@ async fn send_discovery( parent_span_for_tasks, ); } - Ok(()) } /// Handle possibly sending the first chunk of data received through discovery. Returns @@ -200,14 +224,19 @@ fn handle_discovery_chunk( #[derive(Debug)] pub(crate) struct DownloadState { current_seq: AtomicU64, + cancel_tx: CancelNotificationSender, + cancel_rx: CancelNotificationReceiver, } type DownloadContext = TransferContext; impl DownloadContext { fn new(handle: Arc) -> Self { + let (cancel_tx, cancel_rx) = watch::channel(false); let state = Arc::new(DownloadState { current_seq: AtomicU64::new(0), + cancel_tx, + cancel_rx, }); TransferContext { handle, state } } diff --git a/aws-s3-transfer-manager/src/operation/download/body.rs b/aws-s3-transfer-manager/src/operation/download/body.rs index e5348c9..aa416a9 100644 --- a/aws-s3-transfer-manager/src/operation/download/body.rs +++ b/aws-s3-transfer-manager/src/operation/download/body.rs @@ -12,7 +12,7 @@ use crate::io::AggregatedBytes; use super::chunk_meta::ChunkMetadata; -/// Stream of binary data representing an Amazon S3 Object's contents. +/// Stream of [ChunkOutput] representing an Amazon S3 Object's contents and metadata. /// /// Wraps potentially multiple streams of binary data into a single coherent stream. /// The data on this stream is sequenced into the correct order. @@ -81,7 +81,10 @@ impl Body { match self.inner.next().await { None => break, Some(Ok(chunk)) => self.sequencer.push(chunk), - Some(Err(err)) => return Some(Err(err)), + Some(Err(err)) => { + self.close(); + return Some(Err(err)); + } } } diff --git a/aws-s3-transfer-manager/src/operation/download/handle.rs b/aws-s3-transfer-manager/src/operation/download/handle.rs index 63689f3..1d1f98a 100644 --- a/aws-s3-transfer-manager/src/operation/download/handle.rs +++ b/aws-s3-transfer-manager/src/operation/download/handle.rs @@ -23,11 +23,11 @@ pub struct DownloadHandle { /// Object metadata. pub(crate) object_meta: OnceCell, - /// The object content + /// The object content, in chunks, and the metadata for each chunk pub(crate) body: Body, /// Discovery task - pub(crate) discovery: task::JoinHandle>, + pub(crate) discovery: task::JoinHandle<()>, /// All child tasks (ranged GetObject) spawned for this download pub(crate) tasks: Arc>>, @@ -53,7 +53,7 @@ impl DownloadHandle { Ok(meta) } - /// Object content + /// The object content, in chunks, and the metadata for each chunk pub fn body(&self) -> &Body { &self.body } @@ -63,19 +63,16 @@ impl DownloadHandle { &mut self.body } - /// Consume the handle and wait for download transfer to complete - #[tracing::instrument(skip_all, level = "debug", name = "join-download")] - pub async fn join(mut self) -> Result<(), crate::error::Error> { + /// Abort the download and cancel any in-progress work. + pub async fn abort(mut self) { self.body.close(); - - self.discovery.await??; + self.discovery.abort(); + let _ = self.discovery.await; // It's safe to grab the lock here because discovery is already complete, and we will never // lock tasks again after discovery to spawn more tasks. let mut tasks = self.tasks.lock().await; - while let Some(join_result) = tasks.join_next().await { - join_result?; - } - Ok(()) + tasks.abort_all(); + while (tasks.join_next().await).is_some() {} } } diff --git a/aws-s3-transfer-manager/src/operation/download/service.rs b/aws-s3-transfer-manager/src/operation/download/service.rs index ef20736..d3ec31f 100644 --- a/aws-s3-transfer-manager/src/operation/download/service.rs +++ b/aws-s3-transfer-manager/src/operation/download/service.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ use crate::error; +use crate::error::ErrorKind; use crate::http::header; use crate::io::AggregatedBytes; use crate::middleware::limit::concurrency::ConcurrencyLimitLayer; @@ -69,25 +70,33 @@ async fn download_specific_chunk( ); let op = input.into_sdk_operation(ctx.client()); - let mut resp = op - .send() - // no instrument() here because parent span shows duration of send + collect - .await - .map_err(error::from_kind(error::ErrorKind::ChunkFailed))?; - - let body = mem::replace(&mut resp.body, ByteStream::new(SdkBody::taken())); - let body = AggregatedBytes::from_byte_stream(body) - .instrument(tracing::debug_span!( - "collect-body-from-download-chunk", - seq - )) - .await?; - - Ok(ChunkOutput { - seq, - data: body, - metadata: resp.into(), - }) + let mut cancel_rx = ctx.state.cancel_rx.clone(); + tokio::select! { + _ = cancel_rx.changed() => { + tracing::debug!("Received cancellating signal, exiting and not downloading chunk#{seq}"); + Err(error::operation_cancelled()) + }, + resp = op.send() => { + match resp { + Err(err) => Err(error::from_kind(error::ErrorKind::ChunkFailed)(err)), + Ok(mut resp) => { + let body = mem::replace(&mut resp.body, ByteStream::new(SdkBody::taken())); + let body = AggregatedBytes::from_byte_stream(body) + .instrument(tracing::debug_span!( + "collect-body-from-download-chunk", + seq + )) + .await?; + + Ok(ChunkOutput { + seq, + data: body, + metadata: resp.into(), + }) + }, + } + } + } } /// Create a new tower::Service for downloading individual chunks of an object from S3 @@ -139,12 +148,30 @@ pub(super) fn distribute_work( let svc = svc.clone(); let comp_tx = comp_tx.clone(); + let cancel_tx = ctx.state.cancel_tx.clone(); let task = async move { - // TODO: If downloading a chunk fails, do we want to abort the download? let resp = svc.oneshot(req).await; + // If any chunk fails, send cancel notification, to kill any other in-flight chunks + if let Err(err) = &resp { + if *err.kind() == ErrorKind::OperationCancelled { + // Ignore any OperationCancelled errors. + return; + } + if cancel_tx.send(true).is_err() { + tracing::debug!( + "all receiver ends have dropped, unable to send a cancellation signal" + ); + } + } + if let Err(err) = comp_tx.send(resp).await { tracing::debug!(error = ?err, "chunk send failed, channel closed"); + if cancel_tx.send(true).is_err() { + tracing::debug!( + "all receiver ends have dropped, unable to send a cancellation signal" + ); + } } }; tasks.spawn(task.instrument(parent_span_for_tasks.clone())); diff --git a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs index 35b938e..fa670ff 100644 --- a/aws-s3-transfer-manager/src/operation/download_objects/worker.rs +++ b/aws-s3-transfer-manager/src/operation/download_objects/worker.rs @@ -2,6 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ +use crate::operation::download::Body; use async_channel::{Receiver, Sender}; use path_clean::PathClean; use std::borrow::Cow; @@ -12,7 +13,6 @@ use tokio::fs; use tokio::io::AsyncWriteExt; use crate::error::{self, ErrorKind}; -use crate::operation::download::body::Body; use crate::operation::download::{DownloadInput, DownloadInputBuilder}; use crate::operation::DEFAULT_DELIMITER; use crate::types::{DownloadFilter, FailedDownload, FailedTransferPolicy}; @@ -191,13 +191,8 @@ async fn download_single_obj( .has_changed() .expect("the channel should be open as it is owned by `DownloadObjectsState`") { - /* - * TODO(single download cleanup): Comment in the following lines of code once single download has been cleaned up. - * Note that it may not be called `.abort()` depending on the outcome of the cleanup. - * - * handle.abort().await; - * return Err(error::operation_cancelled()); - */ + handle.abort().await; + return Err(error::operation_cancelled()); } let _ = handle.object_meta().await?; @@ -214,8 +209,6 @@ async fn download_single_obj( } } - handle.join().await?; - Ok(()) } diff --git a/aws-s3-transfer-manager/tests/download_test.rs b/aws-s3-transfer-manager/tests/download_test.rs index db69353..cc57d96 100644 --- a/aws-s3-transfer-manager/tests/download_test.rs +++ b/aws-s3-transfer-manager/tests/download_test.rs @@ -5,7 +5,7 @@ use aws_config::Region; use aws_s3_transfer_manager::{ - error::BoxError, + error::{BoxError, Error}, operation::download::DownloadHandle, types::{ConcurrencySetting, PartSize}, }; @@ -45,14 +45,22 @@ fn dummy_expected_request() -> http_02x::Request { } /// drain/consume the body -async fn drain(handle: &mut DownloadHandle) -> Result { +async fn drain(handle: &mut DownloadHandle) -> Result { let body = handle.body_mut(); let mut data = BytesMut::new(); + let mut error: Option = None; while let Some(chunk) = body.next().await { - let chunk = chunk?.data.into_bytes(); - data.put(chunk); + match chunk { + Ok(chunk) => data.put(chunk.data.into_bytes()), + Err(err) => { + error.get_or_insert(err); + } + } } + if let Some(error) = error { + return Err(error); + } Ok(data.into()) } @@ -148,11 +156,9 @@ async fn test_download_ranges() { requests[2].headers().get("Range"), Some("bytes=10485760-12582911") ); - - handle.join().await.unwrap(); } -/// Test body not consumed which should not prevent the handle from being joined +/// Test body not consumed which should not prevent the handle from being dropped #[tokio::test] async fn test_body_not_consumed() { let data = rand_data(12 * MEBIBYTE); @@ -160,14 +166,33 @@ async fn test_body_not_consumed() { let (tm, _) = simple_test_tm(&data, part_size); - let handle = tm + let mut handle = tm .download() .bucket("test-bucket") .key("test-object") .initiate() .unwrap(); - handle.join().await.unwrap(); + let _ = handle.body_mut().next().await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_abort_download() { + let data = rand_data(25 * MEBIBYTE); + let part_size = MEBIBYTE; + + let (tm, http_client) = simple_test_tm(&data, part_size); + + let handle = tm + .download() + .bucket("test-bucket") + .key("test-object") + .initiate() + .unwrap(); + let _ = handle.object_meta().await; + handle.abort().await; + let requests = http_client.actual_requests().collect::>(); + assert!(requests.len() < data.len() / part_size); } pin_project! { @@ -282,7 +307,6 @@ async fn test_retry_failed_chunk() { assert_eq!(data.len(), body.len()); let requests = http_client.actual_requests().collect::>(); assert_eq!(3, requests.len()); - handle.join().await.unwrap(); } const ERROR_RESPONSE: &str = r#" @@ -297,7 +321,7 @@ const ERROR_RESPONSE: &str = r#" /// Test non retryable SdkError #[tokio::test] async fn test_non_retryable_error() { - let data = rand_data(12 * MEBIBYTE); + let data = rand_data(20 * MEBIBYTE); let part_size = 8 * MEBIBYTE; let http_client = StaticReplayClient::new(vec![ @@ -334,7 +358,6 @@ async fn test_non_retryable_error() { let _ = drain(&mut handle).await.unwrap_err(); - handle.join().await.unwrap(); let requests = http_client.actual_requests().collect::>(); assert_eq!(2, requests.len()); } @@ -394,7 +417,6 @@ async fn test_retry_max_attempts() { .unwrap(); let _ = drain(&mut handle).await.unwrap_err(); - handle.join().await.unwrap(); let requests = http_client.actual_requests().collect::>(); assert_eq!(4, requests.len()); }