Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
a5huynh committed Nov 18, 2024
2 parents ec72561 + 4cea662 commit 451e194
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 66 deletions.
1 change: 1 addition & 0 deletions crates/entities/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod resource_rule;
pub mod schema;
pub mod tag;
pub mod vec_documents;
pub mod vec_to_indexed;

use shared::config::Config;

Expand Down
56 changes: 50 additions & 6 deletions crates/entities/src/models/vec_documents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ where
db.execute(statement).await
}

pub async fn delete_embedding_by_ids<C>(db: &C, ids: &[i64]) -> Result<ExecResult, DbErr>
where
C: ConnectionTrait,
{
let st = format!(
r#"
delete from vec_documents where rowid in ({})
"#,
ids.iter()
.map(|id| format!("{}", id))
.collect::<Vec<String>>()
.join(",")
);
let statement = Statement::from_string(db.get_database_backend(), st);

db.execute(statement).await
}

pub async fn delete_embeddings_by_url<C>(db: &C, urls: &[String]) -> Result<ExecResult, DbErr>
where
C: ConnectionTrait,
Expand Down Expand Up @@ -127,20 +145,46 @@ where
Statement::from_sql_and_values(
db.get_database_backend(),
r#"
SELECT vec_documents.rowid as id, vec_documents.distance, indexed_document.doc_id FROM vec_documents
left join indexed_document on indexed_document.id = vec_documents.rowid
WITH RankedScores AS (
SELECT
indexed_document.id AS score_id,
vd.distance,
indexed_document.doc_id,
ROW_NUMBER() OVER (PARTITION BY indexed_document.doc_id ORDER BY vd.distance ASC) AS rank
FROM
vec_documents vd
left JOIN
vec_to_indexed vti
ON vd.rowid = vti.id
left JOIN indexed_document
ON vti.indexed_id = indexed_document.id
left join document_tag on document_tag.indexed_document_id = indexed_document.id
WHERE document_tag.id in $1 AND vec_documents.embedding MATCH $2 AND k = 10 ORDER BY vec_documents.distance ASC limit 20;
WHERE document_tag.id in $1 AND vd.embedding MATCH $2 AND k = 25 ORDER BY vd.distance ASC
)
SELECT score_id as id, distance, doc_id FROM RankedScores WHERE rank = 1 ORDER BY distance ASC limit 10;
"#,
vec![lens_ids.to_owned().into(), embedding_string.into()],
)
} else {
Statement::from_sql_and_values(
db.get_database_backend(),
r#"
SELECT vec_documents.rowid as id, vec_documents.distance, indexed_document.doc_id FROM vec_documents
left join indexed_document on indexed_document.id = vec_documents.rowid
WHERE vec_documents.embedding MATCH $1 AND k = 10 ORDER BY vec_documents.distance ASC limit 20;
WITH RankedScores AS (
SELECT
indexed_document.id AS score_id,
vd.distance,
indexed_document.doc_id,
ROW_NUMBER() OVER (PARTITION BY indexed_document.doc_id ORDER BY vd.distance ASC) AS rank
FROM
vec_documents vd
left JOIN
vec_to_indexed vti
ON vd.rowid = vti.id
left JOIN indexed_document
ON vti.indexed_id = indexed_document.id
WHERE vd.embedding MATCH $1 AND k = 25 ORDER BY vd.distance ASC
)
SELECT score_id as id, distance, doc_id FROM RankedScores WHERE rank = 1 ORDER BY distance ASC limit 10;
"#,
vec![embedding_string.into()],
)
Expand Down
96 changes: 96 additions & 0 deletions crates/entities/src/models/vec_to_indexed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use sea_orm::{entity::prelude::*, InsertResult, Set};
use serde::Serialize;

use super::{indexed_document, vec_documents};

#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Eq)]
#[sea_orm(table_name = "vec_to_indexed")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i64,
pub indexed_id: i64,
/// When this was first added to the crawl queue.
pub created_at: DateTimeUtc,
/// When this task was last updated.
pub updated_at: DateTimeUtc,
}

