diff --git a/lantern_extras/src/bm25_agg.rs b/lantern_extras/src/bm25_agg.rs index e053db1c..e179a91c 100644 --- a/lantern_extras/src/bm25_agg.rs +++ b/lantern_extras/src/bm25_agg.rs @@ -582,13 +582,6 @@ mod tests { use itertools::Itertools; use pgrx::prelude::*; - // TODO: turn this into a test: - // select term, cardinality(array_agg(DISTINCT term_freq)) as uniq_term_freq_must_be_1, - // (array_agg(DISTINCT term_freq))[1] any_term_freq_must_be_equal_to_next_col , - // SUM(cardinality(doc_ids)) from corpus_bm25 where term_freq != cardinality(doc_ids) - // GROUP BY term ORDER BY term; - // TODO: check that doc_ids contain UNIQUE doc ids - #[pg_test] fn test_bm25_agg_and_api() -> spi::Result<()> { // Step 1: Create the documents table @@ -711,6 +704,37 @@ mod tests { // the added row had one new word, so the bm25 table must grow by 1 assert_eq!(new_len, prev_len + 1); + // Test some axuilary _bm25 table invariants + + let doc_ids_invariant = Spi::get_one::( + "SELECT NOT EXISTS (SELECT 1 FROM documents_bm25 WHERE NOT cardinality(doc_ids) = cardinality((SELECT array_agg(DISTINCT e) FROM unnest( doc_ids) e)))", + )?; + assert!( + doc_ids_invariant.unwrap(), + "doc_ids must contain unique doc ids" + ); + + // Test for term frequency invariants + let term_freq_invariant = Spi::get_one::( + "SELECT NOT EXISTS ( + SELECT 1 + FROM ( + SELECT term, cardinality(array_agg(DISTINCT term_freq)) AS uniq_term_freq_must_be_1, + (array_agg(DISTINCT term_freq))[1] AS any_term_freq_must_be_equal_to_next_col, + SUM(cardinality(doc_ids)) + FROM documents_bm25 + WHERE term_freq != cardinality(doc_ids) + GROUP BY term + ) AS subquery + WHERE uniq_term_freq_must_be_1 != 1 + );", + )?; + + // note: this invariant needs to be modified once insertions become allowed + assert!( + term_freq_invariant.unwrap(), + "Each term must have a unique term frequency equal to the cardinality of doc_ids" + ); // Test the bm25_score function // The bm25_score function calculates the BM25 score for the given query in the given // document