Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow handlers to return user-defined error types #1180

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
af226ea
[WIP] custom error responses using `HttpResponse`
hawkw Nov 13, 2024
4d7c3e4
use a new trait, but HttpResponseContent
hawkw Nov 13, 2024
5bd7e3d
hmmm maybe this is good actually
hawkw Nov 14, 2024
2eff88a
wip schema generation
hawkw Nov 18, 2024
f2c7f5f
use schemars existing deduplication
hawkw Nov 18, 2024
8da1c05
use a refined type for error status
hawkw Nov 20, 2024
86b3afb
just have `HttpError` be a normal `HttpResponseError`
hawkw Nov 20, 2024
1f611bf
just rely on `schemars` to disambiguate colliding names
hawkw Nov 20, 2024
90e7247
start documenting stuff
hawkw Nov 20, 2024
06a1af3
TRYBUILD=overwrite
hawkw Nov 20, 2024
6de6de3
docs etc
hawkw Nov 20, 2024
cc4a2b9
remove unneeded `JsonSchema` impl for `HttpError`
hawkw Nov 20, 2024
c513c46
theory of operation comment in error module
hawkw Nov 20, 2024
cfc582b
actually, we can completely insulate the user from `HandlerError`
hawkw Nov 20, 2024
3d0575c
EXPECTORATE=overwrite
hawkw Nov 20, 2024
6b4b6d4
fix wsrong doctest
hawkw Nov 20, 2024
5f374b8
Merge branch 'main' into eliza/custom-error-httpresponse-result
hawkw Nov 21, 2024
53ed323
rustfmt (NEVER use the github web merge editor)
hawkw Nov 21, 2024
ab798a9
update to track merged changes
hawkw Nov 21, 2024
e87ad82
EXPECTORATE=overwrite
hawkw Nov 21, 2024
6c9c824
Apply docs suggestions from @ahl
hawkw Nov 21, 2024
10a4a99
remove local envrc
hawkw Nov 21, 2024
b9f194c
update copyright dates
hawkw Nov 21, 2024
576ba5f
reticulating comments
hawkw Nov 21, 2024
8a4d52f
reticulating comments
hawkw Nov 21, 2024
f9642d1
nicer error for missing `HttpResponse` impls
hawkw Nov 21, 2024
8f6d70e
fix trait-based stub API not knowing about error schemas
hawkw Nov 21, 2024
ccbbbe2
EXPECTORATE=overwrite
hawkw Nov 21, 2024
46b4df1
whoops i forgot to add changes to endpoint tests
hawkw Nov 22, 2024
00bcea7
convert `HttpError`s into endpoint's error type
hawkw Nov 22, 2024
a6c3472
add a note about `HttpError`
hawkw Nov 22, 2024
4c93e2e
reticulating implementation comments
hawkw Nov 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions dropshot/examples/custom-error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2024 Oxide Computer Company

//! An example demonstrating how to return user-defined error types from
//! endpoint handlers.

use dropshot::endpoint;
use dropshot::ApiDescription;
use dropshot::ConfigLogging;
use dropshot::ConfigLoggingLevel;
use dropshot::ErrorStatusCode;
use dropshot::HttpError;
use dropshot::HttpResponseError;
use dropshot::HttpResponseOk;
use dropshot::Path;
use dropshot::RequestContext;
use dropshot::ServerBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;

#[derive(Debug, thiserror::Error, Serialize, JsonSchema)]
enum ThingyError {
#[allow(dead_code)]
#[error("no thingies are currently available")]
NoThingies,
#[error("invalid thingy: {:?}", .name)]
InvalidThingy { name: String },
#[error("{message}")]
Other {
message: String,
#[serde(skip)]
internal_message: String,
#[serde(skip)]
status: ErrorStatusCode,
error_code: Option<String>,
},
}

/// Any type implementing `dropshot::HttpResponseError` and
/// `HttpResponseContent` may be used as an error type for a
/// return value from an endpoint handler.
impl HttpResponseError for ThingyError {
// Note that this method returns a `dropshot::ErrorStatusCode`, rather than
// an `http::StatusCode`. This type is a refinement of `http::StatusCode`
// that can only be constructed from status codes in 4xx (client error) or
// 5xx (server error) ranges.
fn status_code(&self) -> dropshot::ErrorStatusCode {
match self {
ThingyError::NoThingies => {
// The `dropshot::ErrorStatusCode` type provides constants for
// all well-known 4xx and 5xx status codes, such as 503 Service
// Unavailable.
dropshot::ErrorStatusCode::SERVICE_UNAVAILABLE
}
ThingyError::InvalidThingy { .. } => {
// Alternatively, an `ErrorStatusCode` can be constructed from a
// u16, but the `ErrorStatusCode::from_u16` constructor
// validates that the status code is a 4xx or 5xx.
//
// This allows using extended status codes, while still
// validating that they are errors.
dropshot::ErrorStatusCode::from_u16(442)
.expect("442 is a 4xx status code")
}
ThingyError::Other { status, .. } => *status,
}
}
}