#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {
IndexedId,
}

impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::IndexedId => Entity::belongs_to(super::indexed_document::Entity)
.from(Column::IndexedId)
.to(super::indexed_document::Column::Id)
.into(),
}
}
}

#[async_trait::async_trait]
impl ActiveModelBehavior for ActiveModel {
fn new() -> Self {
Self {
created_at: Set(chrono::Utc::now()),
updated_at: Set(chrono::Utc::now()),
..ActiveModelTrait::default()
}
}

// Triggered before insert / update
async fn before_save<C>(mut self, _db: &C, _insert: bool) -> Result<Self, DbErr>
where
C: ConnectionTrait,
{
Ok(self)
}
}

pub async fn insert_embedding_mapping(
db: &DatabaseConnection,
indexed_id: i64,
) -> Result<InsertResult<ActiveModel>, DbErr> {
let mut active_model = ActiveModel::new();
active_model.indexed_id = Set(indexed_id);

Entity::insert(active_model).exec(db).await
}

pub async fn delete_all_for_document(
db: &DatabaseConnection,
indexed_id: i64,
) -> Result<(), DbErr> {
let documents = Entity::find()
.filter(Column::IndexedId.eq(indexed_id))
.all(db)
.await?;

if !documents.is_empty() {
let ids = documents.iter().map(|val| val.id).collect::<Vec<i64>>();
let _ = vec_documents::delete_embedding_by_ids(db, &ids).await?;

let _ = Entity::delete_many()
.filter(Column::Id.is_in(ids))
.exec(db)
.await;
Ok(())
} else {
Ok(())
}
}

pub async fn delete_all_by_urls(db: &DatabaseConnection, urls: &[String]) -> Result<(), DbErr> {
let documents = indexed_document::Entity::find()
.filter(indexed_document::Column::Url.is_in(urls))
.all(db)
.await?;

for doc in documents {
delete_all_for_document(db, doc.id).await?;
}
Ok(())
}
2 changes: 2 additions & 0 deletions crates/migrations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod m20230220_000001_remove_legacy_plugins;
mod m20230315_000001_migrate_search_schema;
mod m20241029_000001_add_vector;
mod m20241105_000001_add_embeddings_table;
mod m20241115_000001_embedding_to_indexed_document;
mod utils;

pub struct Migrator;
Expand Down Expand Up @@ -69,6 +70,7 @@ impl MigratorTrait for Migrator {
Box::new(m20230315_000001_migrate_search_schema::Migration),
Box::new(m20241029_000001_add_vector::Migration),
Box::new(m20241105_000001_add_embeddings_table::Migration),
Box::new(m20241115_000001_embedding_to_indexed_document::Migration),
]
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use sea_orm_migration::prelude::*;

#[derive(DeriveMigrationName)]
pub struct Migration;

#[derive(Iden)]
enum VecToIndexed {
#[iden = "vec_to_indexed"]
Table,
Id,
IndexedId,
CreatedAt,
UpdatedAt,
}

#[derive(Iden)]
enum IndexedDocument {
#[iden = "indexed_document"]
Table,
Id,
}

#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(VecToIndexed::Table)
.if_not_exists()
.foreign_key(
ForeignKey::create()
.name("fk-vec_to_indexed-indexed_document")
.from(VecToIndexed::Table, VecToIndexed::IndexedId)
.to(IndexedDocument::Table, IndexedDocument::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::NoAction),
)
.col(
ColumnDef::new(VecToIndexed::Id)
.big_integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(VecToIndexed::IndexedId)
.big_integer()
.not_null(),
)
.col(
ColumnDef::new(VecToIndexed::CreatedAt)
.timestamp_with_time_zone()
.not_null(),
)
.col(
ColumnDef::new(VecToIndexed::UpdatedAt)
.timestamp_with_time_zone()
.not_null(),
)
.to_owned(),
)
.await?;

Ok(())
}

