Skip to content

Commit

Permalink
Add a unit test to cover auto generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-Shang committed Dec 20, 2024
1 parent bf35fc5 commit ea78a01
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def neo4j_query_engine():
]

# define which entities can have which relations
validation_schema = {
schema = {
"EMPLOYEE": ["FOLLOWS", "APPLIES_TO", "ASSIGNED_TO", "ENTITLED_TO", "REPORTS_TO"],
"EMPLOYER": ["PROVIDES", "DEFINED_AS", "MANAGES", "REQUIRES"],
"POLICY": ["APPLIES_TO", "DEFINED_AS", "REQUIRES"],
Expand All @@ -69,7 +69,7 @@ def neo4j_query_engine():
database="neo4j", # Change if you want to store the graphh in your custom database
entities=entities, # possible entities
relations=relations, # possible relations
validation_schema=validation_schema, # schema to validate the extracted triplets
schema=schema,
strict=True, # enofrce the extracted triplets to be in the schema
)

Expand All @@ -78,6 +78,23 @@ def neo4j_query_engine():
return query_engine


# Test fixture to test auto-generation without given schema
@pytest.fixture(scope="module")
def neo4j_query_engine_auto():
"""
Test the engine with auto-generated property graph
"""
query_engine = Neo4jGraphQueryEngine(
username="neo4j",
password="password",
host="bolt://172.17.0.3",
port=7687,
database="neo4j",
)
query_engine.connect_db() # Connect to the existing graph
return query_engine


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip or skip_openai,
reason=reason,
Expand Down Expand Up @@ -117,3 +134,18 @@ def test_neo4j_add_records(neo4j_query_engine):
print(query_result.answer)

assert query_result.answer.find("Keanu Reeves") >= 0


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip or skip_openai,
reason=reason,
)
def test_neo4j_auto(neo4j_query_engine_auto):
"""
Test querying with auto-generated property graph
"""
question = "Which company is the employer?"
query_result: GraphStoreQueryResult = neo4j_query_engine_auto.query(question=question)

print(query_result.answer)
assert query_result.answer.find("BUZZ") >= 0

0 comments on commit ea78a01

Please sign in to comment.