Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Memgraph integration #16345

Merged
merged 9 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
matea16 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Graph-Stores Integration: Memgraph
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import logging
from llama_index.llms.openai import OpenAI
from llama_index.core import Settings
from llama_index.core import KnowledgeGraphIndex, SimpleDirectoryReader, StorageContext
from llama_index.graph_stores.memgraph import MemgraphGraphStore


# Step 1: Set up OpenAI API key
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>" # Replace with your OpenAI API key

# Step 2: Configure logging
logging.basicConfig(level=logging.INFO)

# Step 3: Configure OpenAI LLM
llm = OpenAI(temperature=0, model="gpt-3.5-turbo")
Settings.llm = llm
Settings.chunk_size = 512

# Step 4: Write documents to text files (Simulating loading documents from disk)
documents = {
"doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.",
"doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.",
"doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development."
}

for filename, content in documents.items():
with open(filename, "w") as file:
file.write(content)

# Step 5: Load documents
loaded_documents = SimpleDirectoryReader(".").load_data()

# Step 6: Set up Memgraph connection
username = "" # Enter your Memgraph username (default "")
password = "" # Enter your Memgraph password (default "")
url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687'
database = "memgraph" # Name of the database, default is 'memgraph'

graph_store = MemgraphGraphStore(
username=username,
password=password,
url=url,
database=database,
)

storage_context = StorageContext.from_defaults(graph_store=graph_store)

# Step 7: Create a Knowledge Graph Index
index = KnowledgeGraphIndex.from_documents(
loaded_documents,
storage_context=storage_context,
max_triplets_per_chunk=3,
)

# Step 8: Query the Knowledge Graph
query_engine = index.as_query_engine(include_text=False, response_mode="tree_summarize")
response = query_engine.query("Tell me about Python and its uses")

print("Query Response:")
print(response)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import urllib.request
import nest_asyncio
import logging
from llama_index.core import SimpleDirectoryReader, PropertyGraphIndex
from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor


# 1. Setup OpenAI API Key (replace this with your actual key)
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>" # Replace with your OpenAI API key

# 2. Create the data directory and download the Paul Graham essay
os.makedirs('data/paul_graham/', exist_ok=True)

url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt'
output_path = 'data/paul_graham/paul_graham_essay.txt'
urllib.request.urlretrieve(url, output_path)

# 3. Ensure nest_asyncio is applied
nest_asyncio.apply()

# Step 2: Read the file, replace single quotes, and save the modified content
with open(output_path, 'r', encoding='utf-8') as file:
content = file.read()

# Replace single quotes with escaped single quotes
modified_content = content.replace("'", "\\'")

# Save the modified content back to the same file
with open(output_path, 'w', encoding='utf-8') as file:
file.write(modified_content)

# 4. Load the document data
documents = SimpleDirectoryReader("./data/paul_graham/").load_data()

# 5. Setup Memgraph connection (ensure Memgraph is running)
username = "" # Enter your Memgraph username (default "")
password = "" # Enter your Memgraph password (default "")
url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687'

graph_store = MemgraphPropertyGraphStore(
username=username,
password=password,
url=url,
)

# 6. Create the Property Graph Index
index = PropertyGraphIndex.from_documents(
documents,
embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"),
kg_extractors=[
SchemaLLMPathExtractor(
llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0),
)
],
property_graph_store=graph_store,
show_progress=True,
)

# 7. Querying the graph
retriever = index.as_retriever(include_text=False)

# Example query: "What happened at Interleaf and Viaweb?"
nodes = retriever.retrieve("What happened at Interleaf and Viaweb?")

# Output results
print("Query Results:")
for node in nodes:
print(node.text)

# Alternatively, using a query engine
query_engine = index.as_query_engine(include_text=True)

# Perform a query and print the detailed response
response = query_engine.query("What happened at Interleaf and Viaweb?")
print("\nDetailed Query Response:")
print(str(response))
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from llama_index.graph_stores.memgraph.kg_base import MemgraphGraphStore
from llama_index.graph_stores.memgraph.property_graph import MemgraphPropertyGraphStore

