Skip to content

Commit

Permalink
Merge pull request #81 from stackhpc/resource-management
Browse files Browse the repository at this point in the history
Add support for resource management
  • Loading branch information
markgoddard authored Sep 22, 2023
2 parents 138c501 + fcb8b86 commit b832823
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 11 deletions.
6 changes: 4 additions & 2 deletions benches/s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use aws_sdk_s3::Client;
use aws_types::region::Region;
use axum::body::Bytes;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use reductionist::resource_manager::ResourceManager;
use reductionist::s3_client::{S3Client, S3ClientMap};
use url::Url;
// Bring trait into scope to use as_bytes method.
Expand Down Expand Up @@ -42,6 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let bucket = "s3-client-bench";
let runtime = tokio::runtime::Runtime::new().unwrap();
let map = S3ClientMap::new();
let resource_manager = ResourceManager::new(None, None, None);
for size_k in [64, 256, 1024] {
let size: isize = size_k * 1024;
let data: Vec<u32> = (0_u32..(size as u32)).collect::<Vec<u32>>();
Expand All @@ -53,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) {
b.to_async(&runtime).iter(|| async {
let client = S3Client::new(&url, username, password).await;
client
.download_object(black_box(bucket), &key, None)
.download_object(black_box(bucket), &key, None, &resource_manager, &mut None)
.await
.unwrap();
})
Expand All @@ -63,7 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) {
b.to_async(&runtime).iter(|| async {
let client = map.get(&url, username, password).await;
client
.download_object(black_box(bucket), &key, None)
.download_object(black_box(bucket), &key, None, &resource_manager, &mut None)
.await
.unwrap();
})
Expand Down
41 changes: 35 additions & 6 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::metrics::{metrics_handler, track_metrics};
use crate::models;
use crate::operation;
use crate::operations;
use crate::resource_manager::ResourceManager;
use crate::s3_client;
use crate::types::{ByteOrder, NATIVE_BYTE_ORDER};
use crate::validated_json::ValidatedJson;
Expand All @@ -25,6 +26,7 @@ use axum::{
};

use std::sync::Arc;
use tokio::sync::SemaphorePermit;
use tower::Layer;
use tower::ServiceBuilder;
use tower_http::normalize_path::NormalizePathLayer;
Expand Down Expand Up @@ -54,14 +56,21 @@ struct AppState {

/// Map of S3 client objects.
s3_client_map: s3_client::S3ClientMap,

/// Resource manager.
resource_manager: ResourceManager,
}

impl AppState {
/// Create and return an [AppState].
fn new(args: &CommandLineArgs) -> Self {
let task_limit = args.thread_limit.or_else(|| Some(num_cpus::get() - 1));
let resource_manager =
ResourceManager::new(args.s3_connection_limit, args.memory_limit, task_limit);
Self {
args: args.clone(),
s3_client_map: s3_client::S3ClientMap::new(),
resource_manager,
}
}
}
Expand Down Expand Up @@ -176,14 +185,26 @@ async fn schema() -> &'static str {
///
/// * `auth`: Basic authentication credentials
/// * `request_data`: RequestData object for the request
#[tracing::instrument(level = "DEBUG", skip(client, request_data))]
async fn download_object(
#[tracing::instrument(
level = "DEBUG",
skip(client, request_data, resource_manager, mem_permits)
)]
async fn download_object<'a>(
client: &s3_client::S3Client,
request_data: &models::RequestData,
resource_manager: &'a ResourceManager,
mem_permits: &mut Option<SemaphorePermit<'a>>,
) -> Result<Bytes, ActiveStorageError> {
let range = s3_client::get_range(request_data.offset, request_data.size);
let _conn_permits = resource_manager.s3_connection().await?;
client
.download_object(&request_data.bucket, &request_data.object, range)
.download_object(
&request_data.bucket,
&request_data.object,
range,
resource_manager,
mem_permits,
)
.await
}

