diff --git a/Cargo.toml b/Cargo.toml index aa6020e..3fe03a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,9 +23,9 @@ thiserror = "1.0.50" tracing = "0.1.40" [dev-dependencies] -axum = { version = "0.6.20", default-features = false, features = ["tokio"] } +axum = { version = "0.7.5", default-features = false, features = ["tokio", "http1"] } tokio = { version = "1.33.0", default-features = false, features = ["macros", "test-util"] } -tower-http = { version = "0.4.4", default-features = false, features = ["fs"] } +tower-http = { version = "0.5.2", default-features = false, features = ["fs"] } async_zip = { version = "0.0.15", default-features = false, features = ["tokio"] } assert_matches = "1.5.0" rstest = { version = "0.18.2" } diff --git a/src/lib.rs b/src/lib.rs index 2b99aa0..786f454 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -669,7 +669,9 @@ mod test { async fn async_range_reader_zip(#[case] check_method: CheckSupportMethod) { // Spawn a static file server let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); - let server = StaticDirectoryServer::new(&path); + let server = StaticDirectoryServer::new(&path) + .await + .expect("could not initialize server"); // check that file is there and has the right size let filepath = path.join("andes-1.8.3-pyhd8ed1ab_0.conda"); @@ -776,7 +778,9 @@ mod test { async fn async_range_reader(#[case] check_method: CheckSupportMethod) { // Spawn a static file server let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); - let server = StaticDirectoryServer::new(&path); + let server = StaticDirectoryServer::new(&path) + .await + .expect("could not initialize server"); // Construct an AsyncRangeReader let (mut range, _) = AsyncHttpRangeReader::new( @@ -820,7 +824,9 @@ mod test { #[tokio::test] async fn test_not_found() { - let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR"))); + let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR"))) + .await + .expect("could not initialize server"); let err = AsyncHttpRangeReader::new( Client::new(), server.url().join("not-found").unwrap(), diff --git a/src/static_directory_server.rs b/src/static_directory_server.rs index 7d082f7..4cb2b4d 100644 --- a/src/static_directory_server.rs +++ b/src/static_directory_server.rs @@ -1,4 +1,5 @@ use axum::routing::get_service; +use axum::ServiceExt; use reqwest::Url; use std::net::SocketAddr; use std::path::Path; @@ -21,7 +22,7 @@ impl StaticDirectoryServer { } impl StaticDirectoryServer { - pub fn new(path: impl AsRef) -> Self { + pub async fn new(path: impl AsRef) -> Result { let service = get_service(ServeDir::new(path)); // Create a router that will serve the static files @@ -31,25 +32,27 @@ impl StaticDirectoryServer { // port is very important because it enables creating multiple instances at the same time. // We need this to be able to run tests in parallel. let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); - let server = axum::Server::bind(&addr).serve(app.into_make_service()); + let listener = tokio::net::TcpListener::bind(addr).await?; // Get the address of the server so we can bind to it at a later stage. - let addr = server.local_addr(); + let addr = listener.local_addr()?; // Setup a graceful shutdown trigger which is fired when this instance is dropped. let (tx, rx) = oneshot::channel(); - let server = server.with_graceful_shutdown(async { - rx.await.ok(); - }); - // Spawn the server. Let go of the JoinHandle, we can use the graceful shutdown trigger to - // stop the server. - tokio::spawn(server); + // Spawn the server in the background. + tokio::spawn(async move { + let _ = axum::serve(listener, app.into_make_service()) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .await; + }); - Self { + Ok(Self { local_addr: addr, shutdown_sender: Some(tx), - } + }) } } @@ -60,3 +63,9 @@ impl Drop for StaticDirectoryServer { } } } +/// Error type used for [`StaticDirectoryServerError`] +#[derive(Debug, thiserror::Error)] +pub enum StaticDirectoryServerError { + #[error(transparent)] + Io(#[from] std::io::Error), +}