Skip to content

Commit

Permalink
feat(replicate): support creating prediction from official model
Browse files Browse the repository at this point in the history
  • Loading branch information
roushou committed Sep 8, 2024
1 parent 4e4e848 commit 273c49f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
Binary file added assets/dragon.webp
Binary file not shown.
3 changes: 3 additions & 0 deletions examples/replicate-basic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Replicate Basic Example

![spartan](../../assets/dragon.webp)
27 changes: 14 additions & 13 deletions examples/replicate-basic/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use replic::{
client::{Client, CreatePrediction},
client::{Client, CreateModelPrediction},
config::Config,
};

Expand All @@ -8,16 +8,17 @@ async fn main() {
let config = Config::from_env().unwrap();
let client = Client::new(config).unwrap();

let collections = client
.create_prediction(CreatePrediction {
version: "f2ab8a5bfe79f02f0789a146cf5e73d2a4ff2684a98c2b303d1e1ff3814271db".to_string(),
input: serde_json::json!({
"prompt": "black forest gateau cake spelling out the words \"FLUX SCHNELL\", tasty, food photography, dynamic shot"
}),
webhook: None,
webhook_event_filters: None,
})
.await
.unwrap();
println!("{:?}", collections);
let payload = CreateModelPrediction {
owner: "black-forest-labs".to_string(),
name: "flux-schnell".to_string(),
input: serde_json::json!({
"prompt": "3D model of a baby dragon",
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100
}),
};
let prediction = client.create_model_prediction(payload).await.unwrap();
println!("{:?}", prediction);
}
28 changes: 28 additions & 0 deletions replicate/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ impl Client {
Ok(response)
}

/// Create a prediction from an official model
pub async fn create_model_prediction(
&self,
payload: CreateModelPrediction,
) -> Result<Prediction, Error> {
let path = format!("models/{}/{}/predictions", payload.owner, payload.name);
let response = self
.request(Method::POST, path.as_str())?
.json(&serde_json::json!({ "input": payload.input }))
.send()
.await?
.json::<Prediction>()
.await?;
Ok(response)
}

/// Cancel a prediction.
pub async fn cancel_prediction(&self, prediction_id: String) -> Result<(), Error> {
let path = format!("predictions/{}/cancel", prediction_id);
Expand Down Expand Up @@ -491,6 +507,18 @@ pub struct CreatePrediction {
pub webhook_event_filters: Option<Vec<WebHookEvent>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateModelPrediction {
/// Model owner
pub owner: String,

/// Model name
pub name: String,

/// The model's input as a JSON object.
pub input: serde_json::Value,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WebHookEvent {
Expand Down

0 comments on commit 273c49f

Please sign in to comment.