diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a1a71..1d11a6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## What's Changed in replic-v0.1.3 +* chore(replicate): improve error handling + +**Full Changelog**: https://github.com///compare/replic-v0.1.2...replic-v0.1.3 + ## What's Changed in 0.2.0 * docs: update * chore(mesh): bump anthropic version diff --git a/Cargo.lock b/Cargo.lock index a2b48ac..22b3ce7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -711,7 +711,7 @@ dependencies = [ [[package]] name = "replic" -version = "0.1.2" +version = "0.1.3" dependencies = [ "futures-util", "pretty_assertions", diff --git a/replicate/Cargo.toml b/replicate/Cargo.toml index ed1794c..67e4682 100644 --- a/replicate/Cargo.toml +++ b/replicate/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "replic" -version = "0.1.2" +version = "0.1.3" edition = "2021" authors = ["Roushou "] description = "Replicate Rust SDK" diff --git a/replicate/src/client.rs b/replicate/src/client.rs index c256023..0dd0348 100644 --- a/replicate/src/client.rs +++ b/replicate/src/client.rs @@ -1,6 +1,6 @@ use reqwest::{ header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}, - Method, RequestBuilder, Url, + Method, RequestBuilder, Response, StatusCode, Url, }; use serde::{Deserialize, Serialize}; @@ -47,24 +47,14 @@ impl Client { /// Get the authenticated account. pub async fn account(&self) -> Result { - let response = self - .request(Method::GET, "account")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "account")?.send().await?; + self.handle_response::(response).await } /// List collections of models. pub async fn collections(&self) -> Result { - let response = self - .request(Method::GET, "collections")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "collections")?.send().await?; + self.handle_response::(response).await } /// List collection of models. @@ -73,36 +63,21 @@ impl Client { collection: String, ) -> Result { let path = format!("collections/{}", collection); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// Get information about a deployment by name including the current release. pub async fn deployment(&self, owner: String, name: String) -> Result { let path = format!("deployments/{}/{}", owner, name); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// List deployments associated with the current account, including the latest release configuration for each deployment. pub async fn deployments(&self) -> Result { - let response = self - .request(Method::GET, "deployments")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "deployments")?.send().await?; + self.handle_response::(response).await } /// Create a new deployment. @@ -111,10 +86,8 @@ impl Client { .request(Method::POST, "deployments")? .json(&payload) .send() - .await? - .json::() .await?; - Ok(response) + self.handle_response::(response).await } /// Update a deployment. @@ -128,10 +101,8 @@ impl Client { .request(Method::PATCH, path.as_str())? .json(&payload) .send() - .await? - .json::() .await?; - Ok(response) + self.handle_response::(response).await } /// Delete a deployment. @@ -140,31 +111,21 @@ impl Client { /// - You can only delete deployments that have been offline and unused for at least 15 minutes. pub async fn delete_deployment(&self, owner: String, name: String) -> Result<(), Error> { let path = format!("deployments/{}/{}", owner, name); - self.request(Method::DELETE, path.as_str())?.send().await?; - Ok(()) + let response = self.request(Method::DELETE, path.as_str())?.send().await?; + self.handle_response(response).await } /// Get a prediction. pub async fn prediction(&self, prediction_id: String) -> Result { let path = format!("predictions/{}", prediction_id); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// List predictions. pub async fn predictions(&self) -> Result { - let response = self - .request(Method::GET, "predictions")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "predictions")?.send().await?; + self.handle_response::(response).await } /// Create a prediction. @@ -173,10 +134,8 @@ impl Client { .request(Method::POST, "predictions")? .json(&payload) .send() - .await? - .json::() .await?; - Ok(response) + self.handle_response::(response).await } /// Create a prediction from an official model @@ -189,69 +148,47 @@ impl Client { .request(Method::POST, path.as_str())? .json(&serde_json::json!({ "input": payload.input })) .send() - .await? - .json::() .await?; - Ok(response) + self.handle_response::(response).await } /// Cancel a prediction. pub async fn cancel_prediction(&self, prediction_id: String) -> Result<(), Error> { let path = format!("predictions/{}/cancel", prediction_id); - self.request(Method::POST, path.as_str())?.send().await?; - Ok(()) + let response = self.request(Method::POST, path.as_str())?.send().await?; + self.handle_response(response).await } /// Get a training. pub async fn training(&self, training_id: String) -> Result { let path = format!("trainings/{}", training_id); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// List trainings. pub async fn trainings(&self) -> Result { - let response = self - .request(Method::GET, "trainings")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "trainings")?.send().await?; + self.handle_response::(response).await } /// Cancel a training. pub async fn cancel_training(&self, training_id: String) -> Result<(), Error> { let path = format!("trainings/{}/cancel", training_id); - self.request(Method::POST, path.as_str())?.send().await?; - Ok(()) + let response = self.request(Method::POST, path.as_str())?.send().await?; + self.handle_response(response).await } /// List available hardware for models. pub async fn hardware(&self) -> Result, Error> { - let response = self - .request(Method::GET, "hardware")? - .send() - .await? - .json::>() - .await?; - Ok(response) + let response = self.request(Method::GET, "hardware")?.send().await?; + self.handle_response::>(response).await } /// List public models. pub async fn public_models(&self) -> Result { - let response = self - .request(Method::GET, "models")? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, "models")?.send().await?; + self.handle_response::(response).await } /// Get model. @@ -261,13 +198,8 @@ impl Client { name: impl Into, ) -> Result { let path = format!("models/{}/{}", owner.into(), name.into()); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// List model versions. @@ -277,13 +209,8 @@ impl Client { name: impl Into, ) -> Result { let path = format!("models/{}/{}/versions", owner.into(), name.into()); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// Get model version. @@ -299,13 +226,8 @@ impl Client { name.into(), version_id.into() ); - let response = self - .request(Method::GET, path.as_str())? - .send() - .await? - .json::() - .await?; - Ok(response) + let response = self.request(Method::GET, path.as_str())?.send().await?; + self.handle_response::(response).await } /// Get WebHook default secret @@ -313,10 +235,8 @@ impl Client { let response = self .request(Method::GET, "webhooks/default/secret")? .send() - .await? - .json::() .await?; - Ok(response) + self.handle_response::(response).await } fn request(&self, method: Method, path: &str) -> Result { @@ -326,6 +246,48 @@ impl Client { .map_err(|err| Error::UrlParse(err.to_string()))?; Ok(self.http_client.request(method, url)) } + + async fn handle_response(&self, response: Response) -> Result + where + T: serde::de::DeserializeOwned, + { + let status = response.status(); + if status.is_success() | status.is_redirection() { + match response.json::().await { + Ok(data) => Ok(data), + // TODO: this should be a serde error + Err(err) => Err(Error::HttpRequest(err)), + } + } else { + match status { + StatusCode::BAD_REQUEST => { + let error_msg = response.text().await?; + Err(Error::BadRequest(error_msg)) + } + StatusCode::UNAUTHORIZED => { + let error_msg = response.text().await?; + Err(Error::Unauthorized(error_msg)) + } + StatusCode::FORBIDDEN => { + let error_msg = response.text().await?; + Err(Error::Forbidden(error_msg)) + } + StatusCode::TOO_MANY_REQUESTS => { + let error_msg = response.text().await?; + Err(Error::RateLimited(error_msg)) + } + StatusCode::INTERNAL_SERVER_ERROR => { + let error_msg = response.text().await?; + Err(Error::InternalServerError(error_msg)) + } + StatusCode::SERVICE_UNAVAILABLE => { + let error_msg = response.text().await?; + Err(Error::ServiceUnavailable(error_msg)) + } + status => Err(Error::UnexpectedStatus(status)), + } + } + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/replicate/src/error.rs b/replicate/src/error.rs index 962895b..a720925 100644 --- a/replicate/src/error.rs +++ b/replicate/src/error.rs @@ -1,18 +1,37 @@ +use reqwest::StatusCode; use serde::Deserialize; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("API error: {0}")] - Api(ApiError), - - #[error("HTTP client error: {0}")] - Network(#[from] reqwest::Error), + #[error("HTTP request failed: {0}")] + HttpRequest(#[from] reqwest::Error), #[error("URL parse error: {0}")] UrlParse(String), + #[error("Resource not found: {0}")] + NotFound(String), + + #[error("Bad request: {0}")] + BadRequest(String), + + #[error("Unauthorized: {0}")] + Unauthorized(String), + + #[error("Forbidden: {0}")] + Forbidden(String), + + #[error("Rate limited: {0}")] + RateLimited(String), + + #[error("Internal server error: {0}")] + InternalServerError(String), + + #[error("Service unavailable: {0}")] + ServiceUnavailable(String), + #[error("Failed to deserialize: {0}")] - JsonDeserialize(#[from] serde_json::Error), + JsonDeserialization(#[from] serde_json::Error), #[error("Invalid header value: {0}")] InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue), @@ -20,11 +39,11 @@ pub enum Error { #[error("Missing API key {0}")] MissingApiKey(&'static str), - #[error("Invalid Stream Event")] - InvalidStreamEvent, + #[error("Unexpected status code: {0}")] + UnexpectedStatus(StatusCode), - #[error("Unexpected error: {0}")] - Unexpected(String), + #[error("API error: {0}")] + Api(ApiError), } #[derive(Debug, Deserialize, PartialEq, Eq, thiserror::Error)]