Skip to content

Commit

Permalink
feat(graphQL): impl UserValue in MessageAttachmentDoc
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Zhang <[email protected]>
  • Loading branch information
zwpaper committed Dec 9, 2024
1 parent 0a4bbdb commit da07589
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 65 deletions.
10 changes: 2 additions & 8 deletions ee/tabby-schema/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,6 @@ type MessageAttachment {
doc: [MessageAttachmentDoc!]!
}

type MessageAttachmentAuthor {
id: String!
email: String
name: String
}

type MessageAttachmentClientCode {
filepath: String
startLine: Int
Expand All @@ -518,15 +512,15 @@ type MessageAttachmentCodeScores {
type MessageAttachmentIssueDoc {
title: String!
link: String!
author: MessageAttachmentAuthor
author: User
body: String!
closed: Boolean!
}

type MessageAttachmentPullDoc {
title: String!
link: String!
author: MessageAttachmentAuthor
author: User
body: String!
patch: String!
merged: Boolean!
Expand Down
20 changes: 8 additions & 12 deletions ee/tabby-schema/src/dao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,7 @@ impl From<ThreadMessageAttachmentDoc> for thread::MessageAttachmentDoc {
thread::MessageAttachmentDoc::Issue(thread::MessageAttachmentIssueDoc {
title: val.title,
link: val.link,
author: val.author_user_id.map(|x| UserValue {
id: x,
email: None,
name: None,
}),
author: None, // will be filled in service layer
body: val.body,
closed: val.closed,
})
Expand All @@ -256,11 +252,7 @@ impl From<ThreadMessageAttachmentDoc> for thread::MessageAttachmentDoc {
thread::MessageAttachmentDoc::Pull(thread::MessageAttachmentPullDoc {
title: val.title,
link: val.link,
author: val.author_user_id.map(|x| UserValue {
id: x,
email: None,
name: None,
}),
author: None, // will be filled in service layer
body: val.body,
patch: val.diff,
merged: val.merged,
Expand All @@ -284,7 +276,9 @@ impl From<&thread::MessageAttachmentDoc> for ThreadMessageAttachmentDoc {
ThreadMessageAttachmentDoc::Issue(ThreadMessageAttachmentIssueDoc {
title: val.title.clone(),
link: val.link.clone(),
author_user_id: val.author.as_ref().map(|x| x.id.clone()),
author_user_id: val.author.as_ref().map(|x| match x {
UserValue::UserSecured(user) => user.id.to_string(),
}),
body: val.body.clone(),
closed: val.closed,
})
Expand All @@ -293,7 +287,9 @@ impl From<&thread::MessageAttachmentDoc> for ThreadMessageAttachmentDoc {
ThreadMessageAttachmentDoc::Pull(ThreadMessageAttachmentPullDoc {
title: val.title.clone(),
link: val.link.clone(),
author_user_id: val.author.as_ref().map(|x| x.id.clone()),
author_user_id: val.author.as_ref().map(|x| match x {
UserValue::UserSecured(user) => user.id.to_string(),
}),
body: val.body.clone(),
diff: val.patch.clone(),
merged: val.merged,
Expand Down
17 changes: 4 additions & 13 deletions ee/tabby-schema/src/schema/thread/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ pub struct MessageAttachmentPullDoc {
pub merged: bool,
}

impl From<DocSearchDocument> for MessageAttachmentDoc {
fn from(doc: DocSearchDocument) -> Self {
impl MessageAttachmentDoc {
pub fn from_doc_search_document(doc: DocSearchDocument, author: Option<UserValue>) -> Self {
match doc {
DocSearchDocument::Web(web) => MessageAttachmentDoc::Web(MessageAttachmentWebDoc {
title: web.title,
Expand All @@ -170,15 +170,15 @@ impl From<DocSearchDocument> for MessageAttachmentDoc {
MessageAttachmentDoc::Issue(MessageAttachmentIssueDoc {
title: issue.title,
link: issue.link,
author: None,
author: author,
body: issue.body,
closed: issue.closed,
})
}
DocSearchDocument::Pull(pull) => MessageAttachmentDoc::Pull(MessageAttachmentPullDoc {
title: pull.title,
link: pull.link,
author: None,
author: author,
body: pull.body,
patch: pull.diff,
merged: pull.merged,
Expand All @@ -194,15 +194,6 @@ pub struct MessageDocSearchHit {
pub score: f64,
}

impl From<DocSearchHit> for MessageDocSearchHit {
fn from(hit: DocSearchHit) -> Self {
Self {
doc: hit.doc.into(),
score: hit.score as f64,
}
}
}

#[derive(GraphQLObject)]
#[graphql(context = Context)]
pub struct Thread {
Expand Down
46 changes: 39 additions & 7 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ use tabby_common::{
CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchParams, CodeSearchQuery,
CodeSearchScores,
},
structured_doc::{DocSearch, DocSearchError, DocSearchHit},
structured_doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit},
},
config::AnswerConfig,
};
use tabby_inference::ChatCompletionStream;
use tabby_schema::{
auth::AuthenticationService,
context::{ContextInfoHelper, ContextService},
policy::AccessPolicy,
repository::{Repository, RepositoryService},
thread::{
self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment,
MessageAttachmentDoc, ThreadAssistantMessageAttachmentsCode,
MessageAttachmentDoc, MessageDocSearchHit, ThreadAssistantMessageAttachmentsCode,
ThreadAssistantMessageAttachmentsDoc, ThreadAssistantMessageContentDelta,
ThreadRelevantQuestions, ThreadRunItem, ThreadRunOptionsInput,
},
Expand All @@ -44,6 +45,7 @@ use crate::bail;

pub struct AnswerService {
config: AnswerConfig,
auth: Arc<dyn AuthenticationService>,
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
Expand All @@ -55,6 +57,7 @@ pub struct AnswerService {
impl AnswerService {
fn new(
config: &AnswerConfig,
auth: Arc<dyn AuthenticationService>,
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
Expand All @@ -64,6 +67,7 @@ impl AnswerService {
) -> Self {
Self {
config: config.clone(),
auth,
chat,
code,
doc,
Expand Down Expand Up @@ -122,14 +126,24 @@ impl AnswerService {
if let Some(doc_query) = options.doc_query.as_ref() {
let hits = self.collect_relevant_docs(&context_info_helper, doc_query)
.await;
attachment.doc = hits.iter()
.map(|x| x.doc.clone().into())
.collect::<Vec<_>>();
attachment.doc = futures::future::join_all(hits.iter().map(|x| async {
Self::new_message_attachment_doc(self.auth.clone(), x.doc.clone()).await
})).await;

debug!("doc content: {:?}: {:?}", doc_query.content, attachment.doc.len());

if !attachment.doc.is_empty() {
let hits = hits.into_iter().map(|x| x.into()).collect::<Vec<_>>();
let hits = futures::future::join_all(hits.into_iter().map(|x| {
let score = x.score;
let doc = x.doc.clone();
let auth = self.auth.clone();
async move {
MessageDocSearchHit {
score: score as f64,
doc: Self::new_message_attachment_doc(auth, doc).await,
}
}
})).await;
yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc(
ThreadAssistantMessageAttachmentsDoc { hits }
));
Expand Down Expand Up @@ -201,6 +215,23 @@ impl AnswerService {
Ok(Box::pin(s))
}

async fn new_message_attachment_doc(
auth: Arc<dyn AuthenticationService>,
doc: DocSearchDocument,
) -> MessageAttachmentDoc {
let email = match &doc {
DocSearchDocument::Issue(issue) => issue.author_email.as_deref(),
DocSearchDocument::Pull(pull) => pull.author_email.as_deref(),
_ => None,
};
let user = if let Some(email) = email {
auth.get_user_by_email(&email).await.ok().map(|x| x.into())
} else {
None
};
MessageAttachmentDoc::from_doc_search_document(doc, user)
}

async fn collect_relevant_code(
&self,
helper: &ContextInfoHelper,
Expand Down Expand Up @@ -377,14 +408,15 @@ fn trim_bullet(s: &str) -> String {

pub fn create(
config: &AnswerConfig,
auth: Arc<dyn AuthenticationService>,
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
context: Arc<dyn ContextService>,
serper: Option<Box<dyn DocSearch>>,
repository: Arc<dyn RepositoryService>,
) -> AnswerService {
AnswerService::new(config, chat, code, doc, context, serper, repository)
AnswerService::new(config, auth, chat, code, doc, context, serper, repository)
}

fn convert_messages_to_chat_completion_request(
Expand Down
45 changes: 20 additions & 25 deletions ee/tabby-webserver/src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
mod access_policy;
mod analytic;
pub mod answer;
mod auth;
pub mod auth;
pub mod background_job;
pub mod context;
mod email;
pub mod email;
pub mod event_logger;
pub mod integration;
pub mod job;
mod license;
pub mod license;
mod preset_web_documents_data;
pub mod repository;
mod setting;
pub mod setting;
mod thread;
mod user_event;
mod user_group;
Expand Down Expand Up @@ -58,10 +58,9 @@ use tabby_schema::{
AsID, AsRowid, CoreError, Result, ServiceLocator,
};

use self::{
analytic::new_analytic_service, email::new_email_service, license::new_license_service,
};
use self::analytic::new_analytic_service;
use crate::rate_limit::UserRateLimiter;

struct ServerContext {
db_conn: DbConn,
mail: Arc<dyn EmailService>,
Expand Down Expand Up @@ -91,6 +90,7 @@ struct ServerContext {
impl ServerContext {
pub async fn new(
logger: Arc<dyn EventLogger>,
auth: Arc<dyn AuthenticationService>,
chat: Option<Arc<dyn ChatCompletionStream>>,
completion: Option<Arc<dyn CompletionStream>>,
code: Arc<dyn CodeSearch>,
Expand All @@ -100,21 +100,13 @@ impl ServerContext {
answer: Option<Arc<AnswerService>>,
context: Arc<dyn ContextService>,
web_documents: Arc<dyn WebDocumentService>,
mail: Arc<dyn EmailService>,
license: Arc<dyn LicenseService>,
setting: Arc<dyn SettingService>,
db_conn: DbConn,
embedding: Arc<dyn EmbeddingService>,
) -> Self {
let mail = Arc::new(
new_email_service(db_conn.clone())
.await
.expect("failed to initialize mail service"),
);
let license = Arc::new(
new_license_service(db_conn.clone())
.await
.expect("failed to initialize license service"),
);
let user_event = Arc::new(user_event::create(db_conn.clone()));
let setting = Arc::new(setting::create(db_conn.clone()));
let thread = Arc::new(thread::create(db_conn.clone(), answer.clone()));
let user_group = Arc::new(user_group::create(db_conn.clone()));
let access_policy = Arc::new(access_policy::create(db_conn.clone(), context.clone()));
Expand All @@ -132,16 +124,11 @@ impl ServerContext {
.await;

Self {
mail: mail.clone(),
mail,
embedding,
chat,
completion,
auth: Arc::new(auth::create(
db_conn.clone(),
mail,
license.clone(),
setting.clone(),
)),
auth,
web_documents,
thread,
context,
Expand Down Expand Up @@ -354,6 +341,7 @@ impl ServiceLocator for ArcServerContext {

pub async fn create_service_locator(
logger: Arc<dyn EventLogger>,
auth: Arc<dyn AuthenticationService>,
chat: Option<Arc<dyn ChatCompletionStream>>,
completion: Option<Arc<dyn CompletionStream>>,
code: Arc<dyn CodeSearch>,
Expand All @@ -363,12 +351,16 @@ pub async fn create_service_locator(
answer: Option<Arc<AnswerService>>,
context: Arc<dyn ContextService>,
web_documents: Arc<dyn WebDocumentService>,
mail: Arc<dyn EmailService>,
license: Arc<dyn LicenseService>,
setting: Arc<dyn SettingService>,
db: DbConn,
embedding: Arc<dyn EmbeddingService>,
) -> Arc<dyn ServiceLocator> {
Arc::new(ArcServerContext::new(
ServerContext::new(
logger,
auth,
chat,
completion,
code,
Expand All @@ -378,6 +370,9 @@ pub async fn create_service_locator(
answer,
context,
web_documents,
mail,
license,
setting,
db,
embedding,
)
Expand Down
Loading

0 comments on commit da07589

Please sign in to comment.