Skip to content

Commit

Permalink
feat(openai): add support for O1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
roushou committed Sep 12, 2024
1 parent 7714846 commit e090514
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 107 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## What's Changed in opai-v0.3.0
* feat(openai): add support for `O1` models

**Full Changelog**: https://github.com///compare/opai-v0.2.0...opai-v0.3.0

## What's Changed in opai-v0.2.0
* docs(openai): fix typo
* feat(openai): support image generation
Expand Down
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["anthropic", "mesh", "openai", "replicate", "examples/*"]
resolver = "2"

[workspace.package]
version = "0.4.0"
version = "0.5.0"
edition = "2021"
authors = ["Roushou <[email protected]>"]
description = "Rust SDK to build AI-powered apps"
Expand Down
4 changes: 2 additions & 2 deletions examples/openai-chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use opai::{
chats::message::{CreateChatCompletion, Message, Role},
client::Client,
config::Config,
models::gpt::GptModel,
models::{gpt::Gpt, Model},
};

#[tokio::main]
Expand All @@ -15,7 +15,7 @@ async fn main() {
role: Role::User,
name: None,
}];
let request = CreateChatCompletion::new(GptModel::GPT4o, messages);
let request = CreateChatCompletion::new(Model::Gpt(Gpt::GPT4), messages);
let completion = client.chat.create_completion(request).await.unwrap();
println!("{:?}", completion);
}
2 changes: 1 addition & 1 deletion mesh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords.workspace = true

[dependencies]
anthropic-rs = { path = "../anthropic", version = "0.1.6" }
opai = { path = "../openai", version = "0.2.0" }
opai = { path = "../openai", version = "0.3.0" }
replic = { path = "../replicate", version = "0.1.1" }

[package.metadata.docs.rs]
Expand Down
2 changes: 1 addition & 1 deletion openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "opai"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
authors = ["Roushou <[email protected]>"]
description = "OpenAI Rust SDK"
Expand Down
19 changes: 10 additions & 9 deletions openai/src/chats/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use crate::models::gpt::GptModel;
use crate::models::{gpt::Gpt, Model};

