Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neo4j native GraphRAG integration #377

Merged
merged 16 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def init_db(self, input_doc: list[Document]):
else:
raise ValueError("No input documents could be loaded.")

def add_records(self, new_records: list) -> bool:
def add_records(self, new_records: list[Document]) -> bool:
raise NotImplementedError("This method is not supported by FalkorDB SDK yet.")

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _reply_using_falkordb_query(
Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""
# question = self._get_last_question(messages[-1])
question = self._messages_summary(messages, recipient.system_message)
result: GraphStoreQueryResult = self.query_engine.query(question)

Expand Down
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/graph_rag/graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_db(self, input_doc: list[Document] | None = None):
"""
pass

def add_records(self, new_records: list) -> bool:
def add_records(self, new_records: list[Document]) -> bool:
"""
Add new records to the underlying database and add to the graph if required.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def connect_db(self):
show_progress=True,
)

def add_records(self, new_records: list) -> bool:
def add_records(self, new_records: list[Document]) -> bool:
"""
Add new records to the knowledge graph. Must be local files.

Expand All @@ -152,9 +152,8 @@ def add_records(self, new_records: list) -> bool:

try:
"""
SimpleDirectoryReader will select the best file reader based on the file extensions, including:
[DocxReader, EpubReader, HWPReader, ImageReader, IPYNBReader, MarkdownReader, MboxReader,
PandasCSVReader, PandasExcelReader,PDFReader,PptxReader, VideoAudioReader]
SimpleDirectoryReader will select the best file reader based on the file extensions,
see _load_doc for supported file types.
"""
new_documents = SimpleDirectoryReader(input_files=[doc.path_or_url for doc in new_records]).load_data()

Expand Down
218 changes: 218 additions & 0 deletions autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
from typing import List, Optional, Union

from neo4j import GraphDatabase
from neo4j_graphrag.embeddings import Embedder, OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.generation import GraphRAG
from neo4j_graphrag.indexes import create_vector_index
from neo4j_graphrag.llm.openai_llm import LLMInterface, OpenAILLM
from neo4j_graphrag.retrievers import VectorRetriever

from .document import Document, DocumentType
from .graph_query_engine import GraphQueryEngine, GraphStoreQueryResult

# Set up logging
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)


class Neo4jNativeGraphQueryEngine(GraphQueryEngine):
"""
A graph query engine implemented using the Neo4j GraphRAG SDK.
Provides functionality to initialize a knowledge graph,
create a vector index, and query the graph using Neo4j and LLM.
"""

def __init__(
self,
host: str = "neo4j://localhost",
port: int = 7687,
username: str = "neo4j",
password: str = "password",
embeddings: Optional[Embedder] = OpenAIEmbeddings(model="text-embedding-3-large"),
embedding_dimension: Optional[int] = 3072,
llm: Optional[LLMInterface] = OpenAILLM(
model_name="gpt-4o",
model_params={"response_format": {"type": "json_object"}, "temperature": 0},
),
query_llm: Optional[LLMInterface] = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}),
entities: Optional[List[str]] = None,
relations: Optional[List[str]] = None,
potential_schema: Optional[List[tuple[str, str, str]]] = None,
):
"""
Initialize a Neo4j graph query engine.

Args:
host (str): Neo4j host URL.
port (int): Neo4j port.
username (str): Neo4j username.
password (str): Neo4j password.
embeddings (Embedder): Embedding model to embed chunk data and retrieve answers.
embedding_dimension (int): Dimension of the embeddings for the model.
llm (LLMInterface): Language model for creating the knowledge graph (returns JSON responses).
query_llm (LLMInterface): Language model for querying the knowledge graph.
entities (List[str], optional): Custom entities for guiding graph construction.
relations (List[str], optional): Custom relations for guiding graph construction.
potential_schema (List[tuple[str, str, str]], optional):
Schema (triplets, i.e., [entity] -> [relationship] -> [entity]) to guide graph construction.
"""
self.uri = f"{host}:{port}"
self.driver = GraphDatabase.driver(self.uri, auth=(username, password))
self.embeddings = embeddings
self.embedding_dimension = embedding_dimension
self.llm = llm
self.query_llm = query_llm
self.entities = entities
self.relations = relations
self.potential_schema = potential_schema

def init_db(self, input_doc: Union[list[Document], None] = None):
"""
Initialize the Neo4j graph database using the provided input doc.
Currently this method only supports single document input (only reads the first doc).

This method supports both text and PDF documents. It performs the following steps:
1. Clears the existing database.
2. Extracts graph nodes and relationships from the input data to build a knowledge graph.
3. Creates a vector index for efficient retrieval.

Args:
input_doc (list[Document]): Input documents for building the graph.

Raises:
ValueError: If the input document is not provided or its type is unsupported.
"""
if input_doc is None or len(input_doc) == 0:
raise ValueError("Input document is required to initialize the database.")
elif len(input_doc) > 1:
raise ValueError("Only the first document will be used to initialize the database.")

logger.info("Clearing the database...")
self._clear_db()

self._initialize_kg_builders()

self._build_graph(input_doc)

self.index_name = "vector-index-name"
logger.info(f"Creating vector index '{self.index_name}'...")
self._create_index(self.index_name)

def add_records(self, new_records: list[Document]) -> bool:
"""
Add new records to the Neo4j database.

Args:
new_records (list[Document]): List of new Documents to be added

Returns:
bool: True if records were added successfully, False otherwise.
"""
for record in new_records:
if not isinstance(record, Document):
raise ValueError("Invalid record type. Expected Document.")

self._build_graph(new_records)

return True

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
"""
Query the Neo4j database using a natural language question.