impl From<HttpError> for ThingyError {
fn from(error: HttpError) -> Self {
ThingyError::Other {
message: error.external_message,
internal_message: error.internal_message,
status: error.status_code,
error_code: error.error_code,
}
}
}

/// Just some kind of thingy returned by the API. This doesn't actually matter.
#[derive(Deserialize, Serialize, JsonSchema)]
struct Thingy {
magic_number: u64,
}

#[derive(Deserialize, JsonSchema)]
struct ThingyPathParams {
name: String,
}

/// Fetch the thingy with the provided name.
#[endpoint {
method = GET,
path = "/thingy/{name}",
}]
async fn get_thingy(
_rqctx: RequestContext<()>,
path_params: Path<ThingyPathParams>,
) -> Result<HttpResponseOk<Thingy>, ThingyError> {
let ThingyPathParams { name } = path_params.into_inner();
Err(ThingyError::InvalidThingy { name })
}

#[endpoint {
method = GET,
path = "/nothing",
}]
async fn get_nothing(
_rqctx: RequestContext<()>,
) -> Result<HttpResponseOk<Thingy>, ThingyError> {
Err(ThingyError::NoThingies)
}

/// An example of an endpoint which returns a `Result<_, HttpError>`.
#[endpoint {
method = GET,
path = "/something",
}]
async fn get_something(
_rqctx: RequestContext<()>,
) -> Result<HttpResponseOk<Thingy>, dropshot::HttpError> {
Ok(HttpResponseOk(Thingy { magic_number: 42 }))
}

#[tokio::main]
async fn main() -> Result<(), String> {
// See dropshot/examples/basic.rs for more details on most of these pieces.
let config_logging =
ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Info };
let log = config_logging
.to_logger("example-custom-error")
.map_err(|error| format!("failed to create logger: {}", error))?;

let mut api = ApiDescription::new();
api.register(get_thingy).unwrap();
api.register(get_nothing).unwrap();
api.register(get_something).unwrap();

api.openapi("Custom Error Example", semver::Version::new(0, 0, 0))
.write(&mut std::io::stdout())
.map_err(|e| e.to_string())?;

let server = ServerBuilder::new(api, (), log)
.start()
.map_err(|error| format!("failed to create server: {}", error))?;

server.await
}
111 changes: 69 additions & 42 deletions dropshot/src/api_description.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright 2023 Oxide Computer Company
// Copyright 2024 Oxide Computer Company
//! Describes the endpoints and handler functions in your API

