Skip to content

Commit

Permalink
chore(webserver): add scores field to streaming code / docs (#2880)
Browse files Browse the repository at this point in the history
* chore(webserver): add scores field to MessageAttachmentCode struct

* update

* update

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Aug 15, 2024
1 parent dc55cbe commit 931bb80
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 40 deletions.
20 changes: 18 additions & 2 deletions ee/tabby-schema/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,33 @@ type MessageAttachmentCode {
startLine: Int!
}

type MessageAttachmentCodeScores {
rrf: Float!
bm25: Float!
embedding: Float!
}

type MessageAttachmentDoc {
title: String!
link: String!
content: String!
}

type MessageCodeSearchHit {
code: MessageAttachmentCode!
scores: MessageAttachmentCodeScores!
}

type MessageConnection {
edges: [MessageEdge!]!
pageInfo: PageInfo!
}

type MessageDocSearchHit {
doc: MessageAttachmentDoc!
score: Float!
}

type MessageEdge {
node: Message!
cursor: String!
Expand Down Expand Up @@ -626,8 +642,8 @@ type ThreadRunItem {
threadRelevantQuestions: [String!]
threadUserMessageCreated: ID
threadAssistantMessageCreated: ID
threadAssistantMessageAttachmentsCode: [MessageAttachmentCode!]
threadAssistantMessageAttachmentsDoc: [MessageAttachmentDoc!]
threadAssistantMessageAttachmentsCode: [MessageCodeSearchHit!]
threadAssistantMessageAttachmentsDoc: [MessageDocSearchHit!]
threadAssistantMessageContentDelta: String
threadAssistantMessageCompleted: ID
}
Expand Down
85 changes: 79 additions & 6 deletions ee/tabby-schema/src/schema/thread/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use chrono::{DateTime, Utc};
use juniper::{graphql_object, GraphQLEnum, GraphQLObject, ID};
use serde::Serialize;
use tabby_common::api::{
code::{CodeSearchDocument, CodeSearchHit, CodeSearchScores},
doc::{DocSearchDocument, DocSearchHit},
};

use crate::{juniper::relay::NodeType, Context};

Expand Down Expand Up @@ -55,13 +59,82 @@ pub struct MessageAttachmentCode {
pub start_line: i32,
}

impl From<CodeSearchDocument> for MessageAttachmentCode {
fn from(doc: CodeSearchDocument) -> Self {
Self {
git_url: doc.git_url,
filepath: doc.filepath,
language: doc.language,
content: doc.body,
start_line: doc.start_line as i32,
}
}
}

#[derive(GraphQLObject, Clone)]
pub struct MessageAttachmentCodeScores {
pub rrf: f64,
pub bm25: f64,
pub embedding: f64,
}

impl From<CodeSearchScores> for MessageAttachmentCodeScores {
fn from(scores: CodeSearchScores) -> Self {
Self {
rrf: scores.rrf as f64,
bm25: scores.bm25 as f64,
embedding: scores.embedding as f64,
}
}
}

#[derive(GraphQLObject)]
pub struct MessageCodeSearchHit {
pub code: MessageAttachmentCode,
pub scores: MessageAttachmentCodeScores,
}

impl From<CodeSearchHit> for MessageCodeSearchHit {
fn from(hit: CodeSearchHit) -> Self {
Self {
code: hit.doc.into(),
scores: hit.scores.into(),
}
}
}

#[derive(GraphQLObject, Clone)]
pub struct MessageAttachmentDoc {
pub title: String,
pub link: String,
pub content: String,
}

impl From<DocSearchDocument> for MessageAttachmentDoc {
fn from(doc: DocSearchDocument) -> Self {
Self {
title: doc.title,
link: doc.link,
content: doc.snippet,
}
}
}

#[derive(GraphQLObject)]
pub struct MessageDocSearchHit {
pub doc: MessageAttachmentDoc,
pub score: f64,
}

impl From<DocSearchHit> for MessageDocSearchHit {
fn from(hit: DocSearchHit) -> Self {
Self {
doc: hit.doc.into(),
score: hit.score as f64,
}
}
}

#[derive(GraphQLObject)]
#[graphql(context = Context)]
pub struct Thread {
Expand Down Expand Up @@ -95,8 +168,8 @@ pub enum ThreadRunItem {
ThreadRelevantQuestions(Vec<String>),
ThreadUserMessageCreated(ID),
ThreadAssistantMessageCreated(ID),
ThreadAssistantMessageAttachmentsCode(Vec<MessageAttachmentCode>),
ThreadAssistantMessageAttachmentsDoc(Vec<MessageAttachmentDoc>),
ThreadAssistantMessageAttachmentsCode(Vec<MessageCodeSearchHit>),
ThreadAssistantMessageAttachmentsDoc(Vec<MessageDocSearchHit>),
ThreadAssistantMessageContentDelta(String),
ThreadAssistantMessageCompleted(ID),
}
Expand Down Expand Up @@ -131,16 +204,16 @@ impl ThreadRunItem {
}
}

fn thread_assistant_message_attachments_code(&self) -> Option<&Vec<MessageAttachmentCode>> {
fn thread_assistant_message_attachments_code(&self) -> Option<&Vec<MessageCodeSearchHit>> {
match self {
ThreadRunItem::ThreadAssistantMessageAttachmentsCode(attachments) => Some(attachments),
ThreadRunItem::ThreadAssistantMessageAttachmentsCode(hits) => Some(hits),
_ => None,
}
}

fn thread_assistant_message_attachments_doc(&self) -> Option<&Vec<MessageAttachmentDoc>> {
fn thread_assistant_message_attachments_doc(&self) -> Option<&Vec<MessageDocSearchHit>> {
match self {
ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(attachments) => Some(attachments),
ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(hits) => Some(hits),
_ => None,
}
}
Expand Down
49 changes: 21 additions & 28 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ use tabby_schema::{
repository::RepositoryService,
thread::{
self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment,
MessageAttachmentCode, MessageAttachmentCodeInput, MessageAttachmentDoc, ThreadRunItem,
ThreadRunOptionsInput,
MessageAttachmentCodeInput, ThreadRunItem, ThreadRunOptionsInput,
},
web_crawler::WebCrawlerService,
};
Expand Down Expand Up @@ -211,38 +210,32 @@ impl AnswerService {

// 1. Collect relevant code if needed.
if let Some(code_query) = options.code_query.as_ref() {
attachment.code = self.collect_relevant_code(code_query, &self.config.code_search_params, options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref())).await.iter()
.map(|x| MessageAttachmentCode {
git_url: x.doc.git_url.clone(),
filepath: x.doc.filepath.clone(),
language: x.doc.language.clone(),
content: x.doc.body.clone(),
start_line: x.doc.start_line as i32,
})
.collect::<Vec<_>>();
let hits = self.collect_relevant_code(code_query, &self.config.code_search_params, options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref())).await;
attachment.code = hits.iter().map(|x| x.doc.clone().into()).collect::<Vec<_>>();

if !hits.is_empty() {
let message_hits = hits.into_iter().map(|x| x.into()).collect::<Vec<_>>();
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode(
message_hits
));
}
};

if !attachment.code.is_empty() {
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode(attachment.code.clone()));
}

// 2. Collect relevant docs if needed.
if let Some(doc_query) = options.doc_query.as_ref() {
attachment.doc = self.collect_relevant_docs(git_url.as_deref(), doc_query)
.await.iter()
.map(|x| MessageAttachmentDoc {
title: x.doc.title.clone(),
content: x.doc.snippet.clone(),
link: x.doc.link.clone(),
})
let hits = self.collect_relevant_docs(git_url.as_deref(), doc_query)
.await;
attachment.doc = hits.iter()
.map(|x| x.doc.clone().into())
.collect::<Vec<_>>();
};

if !attachment.doc.is_empty() {
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(
attachment.doc.clone()
));
}
if !attachment.doc.is_empty() {
let message_hits = hits.into_iter().map(|x| x.into()).collect::<Vec<_>>();
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(
message_hits
));
}
};

// 3. Generate relevant questions.
if options.generate_relevant_questions {
Expand Down
8 changes: 4 additions & 4 deletions ee/tabby-webserver/src/service/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ impl ThreadService for ThreadServiceImpl {
db.append_thread_message_content(assistant_message_id, content).await?;
}

Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode(code)) => {
let code = code
Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode(hits)) => {
let code = hits
.iter()
.map(Into::into)
.map(|x| (&x.code).into())
.collect::<Vec<_>>();
db.update_thread_message_attachments(
assistant_message_id,
Expand All @@ -122,7 +122,7 @@ impl ThreadService for ThreadServiceImpl {
Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(doc)) => {
let doc = doc
.iter()
.map(Into::into)
.map(|x| (&x.doc).into())
.collect::<Vec<_>>();
db.update_thread_message_attachments(
assistant_message_id,
Expand Down

0 comments on commit 931bb80

Please sign in to comment.