Skip to content

Commit

Permalink
graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
0xForerunner committed Oct 25, 2024
1 parent 4f4eb5f commit 41c4246
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 13 deletions.
17 changes: 17 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ alloy = { version = "0.3", features = [
"transports",
"hyper",
"signer-local",
"signers"
"signers",
] }
eyre = "0.6"
futures = "0.3"
Expand All @@ -42,6 +42,7 @@ semaphore = { git = "https://github.com/worldcoin/semaphore-rs", rev = "59b2a0af
serde = { version = "1.0.189", features = ["derive"] }
serde_json = "1.0"
serde_path_to_error = "0.1.16"
humantime-serde = "1.1.1"
take_mut = "0.2.2"
telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "12cc036234b4e9b86f22ff7e35d499e2ff1e6304" }
tempfile = "3.10.1"
Expand Down
10 changes: 2 additions & 8 deletions src/bin/world_tree.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::path::PathBuf;

use clap::Parser;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use telemetry_batteries::metrics::statsd::StatsdBattery;
use telemetry_batteries::tracing::datadog::DatadogBattery;
use telemetry_batteries::tracing::TracingShutdownHandle;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use world_tree::init_world_tree;
use world_tree::tasks::monitor_tasks;
use world_tree::tree::config::ServiceConfig;
use world_tree::tree::error::WorldTreeResult;
use world_tree::tree::service::InclusionProofService;
Expand Down Expand Up @@ -74,12 +73,7 @@ pub async fn main() -> WorldTreeResult<()> {

let service = InclusionProofService::new(world_tree);
let (_, handles) = service.serve(config.socket_address).await?;

let mut handles = handles.into_iter().collect::<FuturesUnordered<_>>();
while let Some(result) = handles.next().await {
tracing::error!("TreeAvailabilityError: {:?}", result);
result?;
}
monitor_tasks(handles, config.shutdown_delay).await?;

Ok(())
}
63 changes: 63 additions & 0 deletions src/tasks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,66 @@
use std::time::Duration;

use futures::stream::FuturesUnordered;
use futures::StreamExt;
use tokio::task::{JoinError, JoinHandle};
use tracing::info;

pub mod ingest;
pub mod observe;
pub mod update;

pub async fn monitor_tasks(
mut handles: FuturesUnordered<JoinHandle<()>>,
shutdown_delay: Duration,
) -> Result<(), JoinError> {
while let Some(result) = handles.next().await {
if let Err(error) = result {
tracing::error!(?error, "Task panicked");
// abort all other tasks
for handle in handles.iter() {
handle.abort();
}
info!("All tasks aborted");
// Give tasks a few seconds to get to an await point
tokio::time::sleep(shutdown_delay).await;
return Err(error);
}
}
Ok(())
}

#[cfg(test)]
mod test {
use super::*;
use std::time::Instant;
use tokio::time::sleep;

#[tokio::test]
async fn test_monitor_tasks() {
let shutdown_delay = Duration::from_millis(100);

let panic_handle = tokio::spawn(async {
panic!("Task failed");
});

let return_handle = tokio::spawn(async {});

let run_time = Duration::from_millis(100);
let run_handle = tokio::spawn(async {
sleep(Duration::from_secs(1)).await;
});

let handles = FuturesUnordered::from_iter([
panic_handle,
return_handle,
run_handle,
]);

let start = Instant::now();
assert!(monitor_tasks(handles, shutdown_delay).await.is_err());

let elapsed = start.elapsed();
assert!(elapsed >= shutdown_delay);
assert!(elapsed <= shutdown_delay + run_time);
}
}
13 changes: 13 additions & 0 deletions src/tree/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::time::Duration;