use crate::extractor::RequestExtractor;
use crate::handler::HttpHandlerFunc;
use crate::handler::HttpResponse;
use crate::handler::HttpResponseContent;
use crate::handler::HttpResponseError;
use crate::handler::HttpRouteHandler;
use crate::handler::RouteHandler;
use crate::handler::StubRouteHandler;
Expand All @@ -14,8 +16,6 @@ use crate::schema_util::j2oas_schema;
use crate::server::ServerContext;
use crate::type_util::type_is_scalar;
use crate::type_util::type_is_string_enum;
use crate::HttpError;
use crate::HttpErrorResponseBody;
use crate::CONTENT_TYPE_JSON;
use crate::CONTENT_TYPE_MULTIPART_FORM_DATA;
use crate::CONTENT_TYPE_OCTET_STREAM;
Expand Down Expand Up @@ -53,6 +53,7 @@ pub struct ApiEndpoint<Context: ServerContext> {
pub parameters: Vec<ApiEndpointParameter>,
pub body_content_type: ApiEndpointBodyContentType,
pub response: ApiEndpointResponse,
pub error: ApiEndpointErrorResponse,
pub summary: Option<String>,
pub description: Option<String>,
pub tags: Vec<String>,
Expand Down Expand Up @@ -81,6 +82,9 @@ impl<'a, Context: ServerContext> ApiEndpoint<Context> {
.expect("unsupported mime type");
let func_parameters = FuncParams::metadata(body_content_type.clone());
let response = ResponseType::response_metadata();
let error = ApiEndpointErrorResponse {
schema: <HandlerType::Error>::content_metadata(),
};
ApiEndpoint {
operation_id,
handler: HttpRouteHandler::new(handler),
Expand All @@ -89,6 +93,7 @@ impl<'a, Context: ServerContext> ApiEndpoint<Context> {
parameters: func_parameters.parameters,
body_content_type,
response,
error,
summary: None,
description: None,
tags: vec![],
Expand Down Expand Up @@ -179,7 +184,10 @@ impl<'a> ApiEndpoint<StubContext> {
ApiEndpointBodyContentType::from_mime_type(content_type)
.expect("unsupported mime type");
let func_parameters = FuncParams::metadata(body_content_type.clone());
let response = ResultType::Response::response_metadata();
let response = <ResultType::Response>::response_metadata();
let error = ApiEndpointErrorResponse {
schema: <ResultType::Error>::content_metadata(),
};
let handler = StubRouteHandler::new_with_name(&operation_id);
ApiEndpoint {
operation_id,
Expand All @@ -189,6 +197,7 @@ impl<'a> ApiEndpoint<StubContext> {
parameters: func_parameters.parameters,
body_content_type,
response,
error,
summary: None,
description: None,
tags: vec![],
Expand All @@ -202,13 +211,16 @@ impl<'a> ApiEndpoint<StubContext> {

pub trait HttpResultType {
type Response: HttpResponse + Send + Sync + 'static;
type Error: HttpResponseError + Send + Sync + 'static;
}

impl<T> HttpResultType for Result<T, HttpError>
impl<T, E> HttpResultType for Result<T, E>
where
T: HttpResponse + Send + Sync + 'static,
E: HttpResponseError + Send + Sync + 'static,
{
type Response = T;
type Error = E;
}

/// ApiEndpointParameter represents the discrete path and query parameters for a
Expand Down Expand Up @@ -333,6 +345,12 @@ pub struct ApiEndpointResponse {
pub description: Option<String>,
}

/// Metadata for an API endpoint's error response type.
#[derive(Debug, Default)]
pub struct ApiEndpointErrorResponse {
pub(crate) schema: Option<ApiSchemaGenerator>,
}

/// Wrapper for both dynamically generated and pre-generated schemas.
pub enum ApiSchemaGenerator {
Gen {
Expand Down Expand Up @@ -919,27 +937,59 @@ impl<Context: ServerContext> ApiDescription<Context> {
}
};

// If the endpoint defines an error type, emit that for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could have added to components.responses as before and then referenced that. I can see the inline approach you've taken as potentially simpler, though it does bloat up the json output...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I'd like to put them in components.responses, too. The reason I didn't is that it might be a bit annoying to determine the name for each response schema. schemars internally disambiguates colliding schema names by turning subsequent ones into like Error2 or whatever, but (AFAICT) we only get that when we actually generate the schema and it gives us back a reference (into components.schemas). We could then try to parse that reference and get the name back out to then use it to generate a components.responses entry for that response, which seems possible, I just thought it seemed annoying enough that I didn't really want to bother with it. Do you think it's worth doing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we name the response based on <T as JsonSchema>::schema_name()? Might that work?

Do I think it's worth doing? I think it's worth trying. It might make the code worse, but it might make the output simpler. At a minimum it will make the diffs against current json simpler. These together--I think--at least warrant giving it a shot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe that deduplication is applied to JsonSchema::schema_name(); as
far as I can tell, it only happens once a schema has already been generated,
because that's when the generator can check if the name already exists in the
set of schemas that have been generated so far?

// the 4xx and 5xx responses.
if let Some(ref schema) = endpoint.error.schema {
let error_schema = match schema {
ApiSchemaGenerator::Gen { ref name, ref schema } => {
j2oas_schema(Some(&name()), &schema(&mut generator))
}
ApiSchemaGenerator::Static {
ref schema,
ref dependencies,
} => {
definitions.extend(dependencies.clone());
j2oas_schema(None, &schema)
}
};
let mut content = indexmap::IndexMap::new();
content.insert(
CONTENT_TYPE_JSON.to_string(),
openapiv3::MediaType {
schema: Some(error_schema),
..Default::default()
},
);
operation.responses.responses.insert(
openapiv3::StatusCode::Range(4),
openapiv3::ReferenceOr::Item(openapiv3::Response {
description: "client error".to_string(),
content: content.clone(),
..Default::default()
}),
);
operation.responses.responses.insert(
openapiv3::StatusCode::Range(5),
openapiv3::ReferenceOr::Item(openapiv3::Response {
description: "server error".to_string(),
content: content.clone(),
..Default::default()
}),
);
}

if let Some(code) = &endpoint.response.success {
// `Ok` response has a known status code. In this case,
// emit it as the response for that status code only.
operation.responses.responses.insert(
openapiv3::StatusCode::Code(code.as_u16()),
openapiv3::ReferenceOr::Item(response),
);

// 4xx and 5xx responses all use the same error information
let err_ref = openapiv3::ReferenceOr::ref_(
"#/components/responses/Error",
);
operation
.responses
.responses
.insert(openapiv3::StatusCode::Range(4), err_ref.clone());
operation
.responses
.responses
.insert(openapiv3::StatusCode::Range(5), err_ref);
} else {
// The `Ok` response could be any status code, so emit it as
// the default response.
operation.responses.default =
Some(openapiv3::ReferenceOr::Item(response))
Some(openapiv3::ReferenceOr::Item(response));
}

// Drop in the operation.
Expand All @@ -950,29 +1000,6 @@ impl<Context: ServerContext> ApiDescription<Context> {
.components
.get_or_insert_with(openapiv3::Components::default);

// All endpoints share an error response
let responses = &mut components.responses;
let mut content = indexmap::IndexMap::new();
content.insert(
CONTENT_TYPE_JSON.to_string(),
openapiv3::MediaType {
schema: Some(j2oas_schema(
None,
&generator.subschema_for::<HttpErrorResponseBody>(),
)),
..Default::default()
},
);

responses.insert(
"Error".to_string(),
openapiv3::ReferenceOr::Item(openapiv3::Response {
description: "Error".to_string(),
content,
..Default::default()
}),
);

// Add the schemas for which we generated references.
let schemas = &mut components.schemas;

Expand Down
Loading