async fn down(&self, _manager: &SchemaManager) -> Result<(), DbErr> {
Ok(())
}
}
59 changes: 52 additions & 7 deletions crates/spyglass-model-interface/src/embedding_api.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, sync::Arc, time::Instant};

use tokenizers::Tokenizer;
use tokenizers::{Encoding, Tokenizer};

use crate::{batch, load_tokenizer, Backend, CandleBackend, Embedding, ModelType, Pool};

Expand Down Expand Up @@ -36,27 +36,71 @@ impl EmbeddingApi {
&self,
content: &str,
content_type: EmbeddingContentType,
) -> anyhow::Result<Vec<f32>> {
) -> anyhow::Result<Vec<Vec<f32>>> {
// TODO need to properly segment the data
let doc_content = match content_type {
EmbeddingContentType::Document => {
format!("search_document: {}", content)
format!("search_document: {}", content.trim())
}
EmbeddingContentType::Query => {
format!("search_query: {}", content)
format!("search_query: {}", content.trim())
}
};

let mut tokens = self
let tokens = self
.tokenizer
.encode(doc_content, false)
.map_err(|err| anyhow::format_err!("Error tokenizing {:?}", err))?;
let token_length = tokens.len();
let mut content_chunks = Vec::new();
if token_length > MAX_TOKENS {
tokens.truncate(MAX_TOKENS, 1, tokenizers::TruncationDirection::Right);
let segment_count = token_length.div_ceil(MAX_TOKENS);
let char_per_segment = content.len().div_euclid(segment_count);

let chunks: Vec<String> = content
.trim()
.chars()
.collect::<Vec<char>>()
.chunks(char_per_segment)
.map(|chunk| chunk.iter().collect::<String>())
.collect();

log::debug!(
"Splitting text into chunks of {} chars long",
char_per_segment
);
for chunk in chunks {
let doc_content = match content_type {
EmbeddingContentType::Document => {
format!("search_document: {}", chunk)
}
EmbeddingContentType::Query => {
format!("search_query: {}", chunk)
}
};
let tokens = self
.tokenizer
.encode(doc_content, false)
.map_err(|err| anyhow::format_err!("Error tokenizing {:?}", err))?;
log::trace!("Chunk was {} tokens long", tokens.len());
content_chunks.push(tokens);
}
} else {
content_chunks.push(tokens);
}

let mut embeddings = Vec::new();
for chunk in content_chunks {
let embedding = self.embed_tokens(chunk.to_owned())?;
embeddings.push(embedding);
}
let input_batch = batch(vec![tokens], [0].to_vec(), vec![]);

Ok(embeddings)
}

pub fn embed_tokens(&self, tokens: Encoding) -> anyhow::Result<Vec<f32>> {
let token_length = tokens.len();
let input_batch = batch(vec![tokens], [0].to_vec(), vec![]);
let start = Instant::now();

match self.backend.embed(input_batch) {
Expand All @@ -66,6 +110,7 @@ impl EmbeddingApi {
token_length,
start.elapsed().as_millis()
);

if let Some(Embedding::Pooled(embedding)) = embed.get(&0) {
Ok(embedding.to_owned())
} else {
Expand Down
10 changes: 8 additions & 2 deletions crates/spyglass/src/api/handler/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ pub async fn search_docs(
}

if let Some(embedding_api) = state.embedding_api.load_full().as_ref() {
match embedding_api.embed(&query, EmbeddingContentType::Query) {
Ok(embedding) => {
match embedding_api
.embed(&query, EmbeddingContentType::Query)
.map(|embedding| embedding.first().map(|val| val.to_owned()))
{
Ok(Some(embedding)) => {
let mut distances =
vec_documents::get_document_distance(&state.db, &lens_ids, &embedding).await;

Expand Down Expand Up @@ -96,6 +99,9 @@ pub async fn search_docs(
}
}
}
Ok(None) => {
log::error!("No embedding could be generated");
}
Err(err) => {
log::error!("Error embedding query {:?}", err);
}
Expand Down
Loading

0 comments on commit 451e194

Please sign in to comment.