Skip to content

Commit

Permalink
refactor: extract IndexReaderProvider for DocSearchService / CodeSear… (
Browse files Browse the repository at this point in the history
#2242)

* refactor: extract IndexReaderProvider for DocSearchService / CodeSearchService

* update

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored May 26, 2024
1 parent f9c4240 commit 978de8b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 131 deletions.
24 changes: 21 additions & 3 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
event::create_event_logger,
health,
model::download_model_if_needed,
tantivy::IndexReaderProvider,
},
to_local_config, Device,
};
Expand Down Expand Up @@ -154,8 +155,21 @@ pub async fn main(config: &Config, args: &ServeArgs) {
repository_access = ws.repository_access();
}

let code = Arc::new(create_code_search(repository_access));
let mut api = api_router(args, &config, logger.clone(), code.clone(), webserver).await;
let index_reader_provider = Arc::new(IndexReaderProvider::default());

let code = Arc::new(create_code_search(
repository_access,
index_reader_provider.clone(),
));
let mut api = api_router(
args,
&config,
logger.clone(),
code.clone(),
index_reader_provider,
webserver,
)
.await;
let mut ui = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.fallback(|| async { axum::response::Redirect::temporary("/swagger-ui") });
Expand Down Expand Up @@ -190,6 +204,7 @@ async fn api_router(
config: &Config,
logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>,
index_reader_provider: Arc<IndexReaderProvider>,
webserver: Option<bool>,
) -> Router {
let model = &config.model;
Expand All @@ -209,7 +224,10 @@ async fn api_router(

let docsearch_state: Option<Arc<dyn DocSearch>> = if let Some(embedding) = &model.embedding {
let embedding = embedding::create(embedding).await;
Some(Arc::new(services::doc::create(embedding)))
Some(Arc::new(services::doc::create(
embedding,
index_reader_provider,
)))
} else {
None
};
Expand Down
130 changes: 54 additions & 76 deletions crates/tabby/src/services/code.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
Expand All @@ -11,48 +11,26 @@ use tabby_common::{
},
config::{RepositoryAccess, RepositoryConfig},
index::{code, IndexSchema},
path,
};
use tantivy::{
collector::{Count, TopDocs},
schema::{self, document::ReferenceValue, Value},
Index, IndexReader, TantivyDocument,
IndexReader, TantivyDocument,
};
use tokio::{
sync::{Mutex, RwLock},
time::sleep,
};
use tracing::debug;
use tokio::sync::Mutex;

struct CodeSearchImpl {
reader: IndexReader,
use super::tantivy::IndexReaderProvider;

struct CodeSearchImpl {
repository_access: Arc<dyn RepositoryAccess>,
repo_cache: Mutex<TimedCache<(), Vec<RepositoryConfig>>>,
}

impl CodeSearchImpl {
fn load(repository_access: Arc<dyn RepositoryAccess>) -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?;

let reader = index
.reader_builder()
.reload_policy(tantivy::ReloadPolicy::OnCommitWithDelay)
.try_into()?;
Ok(Self {
fn new(repository_access: Arc<dyn RepositoryAccess>) -> Self {
Self {
repository_access,
reader,
repo_cache: Mutex::new(TimedCache::with_lifespan(10 * 60)),
})
}

async fn load_async(repository_access: Arc<dyn RepositoryAccess>) -> CodeSearchImpl {
loop {
if let Ok(doc) = Self::load(repository_access.clone()) {
debug!("Index is ready, enabling code search...");
return doc;
}
sleep(Duration::from_secs(60)).await;
}
}

Expand Down Expand Up @@ -95,11 +73,12 @@ impl CodeSearchImpl {

async fn search_with_query(
&self,
reader: &IndexReader,
q: &dyn tantivy::query::Query,
limit: usize,
offset: usize,
) -> Result<CodeSearchResponse, CodeSearchError> {
let searcher = self.reader.searcher();
let searcher = reader.searcher();
let (top_docs, num_hits) =
{ searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? };
let hits: Vec<CodeSearchHit> = {
Expand All @@ -113,6 +92,32 @@ impl CodeSearchImpl {
};
Ok(CodeSearchResponse { num_hits, hits })
}

async fn search_in_language(
&self,
reader: &IndexReader,
mut query: CodeSearchQuery,
limit: usize,
offset: usize,
) -> Result<CodeSearchResponse, CodeSearchError> {
let mut cache = self.repo_cache.lock().await;

let repos = cache
.try_get_or_set_with((), || async {
let repos = self.repository_access.list_repositories().await?;
Ok::<_, anyhow::Error>(repos)
})
.await?;

let Some(git_url) = closest_match(&query.git_url, repos.iter()) else {
return Ok(CodeSearchResponse::default());
};

query.git_url = git_url.to_owned();

let query = code::code_search_query(&query);
self.search_with_query(reader, &query, limit, offset).await
}
}

fn get_text(doc: &TantivyDocument, field: schema::Field) -> &str {
Expand Down Expand Up @@ -143,34 +148,6 @@ fn get_json_text_field<'a>(doc: &'a TantivyDocument, field: schema::Field, name:
.unwrap()
}

#[async_trait]
impl CodeSearch for CodeSearchImpl {
async fn search_in_language(
&self,
mut query: CodeSearchQuery,
limit: usize,
offset: usize,
) -> Result<CodeSearchResponse, CodeSearchError> {
let mut cache = self.repo_cache.lock().await;

let repos = cache
.try_get_or_set_with((), || async {
let repos = self.repository_access.list_repositories().await?;
Ok::<_, anyhow::Error>(repos)
})
.await?;

let Some(git_url) = closest_match(&query.git_url, repos.iter()) else {
return Ok(CodeSearchResponse::default());
};

query.git_url = git_url.to_owned();

let query = code::code_search_query(&query);
self.search_with_query(&query, limit, offset).await
}
}

fn closest_match<'a>(
search_term: &'a str,
search_input: impl IntoIterator<Item = &'a RepositoryConfig>,
Expand All @@ -186,28 +163,27 @@ fn closest_match<'a>(
}

struct CodeSearchService {
search: Arc<RwLock<Option<CodeSearchImpl>>>,
imp: CodeSearchImpl,
provider: Arc<IndexReaderProvider>,
}

impl CodeSearchService {
pub fn new(repository_access: Arc<dyn RepositoryAccess>) -> Self {
let search = Arc::new(RwLock::new(None));

let ret = Self {
search: search.clone(),
};

tokio::spawn(async move {
let code = CodeSearchImpl::load_async(repository_access).await;
*search.write().await = Some(code);
});

ret
pub fn new(
repository_access: Arc<dyn RepositoryAccess>,
provider: Arc<IndexReaderProvider>,
) -> Self {
Self {
imp: CodeSearchImpl::new(repository_access),
provider,
}
}
}

pub fn create_code_search(repository_access: Arc<dyn RepositoryAccess>) -> impl CodeSearch {
CodeSearchService::new(repository_access)
pub fn create_code_search(
repository_access: Arc<dyn RepositoryAccess>,
provider: Arc<IndexReaderProvider>,
) -> impl CodeSearch {
CodeSearchService::new(repository_access, provider)
}

#[async_trait]
Expand All @@ -218,8 +194,10 @@ impl CodeSearch for CodeSearchService {
limit: usize,
offset: usize,
) -> Result<CodeSearchResponse, CodeSearchError> {
if let Some(imp) = self.search.read().await.as_ref() {
imp.search_in_language(query, limit, offset).await
if let Some(reader) = self.provider.reader().await.as_ref() {
self.imp
.search_in_language(reader, query, limit, offset)
.await
} else {
Err(CodeSearchError::NotReady)
}
Expand Down
6 changes: 4 additions & 2 deletions crates/tabby/src/services/doc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ use std::sync::Arc;
use tabby_common::api::doc::DocSearch;
use tabby_inference::Embedding;

pub fn create(embedding: Arc<dyn Embedding>) -> impl DocSearch {
tantivy::DocSearchService::new(embedding)
use super::tantivy::IndexReaderProvider;

pub fn create(embedding: Arc<dyn Embedding>, provider: Arc<IndexReaderProvider>) -> impl DocSearch {
tantivy::DocSearchService::new(embedding, provider)
}

pub fn create_serper(api_key: &str) -> impl DocSearch {
Expand Down
66 changes: 16 additions & 50 deletions crates/tabby/src/services/doc/tantivy.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,33 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use tabby_common::{
api::doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse},
index::{self, doc},
path,
};
use tabby_inference::Embedding;
use tantivy::{
collector::TopDocs,
schema::{self, document::ReferenceValue, Value},
Index, IndexReader, TantivyDocument,
IndexReader, TantivyDocument,
};
use tokio::{sync::RwLock, time::sleep};
use tracing::{debug, warn};
use tracing::warn;

use crate::services::tantivy::IndexReaderProvider;

struct DocSearchImpl {
reader: IndexReader,
embedding: Arc<dyn Embedding>,
}

impl DocSearchImpl {
fn load(embedding: Arc<dyn Embedding>) -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?;

Ok(Self {
reader: index.reader_builder().try_into()?,
embedding,
})
fn new(embedding: Arc<dyn Embedding>) -> Self {
Self { embedding }
}

async fn load_async(embedding: Arc<dyn Embedding>) -> DocSearchImpl {
loop {
if let Ok(doc) = Self::load(embedding.clone()) {
debug!("Index is ready, enabling doc search...");
return doc;
}

sleep(Duration::from_secs(60)).await;
}
}
}

#[async_trait]
impl DocSearch for DocSearchImpl {
async fn search(
&self,
reader: &IndexReader,
q: &str,
limit: usize,
offset: usize,
Expand All @@ -56,7 +37,7 @@ impl DocSearch for DocSearchImpl {
let embedding_tokens_query =
index::embedding_tokens_query(embedding.len(), embedding.iter());

let searcher = self.reader.searcher();
let searcher = reader.searcher();
let top_chunks = searcher.search(
&embedding_tokens_query,
&TopDocs::with_limit(limit).and_offset(offset),
Expand Down Expand Up @@ -118,30 +99,15 @@ fn get_json_text_field<'a>(doc: &'a TantivyDocument, field: schema::Field, name:
}

pub struct DocSearchService {
search: Arc<RwLock<Option<DocSearchImpl>>>,
loader: tokio::task::JoinHandle<()>,
}

impl Drop for DocSearchService {
fn drop(&mut self) {
if !self.loader.is_finished() {
self.loader.abort();
}
}
imp: DocSearchImpl,
provider: Arc<IndexReaderProvider>,
}

impl DocSearchService {
pub fn new(embedding: Arc<dyn Embedding>) -> Self {
let search = Arc::new(RwLock::new(None));
let cloned_search = search.clone();
let loader = tokio::spawn(async move {
let doc = DocSearchImpl::load_async(embedding).await;
*cloned_search.write().await = Some(doc);
});

pub fn new(embedding: Arc<dyn Embedding>, provider: Arc<IndexReaderProvider>) -> Self {
Self {
search: search.clone(),
loader,
imp: DocSearchImpl::new(embedding),
provider,
}
}
}
Expand All @@ -154,8 +120,8 @@ impl DocSearch for DocSearchService {
limit: usize,
offset: usize,
) -> Result<DocSearchResponse, DocSearchError> {
if let Some(imp) = self.search.read().await.as_ref() {
imp.search(q, limit, offset).await
if let Some(reader) = self.provider.reader().await.as_ref() {
self.imp.search(reader, q, limit, offset).await
} else {
Err(DocSearchError::NotReady)
}
Expand Down
1 change: 1 addition & 0 deletions crates/tabby/src/services/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub mod embedding;
pub mod event;
pub mod health;
pub mod model;
pub mod tantivy;
Loading

0 comments on commit 978de8b

Please sign in to comment.