Skip to content

Commit

Permalink
fix(answer): fix panic when failed to gen relevant questions (#3462)
Browse files Browse the repository at this point in the history
* fix(answer): fix panic when failed to gen relevant questions

* test: fix fake chat stream
  • Loading branch information
zwpaper authored Nov 25, 2024
1 parent 3b77c9c commit a03054c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 23 deletions.
106 changes: 85 additions & 21 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,18 @@ impl AnswerService {
if options.generate_relevant_questions {
// Rewrite [[source:${id}]] tags to the actual source name for generate relevant questions.
let content = context_info_helper.rewrite_tag(&query.content);
let questions = self
match self
.generate_relevant_questions_v2(&attachment, &content)
.await;
yield Ok(ThreadRunItem::ThreadRelevantQuestions(ThreadRelevantQuestions{
questions
}));
.await{
Ok(questions) => {
yield Ok(ThreadRunItem::ThreadRelevantQuestions(ThreadRelevantQuestions{
questions
}));
}
Err(err) => {
warn!("Failed to generate relevant questions: {}", err);
}
}
}

// 4. Prepare requesting LLM
Expand Down Expand Up @@ -299,9 +305,9 @@ impl AnswerService {
&self,
attachment: &MessageAttachment,
question: &str,
) -> Vec<String> {
) -> anyhow::Result<Vec<String>> {
if attachment.code.is_empty() && attachment.doc.is_empty() {
return vec![];
return Ok(vec![]);
}

let snippets: Vec<String> = attachment
Expand Down Expand Up @@ -343,24 +349,20 @@ Remember, based on the original question and related contexts, suggest three suc
.build()
.expect("Failed to create ChatCompletionRequestUserMessage"),
)])
.build()
.expect("Failed to create ChatCompletionRequest");
.build()?;

let chat = self.chat.clone();
let s = chat
.chat(request)
.await
.expect("Failed to create chat completion stream");
let s = chat.chat(request).await?;
let content = s.choices[0]
.message
.content
.as_deref()
.expect("Failed to get content from chat completion");
content
.ok_or_else(|| anyhow!("Failed to get content from chat completion"))?;
Ok(content
.lines()
.map(trim_bullet)
.filter(|x| !x.is_empty())
.collect()
.collect())
}
}

Expand Down Expand Up @@ -820,7 +822,9 @@ mod tests {

#[tokio::test]
async fn test_collect_relevant_code() {
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream);
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
Expand Down Expand Up @@ -919,7 +923,9 @@ mod tests {

#[tokio::test]
async fn test_generate_relevant_questions_v2() {
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream);
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
Expand Down Expand Up @@ -972,11 +978,67 @@ mod tests {
"Can you explain how the Flask app works in this context?".to_string(),
];

assert_eq!(result, expected);
assert_eq!(result.unwrap(), expected);
}

#[tokio::test]
async fn test_generate_relevant_questions_v2_error() {
let chat: Arc<dyn ChatCompletionStream> =
Arc::new(FakeChatCompletionStream { return_error: true });
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
let serper = Some(Box::new(FakeDocSearch) as Box<dyn DocSearch>);
let config = make_answer_config();
let db = DbConn::new_in_memory().await.unwrap();
let repo = make_repository_service(db).await.unwrap();

let service = AnswerService::new(
&config,
chat.clone(),
code.clone(),
doc.clone(),
context.clone(),
serper,
repo,
);

let attachment = MessageAttachment {
doc: vec![tabby_schema::thread::MessageAttachmentDoc::Web(
tabby_schema::thread::MessageAttachmentWebDoc {
title: "1. Example Document".to_owned(),
content: "This is an example".to_owned(),
link: "https://example.com".to_owned(),
},
)],
code: vec![tabby_schema::thread::MessageAttachmentCode {
git_url: "https://github.com".to_owned(),
filepath: "server.py".to_owned(),
language: "python".to_owned(),
content: "print('Hello, server!')".to_owned(),
start_line: 1,
}],
client_code: vec![tabby_schema::thread::MessageAttachmentClientCode {
filepath: Some("client.py".to_owned()),
content: "print('Hello, client!')".to_owned(),
start_line: Some(1),
}],
};

let question = "What is the purpose of this code?";

let result = service
.generate_relevant_questions_v2(&attachment, question)
.await;

assert!(result.is_err());
}

#[tokio::test]
async fn test_collect_relevant_docs() {
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream);
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
Expand Down Expand Up @@ -1043,7 +1105,9 @@ mod tests {
use futures::StreamExt;
use tabby_schema::{policy::AccessPolicy, thread::ThreadRunOptionsInput};

let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream);
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
Expand Down
14 changes: 13 additions & 1 deletion ee/tabby-webserver/src/service/answer/testutils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,25 @@ use tabby_schema::{

use crate::{integration, job, repository};

pub struct FakeChatCompletionStream;
pub struct FakeChatCompletionStream {
pub return_error: bool,
}

#[async_trait]
impl ChatCompletionStream for FakeChatCompletionStream {
async fn chat(
&self,
_request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
if self.return_error {
return Err(OpenAIError::ApiError(async_openai::error::ApiError {
message: "error".to_string(),
code: None,
param: None,
r#type: None,
}));
}

Ok(CreateChatCompletionResponse {
id: "test-response".to_owned(),
created: 0,
Expand Down
4 changes: 3 additions & 1 deletion ee/tabby-webserver/src/service/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ mod tests {
async fn test_create_run() {
let db = DbConn::new_in_memory().await.unwrap();
let user_id = create_user(&db).await.as_id();
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream);
let chat: Arc<dyn ChatCompletionStream> = Arc::new(FakeChatCompletionStream {
return_error: false,
});
let code: Arc<dyn CodeSearch> = Arc::new(FakeCodeSearch);
let doc: Arc<dyn DocSearch> = Arc::new(FakeDocSearch);
let context: Arc<dyn ContextService> = Arc::new(FakeContextService);
Expand Down

0 comments on commit a03054c

Please sign in to comment.