-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(openai): support image generation
- Loading branch information
Showing
3 changed files
with
234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
use core::fmt; | ||
use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Url}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::error::Error; | ||
|
||
pub struct ImageClient { | ||
base_url: Url, | ||
http_client: ReqwestClient, | ||
} | ||
|
||
impl ImageClient { | ||
pub fn new(base_url: Url, http_client: ReqwestClient) -> Self { | ||
Self { | ||
base_url, | ||
http_client, | ||
} | ||
} | ||
|
||
/// Creates an image given a prompt. | ||
pub async fn create_image(&self, payload: CreateImage) -> Result<CreateImageResponse, Error> { | ||
let image = self | ||
.request(Method::POST, "images/generations")? | ||
.json(&payload) | ||
.send() | ||
.await? | ||
.json::<CreateImageResponse>() | ||
.await?; | ||
Ok(image) | ||
} | ||
|
||
fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> { | ||
let url = self | ||
.base_url | ||
.join(path) | ||
.map_err(|err| Error::UrlParse(err.to_string()))?; | ||
Ok(self.http_client.request(method, url)) | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct CreateImage { | ||
/// A text description of the desired image(s). The maximum length is 1000 characters for **dall-e-2** and 4000 characters for **dall-e-3**. | ||
pub prompt: String, | ||
|
||
/// The model to use for image generation. | ||
pub model: ImageModel, | ||
|
||
/// The number of images to generate. Must be between 1 and 10. For **dall-e-3**, only n=1 is supported. | ||
pub n: u8, | ||
|
||
/// The quality of the image that will be generated. **hd** creates images with finer details and greater consistency across the image. This param is only supported for **dall-e-3**. | ||
/// | ||
/// Defaults to **standard**. | ||
pub quality: ImageQuality, | ||
|
||
/// The format in which the generated images are returned. Must be one of **url** or **b64_json**. | ||
pub response_format: ImageResponseFormat, | ||
|
||
/// The size of the generated images. | ||
/// | ||
/// For **dall-e-2**: | ||
/// - **256x256** | ||
/// - **512x512** | ||
/// - **1024x1024** | ||
/// | ||
/// For **dall-e-3**: | ||
/// - **1024x1024** | ||
/// - **1792x1024** | ||
/// - **1024x1792** | ||
pub size: ImageSize, | ||
|
||
/// The style of the generated images. Must be one of **vivid** or **natural**. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. | ||
/// | ||
/// Note: This param is only supported for **dall-e-3** | ||
pub style: ImageStyle, | ||
|
||
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub user: Option<String>, | ||
} | ||
|
||
impl CreateImage { | ||
pub fn new(prompt: impl Into<String>, model: ImageModel) -> Self { | ||
Self { | ||
prompt: prompt.into(), | ||
model, | ||
n: 1, | ||
quality: ImageQuality::default(), | ||
response_format: ImageResponseFormat::default(), | ||
size: ImageSize::default(), | ||
style: ImageStyle::default(), | ||
user: None, | ||
} | ||
} | ||
|
||
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self { | ||
self.prompt = prompt.into(); | ||
self | ||
} | ||
|
||
pub fn with_model(mut self, model: ImageModel) -> Self { | ||
self.model = model; | ||
self | ||
} | ||
|
||
pub fn with_n(mut self, n: u8) -> Self { | ||
self.n = if n > 10 { 10 } else { n }; | ||
self | ||
} | ||
|
||
pub fn with_quality(mut self, quality: ImageQuality) -> Self { | ||
self.quality = quality; | ||
self | ||
} | ||
|
||
pub fn with_response_format(mut self, response_format: ImageResponseFormat) -> Self { | ||
self.response_format = response_format; | ||
self | ||
} | ||
|
||
pub fn with_style(mut self, style: ImageStyle) -> Self { | ||
self.style = style; | ||
self | ||
} | ||
|
||
pub fn with_user(mut self, user: impl Into<String>) -> Self { | ||
self.user = Some(user.into()); | ||
self | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub enum ImageModel { | ||
#[serde(rename = "dall-e-2")] | ||
DallE2, | ||
#[serde(rename = "dall-e-3")] | ||
DallE3, | ||
} | ||
|
||
impl fmt::Display for ImageModel { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
match self { | ||
Self::DallE2 => write!(f, "dall-e-2"), | ||
Self::DallE3 => write!(f, "dall-e-3"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
#[serde(rename_all = "lowercase")] | ||
pub enum ImageQuality { | ||
HD, | ||
Standard, | ||
} | ||
|
||
impl Default for ImageQuality { | ||
fn default() -> Self { | ||
Self::Standard | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct CreateImageResponse { | ||
pub created: u64, | ||
pub data: Vec<ImageData>, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct ImageData { | ||
pub url: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
#[serde(rename_all = "snake_case")] | ||
pub enum ImageResponseFormat { | ||
Url, | ||
Base64Json, | ||
} | ||
|
||
impl Default for ImageResponseFormat { | ||
fn default() -> Self { | ||
Self::Url | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub enum ImageSize { | ||
#[serde(rename = "256x256")] | ||
S256x256, | ||
#[serde(rename = "512x512")] | ||
S512x512, | ||
#[serde(rename = "1024x1024")] | ||
S1024x1024, | ||
#[serde(rename = "1792x1024")] | ||
S1792x1024, | ||
#[serde(rename = "1024x1792")] | ||
S1024x1792, | ||
} | ||
|
||
impl Default for ImageSize { | ||
fn default() -> Self { | ||
Self::S1024x1024 | ||
} | ||
} | ||
|
||
impl fmt::Display for ImageSize { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
match self { | ||
Self::S256x256 => write!(f, "256x256"), | ||
Self::S512x512 => write!(f, "512x512"), | ||
Self::S1024x1024 => write!(f, "1024x1024"), | ||
Self::S1792x1024 => write!(f, "1792x1024"), | ||
Self::S1024x1792 => write!(f, "1024x1792"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
#[serde(rename_all = "lowercase")] | ||
pub enum ImageStyle { | ||
Vivid, | ||
Natural, | ||
} | ||
|
||
impl Default for ImageStyle { | ||
fn default() -> Self { | ||
Self::Vivid | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters