Skip to content

Commit

Permalink
fix: Issues / PRs should only be searched when repository is selected…
Browse files Browse the repository at this point in the history
… in Answer Engine (#2703)

* temp: support collect source ids being enabled

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* collect sources for doc query

* fix source ids scoring

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Jul 23, 2024
1 parent bf7bf53 commit 2a9ec6d
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Fixed and Improvements
body: Issues / PRs should only be searched when repository is selected in Answer Engine
time: 2024-07-23T11:48:26.918006+08:00
7 changes: 6 additions & 1 deletion crates/tabby-common/src/api/doc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,10 @@ pub enum DocSearchError {

#[async_trait]
pub trait DocSearch: Send + Sync {
async fn search(&self, q: &str, limit: usize) -> Result<DocSearchResponse, DocSearchError>;
async fn search(
&self,
source_ids: &[String],
q: &str,
limit: usize,
) -> Result<DocSearchResponse, DocSearchError>;
}
17 changes: 17 additions & 0 deletions crates/tabby-common/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ impl IndexSchema {
tantivy::schema::IndexRecordOption::Basic,
))
}

pub fn source_ids_query(&self, source_ids: &[String]) -> impl Query {
BooleanQuery::new(
source_ids
.iter()
.map(|source_id| -> (Occur, Box<(dyn Query)>) {
(
Occur::Should,
Box::new(TermQuery::new(
Term::from_field_text(self.field_source_id, source_id),
tantivy::schema::IndexRecordOption::Basic,
)),
)
})
.collect::<Vec<_>>(),
)
}
}

