Skip to content

Commit

Permalink
fix(doc_search): skip indexing if no embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
zwpaper committed Nov 19, 2024
1 parent 8b26c4f commit d6374c0
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 6 deletions.
3 changes: 3 additions & 0 deletions crates/tabby-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use indexer::{IndexAttributeBuilder, Indexer};

mod structured_doc;

#[cfg(test)]
mod testutils;

pub mod public {
use indexer::IndexGarbageCollector;

Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-index/src/structured_doc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct StructuredDocBuilder {
}

impl StructuredDocBuilder {
fn new(embedding: Arc<dyn Embedding>) -> Self {
pub fn new(embedding: Arc<dyn Embedding>) -> Self {
Self { embedding }
}
}
Expand Down
25 changes: 21 additions & 4 deletions crates/tabby-index/src/structured_doc/public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use chrono::{DateTime, Utc};
use futures::StreamExt;
use tabby_common::index::corpus;
use tabby_inference::Embedding;
use tantivy::TantivyDocument;

pub use super::types::{
issue::IssueDocument as StructuredDocIssueFields, web::WebDocument as StructuredDocWebFields,
Expand Down Expand Up @@ -34,16 +35,32 @@ impl StructuredDocIndexer {
return false;
};

stream! {
let docs: Vec<TantivyDocument> = stream! {
let (id, s) = self.builder.build(document).await;
self.indexer.delete(&id);

for await doc in s.buffer_unordered(std::cmp::max(std::thread::available_parallelism().unwrap().get() * 2, 32)) {
for await doc in s.buffer_unordered(std::cmp::max(
std::thread::available_parallelism().unwrap().get() * 2,
32,
)) {
if let Ok(Some(doc)) = doc {
self.indexer.add(doc).await;
yield doc
}
}
}.count().await;
}
.collect()
.await;

// If there is only one document,
// it means that only the doc is returned and not the chunks
// skip it
if docs.len() == 1 {
return false;
}

for doc in docs.iter() {
self.indexer.add(doc.clone()).await;
}
true
}

Expand Down
174 changes: 174 additions & 0 deletions crates/tabby-index/src/testutils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
mod structured_doc_tests {
use crate::indexer::Indexer;
use crate::structured_doc::public::{
StructuredDoc, StructuredDocFields, StructuredDocIndexer, StructuredDocIssueFields,
};
use std::env;
use std::sync::Arc;
use tabby_common::index::corpus;
use tabby_inference::MockEmbedding;
use temp_testdir::TempDir;

#[test]
fn test_structured_doc_empty_embedding() {
let temp_dir = TempDir::default();
env::set_var("TABBY_ROOT", temp_dir.as_ref());

let embedding = MockEmbedding::new(vec![]);
let embedding = Arc::new(embedding);
let indexer = StructuredDocIndexer::new(embedding.clone());
let doc = StructuredDoc {
source_id: "source".to_owned(),
fields: StructuredDocFields::Issue(StructuredDocIssueFields {
link: "empty_embedding".to_owned(),
title: "title".to_owned(),
body: "body".to_owned(),
closed: false,
}),
};

let updated_at = chrono::Utc::now();
let res = tokio::runtime::Runtime::new()
.unwrap()
.block_on(async { indexer.add(updated_at, doc).await });
assert!(!res);
indexer.commit();

let validator = Indexer::new(corpus::STRUCTURED_DOC);
assert!(!validator.is_indexed("empty_embedding"));

env::remove_var("TABBY_ROOT");
}

#[test]
fn test_structured_doc_with_embedding() {
let temp_dir = TempDir::default();
env::set_var("TABBY_ROOT", temp_dir.as_ref());

let embedding = MockEmbedding::new(vec![1.0]);
let embedding = Arc::new(embedding);
let indexer = StructuredDocIndexer::new(embedding.clone());
let doc = StructuredDoc {
source_id: "source".to_owned(),
fields: StructuredDocFields::Issue(StructuredDocIssueFields {
link: "with_embedding".to_owned(),
title: "title".to_owned(),
body: "body".to_owned(),
closed: false,
}),
};

let updated_at = chrono::Utc::now();
let res = tokio::runtime::Runtime::new()
.unwrap()
.block_on(async { indexer.add(updated_at, doc).await });
assert!(res);
indexer.commit();

let validator = Indexer::new(corpus::STRUCTURED_DOC);
assert!(validator.is_indexed("with_embedding"));

env::remove_var("TABBY_ROOT");
}
}

mod indexer_tests {
use crate::indexer::TantivyDocBuilder;
use crate::structured_doc::{
public::{StructuredDoc, StructuredDocFields, StructuredDocIssueFields},
StructuredDocBuilder,
};
use futures::StreamExt;
use std::env;
use std::sync::Arc;
use tabby_common::index::corpus;
use tabby_inference::MockEmbedding;
use temp_testdir::TempDir;

/// Test that the indexer return none when the embedding is empty
/// meaning that the chunk is not saved at tantivy
#[test]
fn test_indexer_empty_embedding() {
let temp_dir = TempDir::default();
env::set_var("TABBY_ROOT", temp_dir.as_ref());

let embedding = MockEmbedding::new(vec![]);
let builder = StructuredDocBuilder::new(Arc::new(embedding));
let indexer = TantivyDocBuilder::new(corpus::STRUCTURED_DOC, builder);

let doc = StructuredDoc {
source_id: "source".to_owned(),
fields: StructuredDocFields::Issue(StructuredDocIssueFields {
link: "empty_embedding".to_owned(),
title: "title".to_owned(),
body: "body".to_owned(),
closed: false,
}),
};

let (id, s) = tokio::runtime::Runtime::new()
.unwrap()
.block_on(async { indexer.build(doc).await });
assert_eq!(id, "empty_embedding");

let res = tokio::runtime::Runtime::new().unwrap().block_on(async {
s.buffer_unordered(std::cmp::max(
std::thread::available_parallelism().unwrap().get() * 2,
32,
))
.collect::<Vec<_>>()
.await
});

// the first element is the document itself
assert_eq!(res.len(), 2);
// the second element is the chunk,
// which is empty as the MockEmbedding returns empty
assert!(res[1].is_ok());
assert!(res[1].as_ref().unwrap().is_none());

env::remove_var("TABBY_ROOT");
}

#[test]
fn test_indexer_with_embedding() {
let temp_dir = TempDir::default();
env::set_var("TABBY_ROOT", temp_dir.as_ref());

let embedding = MockEmbedding::new(vec![1.0, 2.0]);
let builder = StructuredDocBuilder::new(Arc::new(embedding));
let indexer = TantivyDocBuilder::new(corpus::STRUCTURED_DOC, builder);

let doc = StructuredDoc {
source_id: "source".to_owned(),
fields: StructuredDocFields::Issue(StructuredDocIssueFields {
link: "with_embedding".to_owned(),
title: "title".to_owned(),
body: "body".to_owned(),
closed: false,
}),
};

let (id, s) = tokio::runtime::Runtime::new()
.unwrap()
.block_on(async { indexer.build(doc).await });

assert_eq!(id, "with_embedding");

let res = tokio::runtime::Runtime::new().unwrap().block_on(async {
s.buffer_unordered(std::cmp::max(
std::thread::available_parallelism().unwrap().get() * 2,
32,
))
.collect::<Vec<_>>()
.await
});

// the first element is the document itself
assert_eq!(res.len(), 2);
assert!(res[1].is_ok());
assert!(res[1].as_ref().unwrap().is_some());

env::remove_var("TABBY_ROOT");
}
}
26 changes: 26 additions & 0 deletions crates/tabby-inference/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,29 @@ use async_trait::async_trait;
pub trait Embedding: Sync + Send {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>>;
}

pub mod tests {
use super::*;
use anyhow::Result;

pub struct MockEmbedding {
result: Vec<f32>,
}

impl MockEmbedding {
pub fn new(result: Vec<f32>) -> Self {
Self { result }
}
}

#[async_trait]
impl Embedding for MockEmbedding {
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
if prompt.starts_with("error") {
Err(anyhow::anyhow!(prompt.to_owned()))
} else {
Ok(self.result.clone())
}
}
}
}
2 changes: 1 addition & 1 deletion crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod embedding;
pub use chat::{ChatCompletionStream, ExtendedOpenAIConfig};
pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream};
pub use embedding::Embedding;
pub use embedding::{tests::MockEmbedding, Embedding};

fn default_seed() -> u64 {
std::time::SystemTime::now()
Expand Down

0 comments on commit d6374c0

Please sign in to comment.