From 0c38d77c99741a2d5501f7d3bad921c492d484bb Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 13 Aug 2024 16:54:01 -0700 Subject: [PATCH] fix(answer): properly handle user input attachment (#2865) * fix(answer): properly handle user input attachment * update * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- Cargo.lock | 1 + ee/tabby-db/src/threads.rs | 5 +- ee/tabby-schema/graphql/schema.graphql | 6 +- ee/tabby-schema/src/dao.rs | 6 + ee/tabby-schema/src/schema/mod.rs | 18 +- ee/tabby-schema/src/schema/thread.rs | 1 + ee/tabby-schema/src/schema/thread/inputs.rs | 5 +- ee/tabby-schema/src/schema/thread/types.rs | 6 +- ee/tabby-webserver/Cargo.toml | 1 + ee/tabby-webserver/src/service/answer.rs | 156 +++++++++++++----- ...t_messages_to_chat_completion_request.snap | 10 ++ ee/tabby-webserver/src/service/thread.rs | 62 +++---- 12 files changed, 187 insertions(+), 90 deletions(-) create mode 100644 ee/tabby-webserver/src/service/snapshots/tabby_webserver__service__answer__tests__convert_messages_to_chat_completion_request.snap diff --git a/Cargo.lock b/Cargo.lock index 6528a8b65b8e..0e1a7102a414 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5239,6 +5239,7 @@ dependencies = [ "futures", "gitlab", "hyper 1.3.1", + "insta", "jsonwebtoken", "juniper", "juniper_axum", diff --git a/ee/tabby-db/src/threads.rs b/ee/tabby-db/src/threads.rs index 0e747607917e..7466b903d026 100644 --- a/ee/tabby-db/src/threads.rs +++ b/ee/tabby-db/src/threads.rs @@ -39,8 +39,11 @@ pub struct ThreadMessageAttachmentDoc { #[derive(Serialize, Deserialize)] pub struct ThreadMessageAttachmentCode { - pub filepath: Option, + pub git_url: String, + pub language: String, + pub filepath: String, pub content: String, + pub start_line: usize, } impl DbConn { diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index 61eea5535a95..691513bca27a 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -396,10 +396,12 @@ type MessageAttachment { doc: [MessageAttachmentDoc!]! } -"If you want to change the struct, please make sure the change is backward compatible." type MessageAttachmentCode { - filepath: String + gitUrl: String! + filepath: String! + language: String! content: String! + startLine: Int! } type MessageAttachmentDoc { diff --git a/ee/tabby-schema/src/dao.rs b/ee/tabby-schema/src/dao.rs index 4aa494352564..b53b677699b5 100644 --- a/ee/tabby-schema/src/dao.rs +++ b/ee/tabby-schema/src/dao.rs @@ -233,8 +233,11 @@ impl TryFrom for UserEvent { impl From for thread::MessageAttachmentCode { fn from(value: ThreadMessageAttachmentCode) -> Self { Self { + git_url: value.git_url, filepath: value.filepath, + language: value.language, content: value.content, + start_line: value.start_line as i32, } } } @@ -242,8 +245,11 @@ impl From for thread::MessageAttachmentCode { impl From<&thread::MessageAttachmentCode> for ThreadMessageAttachmentCode { fn from(val: &thread::MessageAttachmentCode) -> Self { ThreadMessageAttachmentCode { + git_url: val.git_url.clone(), filepath: val.filepath.clone(), + language: val.language.clone(), content: val.content.clone(), + start_line: val.start_line as usize, } } } diff --git a/ee/tabby-schema/src/schema/mod.rs b/ee/tabby-schema/src/schema/mod.rs index 583f8c8f5bad..6d396d14be8c 100644 --- a/ee/tabby-schema/src/schema/mod.rs +++ b/ee/tabby-schema/src/schema/mod.rs @@ -1000,7 +1000,13 @@ impl Subscription { let thread_id = thread.create(&user.id, &input.thread).await?; thread - .create_run(&thread_id, &input.options, true, true) + .create_run( + &thread_id, + &input.options, + input.thread.user_message.attachments.as_ref(), + true, + true, + ) .await } @@ -1029,8 +1035,14 @@ impl Subscription { svc.append_user_message(&input.thread_id, &input.additional_user_message) .await?; - svc.create_run(&input.thread_id, &input.options, true, false) - .await + svc.create_run( + &input.thread_id, + &input.options, + input.additional_user_message.attachments.as_ref(), + true, + false, + ) + .await } } diff --git a/ee/tabby-schema/src/schema/thread.rs b/ee/tabby-schema/src/schema/thread.rs index 7cc227cb58cf..e91744625d5c 100644 --- a/ee/tabby-schema/src/schema/thread.rs +++ b/ee/tabby-schema/src/schema/thread.rs @@ -35,6 +35,7 @@ pub trait ThreadService: Send + Sync { &self, id: &ID, options: &ThreadRunOptionsInput, + attachment_input: Option<&MessageAttachmentInput>, yield_last_user_message: bool, yield_thread_created: bool, ) -> Result; diff --git a/ee/tabby-schema/src/schema/thread/inputs.rs b/ee/tabby-schema/src/schema/thread/inputs.rs index 39721e0aa5d5..02b141be5db7 100644 --- a/ee/tabby-schema/src/schema/thread/inputs.rs +++ b/ee/tabby-schema/src/schema/thread/inputs.rs @@ -97,14 +97,13 @@ pub struct CreateThreadRunInput { pub options: ThreadRunOptionsInput, } -#[derive(GraphQLInputObject)] +#[derive(GraphQLInputObject, Clone)] pub struct MessageAttachmentInput { pub code: Vec, } -#[derive(GraphQLInputObject)] +#[derive(GraphQLInputObject, Clone)] pub struct MessageAttachmentCodeInput { pub filepath: Option, - pub content: String, } diff --git a/ee/tabby-schema/src/schema/thread/types.rs b/ee/tabby-schema/src/schema/thread/types.rs index 9455046bf518..143b1fc87a8e 100644 --- a/ee/tabby-schema/src/schema/thread/types.rs +++ b/ee/tabby-schema/src/schema/thread/types.rs @@ -46,11 +46,13 @@ pub struct MessageAttachment { pub doc: Vec, } -/// If you want to change the struct, please make sure the change is backward compatible. #[derive(GraphQLObject, Clone)] pub struct MessageAttachmentCode { - pub filepath: Option, + pub git_url: String, + pub filepath: String, + pub language: String, pub content: String, + pub start_line: i32, } #[derive(GraphQLObject, Clone)] diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 5937f2dac44e..3c1108be3d22 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -63,3 +63,4 @@ tabby-db = { path = "../../ee/tabby-db", features = ["testutils"] } tabby-common = { path = "../../crates/tabby-common", features = ["testutils"] } serial_test = { workspace = true } temp_testdir = { workspace = true } +insta = { workspace = true, features = ["yaml", "redactions"] } diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index 19a5c3ac6eb8..49663f787d76 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -23,7 +23,8 @@ use tabby_schema::{ repository::RepositoryService, thread::{ self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment, - MessageAttachmentCode, MessageAttachmentDoc, ThreadRunItem, ThreadRunOptionsInput, + MessageAttachmentCode, MessageAttachmentCodeInput, MessageAttachmentDoc, ThreadRunItem, + ThreadRunOptionsInput, }, web_crawler::WebCrawlerService, }; @@ -139,7 +140,7 @@ impl AnswerService { yield AnswerResponseChunk::RelevantQuestions(relevant_questions); } - let code_snippets: Vec = code_snippets.iter().map(|x| MessageAttachmentCode { + let code_snippets: Vec = code_snippets.iter().map(|x| MessageAttachmentCodeInput { filepath: x.filepath.clone(), content: x.content.clone(), }).collect(); @@ -190,9 +191,11 @@ impl AnswerService { self: Arc, messages: &[tabby_schema::thread::Message], options: &ThreadRunOptionsInput, + user_attachment_input: Option<&tabby_schema::thread::MessageAttachmentInput>, ) -> tabby_schema::Result>> { let messages = messages.to_vec(); let options = options.clone(); + let user_attachment_input = user_attachment_input.cloned(); let s = stream! { let query = match messages.last() { @@ -209,9 +212,12 @@ impl AnswerService { // 1. Collect relevant code if needed. if let Some(code_query) = options.code_query.as_ref() { attachment.code = self.collect_relevant_code(code_query, &self.config.code_search_params, options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref())).await.iter() - .map(|x| MessageAttachmentCode{ - filepath: Some(x.doc.filepath.clone()), + .map(|x| MessageAttachmentCode { + git_url: x.doc.git_url.clone(), + filepath: x.doc.filepath.clone(), + language: x.doc.language.clone(), content: x.doc.body.clone(), + start_line: x.doc.start_line as i32, }) .collect::>(); }; @@ -248,7 +254,7 @@ impl AnswerService { // 4. Prepare requesting LLM let request = { - let chat_messages = convert_messages_to_chat_completion_request(&messages, &attachment)?; + let chat_messages = convert_messages_to_chat_completion_request(&messages, &attachment, user_attachment_input.as_ref())?; CreateChatCompletionRequestArgs::default() .messages(chat_messages) @@ -398,11 +404,10 @@ impl AnswerService { .code .iter() .map(|snippet| { - if let Some(filepath) = &snippet.filepath { - format!("``` title=\"{}\"\n{}\n```", filepath, snippet.content) - } else { - format!("```\n{}\n```", snippet.content) - } + format!( + "```{} title=\"{}\"\n{}\n```", + snippet.language, snippet.filepath, snippet.content + ) }) .chain( attachment @@ -531,7 +536,7 @@ Remember, based on the original question and related contexts, suggest three suc async fn generate_prompt( &self, - code_snippets: &[MessageAttachmentCode], + code_snippets: &[MessageAttachmentCodeInput], relevant_code: &[CodeSearchHit], relevant_docs: &[DocSearchHit], question: &str, @@ -628,6 +633,7 @@ fn set_content(message: &mut ChatCompletionRequestMessage, content: String) { fn convert_messages_to_chat_completion_request( messages: &[tabby_schema::thread::Message], attachment: &tabby_schema::thread::MessageAttachment, + user_attachment_input: Option<&tabby_schema::thread::MessageAttachmentInput>, ) -> anyhow::Result> { let mut output = vec![]; output.reserve(messages.len()); @@ -646,7 +652,7 @@ fn convert_messages_to_chat_completion_request( let y = &messages[i + 1]; - build_user_prompt(x, &y.attachment) + build_user_prompt(&x.content, &y.attachment, None) } else { x.content.clone() }; @@ -662,7 +668,11 @@ fn convert_messages_to_chat_completion_request( output.push(ChatCompletionRequestMessage::System( ChatCompletionRequestSystemMessage { - content: build_user_prompt(&messages[messages.len() - 1], attachment), + content: build_user_prompt( + &messages[messages.len() - 1].content, + attachment, + user_attachment_input, + ), role: Role::User, name: None, }, @@ -672,41 +682,43 @@ fn convert_messages_to_chat_completion_request( } fn build_user_prompt( - user_message: &tabby_schema::thread::Message, + user_input: &str, assistant_attachment: &tabby_schema::thread::MessageAttachment, + user_attachment_input: Option<&tabby_schema::thread::MessageAttachmentInput>, ) -> String { // If the user message has no code attachment and the assistant message has no code attachment or doc attachment, return the user message directly. - if user_message.attachment.code.is_empty() + if user_attachment_input + .map(|x| x.code.is_empty()) + .unwrap_or(true) && assistant_attachment.code.is_empty() && assistant_attachment.doc.is_empty() { - return user_message.content.clone(); + return user_input.to_owned(); } - let snippets: Vec = user_message - .attachment - .code + let snippets: Vec = assistant_attachment + .doc .iter() - .map(|snippet| { - if let Some(filepath) = &snippet.filepath { - format!("```title=\"{}\"\n{}\n```", filepath, snippet.content) - } else { - format!("```\n{}\n```", snippet.content) - } - }) - .chain(assistant_attachment.code.iter().map(|snippet| { - if let Some(filepath) = &snippet.filepath { - format!("```title=\"{}\"\n{}\n```", filepath, snippet.content) - } else { - format!("```\n{}\n```", snippet.content) - } - })) + .map(|doc| format!("```\n{}\n```", doc.content)) .chain( - assistant_attachment - .doc + user_attachment_input + .map(|x| &x.code) + .unwrap_or(&vec![]) .iter() - .map(|doc| format!("```\n{}\n```", doc.content)), + .map(|snippet| { + if let Some(filepath) = &snippet.filepath { + format!("```title=\"{}\"\n{}\n```", filepath, snippet.content) + } else { + format!("```\n{}\n```", snippet.content) + } + }), ) + .chain(assistant_attachment.code.iter().map(|snippet| { + format!( + "```{} title=\"{}\"\n{}\n```", + snippet.language, snippet.filepath, snippet.content + ) + })) .collect(); let citations: Vec = snippets @@ -716,7 +728,6 @@ fn build_user_prompt( .collect(); let context = citations.join("\n\n"); - let question = &user_message.content; format!( r#" @@ -732,7 +743,76 @@ Here are the set of contexts: Remember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question: -{question} +{user_input} "# ) } + +#[cfg(test)] +mod tests { + use juniper::ID; + use tabby_schema::AsID; + + fn make_message( + id: i32, + content: &str, + role: tabby_schema::thread::Role, + attachment: Option, + ) -> tabby_schema::thread::Message { + tabby_schema::thread::Message { + id: id.as_id(), + thread_id: ID::new("0"), + content: content.to_owned(), + role, + attachment: attachment.unwrap_or_default(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + } + } + + #[test] + fn test_convert_messages_to_chat_completion_request() { + // Fake assistant attachment + let attachment = tabby_schema::thread::MessageAttachment { + doc: vec![tabby_schema::thread::MessageAttachmentDoc { + title: "1. Example Document".to_owned(), + content: "This is an example".to_owned(), + link: "https://example.com".to_owned(), + }], + code: vec![tabby_schema::thread::MessageAttachmentCode { + git_url: "https://github.com".to_owned(), + filepath: "server.py".to_owned(), + language: "python".to_owned(), + content: "print('Hello, server!')".to_owned(), + start_line: 1, + }], + }; + + let messages = vec![ + make_message(1, "Hello", tabby_schema::thread::Role::User, None), + make_message( + 2, + "Hi", + tabby_schema::thread::Role::Assistant, + Some(attachment), + ), + make_message(3, "How are you?", tabby_schema::thread::Role::User, None), + ]; + + let user_attachment_input = tabby_schema::thread::MessageAttachmentInput { + code: vec![tabby_schema::thread::MessageAttachmentCodeInput { + filepath: Some("client.py".to_owned()), + content: "print('Hello, client!')".to_owned(), + }], + }; + + let output = super::convert_messages_to_chat_completion_request( + &messages, + &tabby_schema::thread::MessageAttachment::default(), + Some(&user_attachment_input), + ) + .unwrap(); + + insta::assert_yaml_snapshot!(output); + } +} diff --git a/ee/tabby-webserver/src/service/snapshots/tabby_webserver__service__answer__tests__convert_messages_to_chat_completion_request.snap b/ee/tabby-webserver/src/service/snapshots/tabby_webserver__service__answer__tests__convert_messages_to_chat_completion_request.snap new file mode 100644 index 000000000000..23ad8f67903c --- /dev/null +++ b/ee/tabby-webserver/src/service/snapshots/tabby_webserver__service__answer__tests__convert_messages_to_chat_completion_request.snap @@ -0,0 +1,10 @@ +--- +source: ee/tabby-webserver/src/service/answer.rs +expression: output +--- +- content: "\nYou are a professional developer AI assistant. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context and cite the context at the end of each sentence if applicable.\n\nYour answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say \"information is missing on\" followed by the related topic, if the given context do not provide sufficient information.\n\nPlease cite the contexts with the reference numbers, in the format [citation:x]. If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. Other than code and specific names and citations, your answer must be written in the same language as the question.\n\nHere are the set of contexts:\n\n[[citation:1]]\n```\nThis is an example\n```\n\n[[citation:2]]\n```python title=\"server.py\"\nprint('Hello, server!')\n```\n\nRemember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question:\n\nHello\n" + role: user +- content: Hi + role: assistant +- content: "\nYou are a professional developer AI assistant. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context and cite the context at the end of each sentence if applicable.\n\nYour answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say \"information is missing on\" followed by the related topic, if the given context do not provide sufficient information.\n\nPlease cite the contexts with the reference numbers, in the format [citation:x]. If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. Other than code and specific names and citations, your answer must be written in the same language as the question.\n\nHere are the set of contexts:\n\n[[citation:1]]\n```title=\"client.py\"\nprint('Hello, client!')\n```\n\nRemember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question:\n\nHow are you?\n" + role: user diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index 7c6b853ec8f7..5f939bfd77db 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -3,16 +3,15 @@ use std::sync::Arc; use async_trait::async_trait; use futures::StreamExt; use juniper::ID; -use tabby_db::{DbConn, ThreadMessageAttachmentCode, ThreadMessageDAO}; +use tabby_db::{DbConn, ThreadMessageDAO}; use tabby_schema::{ bail, thread::{ - self, CreateMessageInput, CreateThreadInput, ThreadRunItem, ThreadRunOptionsInput, - ThreadRunStream, ThreadService, + self, CreateMessageInput, CreateThreadInput, MessageAttachmentInput, ThreadRunItem, + ThreadRunOptionsInput, ThreadRunStream, ThreadService, }, AsID, AsRowid, DbEnum, Result, }; -use tracing::error; use super::{answer::AnswerService, graphql_pagination_to_filter}; @@ -27,7 +26,7 @@ impl ThreadServiceImpl { .db .list_thread_messages(thread_id.as_rowid()?, None, None, false) .await?; - Ok(to_vec_messages(messages)) + to_vec_messages(messages) } } @@ -52,6 +51,7 @@ impl ThreadService for ThreadServiceImpl { &self, thread_id: &ID, options: &ThreadRunOptionsInput, + attachment_input: Option<&MessageAttachmentInput>, yield_last_user_message: bool, yield_thread_created: bool, ) -> Result { @@ -59,18 +59,7 @@ impl ThreadService for ThreadServiceImpl { bail!("Answer service is not available"); }; - let messages: Vec = self - .get_thread_messages(thread_id) - .await? - .into_iter() - .flat_map(|x| match x.try_into() { - Ok(x) => Some(x), - Err(e) => { - error!("Failed to convert thread message: {}", e); - None - } - }) - .collect(); + let messages = self.get_thread_messages(thread_id).await?; let Some(last_message) = messages.last() else { bail!("Thread has no messages"); @@ -94,7 +83,9 @@ impl ThreadService for ThreadServiceImpl { ) .await?; - let s = answer.answer_v2(&messages, options).await?; + let s = answer + .answer_v2(&messages, options, attachment_input) + .await?; // Copy ownership of db and thread_id for the stream let db = self.db.clone(); @@ -163,22 +154,12 @@ impl ThreadService for ThreadServiceImpl { ) -> Result<()> { let thread_id = thread_id.as_rowid()?; - let code = message.attachments.as_ref().map(|x| { - x.code - .iter() - .map(|x| ThreadMessageAttachmentCode { - filepath: x.filepath.clone(), - content: x.content.clone(), - }) - .collect::>() - }); - self.db .create_thread_message( thread_id, thread::Role::User.as_enum_str(), &message.content, - code.as_deref(), + None, None, true, ) @@ -226,21 +207,20 @@ impl ThreadService for ThreadServiceImpl { .list_thread_messages(thread_id, limit, skip_id, backwards) .await?; - Ok(to_vec_messages(messages)) + to_vec_messages(messages) } } -fn to_vec_messages(messages: Vec) -> Vec { - messages - .into_iter() - .filter_map(|x| match x.try_into() { - Ok(x) => Some(x), - Err(e) => { - error!("Failed to convert thread message: {}", e); - None - } - }) - .collect() +fn to_vec_messages(messages: Vec) -> Result> { + let mut output = vec![]; + output.reserve(messages.len()); + + for x in messages { + let message: thread::Message = x.try_into()?; + output.push(message); + } + + Ok(output) } pub fn create(db: DbConn, answer: Option>) -> impl ThreadService {