Skip to content

Commit

Permalink
Vec documents updates (#560)
Browse files Browse the repository at this point in the history
* change score_id to id

* Add access to context for a document
---------

Co-authored-by: travolin <[email protected]>
  • Loading branch information
travolin and travolin authored Nov 26, 2024
1 parent f3f0762 commit 69e5d7e
Showing 1 changed file with 77 additions and 2 deletions.
79 changes: 77 additions & 2 deletions crates/entities/src/models/vec_documents.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use sea_orm::{ConnectionTrait, DbErr, ExecResult, FromQueryResult, Statement};
use sea_orm::{
ColumnTrait, ConnectionTrait, DbErr, EntityTrait, ExecResult, FromQueryResult, QueryFilter,
Statement,
};

use super::vec_to_indexed;

pub async fn insert_embedding<C>(db: &C, id: i64, embedding: &[f32]) -> Result<ExecResult, DbErr>
where
Expand Down Expand Up @@ -270,7 +275,7 @@ where
db.get_database_backend(),
r#"
SELECT
indexed_document.id AS score_id,
indexed_document.id AS id,
vd.distance,
indexed_document.doc_id,
indexed_document.url,
Expand Down Expand Up @@ -298,3 +303,73 @@ where
err
})
}

pub async fn get_context_for_doc<C>(
db: &C,
document_id: i64,
embedding: &[f32],
) -> Result<Vec<DocDistance>, DbErr>
where
C: ConnectionTrait,
{
let embedding_string = serde_json::to_string(embedding)
.map_err(|err| {
log::error!("Error {:?}", err);
err
})
.unwrap();

let indexed = vec_to_indexed::Entity::find()
.filter(vec_to_indexed::Column::IndexedId.eq(document_id))
.all(db)
.await;

let indexed_vectors = indexed
.map(|indexed| {
indexed
.iter()
.map(|val| val.id.to_string())
.collect::<Vec<String>>()
.join(",")
})
.unwrap_or("".to_string());

let query = format!(
r#"
SELECT
indexed_document.id AS id,
vec_distance_L2(vd.embedding, $1) as distance,
indexed_document.doc_id,
vti.segment_start,
vti.segment_end,
indexed_document.url
FROM
vec_documents vd
left JOIN
vec_to_indexed vti
ON vd.rowid = vti.id
left JOIN indexed_document
ON vti.indexed_id = indexed_document.id
WHERE vd.rowid in ({}) ORDER BY vd.distance ASC
"#,
indexed_vectors
);

let statement = Statement::from_sql_and_values(
db.get_database_backend(),
query,
vec![embedding_string.into()],
);

DocDistance::find_by_statement(statement)
.all(db)
.await
.map_err(|err| {
log::error!("Error is {:?}", err);
err
})
.map(|mut segments| {
segments.sort_by(|a, b| a.distance.total_cmp(&b.distance));
segments
})
}

0 comments on commit 69e5d7e

Please sign in to comment.