diff --git a/lantern_extras/src/bloom.rs b/lantern_extras/src/bloom.rs index cd753a19..1c15a22a 100644 --- a/lantern_extras/src/bloom.rs +++ b/lantern_extras/src/bloom.rs @@ -68,10 +68,9 @@ fn array_to_bloom_bigint(arr: Vec) -> Bloom { return array_to_bloom(arr); } -#[pg_extern(requires = [Bloom])] -fn elem_in_bloom(elem: i32, bloom: Bloom) -> bool { - let bloom: BloomFilter = bloom.into(); - bloom.contains(&elem) +#[pg_extern(immutable, parallel_safe, name = "array_to_bloom")] +fn array_to_bloom_text(arr: Vec) -> Bloom { + return array_to_bloom(arr); } extension_sql!( @@ -79,6 +78,7 @@ extension_sql!( CREATE CAST (smallint[] AS bloom) WITH FUNCTION array_to_bloom(smallint[]); CREATE CAST (integer[] AS bloom) WITH FUNCTION array_to_bloom(integer[]); CREATE CAST (bigint[] AS bloom) WITH FUNCTION array_to_bloom(bigint[]); + CREATE CAST (text[] AS bloom) WITH FUNCTION array_to_bloom(text[]); "#, name = "bloom_type_casts", requires = [ @@ -86,5 +86,18 @@ extension_sql!( array_to_bloom_smallint, array_to_bloom_integer, array_to_bloom_bigint, + array_to_bloom_text, ] ); + +#[pg_extern(immutable, parallel_safe, name = "elem_in_bloom", requires = [Bloom])] +fn elem_in_bloom_numeric(elem: i32, bloom: Bloom) -> bool { + let bloom: BloomFilter = bloom.into(); + bloom.contains(&elem) +} + +#[pg_extern(immutable, parallel_safe, name = "elem_in_bloom", requires = [Bloom])] +fn elem_in_bloom_text(elem: String, bloom: Bloom) -> bool { + let bloom: BloomFilter = bloom.into(); + bloom.contains(&elem) +} diff --git a/lantern_extras/src/bm25_api.rs b/lantern_extras/src/bm25_api.rs index 5abf65f5..57e7ff20 100644 --- a/lantern_extras/src/bm25_api.rs +++ b/lantern_extras/src/bm25_api.rs @@ -1,6 +1,6 @@ use pgrx::extension_sql_file; -extension_sql_file!("./bm25_api.sql", requires = [Bloom]); +extension_sql_file!("./bm25_api.sql", requires = [Bloom, "bloom_type_casts"]); #[cfg(any(test, feature = "pg_test"))] #[pgrx::pg_schema]