diff --git a/src/openai/conversation_handler.rs b/src/openai/conversation_handler.rs new file mode 100644 index 0000000..bda35a4 --- /dev/null +++ b/src/openai/conversation_handler.rs @@ -0,0 +1,246 @@ +use async_openai::{ + config::OpenAIConfig, + types::{ + ChatCompletionFunctions, ChatCompletionFunctionsArgs, ChatCompletionRequestMessage, + ChatCompletionRequestMessageArgs, CreateChatCompletionRequest, + CreateChatCompletionRequestArgs, FinishReason, FunctionCall, Role, + }, + Client, +}; +use async_trait::async_trait; +use futures::StreamExt; +use schemars::{gen::SchemaSettings, JsonSchema}; +use std::{collections::HashMap, sync::Arc}; +use tracing::info; + +use super::OpenAiHistory; + +fn get_schema_generator() -> schemars::gen::SchemaGenerator { + let settings = SchemaSettings::draft07().with(|s| { + s.inline_subschemas = true; + s.meta_schema = None; + }); + settings.into_generator() +} + +pub enum OpenAiApiResponse { + AssistantResponse(String), + FunctionCallWithNoResponse, +} + +#[async_trait] +pub trait AsyncCallback: Send + Sync { + async fn call(&self, args: &str) -> anyhow::Result; +} + +#[derive(Clone)] +pub struct ChatGptConversation { + history: Vec, + functions: Vec, + temperature: Option, + top_p: Option, + model_name: String, + function_table: HashMap>, +} + +impl ChatGptConversation { + pub fn new(system_prompt: &str, model_name: &str) -> Self { + let history = vec![ChatCompletionRequestMessageArgs::default() + .content(system_prompt) + .role(Role::System) + .build() + // can this fail? + .expect("Failed to build system prompt message")]; + Self { + history, + functions: vec![], + temperature: None, + top_p: None, + model_name: model_name.to_string(), + function_table: HashMap::new(), + } + } + + pub fn add_function( + &mut self, + function_name: &str, + function_description: &str, + func: Arc, + ) -> anyhow::Result<()> { + let schema = get_schema_generator().into_root_schema_for::(); + let schema_json = serde_json::to_value(&schema)?; + let new_function = ChatCompletionFunctionsArgs::default() + .name(function_name) + .description(function_description) + .parameters(schema_json) + .build()?; + + self.functions.push(new_function); + + self.function_table.insert(function_name.to_string(), func); + Ok(()) + } + + async fn call_function(&self, name: &str, args: &str) -> anyhow::Result { + info!("Calling function {:?} with args {:?}", name, args); + let function = self + .function_table + .get(name) + .ok_or_else(|| anyhow::anyhow!("Function {} not found", name))?; + function.call(args).await + } + + /// build request message + fn build_request_message(&self) -> anyhow::Result { + // request builder setup is a bit more complicated because of the optional parameters + let mut request_builder = CreateChatCompletionRequestArgs::default(); + + request_builder + .model(self.model_name.clone()) + .messages(self.history.clone()) + .functions(self.functions.clone()) + .function_call("auto"); + + if let Some(temperature) = self.temperature { + request_builder.temperature(temperature); + } + + if let Some(top_p) = self.top_p { + request_builder.top_p(top_p); + } + + Ok(request_builder.build()?) + } + + /// stream next message + pub async fn next_message_stream( + &mut self, + message_text: Option<&str>, + client: &Client, + ) -> anyhow::Result { + if let Some(message_text) = message_text { + let user_message = ChatCompletionRequestMessageArgs::default() + .content(message_text) + .role(Role::User) + .build()?; + + self.history.push(user_message); + } + + let request = self.build_request_message()?; + + let mut stream = client.chat().create_stream(request).await?; + + let mut response_role = None; + let mut response_content_buffer = String::new(); + let mut fn_name = String::new(); + let mut fn_args = String::new(); + + // For reasons not documented in OpenAI docs / OpenAPI spec, the response of streaming call is different and doesn't include all the same fields. + while let Some(result) = stream.next().await { + let response = result?; + + // assert that we only get one response + if response.choices.len() != 1 { + return Err(anyhow::anyhow!( + "expected 1 response from OpenAI, got {}", + response.choices.len() + )); + } + let choice = response + .choices + .first() + .expect("Failed to get first choice from response"); + + // take response role + if let Some(role) = choice.delta.role { + response_role = Some(role); + } + + // take function call + if let Some(fn_call) = &choice.delta.function_call { + if let Some(name) = &fn_call.name { + fn_name = name.clone(); + } + if let Some(args) = &fn_call.arguments { + fn_args.push_str(args); + } + } + + // take response content + if let Some(delta_content) = &choice.delta.content { + response_content_buffer.push_str(delta_content); + // process chunk (print it?) + } + + // check if response is end + if let Some(finish_reason) = &choice.finish_reason { + // figure out why the conversation ended + if matches!(finish_reason, FinishReason::FunctionCall) { + // function call + + // add function call to history + let function_call_request = ChatCompletionRequestMessageArgs::default() + .role(Role::Assistant) + .function_call(FunctionCall { + name: fn_name.clone(), + arguments: fn_args.clone(), + }) + .build()?; + self.history.push(function_call_request); + + // call function + let result = self.call_function(&fn_name, &fn_args).await?; + + // add function call result to history + let function_call_result = ChatCompletionRequestMessageArgs::default() + .role(Role::Function) + .content(result.to_string()) + .name(fn_name.clone()) + .build()?; + self.history.push(function_call_result); + + if !response_content_buffer.is_empty() { + // function calls can also include a response + + let added_response = ChatCompletionRequestMessageArgs::default() + .content(&response_content_buffer) + .role(response_role.unwrap_or(Role::Assistant)) + .build()?; + + self.history.push(added_response); + return Ok(OpenAiApiResponse::AssistantResponse( + response_content_buffer, + )); + } else { + return Ok(OpenAiApiResponse::FunctionCallWithNoResponse); + } + } else { + // other reasons ass message from assistant + let added_response = ChatCompletionRequestMessageArgs::default() + .content(&response_content_buffer) + .role(response_role.unwrap_or(Role::Assistant)) + .build()?; + + self.history.push(added_response); + return Ok(OpenAiApiResponse::AssistantResponse( + response_content_buffer, + )); + } + } + } + + // return text anyway even if we don't get an end reason + Ok(OpenAiApiResponse::AssistantResponse( + response_content_buffer, + )) + } + + pub fn get_history(&self) -> String { + let history = OpenAiHistory { + history: self.history.clone(), + timestamp: chrono::Utc::now(), + }; + serde_json::to_string_pretty(&history).expect("Failed to serialize chat history") + } +} diff --git a/src/openai/functions.rs b/src/openai/functions.rs index 26bc991..1970f76 100644 --- a/src/openai/functions.rs +++ b/src/openai/functions.rs @@ -14,7 +14,7 @@ use crate::{ zenoh_consts::STANCE_SUBSCRIBER, }; -use super::AsyncCallback; +use super::conversation_handler::AsyncCallback; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct HopperBodyPoseFuncArgs { diff --git a/src/openai/mod.rs b/src/openai/mod.rs index 7a26529..c507e5f 100644 --- a/src/openai/mod.rs +++ b/src/openai/mod.rs @@ -1,22 +1,9 @@ +mod conversation_handler; mod functions; -use async_openai::{ - config::OpenAIConfig, - types::{ - ChatCompletionFunctions, ChatCompletionFunctionsArgs, ChatCompletionRequestMessage, - ChatCompletionRequestMessageArgs, CreateChatCompletionRequest, - CreateChatCompletionRequestArgs, FinishReason, FunctionCall, Role, - }, - Client, -}; -use async_trait::async_trait; -use futures::StreamExt; -use schemars::{gen::SchemaSettings, JsonSchema}; +use async_openai::{config::OpenAIConfig, types::ChatCompletionRequestMessage, Client}; use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - sync::{atomic::AtomicU8, Arc, Mutex}, -}; +use std::sync::{atomic::AtomicU8, Arc, Mutex}; use tokio::select; use tracing::info; use zenoh::prelude::r#async::*; @@ -25,11 +12,12 @@ use crate::{ error::HopperError, face::animations::Animation, ioc_container::IocContainer, + openai::conversation_handler::OpenAiApiResponse, speech::SpeechService, zenoh_consts::{HOPPER_OPENAI_COMMAND_SUBSCRIBER, OPENAI_DIAGNOSTICS_HISTORY}, }; -use self::functions::*; +use self::{conversation_handler::ChatGptConversation, functions::*}; // cheap but dumb // const MODEL_NAME: &str = "gpt-3.5-turbo-0613"; @@ -301,238 +289,8 @@ async fn speak_with_face_animation( Ok(()) } -fn get_schema_generator() -> schemars::gen::SchemaGenerator { - let settings = SchemaSettings::draft07().with(|s| { - s.inline_subschemas = true; - s.meta_schema = None; - }); - settings.into_generator() -} - -pub enum OpenAiApiResponse { - AssistantResponse(String), - FunctionCallWithNoResponse, -} - -#[async_trait] -pub trait AsyncCallback: Send + Sync { - async fn call(&self, args: &str) -> anyhow::Result; -} - -#[derive(Clone)] -pub struct ChatGptConversation { - history: Vec, - functions: Vec, - temperature: Option, - top_p: Option, - model_name: String, - function_table: HashMap>, -} - -impl ChatGptConversation { - pub fn new(system_prompt: &str, model_name: &str) -> Self { - let history = vec![ChatCompletionRequestMessageArgs::default() - .content(system_prompt) - .role(Role::System) - .build() - // can this fail? - .expect("Failed to build system prompt message")]; - Self { - history, - functions: vec![], - temperature: None, - top_p: None, - model_name: model_name.to_string(), - function_table: HashMap::new(), - } - } - - pub fn add_function( - &mut self, - function_name: &str, - function_description: &str, - func: Arc, - ) -> anyhow::Result<()> { - let schema = get_schema_generator().into_root_schema_for::(); - let schema_json = serde_json::to_value(&schema)?; - let new_function = ChatCompletionFunctionsArgs::default() - .name(function_name) - .description(function_description) - .parameters(schema_json) - .build()?; - - self.functions.push(new_function); - - self.function_table.insert(function_name.to_string(), func); - Ok(()) - } - - async fn call_function(&self, name: &str, args: &str) -> anyhow::Result { - info!("Calling function {:?} with args {:?}", name, args); - let function = self - .function_table - .get(name) - .ok_or_else(|| anyhow::anyhow!("Function {} not found", name))?; - function.call(args).await - } - - /// build request message - fn build_request_message(&self) -> anyhow::Result { - // request builder setup is a bit more complicated because of the optional parameters - let mut request_builder = CreateChatCompletionRequestArgs::default(); - - request_builder - .model(self.model_name.clone()) - .messages(self.history.clone()) - .functions(self.functions.clone()) - .function_call("auto"); - - if let Some(temperature) = self.temperature { - request_builder.temperature(temperature); - } - - if let Some(top_p) = self.top_p { - request_builder.top_p(top_p); - } - - Ok(request_builder.build()?) - } - - /// stream next message - pub async fn next_message_stream( - &mut self, - message_text: Option<&str>, - client: &Client, - ) -> anyhow::Result { - if let Some(message_text) = message_text { - let user_message = ChatCompletionRequestMessageArgs::default() - .content(message_text) - .role(Role::User) - .build()?; - - self.history.push(user_message); - } - - let request = self.build_request_message()?; - - let mut stream = client.chat().create_stream(request).await?; - - let mut response_role = None; - let mut response_content_buffer = String::new(); - let mut fn_name = String::new(); - let mut fn_args = String::new(); - - // For reasons not documented in OpenAI docs / OpenAPI spec, the response of streaming call is different and doesn't include all the same fields. - while let Some(result) = stream.next().await { - let response = result?; - - // assert that we only get one response - if response.choices.len() != 1 { - return Err(anyhow::anyhow!( - "expected 1 response from OpenAI, got {}", - response.choices.len() - )); - } - let choice = response - .choices - .first() - .expect("Failed to get first choice from response"); - - // take response role - if let Some(role) = choice.delta.role { - response_role = Some(role); - } - - // take function call - if let Some(fn_call) = &choice.delta.function_call { - if let Some(name) = &fn_call.name { - fn_name = name.clone(); - } - if let Some(args) = &fn_call.arguments { - fn_args.push_str(args); - } - } - - // take response content - if let Some(delta_content) = &choice.delta.content { - response_content_buffer.push_str(delta_content); - // process chunk (print it?) - } - - // check if response is end - if let Some(finish_reason) = &choice.finish_reason { - // figure out why the conversation ended - if matches!(finish_reason, FinishReason::FunctionCall) { - // function call - - // add function call to history - let function_call_request = ChatCompletionRequestMessageArgs::default() - .role(Role::Assistant) - .function_call(FunctionCall { - name: fn_name.clone(), - arguments: fn_args.clone(), - }) - .build()?; - self.history.push(function_call_request); - - // call function - let result = self.call_function(&fn_name, &fn_args).await?; - - // add function call result to history - let function_call_result = ChatCompletionRequestMessageArgs::default() - .role(Role::Function) - .content(result.to_string()) - .name(fn_name.clone()) - .build()?; - self.history.push(function_call_result); - - if !response_content_buffer.is_empty() { - // function calls can also include a response - - let added_response = ChatCompletionRequestMessageArgs::default() - .content(&response_content_buffer) - .role(response_role.unwrap_or(Role::Assistant)) - .build()?; - - self.history.push(added_response); - return Ok(OpenAiApiResponse::AssistantResponse( - response_content_buffer, - )); - } else { - return Ok(OpenAiApiResponse::FunctionCallWithNoResponse); - } - } else { - // other reasons ass message from assistant - let added_response = ChatCompletionRequestMessageArgs::default() - .content(&response_content_buffer) - .role(response_role.unwrap_or(Role::Assistant)) - .build()?; - - self.history.push(added_response); - return Ok(OpenAiApiResponse::AssistantResponse( - response_content_buffer, - )); - } - } - } - - // return text anyway even if we don't get an end reason - Ok(OpenAiApiResponse::AssistantResponse( - response_content_buffer, - )) - } - - pub fn get_history(&self) -> String { - let history = OpenAiHistory { - history: self.history.clone(), - timestamp: chrono::Utc::now(), - }; - serde_json::to_string_pretty(&history).expect("Failed to serialize chat history") - } -} - #[derive(Serialize, Deserialize, Debug)] -struct OpenAiHistory { +pub struct OpenAiHistory { history: Vec, timestamp: chrono::DateTime, }