Skip to content

Commit

Permalink
Adding a test and support for option type parsing in instructor.rs'
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk committed Jul 8, 2024
1 parent 5753984 commit 55ca8bd
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 17 deletions.
142 changes: 130 additions & 12 deletions instructor/src/helpers/response_model.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
use std::collections::HashMap;

use instruct_macros_types::{Parameter, StructInfo};
use instruct_macros_types::{Parameter, ParameterInfo, StructInfo};
use openai_api_rs::v1::chat_completion::{self, JSONSchemaDefine};

fn get_required_properties(info: &StructInfo) -> Vec<String> {
let mut required = Vec::new();

for param in info.parameters.iter() {
match param {
Parameter::Field(field_info) => {
required.push(field_info.name.clone());
if !field_info.is_optional {
required.push(field_info.name.clone());
}
}
Parameter::Struct(struct_info) => {
required.push(struct_info.name.clone());
if !struct_info.is_optional {
required.push(struct_info.name.clone());
}
}
Parameter::Enum(enum_info) => {
required.push(enum_info.title.clone());
if !enum_info.is_optional {
required.push(enum_info.title.clone());
}
}
}
}
Expand All @@ -27,10 +34,19 @@ fn convert_parameter_type(info: &str) -> chat_completion::JSONSchemaType {
"u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "u64" | "i64" | "u128" | "i128" | "usize"
| "isize" => chat_completion::JSONSchemaType::Number,
"bool" => chat_completion::JSONSchemaType::Boolean,

_ => panic!("Unsupported type: {}", info),
}
}

fn get_base_type(field_info: &ParameterInfo) -> &str {
if field_info.r#type.starts_with("Option<") && field_info.r#type.ends_with('>') {
&field_info.r#type[7..field_info.r#type.len() - 1]
} else {
&field_info.r#type
}
}

fn get_response_model_parameters(t: &StructInfo) -> HashMap<String, Box<JSONSchemaDefine>> {
let mut properties = HashMap::new();

Expand All @@ -40,7 +56,8 @@ fn get_response_model_parameters(t: &StructInfo) -> HashMap<String, Box<JSONSche
let parameter_name = field_info.name.clone();
let parameter_description = field_info.comment.clone();

let parameter_type = convert_parameter_type(&field_info.r#type.to_string());
let base_type = get_base_type(field_info);
let parameter_type = convert_parameter_type(&base_type.to_string());

properties.insert(
parameter_name,
Expand Down Expand Up @@ -282,10 +299,6 @@ mod tests {
]),
};

if expected_parameters != parameters {
println!("Expected Parameters: {:?}", expected_parameters);
println!("Actual Parameters: {:?}", parameters);
}
assert_eq!(expected_parameters, parameters);
}

Expand Down Expand Up @@ -333,10 +346,115 @@ mod tests {
required: Some(vec!["name".to_string(), "age".to_string()]),
};

if expected_parameters != parameters {
println!("Expected Parameters: {:?}", expected_parameters);
println!("Actual Parameters: {:?}", parameters);
assert_eq!(expected_parameters, parameters);
}

#[test]
fn test_struct_with_optional_field() {
#[derive(InstructMacro, Debug, Serialize, Deserialize)]
struct StructWithOptionalField {
#[description("The name of the user")]
name: String,
#[description("The age of the user")]
age: Option<u8>,
}

let struct_info = StructWithOptionalField::get_info();
let parsed_model: StructInfo = match struct_info {
InstructMacroResult::Struct(info) => info,
_ => {
panic!("Expected StructInfo but got a different InstructMacroResult variant");
}
};
let parameters = get_response_model(parsed_model);

let expected_parameters = chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some({
let mut props = std::collections::HashMap::new();
props.insert(
"name".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::String),
description: Some("The name of the user".to_string()),
..Default::default()
}),
);
props.insert(
"age".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::Number),
description: Some("The age of the user".to_string()),
..Default::default()
}),
);
props
}),
required: Some(vec!["name".to_string()]), // Only "name" should be required
};

assert_eq!(expected_parameters, parameters);
}

