Skip to content

Commit

Permalink
Allow thresholding on vector and fulltext indexes for Hybrid retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Jan 3, 2025
1 parent 39fd4f7 commit 95dd2d9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
30 changes: 8 additions & 22 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,29 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
f"CALL () {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
f"THEN (n.score / vector_index_max_score) ELSE 0 END AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
f"THEN (n.score / ft_index_max_score) ELSE 0 END AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
else:
return (
f"CALL {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
f"THEN (n.score / vector_index_max_score) ELSE 0 END AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
f"THEN (n.score / ft_index_max_score) ELSE 0 END AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)

Expand Down Expand Up @@ -174,7 +178,6 @@ def _get_filtered_vector_query(
query_params["embedding_dimension"] = embedding_dimension
return f"{base_query} AND ({where_filters}) {vector_query}", query_params


def get_search_query(
search_type: SearchType,
return_properties: Optional[list[str]] = None,
Expand All @@ -185,23 +188,6 @@ def get_search_query(
filters: Optional[dict[str, Any]] = None,
neo4j_version_is_5_23_or_above: bool = False,
) -> tuple[str, dict[str, Any]]:
"""Build the search query, including pre-filtering if needed, and return clause.
Args
search_type: Search type we want to search for:
return_properties (list[str]): list of property names to return.
It can't be provided together with retrieval_query.
retrieval_query (str): the query to use to retrieve the search results
It can't be provided together with return_properties.
node_label (str): node label we want to search for
embedding_node_property (str): the name of the property holding the embeddings
embedding_dimension (int): the dimension of the embeddings
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search
Returns:
tuple[str, dict[str, Any]]: query and parameters
"""
warnings.warn(
"The default returned 'id' field in the search results will be removed. Please switch to using 'elementId' instead.",
DeprecationWarning,
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def get_search_results(
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -159,6 +161,8 @@ def get_search_results(
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.
Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -180,6 +184,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down Expand Up @@ -296,6 +303,8 @@ def get_search_results(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -313,6 +322,8 @@ def get_search_results(
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.
Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -335,6 +346,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def test_hybrid_search_text_happy_path(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -262,6 +264,8 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -345,6 +349,8 @@ def test_hybrid_retriever_return_properties(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down
24 changes: 16 additions & 8 deletions tests/unit/test_neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def test_hybrid_search_basic() -> None:
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index "
"THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index "
"THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
Expand Down Expand Up @@ -129,17 +131,20 @@ def test_hybrid_search_with_retrieval_query() -> None:
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
+ retrieval_query
)
result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query)
result, _ = get_search_query(
SearchType.HYBRID,
retrieval_query=retrieval_query,
)
assert result.strip() == expected.strip()


Expand All @@ -151,17 +156,20 @@ def test_hybrid_search_with_properties() -> None:
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"RETURN n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index THEN (n.score / vector_index_max_score) ELSE 0 END AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"RETURN n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index THEN (n.score / ft_index_max_score) ELSE 0 END AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.HYBRID, return_properties=properties)
result, _ = get_search_query(
SearchType.HYBRID,
return_properties=properties,
)
assert result.strip() == expected.strip()


Expand Down

0 comments on commit 95dd2d9

Please sign in to comment.