Skip to content

Commit

Permalink
feat(openai): support image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
roushou committed Sep 12, 2024
1 parent 432f318 commit 47e01ad
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 1 deletion.
4 changes: 3 additions & 1 deletion openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ use reqwest::{
};

use crate::{
models::ModelClient, moderations::ModerationClient,
chats::ChatClient, config::Config, embeddings::EmbeddingClient, error::Error,
images::ImageClient, models::ModelClient, moderations::ModerationClient,
};

pub struct Client {
api_key: String,
base_url: Url,
pub chat: ChatClient,
pub image: ImageClient,
pub model: ModelClient,
pub moderation: ModerationClient,
pub embedding: EmbeddingClient,
Expand All @@ -37,6 +38,7 @@ impl Client {
api_key: config.api_key,
base_url: base_url.clone(),
chat: ChatClient::new(base_url.clone(), http_client.clone()),
image: ImageClient::new(base_url.clone(), http_client.clone()),
model: ModelClient::new(base_url.clone(), http_client.clone()),
moderation: ModerationClient::new(base_url.clone(), http_client.clone()),
embedding: EmbeddingClient::new(base_url, http_client),
Expand Down
230 changes: 230 additions & 0 deletions openai/src/images/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use core::fmt;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Url};
use serde::{Deserialize, Serialize};

use crate::error::Error;

pub struct ImageClient {
base_url: Url,
http_client: ReqwestClient,
}

impl ImageClient {
pub fn new(base_url: Url, http_client: ReqwestClient) -> Self {
Self {
base_url,
http_client,
}
}

/// Creates an image given a prompt.
pub async fn create_image(&self, payload: CreateImage) -> Result<CreateImageResponse, Error> {
let image = self
.request(Method::POST, "images/generations")?
.json(&payload)
.send()
.await?
.json::<CreateImageResponse>()
.await?;
Ok(image)
}

fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
let url = self
.base_url
.join(path)
.map_err(|err| Error::UrlParse(err.to_string()))?;
Ok(self.http_client.request(method, url))
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CreateImage {
/// A text description of the desired image(s). The maximum length is 1000 characters for **dall-e-2** and 4000 characters for **dall-e-3**.
pub prompt: String,

/// The model to use for image generation.
pub model: ImageModel,

/// The number of images to generate. Must be between 1 and 10. For **dall-e-3**, only n=1 is supported.
pub n: u8,

/// The quality of the image that will be generated. **hd** creates images with finer details and greater consistency across the image. This param is only supported for **dall-e-3**.
///
/// Defaults to **standard**.
pub quality: ImageQuality,

/// The format in which the generated images are returned. Must be one of **url** or **b64_json**.
pub response_format: ImageResponseFormat,

/// The size of the generated images.
///
/// For **dall-e-2**:
/// - **256x256**
/// - **512x512**
/// - **1024x1024**
///
/// For **dall-e-3**:
/// - **1024x1024**
/// - **1792x1024**
/// - **1024x1792**
pub size: ImageSize,

/// The style of the generated images. Must be one of **vivid** or **natural**. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.
///
/// Note: This param is only supported for **dall-e-3**
pub style: ImageStyle,

/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

impl CreateImage {
pub fn new(prompt: impl Into<String>, model: ImageModel) -> Self {
Self {
prompt: prompt.into(),
model,
n: 1,
quality: ImageQuality::default(),
response_format: ImageResponseFormat::default(),
size: ImageSize::default(),
style: ImageStyle::default(),
user: None,
}
}

pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}

pub fn with_model(mut self, model: ImageModel) -> Self {
self.model = model;
self
}

pub fn with_n(mut self, n: u8) -> Self {
self.n = if n > 10 { 10 } else { n };
self
}

pub fn with_quality(mut self, quality: ImageQuality) -> Self {
self.quality = quality;
self
}

pub fn with_response_format(mut self, response_format: ImageResponseFormat) -> Self {
self.response_format = response_format;
self
}

pub fn with_style(mut self, style: ImageStyle) -> Self {
self.style = style;
self
}

pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
}

#[derive(Debug, Serialize, Deserialize)]
pub enum ImageModel {
#[serde(rename = "dall-e-2")]
DallE2,
#[serde(rename = "dall-e-3")]
DallE3,
}

impl fmt::Display for ImageModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DallE2 => write!(f, "dall-e-2"),
Self::DallE3 => write!(f, "dall-e-3"),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageQuality {
HD,
Standard,
}

impl Default for ImageQuality {
fn default() -> Self {
Self::Standard
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CreateImageResponse {
pub created: u64,
pub data: Vec<ImageData>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ImageData {
pub url: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImageResponseFormat {
Url,
Base64Json,
}

impl Default for ImageResponseFormat {
fn default() -> Self {
Self::Url
}
}

#[derive(Debug, Serialize, Deserialize)]
pub enum ImageSize {
#[serde(rename = "256x256")]
S256x256,
#[serde(rename = "512x512")]
S512x512,
#[serde(rename = "1024x1024")]
S1024x1024,
#[serde(rename = "1792x1024")]
S1792x1024,
#[serde(rename = "1024x1792")]
S1024x1792,
}

impl Default for ImageSize {
fn default() -> Self {
Self::S1024x1024
}
}

impl fmt::Display for ImageSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::S256x256 => write!(f, "256x256"),
Self::S512x512 => write!(f, "512x512"),
Self::S1024x1024 => write!(f, "1024x1024"),
Self::S1792x1024 => write!(f, "1792x1024"),
Self::S1024x1792 => write!(f, "1024x1792"),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageStyle {
Vivid,
Natural,
}

impl Default for ImageStyle {
fn default() -> Self {
Self::Vivid
}
}
1 change: 1 addition & 0 deletions openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ pub mod client;
pub mod config;
pub mod embeddings;
pub mod error;
pub mod images;
pub mod models;
pub mod moderations;

0 comments on commit 47e01ad

Please sign in to comment.