/// Chat completion response returned by the model
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -15,15 +15,15 @@ pub struct ChatCompletion {
pub created: u64,

/// The model used for the chat completion.
pub model: GptModel,
pub model: Model,

/// The service tier used for processing the request. This field is only included if the **service_tier** parameter is specified in the request.
pub service_tier: Option<String>,

/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes have been made that might impact determinism
pub system_fingerprint: String,
pub system_fingerprint: Option<String>,

/// The object type, which is always **chat.completion**.
pub object: String,
Expand Down Expand Up @@ -90,8 +90,9 @@ pub struct ToolCall {
pub struct LogProb {
/// A list of message content tokens with log probability information.
pub content: Option<Vec<LogProbContent>>,
// A list of message refusal tokens with log probability information.
// pub refusal: Option<String>,

/// A list of message refusal tokens with log probability information.
pub refusal: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -132,7 +133,7 @@ pub enum Role {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateChatCompletion {
/// The model that will complete your prompt e.g. GPT-4o
pub model: GptModel,
pub model: Model,

/// Input messages.
pub messages: Vec<Message>,
Expand Down Expand Up @@ -264,15 +265,15 @@ pub struct StreamOptions {
}

impl CreateChatCompletion {
pub fn new(model: GptModel, messages: Vec<Message>) -> Self {
pub fn new(model: Model, messages: Vec<Message>) -> Self {
Self {
model,
messages,
..Default::default()
}
}

pub fn with_model(mut self, model: GptModel) -> Self {
pub fn with_model(mut self, model: Model) -> Self {
self.model = model;
self
}
Expand Down Expand Up @@ -383,7 +384,7 @@ impl CreateChatCompletion {
impl Default for CreateChatCompletion {
fn default() -> Self {
Self {
model: GptModel::GPT4o,
model: Model::Gpt(Gpt::GPT4),
max_tokens: Some(1000),
messages: Vec::new(),
n: None,
Expand Down
150 changes: 63 additions & 87 deletions openai/src/models/gpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub enum GptModel {
pub enum Gpt {
#[serde(rename(serialize = "gpt-4"))]
GPT4,
#[serde(rename(serialize = "gpt-4o"))]
Expand All @@ -16,15 +16,15 @@ pub enum GptModel {
GPT35Turbo,
}

impl<'de> Deserialize<'de> for GptModel {
impl<'de> Deserialize<'de> for Gpt {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GptModelVisitor;

impl<'de> serde::de::Visitor<'de> for GptModelVisitor {
type Value = GptModel;
type Value = Gpt;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representing a GPT model")
Expand All @@ -35,17 +35,17 @@ impl<'de> Deserialize<'de> for GptModel {
E: serde::de::Error,
{
match value {
"gpt-4" => Ok(GptModel::GPT4),
"gpt-4o" => Ok(GptModel::GPT4o),
"gpt-4o-mini" => Ok(GptModel::GPT4oMini),
"gpt-4-turbo" => Ok(GptModel::GPT4Turbo),
"chatgpt-4o-latest" => Ok(GptModel::GPT4o),
"gpt-4" => Ok(Gpt::GPT4),
"gpt-4o" => Ok(Gpt::GPT4o),
"gpt-4o-mini" => Ok(Gpt::GPT4oMini),
"gpt-4-turbo" => Ok(Gpt::GPT4Turbo),
"chatgpt-4o-latest" => Ok(Gpt::GPT4o),
// The order is important for correct matching
_ if value.starts_with("gpt-3.5-turbo") => Ok(GptModel::GPT35Turbo),
_ if value.starts_with("gpt-4-turbo-") => Ok(GptModel::GPT4Turbo),
_ if value.starts_with("gpt-4-") => Ok(GptModel::GPT4),
_ if value.starts_with("gpt-4o-mini-") => Ok(GptModel::GPT4oMini),
_ if value.starts_with("gpt-4o-") => Ok(GptModel::GPT4o),
_ if value.starts_with("gpt-3.5-turbo") => Ok(Gpt::GPT35Turbo),
_ if value.starts_with("gpt-4-turbo-") => Ok(Gpt::GPT4Turbo),
_ if value.starts_with("gpt-4-") => Ok(Gpt::GPT4),
_ if value.starts_with("gpt-4o-mini-") => Ok(Gpt::GPT4oMini),
_ if value.starts_with("gpt-4o-") => Ok(Gpt::GPT4o),
_ => Err(E::custom(format!("Unknown GPT model: {}", value))),
}
}
Expand All @@ -55,28 +55,28 @@ impl<'de> Deserialize<'de> for GptModel {
}
}

impl FromStr for GptModel {
impl FromStr for Gpt {
type Err = crate::error::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-4" => Ok(GptModel::GPT4),
"gpt-4o" => Ok(GptModel::GPT4o),
"gpt-4o-mini" => Ok(GptModel::GPT4oMini),
"gpt-4-turbo" => Ok(GptModel::GPT4Turbo),
"chatgpt-4o-latest" => Ok(GptModel::GPT4o),
_ if s.starts_with("gpt-3.5-turbo") => Ok(GptModel::GPT35Turbo),
"gpt-4" => Ok(Gpt::GPT4),
"gpt-4o" => Ok(Gpt::GPT4o),
"gpt-4o-mini" => Ok(Gpt::GPT4oMini),
"gpt-4-turbo" => Ok(Gpt::GPT4Turbo),
"chatgpt-4o-latest" => Ok(Gpt::GPT4o),
_ if s.starts_with("gpt-3.5-turbo") => Ok(Gpt::GPT35Turbo),
// The order is important for correct matching
_ if s.starts_with("gpt-4o-mini-") => Ok(GptModel::GPT4oMini),
_ if s.starts_with("gpt-4o-") => Ok(GptModel::GPT4o),
_ if s.starts_with("gpt-4-turbo-") => Ok(GptModel::GPT4Turbo),
_ if s.starts_with("gpt-4-") => Ok(GptModel::GPT4),
_ if s.starts_with("gpt-4o-mini-") => Ok(Gpt::GPT4oMini),
_ if s.starts_with("gpt-4o-") => Ok(Gpt::GPT4o),
_ if s.starts_with("gpt-4-turbo-") => Ok(Gpt::GPT4Turbo),
_ if s.starts_with("gpt-4-") => Ok(Gpt::GPT4),
_ => Err(crate::error::Error::ModelNotSupported(s.to_string())),
}
}
}

impl fmt::Display for GptModel {
impl fmt::Display for Gpt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GPT4 => write!(f, "gpt-4"),
Expand All @@ -97,102 +97,78 @@ mod tests {
#[test]
fn should_deserialize_gpt_models() {
// Test exact matches
assert_eq!(from_str::<GptModel>(r#""gpt-4""#).unwrap(), GptModel::GPT4);
assert_eq!(from_str::<Gpt>(r#""gpt-4""#).unwrap(), Gpt::GPT4);
assert_eq!(from_str::<Gpt>(r#""gpt-4o""#).unwrap(), Gpt::GPT4o);
assert_eq!(from_str::<Gpt>(r#""gpt-4o-mini""#).unwrap(), Gpt::GPT4oMini);
assert_eq!(from_str::<Gpt>(r#""gpt-4-turbo""#).unwrap(), Gpt::GPT4Turbo);
assert_eq!(
from_str::<GptModel>(r#""gpt-4o""#).unwrap(),
GptModel::GPT4o
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4o-mini""#).unwrap(),
GptModel::GPT4oMini
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4-turbo""#).unwrap(),
GptModel::GPT4Turbo
);
assert_eq!(
from_str::<GptModel>(r#""chatgpt-4o-latest""#).unwrap(),
GptModel::GPT4o
from_str::<Gpt>(r#""chatgpt-4o-latest""#).unwrap(),
Gpt::GPT4o
);

// Test prefix matches
assert_eq!(
from_str::<GptModel>(r#""gpt-3.5-turbo""#).unwrap(),
GptModel::GPT35Turbo
from_str::<Gpt>(r#""gpt-3.5-turbo""#).unwrap(),
Gpt::GPT35Turbo
);
assert_eq!(
from_str::<GptModel>(r#""gpt-3.5-turbo-0125""#).unwrap(),
GptModel::GPT35Turbo
from_str::<Gpt>(r#""gpt-3.5-turbo-0125""#).unwrap(),
Gpt::GPT35Turbo
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4-0125-preview""#).unwrap(),
GptModel::GPT4
from_str::<Gpt>(r#""gpt-4-0125-preview""#).unwrap(),
Gpt::GPT4
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4o-2024-05-13""#).unwrap(),
GptModel::GPT4o
from_str::<Gpt>(r#""gpt-4o-2024-05-13""#).unwrap(),
Gpt::GPT4o
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4o-mini-1234""#).unwrap(),
GptModel::GPT4oMini
from_str::<Gpt>(r#""gpt-4o-mini-1234""#).unwrap(),
Gpt::GPT4oMini
);
assert_eq!(
from_str::<GptModel>(r#""gpt-4-turbo-2024-04-09""#).unwrap(),
GptModel::GPT4Turbo
from_str::<Gpt>(r#""gpt-4-turbo-2024-04-09""#).unwrap(),
Gpt::GPT4Turbo
);

// Test error case
assert!(from_str::<GptModel>(r#""unknown-model""#).is_err());
assert!(from_str::<Gpt>(r#""unknown-model""#).is_err());
}

#[test]
fn test_gpt_model_from_str() {
// Test exact matches
assert_eq!("gpt-4".parse::<GptModel>().unwrap(), GptModel::GPT4);
assert_eq!("gpt-4o".parse::<GptModel>().unwrap(), GptModel::GPT4o);
assert_eq!(
"gpt-4o-mini".parse::<GptModel>().unwrap(),
GptModel::GPT4oMini
);
assert_eq!(
"gpt-4-turbo".parse::<GptModel>().unwrap(),
GptModel::GPT4Turbo
);
assert_eq!(
"chatgpt-4o-latest".parse::<GptModel>().unwrap(),
GptModel::GPT4o
);
assert_eq!("gpt-4".parse::<Gpt>().unwrap(), Gpt::GPT4);
assert_eq!("gpt-4o".parse::<Gpt>().unwrap(), Gpt::GPT4o);
assert_eq!("gpt-4o-mini".parse::<Gpt>().unwrap(), Gpt::GPT4oMini);
assert_eq!("gpt-4-turbo".parse::<Gpt>().unwrap(), Gpt::GPT4Turbo);
assert_eq!("chatgpt-4o-latest".parse::<Gpt>().unwrap(), Gpt::GPT4o);

// Test prefix matches
assert_eq!("gpt-3.5-turbo".parse::<Gpt>().unwrap(), Gpt::GPT35Turbo);
assert_eq!(
"gpt-3.5-turbo".parse::<GptModel>().unwrap(),
GptModel::GPT35Turbo
);
assert_eq!(
"gpt-3.5-turbo-0125".parse::<GptModel>().unwrap(),
GptModel::GPT35Turbo
);
assert_eq!("gpt-4-9012".parse::<GptModel>().unwrap(), GptModel::GPT4);
assert_eq!("gpt-4o-5678".parse::<GptModel>().unwrap(), GptModel::GPT4o);
assert_eq!(
"gpt-4o-mini-1234".parse::<GptModel>().unwrap(),
GptModel::GPT4oMini
"gpt-3.5-turbo-0125".parse::<Gpt>().unwrap(),
Gpt::GPT35Turbo
);
assert_eq!("gpt-4-9012".parse::<Gpt>().unwrap(), Gpt::GPT4);
assert_eq!("gpt-4o-5678".parse::<Gpt>().unwrap(), Gpt::GPT4o);
assert_eq!("gpt-4o-mini-1234".parse::<Gpt>().unwrap(), Gpt::GPT4oMini);
assert_eq!(
"gpt-4-turbo-2024-04-09".parse::<GptModel>().unwrap(),
GptModel::GPT4Turbo
"gpt-4-turbo-2024-04-09".parse::<Gpt>().unwrap(),
Gpt::GPT4Turbo
);

// Test error case
assert!("unknown-model".parse::<GptModel>().is_err());
assert!("unknown-model".parse::<Gpt>().is_err());
}

#[test]
fn should_display_gpt_models() {
assert_eq!(GptModel::GPT35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(GptModel::GPT4.to_string(), "gpt-4");
assert_eq!(GptModel::GPT4o.to_string(), "gpt-4o");
assert_eq!(GptModel::GPT4oMini.to_string(), "gpt-4o-mini");
assert_eq!(GptModel::GPT4Turbo.to_string(), "gpt-4-turbo");
assert_eq!(Gpt::GPT35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Gpt::GPT4.to_string(), "gpt-4");
assert_eq!(Gpt::GPT4o.to_string(), "gpt-4o");
assert_eq!(Gpt::GPT4oMini.to_string(), "gpt-4o-mini");
assert_eq!(Gpt::GPT4Turbo.to_string(), "gpt-4-turbo");
}
}
Loading

0 comments on commit e090514

Please sign in to comment.