diff --git a/crates/entities/src/models/mod.rs b/crates/entities/src/models/mod.rs index 244beff31..30ef368f2 100644 --- a/crates/entities/src/models/mod.rs +++ b/crates/entities/src/models/mod.rs @@ -16,6 +16,7 @@ pub mod resource_rule; pub mod schema; pub mod tag; pub mod vec_documents; +pub mod vec_to_indexed; use shared::config::Config; diff --git a/crates/entities/src/models/vec_documents.rs b/crates/entities/src/models/vec_documents.rs index daf0ef55b..004e4bcd5 100644 --- a/crates/entities/src/models/vec_documents.rs +++ b/crates/entities/src/models/vec_documents.rs @@ -73,6 +73,24 @@ where db.execute(statement).await } +pub async fn delete_embedding_by_ids(db: &C, ids: &[i64]) -> Result +where + C: ConnectionTrait, +{ + let st = format!( + r#" + delete from vec_documents where rowid in ({}) + "#, + ids.iter() + .map(|id| format!("{}", id)) + .collect::>() + .join(",") + ); + let statement = Statement::from_string(db.get_database_backend(), st); + + db.execute(statement).await +} + pub async fn delete_embeddings_by_url(db: &C, urls: &[String]) -> Result where C: ConnectionTrait, @@ -127,10 +145,23 @@ where Statement::from_sql_and_values( db.get_database_backend(), r#" - SELECT vec_documents.rowid as id, vec_documents.distance, indexed_document.doc_id FROM vec_documents - left join indexed_document on indexed_document.id = vec_documents.rowid + WITH RankedScores AS ( + SELECT + indexed_document.id AS score_id, + vd.distance, + indexed_document.doc_id, + ROW_NUMBER() OVER (PARTITION BY indexed_document.doc_id ORDER BY vd.distance ASC) AS rank + FROM + vec_documents vd + left JOIN + vec_to_indexed vti + ON vd.rowid = vti.id + left JOIN indexed_document + ON vti.indexed_id = indexed_document.id left join document_tag on document_tag.indexed_document_id = indexed_document.id - WHERE document_tag.id in $1 AND vec_documents.embedding MATCH $2 AND k = 10 ORDER BY vec_documents.distance ASC limit 20; + WHERE document_tag.id in $1 AND vd.embedding MATCH $2 AND k = 25 ORDER BY vd.distance ASC + ) + SELECT score_id as id, distance, doc_id FROM RankedScores WHERE rank = 1 ORDER BY distance ASC limit 10; "#, vec![lens_ids.to_owned().into(), embedding_string.into()], ) @@ -138,9 +169,22 @@ where Statement::from_sql_and_values( db.get_database_backend(), r#" - SELECT vec_documents.rowid as id, vec_documents.distance, indexed_document.doc_id FROM vec_documents - left join indexed_document on indexed_document.id = vec_documents.rowid - WHERE vec_documents.embedding MATCH $1 AND k = 10 ORDER BY vec_documents.distance ASC limit 20; + WITH RankedScores AS ( + SELECT + indexed_document.id AS score_id, + vd.distance, + indexed_document.doc_id, + ROW_NUMBER() OVER (PARTITION BY indexed_document.doc_id ORDER BY vd.distance ASC) AS rank + FROM + vec_documents vd + left JOIN + vec_to_indexed vti + ON vd.rowid = vti.id + left JOIN indexed_document + ON vti.indexed_id = indexed_document.id + WHERE vd.embedding MATCH $1 AND k = 25 ORDER BY vd.distance ASC + ) + SELECT score_id as id, distance, doc_id FROM RankedScores WHERE rank = 1 ORDER BY distance ASC limit 10; "#, vec![embedding_string.into()], ) diff --git a/crates/entities/src/models/vec_to_indexed.rs b/crates/entities/src/models/vec_to_indexed.rs new file mode 100644 index 000000000..fb7809a8a --- /dev/null +++ b/crates/entities/src/models/vec_to_indexed.rs @@ -0,0 +1,96 @@ +use sea_orm::{entity::prelude::*, InsertResult, Set}; +use serde::Serialize; + +use super::{indexed_document, vec_documents}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Eq)] +#[sea_orm(table_name = "vec_to_indexed")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i64, + pub indexed_id: i64, + /// When this was first added to the crawl queue. + pub created_at: DateTimeUtc, + /// When this task was last updated. + pub updated_at: DateTimeUtc, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation { + IndexedId, +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + match self { + Self::IndexedId => Entity::belongs_to(super::indexed_document::Entity) + .from(Column::IndexedId) + .to(super::indexed_document::Column::Id) + .into(), + } + } +} + +#[async_trait::async_trait] +impl ActiveModelBehavior for ActiveModel { + fn new() -> Self { + Self { + created_at: Set(chrono::Utc::now()), + updated_at: Set(chrono::Utc::now()), + ..ActiveModelTrait::default() + } + } + + // Triggered before insert / update + async fn before_save(mut self, _db: &C, _insert: bool) -> Result + where + C: ConnectionTrait, + { + Ok(self) + } +} + +pub async fn insert_embedding_mapping( + db: &DatabaseConnection, + indexed_id: i64, +) -> Result, DbErr> { + let mut active_model = ActiveModel::new(); + active_model.indexed_id = Set(indexed_id); + + Entity::insert(active_model).exec(db).await +} + +pub async fn delete_all_for_document( + db: &DatabaseConnection, + indexed_id: i64, +) -> Result<(), DbErr> { + let documents = Entity::find() + .filter(Column::IndexedId.eq(indexed_id)) + .all(db) + .await?; + + if !documents.is_empty() { + let ids = documents.iter().map(|val| val.id).collect::>(); + let _ = vec_documents::delete_embedding_by_ids(db, &ids).await?; + + let _ = Entity::delete_many() + .filter(Column::Id.is_in(ids)) + .exec(db) + .await; + Ok(()) + } else { + Ok(()) + } +} + +pub async fn delete_all_by_urls(db: &DatabaseConnection, urls: &[String]) -> Result<(), DbErr> { + let documents = indexed_document::Entity::find() + .filter(indexed_document::Column::Url.is_in(urls)) + .all(db) + .await?; + + for doc in documents { + delete_all_for_document(db, doc.id).await?; + } + Ok(()) +} diff --git a/crates/migrations/src/lib.rs b/crates/migrations/src/lib.rs index e5c10057f..04bc3aa63 100644 --- a/crates/migrations/src/lib.rs +++ b/crates/migrations/src/lib.rs @@ -32,6 +32,7 @@ mod m20230220_000001_remove_legacy_plugins; mod m20230315_000001_migrate_search_schema; mod m20241029_000001_add_vector; mod m20241105_000001_add_embeddings_table; +mod m20241115_000001_embedding_to_indexed_document; mod utils; pub struct Migrator; @@ -69,6 +70,7 @@ impl MigratorTrait for Migrator { Box::new(m20230315_000001_migrate_search_schema::Migration), Box::new(m20241029_000001_add_vector::Migration), Box::new(m20241105_000001_add_embeddings_table::Migration), + Box::new(m20241115_000001_embedding_to_indexed_document::Migration), ] } } diff --git a/crates/migrations/src/m20241115_000001_embedding_to_indexed_document.rs b/crates/migrations/src/m20241115_000001_embedding_to_indexed_document.rs new file mode 100644 index 000000000..b112b3cfc --- /dev/null +++ b/crates/migrations/src/m20241115_000001_embedding_to_indexed_document.rs @@ -0,0 +1,71 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[derive(Iden)] +enum VecToIndexed { + #[iden = "vec_to_indexed"] + Table, + Id, + IndexedId, + CreatedAt, + UpdatedAt, +} + +#[derive(Iden)] +enum IndexedDocument { + #[iden = "indexed_document"] + Table, + Id, +} + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(VecToIndexed::Table) + .if_not_exists() + .foreign_key( + ForeignKey::create() + .name("fk-vec_to_indexed-indexed_document") + .from(VecToIndexed::Table, VecToIndexed::IndexedId) + .to(IndexedDocument::Table, IndexedDocument::Id) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::NoAction), + ) + .col( + ColumnDef::new(VecToIndexed::Id) + .big_integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(VecToIndexed::IndexedId) + .big_integer() + .not_null(), + ) + .col( + ColumnDef::new(VecToIndexed::CreatedAt) + .timestamp_with_time_zone() + .not_null(), + ) + .col( + ColumnDef::new(VecToIndexed::UpdatedAt) + .timestamp_with_time_zone() + .not_null(), + ) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, _manager: &SchemaManager) -> Result<(), DbErr> { + Ok(()) + } +} diff --git a/crates/spyglass-model-interface/src/embedding_api.rs b/crates/spyglass-model-interface/src/embedding_api.rs index 1790713af..0d9d2aa79 100644 --- a/crates/spyglass-model-interface/src/embedding_api.rs +++ b/crates/spyglass-model-interface/src/embedding_api.rs @@ -1,6 +1,6 @@ use std::{path::PathBuf, sync::Arc, time::Instant}; -use tokenizers::Tokenizer; +use tokenizers::{Encoding, Tokenizer}; use crate::{batch, load_tokenizer, Backend, CandleBackend, Embedding, ModelType, Pool}; @@ -36,27 +36,71 @@ impl EmbeddingApi { &self, content: &str, content_type: EmbeddingContentType, - ) -> anyhow::Result> { + ) -> anyhow::Result>> { // TODO need to properly segment the data let doc_content = match content_type { EmbeddingContentType::Document => { - format!("search_document: {}", content) + format!("search_document: {}", content.trim()) } EmbeddingContentType::Query => { - format!("search_query: {}", content) + format!("search_query: {}", content.trim()) } }; - let mut tokens = self + let tokens = self .tokenizer .encode(doc_content, false) .map_err(|err| anyhow::format_err!("Error tokenizing {:?}", err))?; let token_length = tokens.len(); + let mut content_chunks = Vec::new(); if token_length > MAX_TOKENS { - tokens.truncate(MAX_TOKENS, 1, tokenizers::TruncationDirection::Right); + let segment_count = token_length.div_ceil(MAX_TOKENS); + let char_per_segment = content.len().div_euclid(segment_count); + + let chunks: Vec = content + .trim() + .chars() + .collect::>() + .chunks(char_per_segment) + .map(|chunk| chunk.iter().collect::()) + .collect(); + + log::debug!( + "Splitting text into chunks of {} chars long", + char_per_segment + ); + for chunk in chunks { + let doc_content = match content_type { + EmbeddingContentType::Document => { + format!("search_document: {}", chunk) + } + EmbeddingContentType::Query => { + format!("search_query: {}", chunk) + } + }; + let tokens = self + .tokenizer + .encode(doc_content, false) + .map_err(|err| anyhow::format_err!("Error tokenizing {:?}", err))?; + log::trace!("Chunk was {} tokens long", tokens.len()); + content_chunks.push(tokens); + } + } else { + content_chunks.push(tokens); + } + + let mut embeddings = Vec::new(); + for chunk in content_chunks { + let embedding = self.embed_tokens(chunk.to_owned())?; + embeddings.push(embedding); } - let input_batch = batch(vec![tokens], [0].to_vec(), vec![]); + Ok(embeddings) + } + + pub fn embed_tokens(&self, tokens: Encoding) -> anyhow::Result> { + let token_length = tokens.len(); + let input_batch = batch(vec![tokens], [0].to_vec(), vec![]); let start = Instant::now(); match self.backend.embed(input_batch) { @@ -66,6 +110,7 @@ impl EmbeddingApi { token_length, start.elapsed().as_millis() ); + if let Some(Embedding::Pooled(embedding)) = embed.get(&0) { Ok(embedding.to_owned()) } else { diff --git a/crates/spyglass/src/api/handler/search.rs b/crates/spyglass/src/api/handler/search.rs index 4c3568aac..597734ecd 100644 --- a/crates/spyglass/src/api/handler/search.rs +++ b/crates/spyglass/src/api/handler/search.rs @@ -63,8 +63,11 @@ pub async fn search_docs( } if let Some(embedding_api) = state.embedding_api.load_full().as_ref() { - match embedding_api.embed(&query, EmbeddingContentType::Query) { - Ok(embedding) => { + match embedding_api + .embed(&query, EmbeddingContentType::Query) + .map(|embedding| embedding.first().map(|val| val.to_owned())) + { + Ok(Some(embedding)) => { let mut distances = vec_documents::get_document_distance(&state.db, &lens_ids, &embedding).await; @@ -96,6 +99,9 @@ pub async fn search_docs( } } } + Ok(None) => { + log::error!("No embedding could be generated"); + } Err(err) => { log::error!("Error embedding query {:?}", err); } diff --git a/crates/spyglass/src/documents/embeddings.rs b/crates/spyglass/src/documents/embeddings.rs index 749d43952..7d7fac4ba 100644 --- a/crates/spyglass/src/documents/embeddings.rs +++ b/crates/spyglass/src/documents/embeddings.rs @@ -1,5 +1,5 @@ use entities::{ - models::{embedding_queue, vec_documents}, + models::{embedding_queue, vec_documents, vec_to_indexed}, sea_orm::EntityTrait, }; use spyglass_model_interface::embedding_api::EmbeddingContentType; @@ -15,37 +15,37 @@ pub async fn processing_embedding(state: AppState, job_id: i64) { .one(&state.db) .await { - Ok(Some(job)) => { - match job.content { - Some(content) => { - let embedding = if let Some(api) = state.embedding_api.load_full().as_ref() { - api.embed(&content, EmbeddingContentType::Document) - } else { - Err(anyhow::format_err!( - "Embedding Model is not properly configured" - )) - }; - match embedding { - Ok(embedding) => { - match vec_documents::insert_embedding( + Ok(Some(job)) => match job.content { + Some(content) => { + let embeddings = if let Some(api) = state.embedding_api.load_full().as_ref() { + api.embed(&content, EmbeddingContentType::Document) + } else { + Err(anyhow::format_err!( + "Embedding Model is not properly configured" + )) + }; + match embeddings { + Ok(embeddings) => { + if let Err(error) = vec_to_indexed::delete_all_for_document( + &state.db, + job.indexed_document_id, + ) + .await + { + log::error!("Error deleting document vectors {:?}", error); + } + + for embedding in embeddings { + match vec_to_indexed::insert_embedding_mapping( &state.db, job.indexed_document_id, - &embedding, ) .await { - Ok(_) => { - let _ = embedding_queue::mark_done(&state.db, job_id).await; - } - Err(insert_error) => { - // The virtual table does not support on conflict so we try to - // insert first then update. - match vec_documents::update_embedding( - &state.db, - job.indexed_document_id, - &embedding, - ) - .await + Ok(insert_result) => { + let id: i64 = insert_result.last_insert_id; + match vec_documents::insert_embedding(&state.db, id, &embedding) + .await { Ok(_) => { let _ = @@ -56,39 +56,42 @@ pub async fn processing_embedding(state: AppState, job_id: i64) { &state.db, job_id, Some(format!( - "Error storing embedding for {}. Error {:?} and {:?}", - job.document_id, insert_error, error + "Error storing embedding for {}. Error {:?}", + job.document_id, error )), ) .await; } } } + Err(error) => { + log::error!("Error inserting mapping {:?}", error); + } } } - Err(error) => { - let _ = embedding_queue::mark_failed( - &state.db, - job_id, - Some(format!( - "Error generating embedding for {}. Error {:?}", - job.document_id, error - )), - ) - .await; - } } - } - None => { - let _ = embedding_queue::mark_failed( - &state.db, - job_id, - Some(format!("No content found for document {}", job.document_id)), - ) - .await; + Err(error) => { + let _ = embedding_queue::mark_failed( + &state.db, + job_id, + Some(format!( + "Error generating embedding for {}. Error {:?}", + job.document_id, error + )), + ) + .await; + } } } - } + None => { + let _ = embedding_queue::mark_failed( + &state.db, + job_id, + Some(format!("No content found for document {}", job.document_id)), + ) + .await; + } + }, Ok(None) => { let _ = embedding_queue::mark_failed( &state.db, diff --git a/crates/spyglass/src/documents/mod.rs b/crates/spyglass/src/documents/mod.rs index a594099be..a92efc47b 100644 --- a/crates/spyglass/src/documents/mod.rs +++ b/crates/spyglass/src/documents/mod.rs @@ -4,7 +4,7 @@ use entities::{ crawl_queue, embedding_queue, indexed_document::{self, find_by_doc_ids}, tag::{self, TagPair}, - vec_documents, + vec_to_indexed, }, sea_orm::{ActiveModelTrait, DatabaseConnection, TryIntoModel}, BATCH_SIZE, @@ -87,7 +87,7 @@ pub async fn delete_documents_by_uri(state: &AppState, uri: Vec) { } // delete their embeddings from the database - if let Err(error) = vec_documents::delete_embeddings_by_url(&state.db, chunk).await { + if let Err(error) = vec_to_indexed::delete_all_by_urls(&state.db, chunk).await { log::warn!("Error deleting document embeddings {:?}", error); }