use alloy::primitives::Address;
use serde::{Deserialize, Serialize};
Expand All @@ -23,6 +24,12 @@ pub struct ServiceConfig {
pub socket_address: Option<SocketAddr>,
#[serde(default)]
pub telemetry: Option<TelemetryConfig>,
/// delay beofre shutting down the server
/// after a task has panicked. This is useful
/// to give tasks a chance to reach an await point
#[serde(with = "humantime_serde")]
#[serde(default = "default::shutdown_delay")]
pub shutdown_delay: Duration,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down Expand Up @@ -144,6 +151,10 @@ mod default {
pub const fn bool_true() -> bool {
true
}

pub const fn shutdown_delay() -> Duration {
Duration::from_secs(1)
}
}

// Utility functions to convert map to vec
Expand Down Expand Up @@ -190,6 +201,7 @@ mod tests {
const S: &str = indoc::indoc! {r#"
tree_depth = 10
socket_address = "127.0.0.1:8080"
shutdown_delay = "1s"
[db]
connection_string = "postgresql://user:password@localhost:5432/dbname"
Expand Down Expand Up @@ -260,6 +272,7 @@ mod tests {
}],
socket_address: Some(([127, 0, 0, 1], 8080).into()),
telemetry: None,
shutdown_delay: Duration::from_secs(1),
};

let serialized = toml::to_string(&config).unwrap();
Expand Down
7 changes: 4 additions & 3 deletions src/tree/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::{middleware, Json};
use axum_middleware::logging;
use futures::stream::FuturesUnordered;
use serde::{Deserialize, Serialize};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -36,7 +37,7 @@ impl InclusionProofService {
pub async fn serve(
self,
addr: Option<SocketAddr>,
) -> WorldTreeResult<(SocketAddr, Vec<JoinHandle<()>>)> {
) -> WorldTreeResult<(SocketAddr, FuturesUnordered<JoinHandle<()>>)> {
// Initialize a new router and spawn the server
tracing::info!(?addr, "Initializing axum server");

Expand Down Expand Up @@ -71,13 +72,13 @@ impl InclusionProofService {

let runner = app_task::TaskRunner::new(world_tree);

let handles = vec![
let handles = FuturesUnordered::from_iter([
runner.spawn_task("Observe", crate::tasks::observe::observe),
runner.spawn_task("Ingest", crate::tasks::ingest::ingest_canonical),
runner.spawn_task("Update", crate::tasks::update::append_updates),
runner.spawn_task("Reallign", crate::tasks::update::reallign),
server_handle,
];
]);

Ok((local_addr, handles))
}
Expand Down
3 changes: 2 additions & 1 deletion tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use alloy::primitives::{Address, U256};
use alloy::providers::Provider;
use eyre::ContextCompat;
use futures::stream::FuturesUnordered;
use rand::Rng;
use semaphore::Field;
use testcontainers::core::{ContainerPort, Mount};
Expand Down Expand Up @@ -56,7 +57,7 @@ macro_rules! attempt_async {

pub async fn setup_world_tree(
config: &ServiceConfig,
) -> WorldTreeResult<(SocketAddr, Vec<JoinHandle<()>>)> {
) -> WorldTreeResult<(SocketAddr, FuturesUnordered<JoinHandle<()>>)> {
let world_tree = init_world_tree(config).await?;

let service = InclusionProofService::new(world_tree);
Expand Down
1 change: 1 addition & 0 deletions tests/empty_start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async fn empty_start() -> WorldTreeResult<()> {
}],
socket_address: None,
telemetry: None,
shutdown_delay: Duration::from_secs(1),
};

let (local_addr, handles) = setup_world_tree(&service_config).await?;
Expand Down
1 change: 1 addition & 0 deletions tests/full_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ async fn full_flow() -> WorldTreeResult<()> {
}],
socket_address: None,
telemetry: None,
shutdown_delay: Duration::from_secs(1),
};

let (local_addr, handles) = setup_world_tree(&service_config).await?;
Expand Down
1 change: 1 addition & 0 deletions tests/many_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async fn many_batches() -> WorldTreeResult<()> {
}],
socket_address: None,
telemetry: None,
shutdown_delay: Duration::from_secs(1),
};

let (local_addr, handles) = setup_world_tree(&service_config).await?;
Expand Down
1 change: 1 addition & 0 deletions tests/missed_event_on_bridged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async fn missing_event_on_bridged() -> WorldTreeResult<()> {
}],
socket_address: None,
telemetry: None,
shutdown_delay: Duration::from_secs(1),
};

let (local_addr, handles) = setup_world_tree(&service_config).await?;
Expand Down

0 comments on commit 41c4246

Please sign in to comment.