diff --git a/crates/sui-indexer-alt/src/ingestion/broadcaster.rs b/crates/sui-indexer-alt/src/ingestion/broadcaster.rs index 3cb0bbf9f89c3..29988a21e0924 100644 --- a/crates/sui-indexer-alt/src/ingestion/broadcaster.rs +++ b/crates/sui-indexer-alt/src/ingestion/broadcaster.rs @@ -1,16 +1,16 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use futures::{future::try_join_all, TryStreamExt}; +use futures::future::try_join_all; use mysten_metrics::spawn_monitored_task; use std::sync::Arc; use sui_types::full_checkpoint_content::CheckpointData; use tokio::{sync::mpsc, task::JoinHandle}; -use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tracing::{error, info}; -use crate::ingestion::error::Error; +use crate::{ingestion::error::Error, task::TrySpawnStreamExt}; use super::{client::IngestionClient, IngestionConfig}; @@ -31,8 +31,7 @@ pub(super) fn broadcaster( info!("Starting ingestion broadcaster"); match ReceiverStream::new(checkpoint_rx) - .map(Ok) - .try_for_each_concurrent(/* limit */ config.ingest_concurrency, |cp| { + .try_for_each_spawned(/* limit */ config.ingest_concurrency, |cp| { let client = client.clone(); let subscribers = subscribers.clone(); diff --git a/crates/sui-indexer-alt/src/pipeline/concurrent/committer.rs b/crates/sui-indexer-alt/src/pipeline/concurrent/committer.rs index bec3e25542705..6de124bc11a84 100644 --- a/crates/sui-indexer-alt/src/pipeline/concurrent/committer.rs +++ b/crates/sui-indexer-alt/src/pipeline/concurrent/committer.rs @@ -4,10 +4,9 @@ use std::{sync::Arc, time::Duration}; use backoff::ExponentialBackoff; -use futures::TryStreamExt; use mysten_metrics::spawn_monitored_task; use tokio::{sync::mpsc, task::JoinHandle}; -use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; @@ -15,6 +14,7 @@ use crate::{ db::Db, metrics::IndexerMetrics, pipeline::{Break, PipelineConfig, WatermarkPart}, + task::TrySpawnStreamExt, }; use super::{Batched, Handler}; @@ -48,8 +48,7 @@ pub(super) fn committer( let write_concurrency = H::WRITE_CONCURRENCY_OVERRIDE.unwrap_or(config.write_concurrency); match ReceiverStream::new(rx) - .map(Ok) - .try_for_each_concurrent(write_concurrency, |Batched { values, watermark }| { + .try_for_each_spawned(write_concurrency, |Batched { values, watermark }| { let values = Arc::new(values); let tx = tx.clone(); let db = db.clone(); diff --git a/crates/sui-indexer-alt/src/pipeline/processor.rs b/crates/sui-indexer-alt/src/pipeline/processor.rs index 795c347bc8f16..576be66422a2e 100644 --- a/crates/sui-indexer-alt/src/pipeline/processor.rs +++ b/crates/sui-indexer-alt/src/pipeline/processor.rs @@ -4,15 +4,14 @@ use std::sync::atomic::AtomicU64; use std::sync::Arc; -use futures::TryStreamExt; use mysten_metrics::spawn_monitored_task; use sui_types::full_checkpoint_content::CheckpointData; use tokio::{sync::mpsc, task::JoinHandle}; -use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; -use crate::{metrics::IndexerMetrics, pipeline::Break}; +use crate::{metrics::IndexerMetrics, pipeline::Break, task::TrySpawnStreamExt}; use super::Indexed; @@ -52,15 +51,15 @@ pub(super) fn processor( spawn_monitored_task!(async move { info!(pipeline = P::NAME, "Starting processor"); let latest_processed_checkpoint = Arc::new(AtomicU64::new(0)); + let processor = Arc::new(processor); match ReceiverStream::new(rx) - .map(Ok) - .try_for_each_concurrent(P::FANOUT, |checkpoint| { + .try_for_each_spawned(P::FANOUT, |checkpoint| { let tx = tx.clone(); let metrics = metrics.clone(); let cancel = cancel.clone(); let latest_processed_checkpoint = latest_processed_checkpoint.clone(); - let processor = &processor; + let processor = processor.clone(); async move { if cancel.is_cancelled() { diff --git a/crates/sui-indexer-alt/src/task.rs b/crates/sui-indexer-alt/src/task.rs index f0c59183942f7..101ae26a8aa03 100644 --- a/crates/sui-indexer-alt/src/task.rs +++ b/crates/sui-indexer-alt/src/task.rs @@ -1,12 +1,135 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::iter; +use std::{future::Future, iter, panic, pin::pin}; -use futures::future::{self, Either}; -use tokio::{signal, sync::oneshot, task::JoinHandle}; +use futures::{ + future::{self, Either}, + stream::{Stream, StreamExt}, +}; +use tokio::{ + signal, + sync::oneshot, + task::{JoinHandle, JoinSet}, +}; use tokio_util::sync::CancellationToken; +/// Extension trait introducing `try_for_each_spawned` to all streams. +pub trait TrySpawnStreamExt: Stream { + /// Attempts to run this stream to completion, executing the provided asynchronous closure on + /// each element from the stream as elements become available. + /// + /// This is similar to [StreamExt::for_each_concurrent], but it may take advantage of any + /// parallelism available in the underlying runtime, because each unit of work is spawned as + /// its own tokio task. + /// + /// The first argument is an optional limit on the number of tasks to spawn concurrently. + /// Values of `0` and `None` are interpreted as no limit, and any other value will result in no + /// more than that many tasks being spawned at one time. + /// + /// ## Safety + /// + /// This function will panic if any of its futures panics, will return early with success if + /// the runtime it is running on is cancelled, and will return early with an error propagated + /// from any worker that produces an error. + fn try_for_each_spawned( + self, + limit: impl Into>, + f: F, + ) -> impl Future> + where + Fut: Future> + Send + 'static, + F: FnMut(Self::Item) -> Fut, + E: Send + 'static; +} + +impl TrySpawnStreamExt for S { + async fn try_for_each_spawned( + self, + limit: impl Into>, + mut f: F, + ) -> Result<(), E> + where + Fut: Future> + Send + 'static, + F: FnMut(Self::Item) -> Fut, + E: Send + 'static, + { + // Maximum number of tasks to spawn concurrently. + let limit = match limit.into() { + Some(0) | None => usize::MAX, + Some(n) => n, + }; + + // Number of permits to spawn tasks left. + let mut permits = limit; + // Handles for already spawned tasks. + let mut join_set = JoinSet::new(); + // Whether the worker pool has stopped accepting new items and is draining. + let mut draining = false; + // Error that occurred in one of the workers, to be propagated to the called on exit. + let mut error = None; + + let mut self_ = pin!(self); + + loop { + tokio::select! { + next = self_.next(), if !draining && permits > 0 => { + if let Some(item) = next { + permits -= 1; + join_set.spawn(f(item)); + } else { + // If the stream is empty, signal that the worker pool is going to + // start draining now, so that once we get all our permits back, we + // know we can wind down the pool. + draining = true; + } + } + + Some(res) = join_set.join_next() => { + match res { + Ok(Err(e)) if error.is_none() => { + error = Some(e); + permits += 1; + draining = true; + } + + Ok(_) => permits += 1, + + // Worker panicked, propagate the panic. + Err(e) if e.is_panic() => { + panic::resume_unwind(e.into_panic()) + } + + // Worker was cancelled -- this can only happen if its join handle was + // cancelled (not possible because that was created in this function), + // or the runtime it was running in was wound down, in which case, + // prepare the worker pool to drain. + Err(e) => { + assert!(e.is_cancelled()); + permits += 1; + draining = true; + } + } + } + + else => { + // Not accepting any more items from the stream, and all our workers are + // idle, so we stop. + if permits == limit && draining { + break; + } + } + } + } + + if let Some(e) = error { + Err(e) + } else { + Ok(()) + } + } +} + /// Manages cleanly exiting the process, either because one of its constituent services has stopped /// or because an interrupt signal was sent to the process. /// @@ -29,7 +152,7 @@ pub async fn graceful_shutdown( None }; - tokio::pin!(interrupt); + let interrupt = pin!(interrupt); let futures: Vec<_> = services .into_iter() .map(|s| Either::Left(Box::pin(async move { s.await.ok() }))) @@ -46,3 +169,164 @@ pub async fn graceful_shutdown( results.extend(future::join_all(rest).await.into_iter().flatten()); results } + +#[cfg(test)] +mod tests { + use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, + }; + + use futures::stream; + + use super::*; + + #[tokio::test] + async fn explicit_sequential_iteration() { + let actual = Arc::new(Mutex::new(vec![])); + let result = stream::iter(0..20) + .try_for_each_spawned(1, |i| { + let actual = actual.clone(); + async move { + tokio::time::sleep(Duration::from_millis(20 - i)).await; + actual.lock().unwrap().push(i); + Ok::<(), ()>(()) + } + }) + .await; + + assert!(result.is_ok()); + + let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap(); + let expect: Vec<_> = (0..20).collect(); + assert_eq!(expect, actual); + } + + #[tokio::test] + async fn concurrent_iteration() { + let actual = Arc::new(AtomicUsize::new(0)); + let result = stream::iter(0..100) + .try_for_each_spawned(16, |i| { + let actual = actual.clone(); + async move { + actual.fetch_add(i, Ordering::Relaxed); + Ok::<(), ()>(()) + } + }) + .await; + + assert!(result.is_ok()); + + let actual = Arc::try_unwrap(actual).unwrap().into_inner(); + let expect = 99 * 100 / 2; + assert_eq!(expect, actual); + } + + #[tokio::test] + async fn implicit_unlimited_iteration() { + let actual = Arc::new(AtomicUsize::new(0)); + let result = stream::iter(0..100) + .try_for_each_spawned(None, |i| { + let actual = actual.clone(); + async move { + actual.fetch_add(i, Ordering::Relaxed); + Ok::<(), ()>(()) + } + }) + .await; + + assert!(result.is_ok()); + + let actual = Arc::try_unwrap(actual).unwrap().into_inner(); + let expect = 99 * 100 / 2; + assert_eq!(expect, actual); + } + + #[tokio::test] + async fn explicit_unlimited_iteration() { + let actual = Arc::new(AtomicUsize::new(0)); + let result = stream::iter(0..100) + .try_for_each_spawned(0, |i| { + let actual = actual.clone(); + async move { + actual.fetch_add(i, Ordering::Relaxed); + Ok::<(), ()>(()) + } + }) + .await; + + assert!(result.is_ok()); + + let actual = Arc::try_unwrap(actual).unwrap().into_inner(); + let expect = 99 * 100 / 2; + assert_eq!(expect, actual); + } + + #[tokio::test] + async fn max_concurrency() { + #[derive(Default, Debug)] + struct Jobs { + max: AtomicUsize, + curr: AtomicUsize, + } + + let jobs = Arc::new(Jobs::default()); + + let result = stream::iter(0..32) + .try_for_each_spawned(4, |_| { + let jobs = jobs.clone(); + async move { + jobs.curr.fetch_add(1, Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(100)).await; + let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed); + jobs.max.fetch_max(prev, Ordering::Relaxed); + Ok::<(), ()>(()) + } + }) + .await; + + assert!(result.is_ok()); + + let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap(); + assert_eq!(curr.into_inner(), 0); + assert!(max.into_inner() <= 4); + } + + #[tokio::test] + async fn error_propagation() { + let actual = Arc::new(Mutex::new(vec![])); + let result = stream::iter(0..100) + .try_for_each_spawned(None, |i| { + let actual = actual.clone(); + async move { + if i < 42 { + actual.lock().unwrap().push(i); + Ok(()) + } else { + Err(()) + } + } + }) + .await; + + assert!(result.is_err()); + + let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap(); + let expect: Vec<_> = (0..42).collect(); + assert_eq!(expect, actual); + } + + #[tokio::test] + #[should_panic] + async fn panic_propagation() { + let _ = stream::iter(0..100) + .try_for_each_spawned(None, |i| async move { + assert!(i < 42); + Ok::<(), ()>(()) + }) + .await; + } +}