Expand All @@ -206,19 +227,27 @@ async fn operation_handler<T: operation::Operation>(
TypedHeader(auth): TypedHeader<Authorization<Basic>>,
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
) -> Result<models::Response, ActiveStorageError> {
let memory = request_data.size.unwrap_or(0);
let mut _mem_permits = state.resource_manager.memory(memory).await?;
let s3_client = state
.s3_client_map
.get(&request_data.source, auth.username(), auth.password())
.instrument(tracing::Span::current())
.await;
let data = download_object(&s3_client, &request_data)
.instrument(tracing::Span::current())
.await?;
let data = download_object(
&s3_client,
&request_data,
&state.resource_manager,
&mut _mem_permits,
)
.instrument(tracing::Span::current())
.await?;
// All remaining work is synchronous. If the use_rayon argument was specified, delegate to the
// Rayon thread pool. Otherwise, execute as normal using Tokio.
if state.args.use_rayon {
tokio_rayon::spawn(move || operation::<T>(request_data, data)).await
} else {
let _task_permit = state.resource_manager.task().await?;
operation::<T>(request_data, data)
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ pub struct CommandLineArgs {
/// Whether to use Rayon for execution of CPU-bound tasks.
#[arg(long, default_value_t = false, env = "REDUCTIONIST_USE_RAYON")]
pub use_rayon: bool,
/// Memory limit in bytes. Default is no limit.
#[arg(long, env = "REDUCTIONIST_MEMORY_LIMIT")]
pub memory_limit: Option<usize>,
/// S3 connection limit. Default is no limit.
#[arg(long, env = "REDUCTIONIST_S3_CONNECTION_LIMIT")]
pub s3_connection_limit: Option<usize>,
/// Thread limit for CPU-bound tasks. Default is one less than the number of CPUs. Used only
/// when use_rayon is false.
#[arg(long, env = "REDUCTIONIST_THREAD_LIMIT")]
pub thread_limit: Option<usize>,
}

/// Returns parsed command line arguments.
Expand Down
39 changes: 38 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ndarray::ShapeError;
use serde::{Deserialize, Serialize};
use std::error::Error;
use thiserror::Error;
use tokio::sync::AcquireError;
use tracing::{event, Level};
use zune_inflate::errors::InflateDecodeErrors;

Expand Down Expand Up @@ -41,9 +42,14 @@ pub enum ActiveStorageError {
#[error("failed to convert from bytes to {type_name}")]
FromBytes { type_name: &'static str },

/// Incompatible missing data descriptor
#[error("Incompatible value {0} for missing")]
IncompatibleMissing(DValue),

/// Insufficient memory to process request
#[error("Insufficient memory to process request ({requested} > {total})")]
InsufficientMemory { requested: usize, total: usize },

/// Error deserialising request data into RequestData
#[error("request data is not valid")]
RequestDataJsonRejection(#[from] JsonRejection),
Expand All @@ -64,6 +70,10 @@ pub enum ActiveStorageError {
#[error("error retrieving object from S3 storage")]
S3GetObject(#[from] SdkError<GetObjectError>),

/// Error acquiring a semaphore
#[error("error acquiring resources")]
SemaphoreAcquireError(#[from] AcquireError),

/// Error creating ndarray ArrayView from Shape
#[error("failed to create array from shape")]
ShapeInvalid(#[from] ShapeError),
Expand Down Expand Up @@ -196,6 +206,10 @@ impl From<ActiveStorageError> for ErrorResponse {
| ActiveStorageError::DecompressionZune(_)
| ActiveStorageError::EmptyArray { operation: _ }
| ActiveStorageError::IncompatibleMissing(_)
| ActiveStorageError::InsufficientMemory {
requested: _,
total: _,
}
| ActiveStorageError::RequestDataJsonRejection(_)
| ActiveStorageError::RequestDataValidationSingle(_)
| ActiveStorageError::RequestDataValidation(_)
Expand All @@ -207,7 +221,8 @@ impl From<ActiveStorageError> for ErrorResponse {
// Internal server error
ActiveStorageError::FromBytes { type_name: _ }
| ActiveStorageError::TryFromInt(_)
| ActiveStorageError::S3ByteStream(_) => Self::internal_server_error(&error),
| ActiveStorageError::S3ByteStream(_)
| ActiveStorageError::SemaphoreAcquireError(_) => Self::internal_server_error(&error),

ActiveStorageError::S3GetObject(sdk_error) => {
// Tailor the response based on the specific SdkError variant.
Expand Down Expand Up @@ -377,6 +392,17 @@ mod tests {
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
}

#[tokio::test]
async fn insufficient_memory() {
let error = ActiveStorageError::InsufficientMemory {
requested: 2,
total: 1,
};
let message = "Insufficient memory to process request (2 > 1)";
let caused_by = None;
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
}

#[tokio::test]
async fn request_data_validation_single() {
let validation_error = validator::ValidationError::new("foo");
Expand Down Expand Up @@ -504,6 +530,17 @@ mod tests {
.await;
}

#[tokio::test]
async fn semaphore_acquire_error() {
let sem = tokio::sync::Semaphore::new(1);
sem.close();
let error = ActiveStorageError::SemaphoreAcquireError(sem.acquire().await.unwrap_err());
let message = "error acquiring resources";
let caused_by = Some(vec!["semaphore closed"]);
test_active_storage_error(error, StatusCode::INTERNAL_SERVER_ERROR, message, caused_by)
.await;
}

#[tokio::test]
async fn shape_error() {
let error = ActiveStorageError::ShapeInvalid(ShapeError::from_kind(
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ pub mod metrics;
pub mod models;
pub mod operation;
pub mod operations;
pub mod resource_manager;
pub mod s3_client;
pub mod server;
#[cfg(test)]
Expand Down
127 changes: 127 additions & 0 deletions src/resource_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//! Resource management
use crate::error::ActiveStorageError;

use tokio::sync::{Semaphore, SemaphorePermit};

/// [crate::resource_manager::ResourceManager] provides a simple way to allocate various resources
/// to tasks. Resource management is performed using a Tokio Semaphore for each type of resource.
pub struct ResourceManager {
/// Optional semaphore for S3 connections.
s3_connections: Option<Semaphore>,

/// Optional semaphore for memory (bytes).
memory: Option<Semaphore>,

/// Optional total memory pool in bytes.
total_memory: Option<usize>,

/// Optional semaphore for tasks.
tasks: Option<Semaphore>,
}

impl ResourceManager {
/// Returns a new ResourceManager object.
pub fn new(
s3_connection_limit: Option<usize>,
memory_limit: Option<usize>,
task_limit: Option<usize>,
) -> Self {
Self {
s3_connections: s3_connection_limit.map(Semaphore::new),
memory: memory_limit.map(Semaphore::new),
total_memory: memory_limit,
tasks: task_limit.map(Semaphore::new),
}
}

/// Acquire an S3 connection resource.
pub async fn s3_connection(&self) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
optional_acquire(&self.s3_connections, 1).await
}

/// Acquire memory resource.
pub async fn memory(
&self,
bytes: usize,
) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
if let Some(total_memory) = self.total_memory {
if bytes > total_memory {
return Err(ActiveStorageError::InsufficientMemory {
requested: bytes,
total: total_memory,
});
};
};
optional_acquire(&self.memory, bytes).await
}

/// Acquire a task resource.
pub async fn task(&self) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
optional_acquire(&self.tasks, 1).await
}
}

/// Acquire permits on an optional Semaphore, if present.
async fn optional_acquire(
sem: &Option<Semaphore>,
n: usize,
) -> Result<Option<SemaphorePermit>, ActiveStorageError> {
let n = n.try_into()?;
if let Some(sem) = sem {
sem.acquire_many(n)
.await
.map(Some)
.map_err(|err| err.into())
} else {
Ok(None)
}
}

#[cfg(test)]
mod tests {
use super::*;

use tokio::sync::TryAcquireError;

#[tokio::test]
async fn no_resource_management() {
let rm = ResourceManager::new(None, None, None);
assert!(rm.s3_connections.is_none());
assert!(rm.memory.is_none());
assert!(rm.tasks.is_none());
let _c = rm.s3_connection().await.unwrap();
let _m = rm.memory(1).await.unwrap();
let _t = rm.task().await.unwrap();
assert!(_c.is_none());
assert!(_m.is_none());
assert!(_t.is_none());
}

#[tokio::test]
async fn full_resource_management() {
let rm = ResourceManager::new(Some(1), Some(1), Some(1));
assert!(rm.s3_connections.is_some());
assert!(rm.memory.is_some());
assert!(rm.tasks.is_some());
let _c = rm.s3_connection().await.unwrap();
let _m = rm.memory(1).await.unwrap();
let _t = rm.task().await.unwrap();
assert!(_c.is_some());
assert!(_m.is_some());
assert!(_t.is_some());
// Check that there are no more resources (without blocking).
assert_eq!(
rm.s3_connections.as_ref().unwrap().try_acquire().err(),
Some(TryAcquireError::NoPermits)
);
assert_eq!(
rm.memory.as_ref().unwrap().try_acquire().err(),
Some(TryAcquireError::NoPermits)
);
assert_eq!(
rm.tasks.as_ref().unwrap().try_acquire().err(),
Some(TryAcquireError::NoPermits)
);
}
}
Loading

0 comments on commit b832823

Please sign in to comment.