diff --git a/Cargo.lock b/Cargo.lock index 0347172..8c3bc63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,6 +661,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "dyn-clone" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30" + [[package]] name = "either" version = "1.8.1" @@ -1578,7 +1584,7 @@ checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "s3-active-storage" -version = "0.1.0" +version = "0.1.1" dependencies = [ "async-trait", "aws-credential-types", @@ -1598,6 +1604,7 @@ dependencies = [ "ndarray-stats", "num-traits", "regex", + "schemars", "serde", "serde_json", "serde_test", @@ -1624,6 +1631,32 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "schemars" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02c613288622e5f0c3fdc5dbd4db1c5fbe752746b1d1a56a0630b78fd00de44f" +dependencies = [ + "bytes", + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", + "url", +] + +[[package]] +name = "schemars_derive" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109da1e6b197438deb6db99952990c7f959572794b80ff93707d55a232545e7c" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 1.0.109", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -1689,6 +1722,17 @@ dependencies = [ "syn 2.0.15", ] +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "serde_json" version = "1.0.96" diff --git a/Cargo.toml b/Cargo.toml index 7af401b..c9ee49a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ ndarray-stats = "0.5" num-traits = "0.2.15" serde = { version = "1.0", features = ["derive"] } serde_json = "*" +schemars = { version = "0.8.11", features = ["url", "bytes"] } strum_macros = "0.24" thiserror = "1.0" tokio = { version = "1.28", features = ["full"] } diff --git a/src/app.rs b/src/app.rs index e4f9d14..fdf0f02 100644 --- a/src/app.rs +++ b/src/app.rs @@ -19,6 +19,7 @@ use axum::{ Router, TypedHeader, }; +use schemars::schema_for; use tower::Layer; use tower::ServiceBuilder; use tower_http::normalize_path::NormalizePathLayer; @@ -82,7 +83,14 @@ fn router() -> Router { } Router::new() - .route("/.well-known/s3-active-storage-schema", get(schema)) + .route( + "/.well-known/s3-active-storage-request-schema", + get(request_schema), + ) + .route( + "/.well-known/s3-active-storage-response-schema", + get(response_schema), + ) .nest("/v1", v1()) } @@ -110,8 +118,19 @@ pub fn service() -> Service { } /// TODO: Return an OpenAPI schema -async fn schema() -> &'static str { - "Hello, world!" +async fn request_schema() -> Result { + let result = serde_json::to_string_pretty(&schema_for!(models::RequestData)); + match result { + Ok(json_schema) => Ok(json_schema), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } +} +async fn response_schema() -> Result { + let result = serde_json::to_string_pretty(&schema_for!(models::Response)); + match result { + Ok(json_schema) => Ok(json_schema), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + } } /// Download an object from S3 diff --git a/src/models.rs b/src/models.rs index 5dc2c86..7ab7c1a 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,13 +1,14 @@ //! Data types and associated functions and methods use axum::body::Bytes; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use strum_macros::Display; use url::Url; use validator::{Validate, ValidationError}; /// Supported numerical data types -#[derive(Clone, Copy, Debug, Deserialize, Display, PartialEq)] +#[derive(Clone, Copy, Debug, Deserialize, Display, PartialEq, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum DType { /// [i32] @@ -41,7 +42,7 @@ impl DType { /// Array ordering /// /// Defines an ordering for multi-dimensional arrays. -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize, PartialEq, JsonSchema)] pub enum Order { /// Row-major (C) ordering C, @@ -51,7 +52,7 @@ pub enum Order { /// A slice of a single dimension of an array /// -/// The API uses NumPy slice semantics: +/// The API uses NumPy slice (i.e. [start, end, stride]) semantics where: /// /// When start or end is negative: /// * positive_start = start + length @@ -65,7 +66,7 @@ pub enum Order { /// * positive_end <= i < positive_start // NOTE: In serde, structs can be deserialised from sequences or maps. This allows us to support // the [, , ] API, with the convenience of named fields. -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize, Validate)] +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize, Validate, JsonSchema)] #[serde(deny_unknown_fields)] #[validate(schema(function = "validate_slice"))] pub struct Slice { @@ -86,7 +87,7 @@ impl Slice { } /// Request data for operations -#[derive(Debug, Deserialize, PartialEq, Validate)] +#[derive(Debug, Deserialize, PartialEq, Validate, JsonSchema)] #[serde(deny_unknown_fields)] #[validate(schema(function = "validate_request_data"))] pub struct RequestData { @@ -179,8 +180,9 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr } /// Response containing the result of a computation and associated metadata. +#[derive(JsonSchema)] pub struct Response { - /// Response data. May be a scalar or multi-dimensional array. + /// Raw response data as bytes. May represent a scalar or multi-dimensional array. pub body: Bytes, /// Data type of the response pub dtype: DType,