Skip to content

Commit

Permalink
chore(replicate): improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
roushou committed Sep 8, 2024
1 parent 51720ba commit 5521932
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 130 deletions.
202 changes: 82 additions & 120 deletions replicate/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
Method, RequestBuilder, Url,
Method, RequestBuilder, Response, StatusCode, Url,
};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -47,24 +47,14 @@ impl Client {

/// Get the authenticated account.
pub async fn account(&self) -> Result<Account, Error> {
let response = self
.request(Method::GET, "account")?
.send()
.await?
.json::<Account>()
.await?;
Ok(response)
let response = self.request(Method::GET, "account")?.send().await?;
self.handle_response::<Account>(response).await
}

/// List collections of models.
pub async fn collections(&self) -> Result<ListCollections, Error> {
let response = self
.request(Method::GET, "collections")?
.send()
.await?
.json::<ListCollections>()
.await?;
Ok(response)
let response = self.request(Method::GET, "collections")?.send().await?;
self.handle_response::<ListCollections>(response).await
}

/// List collection of models.
Expand All @@ -73,36 +63,21 @@ impl Client {
collection: String,
) -> Result<ListCollectionModels, Error> {
let path = format!("collections/{}", collection);
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<ListCollectionModels>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<ListCollectionModels>(response).await
}

/// Get information about a deployment by name including the current release.
pub async fn deployment(&self, owner: String, name: String) -> Result<Deployment, Error> {
let path = format!("deployments/{}/{}", owner, name);
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<Deployment>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<Deployment>(response).await
}

/// List deployments associated with the current account, including the latest release configuration for each deployment.
pub async fn deployments(&self) -> Result<ListDeployments, Error> {
let response = self
.request(Method::GET, "deployments")?
.send()
.await?
.json::<ListDeployments>()
.await?;
Ok(response)
let response = self.request(Method::GET, "deployments")?.send().await?;
self.handle_response::<ListDeployments>(response).await
}

/// Create a new deployment.
Expand All @@ -111,10 +86,8 @@ impl Client {
.request(Method::POST, "deployments")?
.json(&payload)
.send()
.await?
.json::<Deployment>()
.await?;
Ok(response)
self.handle_response::<Deployment>(response).await
}

/// Update a deployment.
Expand All @@ -128,10 +101,8 @@ impl Client {
.request(Method::PATCH, path.as_str())?
.json(&payload)
.send()
.await?
.json::<Deployment>()
.await?;
Ok(response)
self.handle_response::<Deployment>(response).await
}

/// Delete a deployment.
Expand All @@ -140,31 +111,21 @@ impl Client {
/// - You can only delete deployments that have been offline and unused for at least 15 minutes.
pub async fn delete_deployment(&self, owner: String, name: String) -> Result<(), Error> {
let path = format!("deployments/{}/{}", owner, name);
self.request(Method::DELETE, path.as_str())?.send().await?;
Ok(())
let response = self.request(Method::DELETE, path.as_str())?.send().await?;
self.handle_response(response).await
}

/// Get a prediction.
pub async fn prediction(&self, prediction_id: String) -> Result<Prediction, Error> {
let path = format!("predictions/{}", prediction_id);
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<Prediction>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<Prediction>(response).await
}

/// List predictions.
pub async fn predictions(&self) -> Result<ListPredictions, Error> {
let response = self
.request(Method::GET, "predictions")?
.send()
.await?
.json::<ListPredictions>()
.await?;
Ok(response)
let response = self.request(Method::GET, "predictions")?.send().await?;
self.handle_response::<ListPredictions>(response).await
}

/// Create a prediction.
Expand All @@ -173,10 +134,8 @@ impl Client {
.request(Method::POST, "predictions")?
.json(&payload)
.send()
.await?
.json::<Prediction>()
.await?;
Ok(response)
self.handle_response::<Prediction>(response).await
}

/// Create a prediction from an official model
Expand All @@ -189,69 +148,47 @@ impl Client {
.request(Method::POST, path.as_str())?
.json(&serde_json::json!({ "input": payload.input }))
.send()
.await?
.json::<Prediction>()
.await?;
Ok(response)
self.handle_response::<Prediction>(response).await
}

/// Cancel a prediction.
pub async fn cancel_prediction(&self, prediction_id: String) -> Result<(), Error> {
let path = format!("predictions/{}/cancel", prediction_id);
self.request(Method::POST, path.as_str())?.send().await?;
Ok(())
let response = self.request(Method::POST, path.as_str())?.send().await?;
self.handle_response(response).await
}

/// Get a training.
pub async fn training(&self, training_id: String) -> Result<Training, Error> {
let path = format!("trainings/{}", training_id);
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<Training>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<Training>(response).await
}

/// List trainings.
pub async fn trainings(&self) -> Result<ListTrainings, Error> {
let response = self
.request(Method::GET, "trainings")?
.send()
.await?
.json::<ListTrainings>()
.await?;
Ok(response)
let response = self.request(Method::GET, "trainings")?.send().await?;
self.handle_response::<ListTrainings>(response).await
}

/// Cancel a training.
pub async fn cancel_training(&self, training_id: String) -> Result<(), Error> {
let path = format!("trainings/{}/cancel", training_id);
self.request(Method::POST, path.as_str())?.send().await?;
Ok(())
let response = self.request(Method::POST, path.as_str())?.send().await?;
self.handle_response(response).await
}

/// List available hardware for models.
pub async fn hardware(&self) -> Result<Vec<Hardware>, Error> {
let response = self
.request(Method::GET, "hardware")?
.send()
.await?
.json::<Vec<Hardware>>()
.await?;
Ok(response)
let response = self.request(Method::GET, "hardware")?.send().await?;
self.handle_response::<Vec<Hardware>>(response).await
}

/// List public models.
pub async fn public_models(&self) -> Result<ListPublicModels, Error> {
let response = self
.request(Method::GET, "models")?
.send()
.await?
.json::<ListPublicModels>()
.await?;
Ok(response)
let response = self.request(Method::GET, "models")?.send().await?;
self.handle_response::<ListPublicModels>(response).await
}

/// Get model.
Expand All @@ -261,13 +198,8 @@ impl Client {
name: impl Into<String>,
) -> Result<Model, Error> {
let path = format!("models/{}/{}", owner.into(), name.into());
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<Model>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<Model>(response).await
}

/// List model versions.
Expand All @@ -277,13 +209,8 @@ impl Client {
name: impl Into<String>,
) -> Result<ListModelVersions, Error> {
let path = format!("models/{}/{}/versions", owner.into(), name.into());
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<ListModelVersions>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<ListModelVersions>(response).await
}

/// Get model version.
Expand All @@ -299,24 +226,17 @@ impl Client {
name.into(),
version_id.into()
);
let response = self
.request(Method::GET, path.as_str())?
.send()
.await?
.json::<ModelVersion>()
.await?;
Ok(response)
let response = self.request(Method::GET, path.as_str())?.send().await?;
self.handle_response::<ModelVersion>(response).await
}

/// Get WebHook default secret
pub async fn webhook_default_secret(&self) -> Result<WebHookSecret, Error> {
let response = self
.request(Method::GET, "webhooks/default/secret")?
.send()
.await?
.json::<WebHookSecret>()
.await?;
Ok(response)
self.handle_response::<WebHookSecret>(response).await
}

fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
Expand All @@ -326,6 +246,48 @@ impl Client {
.map_err(|err| Error::UrlParse(err.to_string()))?;
Ok(self.http_client.request(method, url))
}

async fn handle_response<T>(&self, response: Response) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
{
let status = response.status();
if status.is_success() | status.is_redirection() {
match response.json::<T>().await {
Ok(data) => Ok(data),
// TODO: this should be a serde error
Err(err) => Err(Error::HttpRequest(err)),
}
} else {
match status {
StatusCode::BAD_REQUEST => {
let error_msg = response.text().await?;
Err(Error::BadRequest(error_msg))
}
StatusCode::UNAUTHORIZED => {
let error_msg = response.text().await?;
Err(Error::Unauthorized(error_msg))
}
StatusCode::FORBIDDEN => {
let error_msg = response.text().await?;
Err(Error::Forbidden(error_msg))
}
StatusCode::TOO_MANY_REQUESTS => {
let error_msg = response.text().await?;
Err(Error::RateLimited(error_msg))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error_msg = response.text().await?;
Err(Error::InternalServerError(error_msg))
}
StatusCode::SERVICE_UNAVAILABLE => {
let error_msg = response.text().await?;
Err(Error::ServiceUnavailable(error_msg))
}
status => Err(Error::UnexpectedStatus(status)),
}
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
Loading

0 comments on commit 5521932

Please sign in to comment.