lazy_static! {
Expand Down
12 changes: 11 additions & 1 deletion crates/tabby/src/services/doc/serper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
use tabby_common::api::doc::{
DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse,
};
use tracing::warn;

#[derive(Debug, Serialize)]
struct SerperRequest {
Expand Down Expand Up @@ -45,7 +46,16 @@ impl SerperService {

#[async_trait]
impl DocSearch for SerperService {
async fn search(&self, q: &str, limit: usize) -> Result<DocSearchResponse, DocSearchError> {
async fn search(
&self,
source_ids: &[String],
q: &str,
limit: usize,
) -> Result<DocSearchResponse, DocSearchError> {
if !source_ids.is_empty() {
warn!("Serper does not support source filtering");
}

let request = SerperRequest {
q: q.to_string(),
num: limit,
Expand Down
44 changes: 30 additions & 14 deletions crates/tabby/src/services/doc/tantivy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tabby_common::{
use tabby_inference::Embedding;
use tantivy::{
collector::TopDocs,
query::{BooleanQuery, ConstScoreQuery, Occur},
query::{BooleanQuery, ConstScoreQuery, Occur, Query},
schema::{self, Value},
IndexReader, TantivyDocument,
};
Expand All @@ -30,22 +30,33 @@ impl DocSearchImpl {

async fn search(
&self,
source_ids: &[String],
reader: &IndexReader,
q: &str,
limit: usize,
) -> Result<DocSearchResponse, DocSearchError> {
let schema = index::IndexSchema::instance();
let embedding = self.embedding.embed(q).await?;
let embedding_tokens_query =
index::embedding_tokens_query(embedding.len(), embedding.iter());
let corpus_query = schema.corpus_query(corpus::WEB);
let query = BooleanQuery::new(vec![
(
Occur::Must,
Box::new(ConstScoreQuery::new(corpus_query, 0.0)),
),
(Occur::Must, Box::new(embedding_tokens_query)),
]);
let query = {
let embedding = self.embedding.embed(q).await?;
let embedding_tokens_query =
index::embedding_tokens_query(embedding.len(), embedding.iter());
let corpus_query = schema.corpus_query(corpus::WEB);

let mut query_clauses: Vec<(Occur, Box<dyn Query>)> = vec![
(
Occur::Must,
Box::new(ConstScoreQuery::new(corpus_query, 0.0)),
),
(Occur::Must, Box::new(embedding_tokens_query)),
];

if !source_ids.is_empty() {
let source_ids_query = Box::new(schema.source_ids_query(source_ids));
let source_ids_query = ConstScoreQuery::new(source_ids_query, 0.0);
query_clauses.push((Occur::Must, Box::new(source_ids_query)));
}
BooleanQuery::new(query_clauses)
};

let searcher = reader.searcher();
let top_chunks = searcher.search(&query, &TopDocs::with_limit(limit * 2))?;
Expand Down Expand Up @@ -160,9 +171,14 @@ impl DocSearchService {

#[async_trait]
impl DocSearch for DocSearchService {
async fn search(&self, q: &str, limit: usize) -> Result<DocSearchResponse, DocSearchError> {
async fn search(
&self,
source_ids: &[String],
q: &str,
limit: usize,
) -> Result<DocSearchResponse, DocSearchError> {
if let Some(reader) = self.provider.reader().await.as_ref() {
self.imp.search(reader, q, limit).await
self.imp.search(source_ids, reader, q, limit).await
} else {
Err(DocSearchError::NotReady)
}
Expand Down
2 changes: 2 additions & 0 deletions ee/tabby-schema/src/schema/repository/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,6 @@ pub trait RepositoryService: Send + Sync {

async fn list_all_repository_urls(&self) -> Result<Vec<RepositoryConfig>>;
async fn list_all_sources(&self) -> Result<Vec<(String, String)>>;

async fn resolve_web_source_id_by_git_url(&self, git_url: &str) -> Result<String>;
}
50 changes: 45 additions & 5 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tabby_common::api::{
doc::{DocSearch, DocSearchDocument, DocSearchError},
};
use tabby_inference::ChatCompletionStream;
use tabby_schema::{repository::RepositoryService, web_crawler::WebCrawlerService};
use tracing::{debug, warn};
use utoipa::ToSchema;

Expand Down Expand Up @@ -53,6 +54,8 @@ pub struct AnswerService {
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
web: Arc<dyn WebCrawlerService>,
repository: Arc<dyn RepositoryService>,
serper: Option<Box<dyn DocSearch>>,
}

Expand All @@ -64,6 +67,8 @@ impl AnswerService {
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
web: Arc<dyn WebCrawlerService>,
repository: Arc<dyn RepositoryService>,
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
) -> Self {
let serper: Option<Box<dyn DocSearch>> =
Expand All @@ -77,6 +82,8 @@ impl AnswerService {
chat,
code,
doc,
web,
repository,
serper,
}
}
Expand All @@ -95,6 +102,8 @@ impl AnswerService {
}
};

let git_url = req.code_query.as_ref().map(|x| x.git_url.clone());

// 1. Collect relevant code if needed.
let relevant_code = if let Some(mut code_query) = req.code_query {
if req.collect_relevant_code_using_user_message {
Expand All @@ -116,7 +125,7 @@ impl AnswerService {

// 2. Collect relevant docs if needed.
let relevant_docs = if req.doc_query {
self.collect_relevant_docs(get_content(query)).await
self.collect_relevant_docs(git_url.as_deref(), get_content(query)).await
} else {
vec![]
};
Expand Down Expand Up @@ -199,10 +208,39 @@ impl AnswerService {
.collect()
}

async fn collect_relevant_docs(&self, query: &str) -> Vec<DocSearchDocument> {
async fn collect_relevant_docs(
&self,
code_query_git_url: Option<&str>,
content: &str,
) -> Vec<DocSearchDocument> {
let source_ids = {
// 1. By default only web sources are considered.
let mut source_ids: Vec<_> = self
.web
.list_web_crawler_urls(None, None, None, None)
.await
.unwrap_or_default()
.into_iter()
.map(|url| url.source_id())
.collect();

// 2. If code_query is available, we also issues / PRs coming from the source.
if let Some(git_url) = code_query_git_url {
if let Ok(git_source_id) = self
.repository
.resolve_web_source_id_by_git_url(git_url)
.await
{
source_ids.push(git_source_id);
}
}

source_ids
};

// 1. Collect relevant docs from the tantivy doc search.
let mut hits = vec![];
let doc_hits = match self.doc.search(query, 5).await {
let doc_hits = match self.doc.search(&source_ids, content, 5).await {
Ok(docs) => docs.hits,
Err(err) => {
if let DocSearchError::NotReady = err {
Expand All @@ -217,7 +255,7 @@ impl AnswerService {

// 2. If serper is available, we also collect from serper
if let Some(serper) = self.serper.as_ref() {
let serper_hits = match serper.search(query, 5).await {
let serper_hits = match serper.search(&[], content, 5).await {
Ok(docs) => docs.hits,
Err(err) => {
warn!("Failed to search serper: {:?}", err);
Expand Down Expand Up @@ -364,9 +402,11 @@ pub fn create(
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
web: Arc<dyn WebCrawlerService>,
repository: Arc<dyn RepositoryService>,
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
) -> AnswerService {
AnswerService::new(chat, code, doc, serper_factory_fn)
AnswerService::new(chat, code, doc, web, repository, serper_factory_fn)
}

fn get_content(message: &ChatCompletionRequestMessage) -> &str {
Expand Down
15 changes: 15 additions & 0 deletions ee/tabby-webserver/src/service/repository/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ impl RepositoryService for RepositoryServiceImpl {

Ok(ret)
}

async fn resolve_web_source_id_by_git_url(&self, git_url: &str) -> Result<String> {
let git_url = RepositoryConfig::canonicalize_url(git_url);

// Only third_party repositories with a git_url could generates a web source (e.g Issues, PRs)
let tp = self.third_party();
let repos = tp
.list_repositories_with_filter(None, None, Some(true), None, None, None, None)
.await?;
repos
.iter()
.find(|r| RepositoryConfig::canonicalize_url(&r.git_url) == git_url)
.map(|r| r.source_id())
.ok_or_else(|| anyhow::anyhow!("No web source found for git_url: {}", git_url).into())
}
}

fn to_grep_file(file: tabby_git::GrepFile) -> tabby_schema::repository::GrepFile {
Expand Down
2 changes: 2 additions & 0 deletions ee/tabby-webserver/src/webserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ impl Webserver {
chat.clone(),
code.clone(),
docsearch.clone(),
ctx.web_crawler().clone(),
ctx.repository().clone(),
serper_factory_fn,
))
});
Expand Down

0 comments on commit 2a9ec6d

Please sign in to comment.