Args:
question (str): The question to be answered by querying the graph.

Returns:
GraphStoreQueryResult: The result of the query.
"""
self.retriever = VectorRetriever(
driver=self.driver,
index_name=self.index_name,
embedder=self.embeddings,
)
rag = GraphRAG(retriever=self.retriever, llm=self.query_llm)
result = rag.search(query_text=question, retriever_config={"top_k": 5})

return GraphStoreQueryResult(answer=result.answer)

def _create_index(self, name: str):
"""
Create a vector index for the Neo4j knowledge graph.

Args:
name (str): Name of the vector index to create.
"""
logger.info(f"Creating vector index '{name}'...")
create_vector_index(
self.driver,
name=name,
label="Chunk",
embedding_property="embedding",
dimensions=self.embedding_dimension,
similarity_fn="euclidean",
)
logger.info(f"Vector index '{name}' created successfully.")

def _clear_db(self):
"""
Clear all nodes and relationships from the Neo4j database.
"""
logger.info("Clearing all nodes and relationships in the database...")
self.driver.execute_query("MATCH (n) DETACH DELETE n;")
logger.info("Database cleared successfully.")

def _initialize_kg_builders(self):
"""
Initialize the knowledge graph builders
"""
logger.info("Initializing the knowledge graph builders...")
self.text_kg_builder = SimpleKGPipeline(
driver=self.driver,
embedder=self.embeddings,
llm=self.llm,
entities=self.entities,
relations=self.relations,
potential_schema=self.potential_schema,
on_error="IGNORE",
from_pdf=False,
)

self.pdf_kg_builder = SimpleKGPipeline(
driver=self.driver,
embedder=self.embeddings,
llm=self.llm,
entities=self.entities,
relations=self.relations,
potential_schema=self.potential_schema,
on_error="IGNORE",
from_pdf=True,
)

def _build_graph(self, input_doc: List[Document]) -> None:
"""
Build the knowledge graph using the provided input documents.

Args:
input_doc (List[Document]): List of input documents for building the graph.
"""
logger.info("Building the knowledge graph...")
for doc in input_doc:
if doc.doctype == DocumentType.TEXT:
with open(doc.path_or_url, "r") as file:
text = file.read()
asyncio.run(self.text_kg_builder.run_async(text=text))
elif doc.doctype == DocumentType.PDF:
asyncio.run(self.pdf_kg_builder.run_async(file_path=doc.path_or_url))
else:
raise ValueError(f"Unsupported document type: {doc.doctype}")

logger.info("Knowledge graph built successfully.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional, Union

from autogen import Agent, ConversableAgent

from .graph_query_engine import GraphStoreQueryResult
from .graph_rag_capability import GraphRagCapability
from .neo4j_native_graph_query_engine import Neo4jNativeGraphQueryEngine


class Neo4jNativeGraphCapability(GraphRagCapability):
"""
The Neo4j native graph capability integrates Neo4j native query engine into a graph rag agent.

For usage, please refer to example notebook/agentchat_graph_rag_neo4j_native.ipynb
"""

def __init__(self, query_engine: Neo4jNativeGraphQueryEngine):
"""
initialize GraphRAG capability with a neo4j native graph query engine
"""
self.query_engine = query_engine

def add_to_agent(self, agent: ConversableAgent):
"""
Add native Neo4j GraphRAG capability to a ConversableAgent.
llm_config of the agent must be None/False (default) to make sure the returned message only contains information retrieved from the graph DB instead of any LLMs.
"""

self.graph_rag_agent = agent

# Validate the agent config
if agent.llm_config not in (None, False):
raise Exception(
"Agents with GraphRAG capabilities do not use an LLM configuration. Please set your llm_config to None or False."
)

# Register method to generate the reply using a Neo4j query
# All other reply methods will be removed
agent.register_reply(
[ConversableAgent, None], self._reply_using_native_neo4j_query, position=0, remove_other_reply_funcs=True
)

def _reply_using_native_neo4j_query(
self,
recipient: ConversableAgent,
messages: Optional[list[dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> tuple[bool, Union[str, dict, None]]:
"""
Query Neo4j and return the message. Internally, it uses the Neo4jNativeGraphQueryEngine to query the graph.

The agent's system message will be incorporated into the query, if it's not blank.

If no results are found, a default message is returned: "I'm sorry, I don't have an answer for that."

Args:
recipient: The agent instance that will receive the message.
messages: A list of messages in the conversation history with the sender.
sender: The agent instance that sent the message.
config: Optional configuration for message processing.

Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""
question = self._messages_summary(messages, recipient.system_message)
result: GraphStoreQueryResult = self.query_engine.query(question)

return True, result.answer if result.answer else "I'm sorry, I don't have an answer for that."

def _messages_summary(self, messages: Union[dict, str], system_message: str) -> str:
"""Summarize the messages in the conversation history. Excluding any message with 'tool_calls' and 'tool_responses'
Includes the 'name' (if it exists) and the 'content', with a new line between each one, like:
customer:
<content>

agent:
<content>
"""

if isinstance(messages, str):
if system_message:
summary = f"IMPORTANT: {system_message}\nContext:\n\n{messages}"
else:
return messages

elif isinstance(messages, list):
summary = ""
for message in messages:
if "content" in message and "tool_calls" not in message and "tool_responses" not in message:
summary += f"{message.get('name', '')}: {message.get('content','')}\n\n"

if system_message:
summary = f"IMPORTANT: {system_message}\nContext:\n\n{summary}"

return summary

else:
raise ValueError("Invalid messages format. Must be a list of messages or a string.")
Loading
Loading