Skip to content

Commit

Permalink
Make GraphRAG search's query backwards compatible (#97)
Browse files Browse the repository at this point in the history
* Make query backwards compatible

* Update CHANGELOG
  • Loading branch information
willtai authored Aug 8, 2024
1 parent b4cde39 commit 149d1e9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@

### Added
- Add optional custom_prompt arg to the Text2CypherRetriever class.


### Changed
- `GraphRAG.search` method first parameter has been renamed `query_text` (was `query`) for consistency with the retrievers interface.
- Made `GraphRAG.search` method backwards compatible with the query parameter, raising warnings to encourage using query_text instead.

## 0.3.1

### Fixed
- Corrected initialization to allow specifying the embedding model name.
- Removed sentence_transformers from embeddings/__init__.py to avoid ImportError when the package is not installed.

### Changed
- `GraphRAG.search` method first parameter has been renamed `query_text` (was `query`) for consistency with the retrievers interface.

## 0.3.0

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ rag = GraphRAG(retriever=retriever, llm=llm)

# Query the graph
query_text = "How do I do similarity search in Neo4j?"
response = rag.search(query=query_text, retriever_config={"top_k": 5})
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
print(response.answer)
```

Expand Down
20 changes: 19 additions & 1 deletion src/neo4j_genai/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

from pydantic import ValidationError
Expand Down Expand Up @@ -53,10 +54,11 @@ def __init__(

def search(
self,
query_text: str,
query_text: str = "",
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool = False,
query: Optional[str] = None,
) -> RagResultModel:
"""This method performs a full RAG search:
1. Retrieval: context retrieval
Expand All @@ -69,12 +71,28 @@ def search(
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
return_context (bool): Whether to return the retriever result (default: False)
query (Optional[str]): The user question. Will be deprecated in favor of query_text.
Returns:
RagResultModel: The LLM-generated answer
"""
try:
if query is not None:
if query_text:
warnings.warn(
"Both 'query' and 'query_text' are provided, 'query_text' will be used.",
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query

validated_data = RagSearchModel(
query_text=query_text,
examples=examples,
Expand Down

0 comments on commit 149d1e9

Please sign in to comment.