From 55ca8bd8509bdfed347765953097153d519f9721 Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Mon, 8 Jul 2024 12:42:39 +0800 Subject: [PATCH] Adding a test and support for option type parsing in instructor.rs' --- instructor/src/helpers/response_model.rs | 142 +++++++++++++++++++++-- instructor/src/lib.rs | 19 ++- instructor/tests/test_option.rs | 54 +++++++++ 3 files changed, 198 insertions(+), 17 deletions(-) create mode 100644 instructor/tests/test_option.rs diff --git a/instructor/src/helpers/response_model.rs b/instructor/src/helpers/response_model.rs index 8359149..4ae42f5 100644 --- a/instructor/src/helpers/response_model.rs +++ b/instructor/src/helpers/response_model.rs @@ -1,20 +1,27 @@ use std::collections::HashMap; -use instruct_macros_types::{Parameter, StructInfo}; +use instruct_macros_types::{Parameter, ParameterInfo, StructInfo}; use openai_api_rs::v1::chat_completion::{self, JSONSchemaDefine}; fn get_required_properties(info: &StructInfo) -> Vec { let mut required = Vec::new(); + for param in info.parameters.iter() { match param { Parameter::Field(field_info) => { - required.push(field_info.name.clone()); + if !field_info.is_optional { + required.push(field_info.name.clone()); + } } Parameter::Struct(struct_info) => { - required.push(struct_info.name.clone()); + if !struct_info.is_optional { + required.push(struct_info.name.clone()); + } } Parameter::Enum(enum_info) => { - required.push(enum_info.title.clone()); + if !enum_info.is_optional { + required.push(enum_info.title.clone()); + } } } } @@ -27,10 +34,19 @@ fn convert_parameter_type(info: &str) -> chat_completion::JSONSchemaType { "u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "u64" | "i64" | "u128" | "i128" | "usize" | "isize" => chat_completion::JSONSchemaType::Number, "bool" => chat_completion::JSONSchemaType::Boolean, + _ => panic!("Unsupported type: {}", info), } } +fn get_base_type(field_info: &ParameterInfo) -> &str { + if field_info.r#type.starts_with("Option<") && field_info.r#type.ends_with('>') { + &field_info.r#type[7..field_info.r#type.len() - 1] + } else { + &field_info.r#type + } +} + fn get_response_model_parameters(t: &StructInfo) -> HashMap> { let mut properties = HashMap::new(); @@ -40,7 +56,8 @@ fn get_response_model_parameters(t: &StructInfo) -> HashMap, + } + + let struct_info = StructWithOptionalField::get_info(); + let parsed_model: StructInfo = match struct_info { + InstructMacroResult::Struct(info) => info, + _ => { + panic!("Expected StructInfo but got a different InstructMacroResult variant"); + } + }; + let parameters = get_response_model(parsed_model); + + let expected_parameters = chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some({ + let mut props = std::collections::HashMap::new(); + props.insert( + "name".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("The name of the user".to_string()), + ..Default::default() + }), + ); + props.insert( + "age".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Number), + description: Some("The age of the user".to_string()), + ..Default::default() + }), + ); + props + }), + required: Some(vec!["name".to_string()]), // Only "name" should be required + }; + + assert_eq!(expected_parameters, parameters); + } + + #[test] + fn test_struct_with_nested_optional_field() { + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct User { + name: String, + age: u8, + } + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct MaybeUser { + user: Option, } + + let struct_info = MaybeUser::get_info(); + let parsed_model: StructInfo = match struct_info { + InstructMacroResult::Struct(info) => info, + _ => { + panic!("Expected StructInfo but got a different InstructMacroResult variant"); + } + }; + let parameters = get_response_model(parsed_model); + + let expected_parameters = chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some({ + let mut props = std::collections::HashMap::new(); + props.insert( + "user".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Object), + description: Some("".to_string()), + properties: Some({ + let mut user_props = std::collections::HashMap::new(); + user_props.insert( + "name".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("".to_string()), + ..Default::default() + }), + ); + user_props.insert( + "age".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::Number), + description: Some("".to_string()), + ..Default::default() + }), + ); + user_props + }), + ..Default::default() + }), + ); + props + }), + required: Some(vec![]), // No required fields + }; + assert_eq!(expected_parameters, parameters); } } diff --git a/instructor/src/lib.rs b/instructor/src/lib.rs index a1f3e71..31c09a8 100644 --- a/instructor/src/lib.rs +++ b/instructor/src/lib.rs @@ -6,7 +6,7 @@ use openai_api_rs::v1::{ error::APIError, }; -use instruct_macros_types::{InstructMacro, InstructMacroResult, StructInfo}; +use instruct_macros_types::{InstructMacro, InstructMacroResult, Parameter, StructInfo}; pub struct InstructorClient { client: Client, @@ -43,9 +43,12 @@ impl InstructorClient { name: None, }; req.messages.push(new_message); + + println!("Error encountered: {}", error); } let result = self._retry_sync::(req.clone(), parsed_model.clone()); + match result { Ok(value) => { match T::validate(&value) { @@ -85,10 +88,12 @@ impl InstructorClient { function: chat_completion::Function { name: parsed_model.name.clone(), description: Some(parsed_model.description.clone()), - parameters: helpers::get_response_model(parsed_model), + parameters: helpers::get_response_model(parsed_model.clone()), }, }; + let parameters_json = serde_json::to_string(&func_call.function).unwrap(); + let req = req .tools(vec![func_call]) .tool_choice(chat_completion::ToolChoiceType::Auto); @@ -104,18 +109,22 @@ impl InstructorClient { 1 => { let tool_call = &tool_calls[0]; let arguments = tool_call.function.arguments.clone().unwrap(); - return serde_json::from_str(&arguments); } _ => { // TODO: Support multiple tool calls at some point let error_message = - format!("Unexpected number of tool calls: {:?}", tool_calls); + format!("Unexpected number of tool calls: {:?}. PLease only generate a single tool call.", tool_calls); return Err(serde::de::Error::custom(error_message)); } } } - _ => panic!("Unexpected finish reason"), + _ => { + let error_message = + "You must call a tool. Make sure to adhere to the provided response format." + .to_string(); + return Err(serde::de::Error::custom(error_message)); + } } } } diff --git a/instructor/tests/test_option.rs b/instructor/tests/test_option.rs new file mode 100644 index 0000000..e7ae602 --- /dev/null +++ b/instructor/tests/test_option.rs @@ -0,0 +1,54 @@ +extern crate instruct_macros; +extern crate instruct_macros_types; + +use instruct_macros::InstructMacro; +use instruct_macros_types::{Parameter, ParameterInfo, StructInfo}; +use instructor_ai::from_openai; +use openai_api_rs::v1::api::Client; + +#[cfg(test)] +mod tests { + use std::env; + + use openai_api_rs::v1::{ + chat_completion::{self, ChatCompletionRequest}, + common::GPT4_O, + }; + use serde::{Deserialize, Serialize}; + + use super::*; + + #[test] + fn test_from_openai() { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let instructor_client = from_openai(client); + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct UserInfo { + name: String, + age: u8, + } + + #[derive(InstructMacro, Debug, Serialize, Deserialize)] + struct MaybeUser { + #[description("This is an optional user field. If the user is not present, the field will be null")] + user: Option, + } + + let req = ChatCompletionRequest::new( + GPT4_O.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("It's a beautiful day out")), + name: None, + }], + ); + + let result = instructor_client + .chat_completion::(req, 3) + .unwrap(); + + println!("{:?}", result); + // assert!(result.user.is_none()); + } +}