Skip to content

Commit

Permalink
Add bm25_score function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ngalstyan4 committed Nov 6, 2024
1 parent 0fbeb95 commit 50ebe89
Showing 1 changed file with 125 additions and 1 deletion.
126 changes: 125 additions & 1 deletion lantern_extras/src/bm25_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,66 @@ CREATE TYPE bm25result AS (
const BM25RESULT_COMPOSITE_TYPE: &str = "bm25result";
type BM25ResultSQLType = Vec<Option<pgrx::composite_type!('static, BM25RESULT_COMPOSITE_TYPE)>>;

#[pg_extern(immutable, parallel_safe)]
fn bm25_score(table_fqn: String, document: String, query: String) -> f32 {
let document_stemmed =
Spi::get_one::<Vec<String>>(&format!("SELECT text_to_stem_array('{}')", document))
.expect("Failed to get stemmed document")
.expect("Stemmed document was NULL");
let table_fqn = table_fqn.to_string();
let (corpus_size, avg_doc_len) = match Spi::get_two::<i32, f32>(&format!("SELECT term_freq AS corpus_size, (doc_ids_len / 100.0)::real AS avg_doc_len FROM {}_bm25 WHERE term IS NULL;", table_fqn))
.expect("Failed to get corpus size") {
(Some(corpus_size), Some(avg_doc_len)) => (corpus_size as u64, avg_doc_len),
_ => panic!("Failed to get corpus size and avg doc len"),
};
let term_data: Vec<(String, f32)> = Spi::connect(|client| {
client
.select(
&format!(
"SELECT term, term_freq FROM {}_bm25 WHERE term = ANY(text_to_stem_array('{}'));",
table_fqn, query
),
None,
None,
)
.expect("Failed to select from _bm25 table")
.into_iter()
.map(|row| {
let term = row
.get::<String>(1)
.expect("Failed to get term from _bm25 table")
.expect("term in _bm25 table was NULL");

let term_freq =
row.get::<i32>(2)
.expect("Failed to get term_freq from _bm25 table")
.expect("term_freq in _bm25 table was NULL") as f32;
(term, term_freq)
})
.collect()
});
let mut bm25 = 0.0;
for (word, term_freq) in term_data {
let mut fq = document_stemmed.iter().filter(|&x| x == &word).count() as f32;
let mut doc_len = document_stemmed.len() as f32;

if term_freq as i32 > BM25_DEFAULT_APPROXIMATION_THRESHHOLD.get() {
fq = if fq > 1. { 1. } else { fq };
doc_len = avg_doc_len;
}
bm25 += calculate_bm25(
doc_len,
fq,
term_freq,
corpus_size as u64,
avg_doc_len,
BM25_DEFAULT_K1.get() as f32,
BM25_DEFAULT_B.get() as f32,
);
}
bm25
}

#[inline(always)]
fn calculate_bm25(
doc_len: f32,
Expand Down Expand Up @@ -519,6 +579,7 @@ impl Aggregate for bm25_agg_limit_bm25params {
#[cfg(any(test, feature = "pg_test"))]
#[pgrx::pg_schema]
mod tests {
use itertools::Itertools;
use pgrx::prelude::*;

// TODO: turn this into a test:
Expand All @@ -529,7 +590,7 @@ mod tests {
// TODO: check that doc_ids contain UNIQUE doc ids

#[pg_test]
fn test_bm25_agg() -> spi::Result<()> {
fn test_bm25_agg_and_api() -> spi::Result<()> {
// Step 1: Create the documents table
Spi::run(
"CREATE TEMP TABLE documents (
Expand Down Expand Up @@ -627,6 +688,69 @@ mod tests {
);
assert!(results.1.unwrap() > 0.0, "BM25 score must be positive.");

// insert more documents and check reindexing
let content = "pomegranate pomegranate pomegranate";

Spi::run_with_args(
"INSERT INTO documents (doc_id, content, stemmed_content) VALUES
(5, $1, text_to_stem_array($1::text));",
Some(vec![(PgBuiltInOids::TEXTOID.oid(), content.into_datum())]),
)?;

let prev_len = Spi::get_one::<i64>("SELECT COUNT(*) FROM documents_bm25;")?.unwrap();
Spi::run(
"SELECT create_bm25_table(
table_name => 'documents',
id_column => 'doc_id',
index_columns => ARRAY['stemmed_content'],
drop_if_exists => TRUE
);",
)?;

let new_len = Spi::get_one::<i64>("SELECT COUNT(*) FROM documents_bm25;")?.unwrap();
// the added row had one new word, so the bm25 table must grow by 1
assert_eq!(new_len, prev_len + 1);

// Test the bm25_score function
// The bm25_score function calculates the BM25 score for the given query in the given
// document
Spi::connect(|client| {
let q = client
.select(
"SELECT bm25_score('documents', 'apple banana kiwi', 'apple'),
bm25_score('documents', 'apple banana kiwi', 'kiwi'),
bm25_score('documents', 'nonexistent_document', 'nonexistent_document');",
None,
None,
)
.expect("Failed to select from _bm25 table")
.into_iter()
.next()
.expect("Failed to get result");

let (bm25_apple, bm25_kiwi, bm25_nonexistent) = (1..=3)
.map(|ind| {
q.get::<f32>(ind)
.expect("Failed to get bm25 score")
.expect("Bm25 is null")
})
.collect_tuple()
.unwrap();
info!(
"values: {}, {}, {}",
bm25_apple, bm25_kiwi, bm25_nonexistent
);
assert!(bm25_apple > 0.0, "BM25 score must be positive.");
assert!(
bm25_kiwi > bm25_apple,
"term kiwi is more rare, so should have higher score"
);
assert!(
bm25_nonexistent == 0.0,
" bm25 score of a term that is not in the corpus should be zero"
);
});

Ok(())
}
}

0 comments on commit 50ebe89

Please sign in to comment.