#[test]
fn test_struct_with_nested_optional_field() {
#[derive(InstructMacro, Debug, Serialize, Deserialize)]
struct User {
name: String,
age: u8,
}

#[derive(InstructMacro, Debug, Serialize, Deserialize)]
struct MaybeUser {
user: Option<User>,
}

let struct_info = MaybeUser::get_info();
let parsed_model: StructInfo = match struct_info {
InstructMacroResult::Struct(info) => info,
_ => {
panic!("Expected StructInfo but got a different InstructMacroResult variant");
}
};
let parameters = get_response_model(parsed_model);

let expected_parameters = chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some({
let mut props = std::collections::HashMap::new();
props.insert(
"user".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::Object),
description: Some("".to_string()),
properties: Some({
let mut user_props = std::collections::HashMap::new();
user_props.insert(
"name".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::String),
description: Some("".to_string()),
..Default::default()
}),
);
user_props.insert(
"age".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::Number),
description: Some("".to_string()),
..Default::default()
}),
);
user_props
}),
..Default::default()
}),
);
props
}),
required: Some(vec![]), // No required fields
};

assert_eq!(expected_parameters, parameters);
}
}
19 changes: 14 additions & 5 deletions instructor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use openai_api_rs::v1::{
error::APIError,
};

use instruct_macros_types::{InstructMacro, InstructMacroResult, StructInfo};
use instruct_macros_types::{InstructMacro, InstructMacroResult, Parameter, StructInfo};

pub struct InstructorClient {
client: Client,
Expand Down Expand Up @@ -43,9 +43,12 @@ impl InstructorClient {
name: None,
};
req.messages.push(new_message);

println!("Error encountered: {}", error);
}

let result = self._retry_sync::<T>(req.clone(), parsed_model.clone());

match result {
Ok(value) => {
match T::validate(&value) {
Expand Down Expand Up @@ -85,10 +88,12 @@ impl InstructorClient {
function: chat_completion::Function {
name: parsed_model.name.clone(),
description: Some(parsed_model.description.clone()),
parameters: helpers::get_response_model(parsed_model),
parameters: helpers::get_response_model(parsed_model.clone()),
},
};

let parameters_json = serde_json::to_string(&func_call.function).unwrap();

let req = req
.tools(vec![func_call])
.tool_choice(chat_completion::ToolChoiceType::Auto);
Expand All @@ -104,18 +109,22 @@ impl InstructorClient {
1 => {
let tool_call = &tool_calls[0];
let arguments = tool_call.function.arguments.clone().unwrap();

return serde_json::from_str(&arguments);
}
_ => {
// TODO: Support multiple tool calls at some point
let error_message =
format!("Unexpected number of tool calls: {:?}", tool_calls);
format!("Unexpected number of tool calls: {:?}. PLease only generate a single tool call.", tool_calls);
return Err(serde::de::Error::custom(error_message));
}
}
}
_ => panic!("Unexpected finish reason"),
_ => {
let error_message =
"You must call a tool. Make sure to adhere to the provided response format."
.to_string();
return Err(serde::de::Error::custom(error_message));
}
}
}
}
Expand Down
54 changes: 54 additions & 0 deletions instructor/tests/test_option.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
extern crate instruct_macros;
extern crate instruct_macros_types;

use instruct_macros::InstructMacro;
use instruct_macros_types::{Parameter, ParameterInfo, StructInfo};
use instructor_ai::from_openai;
use openai_api_rs::v1::api::Client;

#[cfg(test)]
mod tests {
use std::env;

use openai_api_rs::v1::{
chat_completion::{self, ChatCompletionRequest},
common::GPT4_O,
};
use serde::{Deserialize, Serialize};

use super::*;

#[test]
fn test_from_openai() {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let instructor_client = from_openai(client);

#[derive(InstructMacro, Debug, Serialize, Deserialize)]
struct UserInfo {
name: String,
age: u8,
}

#[derive(InstructMacro, Debug, Serialize, Deserialize)]
struct MaybeUser {
#[description("This is an optional user field. If the user is not present, the field will be null")]
user: Option<UserInfo>,
}

let req = ChatCompletionRequest::new(
GPT4_O.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(String::from("It's a beautiful day out")),
name: None,
}],
);

let result = instructor_client
.chat_completion::<MaybeUser>(req, 3)
.unwrap();

println!("{:?}", result);
// assert!(result.user.is_none());
}
}

0 comments on commit 55ca8bd

Please sign in to comment.