__all__ = ["MemgraphGraphStore", "MemgraphPropertyGraphStore"]

Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Memgraph graph store index."""
import logging
from typing import Any, Dict, List, Optional

from llama_index.core.graph_stores.types import GraphStore

logger = logging.getLogger(__name__)

node_properties_query = """
CALL schema.node_type_properties()
YIELD nodeType AS label, propertyName AS property, propertyTypes AS type
WITH label AS nodeLabels, collect({property: property, type: type}) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output
"""

rel_properties_query = """
CALL schema.rel_type_properties()
YIELD relType AS label, propertyName AS property, propertyTypes AS type
WITH label, collect({property: property, type: type}) AS properties
RETURN {type: label, properties: properties} AS output
"""

rel_query = """
MATCH (start_node)-[r]->(end_node)
WITH labels(start_node) AS start, type(r) AS relationship_type, labels(end_node) AS end, keys(r) AS relationship_properties
UNWIND end AS end_label
RETURN DISTINCT {start: start[0], type: relationship_type, end: end_label} AS output
"""

class MemgraphGraphStore(GraphStore):
def __init__(
self,
username: str,
password: str,
url: str,
database: str = "memgraph",
node_label: str = "Entity",
**kwargs: Any,
) -> None:
try:
import neo4j
except ImportError:
raise ImportError("Please install neo4j: pip install neo4j")
self.node_label = node_label
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.schema = ""
# verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Memgraph database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Memgraph database. "
"Please ensure that the username and password are correct"
)
# set schema
self.refresh_schema()

# create constraint
self.query(
"""
CREATE CONSTRAINT ON (n:%s) ASSERT n.id IS UNIQUE;
"""
% (self.node_label)
)

# create index
self.query(
"""
CREATE INDEX ON :%s(id);
"""
% (self.node_label)
)

@property
def client(self) -> Any:
return self._driver


def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
"""Execute a Cypher query."""
with self._driver.session(database=self._database) as session:
result = session.run(query, param_map)
return [record.data() for record in result]

def get(self, subj: str) -> List[List[str]]:
"""Get triplets."""
query = f"""
MATCH (n1:{self.node_label})-[r]->(n2:{self.node_label})
WHERE n1.id = $subj
RETURN type(r), n2.id;
"""

with self._driver.session(database=self._database) as session:
data = session.run(query, {"subj": subj})
return [record.values() for record in data]

def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2
) -> Dict[str, List[List[str]]]:
"""Get flat relation map."""
rel_map: Dict[Any, List[Any]] = {}
if subjs is None or len(subjs) == 0:
return rel_map

query = (
f"""MATCH p=(n1:{self.node_label})-[*1..{depth}]->() """
f"""{"WHERE n1.id IN $subjs" if subjs else ""} """
"UNWIND relationships(p) AS rel "
"WITH n1.id AS subj, collect([type(rel), endNode(rel).id]) AS rels "
"RETURN subj, rels"
)

data = list(self.query(query, {"subjs": subjs}))
if not data:
return rel_map

for record in data:
rel_map[record["subj"]] = record["rels"]

return rel_map

def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
query = f"""
MERGE (n1:`{self.node_label}` {{id:$subj}})
MERGE (n2:`{self.node_label}` {{id:$obj}})
MERGE (n1)-[:`{rel.replace(" ", "_").upper()}`]->(n2)
"""
self.query(query, {"subj": subj, "obj": obj})

def delete(self, subj: str, rel: str, obj: str) -> None:
"""Delete triplet."""
query = f"""
MATCH (n1:`{self.node_label}`)-[r:`{rel}`]->(n2:`{self.node_label}`)
WHERE n1.id = $subj AND n2.id = $obj
DELETE r
"""
self.query(query, {"subj": subj, "obj": obj})

def refresh_schema(self) -> None:
"""
Refreshes the Memgraph graph schema information.
"""
node_properties = self.query(node_properties_query)
relationships_properties = self.query(rel_properties_query)
relationships = self.query(rel_query)

self.schema = f"""
Node properties are the following:
{[el for el in node_properties]}
Relationship properties are the following:
{[el for el in relationships_properties]}
The relationships are the following:
{[el for el in relationships]}
"""

def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the MemgraphGraph store."""
if self.schema and not refresh:
return self.schema
self.refresh_schema()
logger.debug(f"get_schema() schema:\n{self.schema}")
return self.schema
Loading
Loading