From 3c5ac94553722600c9bc59859d4fa5f2fa513eef Mon Sep 17 00:00:00 2001 From: Matea Pesic <80577904+matea16@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:27:13 +0200 Subject: [PATCH] Add Memgraph integration (#16345) --- .../property_graph_memgraph.ipynb | 284 ++++++ .../llama-index-graph-stores-memgraph/BUILD | 3 + .../Makefile | 17 + .../README.md | 147 +++ .../llama_index/graph_stores/memgraph/BUILD | 1 + .../graph_stores/memgraph/__init__.py | 4 + .../graph_stores/memgraph/kg_base.py | 169 +++ .../graph_stores/memgraph/property_graph.py | 962 ++++++++++++++++++ .../pyproject.toml | 58 ++ .../tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_graph_stores_memgraph.py | 10 + .../tests/test_pg_stores_memgraph.py | 92 ++ 13 files changed, 1748 insertions(+) create mode 100644 docs/docs/examples/property_graph/property_graph_memgraph.ipynb create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py diff --git a/docs/docs/examples/property_graph/property_graph_memgraph.ipynb b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb new file mode 100644 index 0000000000000..ccb6491e8fada --- /dev/null +++ b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb @@ -0,0 +1,284 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Memgraph Property Graph Index\n", + "\n", + "[Memgraph](https://memgraph.com/) is an open source graph database built real-time streaming and fast analysis of your stored data.\n", + "\n", + "Before running Memgraph, ensure you have Docker running in the background. The quickest way to try out [Memgraph Platform](https://memgraph.com/docs/getting-started#install-memgraph-platform) (Memgraph database + MAGE library + Memgraph Lab) for the first time is running the following command:\n", + "\n", + "For Linux/macOS:\n", + "```shell\n", + "curl https://install.memgraph.com | sh\n", + "```\n", + "\n", + "For Windows:\n", + "```shell\n", + "iwr https://windows.memgraph.com | iex\n", + "```\n", + "\n", + "From here, you can check Memgraph's visual tool, Memgraph Lab on the [http://localhost:3000/](http://localhost:3000/) or the [desktop version](https://memgraph.com/download) of the app." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index llama-index-graph-stores-memgraph" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment setup " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\n", + " \"OPENAI_API_KEY\"\n", + "] = \"sk-proj-...\" # Replace with your OpenAI API key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the data directory and download the Paul Graham essay we'll be using as the input data for this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import urllib.request\n", + "\n", + "os.makedirs(\"data/paul_graham/\", exist_ok=True)\n", + "\n", + "url = \"https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\"\n", + "output_path = \"data/paul_graham/paul_graham_essay.txt\"\n", + "urllib.request.urlretrieve(url, output_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Read the file, replace single quotes, save the modified content and load the document data using the `SimpleDirectoryReader`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import SimpleDirectoryReader\n", + "\n", + "with open(output_path, \"r\", encoding=\"utf-8\") as file:\n", + " content = file.read()\n", + "\n", + "modified_content = content.replace(\"'\", \"\\\\'\")\n", + "\n", + "with open(output_path, \"w\", encoding=\"utf-8\") as file:\n", + " file.write(modified_content)\n", + "\n", + "documents = SimpleDirectoryReader(\"./data/paul_graham/\").load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Memgraph connection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up your graph store class by providing the database credentials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore\n", + "\n", + "username = \"\" # Enter your Memgraph username (default \"\")\n", + "password = \"\" # Enter your Memgraph password (default \"\")\n", + "url = \"\" # Specify the connection URL, e.g., 'bolt://localhost:7687'\n", + "\n", + "graph_store = MemgraphPropertyGraphStore(\n", + " username=username,\n", + " password=password,\n", + " url=url,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Index Construction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import PropertyGraphIndex\n", + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "from llama_index.llms.openai import OpenAI\n", + "from llama_index.core.indices.property_graph import SchemaLLMPathExtractor\n", + "\n", + "index = PropertyGraphIndex.from_documents(\n", + " documents,\n", + " embed_model=OpenAIEmbedding(model_name=\"text-embedding-ada-002\"),\n", + " kg_extractors=[\n", + " SchemaLLMPathExtractor(\n", + " llm=OpenAI(model=\"gpt-3.5-turbo\", temperature=0.0)\n", + " )\n", + " ],\n", + " property_graph_store=graph_store,\n", + " show_progress=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the graph is created, we can explore it in the UI by visiting [http://localhost:3000/](http://localhost:3000/).\n", + "\n", + "The easiest way to visualize the entire graph is by running a Cypher command similar to this:\n", + "\n", + "```shell\n", + "MATCH p=()-[]-() RETURN p;\n", + "```\n", + "\n", + "This command matches all of the possible paths in the graph and returns entire graph. \n", + "\n", + "To visualize the schema of the graph, visit the Graph schema tab and generate the new schema based on the newly created graph.\n", + "\n", + "To delete an entire graph, use:\n", + "\n", + "```shell\n", + "MATCH (n) DETACH DELETE n;\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Querying and retrieval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = index.as_retriever(include_text=False)\n", + "\n", + "# Example query: \"What happened at Interleaf and Viaweb?\"\n", + "nodes = retriever.retrieve(\"What happened at Interleaf and Viaweb?\")\n", + "\n", + "# Output results\n", + "print(\"Query Results:\")\n", + "for node in nodes:\n", + " print(node.text)\n", + "\n", + "# Alternatively, using a query engine\n", + "query_engine = index.as_query_engine(include_text=True)\n", + "\n", + "# Perform a query and print the detailed response\n", + "response = query_engine.query(\"What happened at Interleaf and Viaweb?\")\n", + "print(\"\\nDetailed Query Response:\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading from an existing graph" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you have an existing graph (either created with LlamaIndex or otherwise), we can connect to and use it!\n", + "\n", + "**NOTE:** If your graph was created outside of LlamaIndex, the most useful retrievers will be [text to cypher](../../module_guides/indexing/lpg_index_guide.md#texttocypherretriever) or [cypher templates](../../module_guides/indexing/lpg_index_guide.md#cyphertemplateretriever). Other retrievers rely on properties that LlamaIndex inserts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI(model=\"gpt-4\", temperature=0.0)\n", + "kg_extractors = [SchemaLLMPathExtractor(llm=llm)]\n", + "\n", + "index = PropertyGraphIndex.from_existing(\n", + " property_graph_store=graph_store,\n", + " kg_extractors=kg_extractors,\n", + " embed_model=OpenAIEmbedding(model_name=\"text-embedding-ada-002\"),\n", + " show_progress=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 64-bit (microsoft store)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "vscode": { + "interpreter": { + "hash": "289d8ae9ac585fcc15d0d9333c941ae27bdf80d3e799883224b20975f2046730" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md new file mode 100644 index 0000000000000..04a925ae066e5 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md @@ -0,0 +1,147 @@ +# LlamaIndex Graph-Stores Integration: Memgraph + +Memgraph is an open source graph database built for real-time streaming and fast analysis. + +In this project, we integrated Memgraph as a graph store to store the LlamaIndex graph data and query it. + +- Property Graph Store: `MemgraphPropertyGraphStore` +- Knowledege Graph Store: `MemgraphGraphStore` + +## Installation + +```shell +pip install llama-index llama-index-graph-stores-memgraph +``` + +## Usage + +### Property Graph Store + +```python +import os +import urllib.request +import nest_asyncio +from llama_index.core import SimpleDirectoryReader, PropertyGraphIndex +from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI +from llama_index.core.indices.property_graph import SchemaLLMPathExtractor + + +os.environ[ + "OPENAI_API_KEY" +] = "" # Replace with your OpenAI API key + +os.makedirs("data/paul_graham/", exist_ok=True) + +url = "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" +output_path = "data/paul_graham/paul_graham_essay.txt" +urllib.request.urlretrieve(url, output_path) + +nest_asyncio.apply() + +with open(output_path, "r", encoding="utf-8") as file: + content = file.read() + +modified_content = content.replace("'", "\\'") + +with open(output_path, "w", encoding="utf-8") as file: + file.write(modified_content) + +documents = SimpleDirectoryReader("./data/paul_graham/").load_data() + +# Setup Memgraph connection (ensure Memgraph is running) +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' + +graph_store = MemgraphPropertyGraphStore( + username=username, + password=password, + url=url, +) + +index = PropertyGraphIndex.from_documents( + documents, + embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"), + kg_extractors=[ + SchemaLLMPathExtractor( + llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), + ) + ], + property_graph_store=graph_store, + show_progress=True, +) + +query_engine = index.as_query_engine(include_text=True) + +response = query_engine.query("What happened at Interleaf and Viaweb?") +print("\nDetailed Query Response:") +print(str(response)) +``` + +### Knowledge Graph Store + +```python +import os +import logging +from llama_index.llms.openai import OpenAI +from llama_index.core import Settings +from llama_index.core import ( + KnowledgeGraphIndex, + SimpleDirectoryReader, + StorageContext, +) +from llama_index.graph_stores.memgraph import MemgraphGraphStore + +os.environ[ + "OPENAI_API_KEY" +] = "" # Replace with your OpenAI API key + +logging.basicConfig(level=logging.INFO) + +llm = OpenAI(temperature=0, model="gpt-3.5-turbo") +Settings.llm = llm +Settings.chunk_size = 512 + +documents = { + "doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.", + "doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.", + "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development.", +} + +for filename, content in documents.items(): + with open(filename, "w") as file: + file.write(content) + +loaded_documents = SimpleDirectoryReader(".").load_data() + +# Setup Memgraph connection (ensure Memgraph is running) +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' +database = "memgraph" # Name of the database, default is 'memgraph' + +graph_store = MemgraphGraphStore( + username=username, + password=password, + url=url, + database=database, +) + +storage_context = StorageContext.from_defaults(graph_store=graph_store) + +index = KnowledgeGraphIndex.from_documents( + loaded_documents, + storage_context=storage_context, + max_triplets_per_chunk=3, +) + +query_engine = index.as_query_engine( + include_text=False, response_mode="tree_summarize" +) +response = query_engine.query("Tell me about Python and its uses") + +print("Query Response:") +print(response) +``` diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py new file mode 100644 index 0000000000000..83b3a573e428e --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py @@ -0,0 +1,4 @@ +from llama_index.graph_stores.memgraph.kg_base import MemgraphGraphStore +from llama_index.graph_stores.memgraph.property_graph import MemgraphPropertyGraphStore + +__all__ = ["MemgraphGraphStore", "MemgraphPropertyGraphStore"] diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py new file mode 100644 index 0000000000000..58bd70d57443e --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py @@ -0,0 +1,169 @@ +"""Memgraph graph store index.""" +import logging +from typing import Any, Dict, List, Optional + +from llama_index.core.graph_stores.types import GraphStore + +logger = logging.getLogger(__name__) + +node_properties_query = """ +CALL schema.node_type_properties() +YIELD nodeType AS label, propertyName AS property, propertyTypes AS type +WITH label AS nodeLabels, collect({property: property, type: type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output +""" + +rel_properties_query = """ +CALL schema.rel_type_properties() +YIELD relType AS label, propertyName AS property, propertyTypes AS type +WITH label, collect({property: property, type: type}) AS properties +RETURN {type: label, properties: properties} AS output +""" + +rel_query = """ +MATCH (start_node)-[r]->(end_node) +WITH labels(start_node) AS start, type(r) AS relationship_type, labels(end_node) AS end, keys(r) AS relationship_properties +UNWIND end AS end_label +RETURN DISTINCT {start: start[0], type: relationship_type, end: end_label} AS output +""" + + +class MemgraphGraphStore(GraphStore): + def __init__( + self, + username: str, + password: str, + url: str, + database: str = "memgraph", + node_label: str = "Entity", + **kwargs: Any, + ) -> None: + try: + import neo4j + except ImportError: + raise ImportError("Please install neo4j: pip install neo4j") + self.node_label = node_label + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + self.schema = "" + # verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the username and password are correct" + ) + # set schema + self.refresh_schema() + + # create constraint + self.query( + """ + CREATE CONSTRAINT ON (n:%s) ASSERT n.id IS UNIQUE; + """ + % (self.node_label) + ) + + # create index + self.query( + """ + CREATE INDEX ON :%s(id); + """ + % (self.node_label) + ) + + @property + def client(self) -> Any: + return self._driver + + def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any: + """Execute a Cypher query.""" + with self._driver.session(database=self._database) as session: + result = session.run(query, param_map) + return [record.data() for record in result] + + def get(self, subj: str) -> List[List[str]]: + """Get triplets.""" + query = f""" + MATCH (n1:{self.node_label})-[r]->(n2:{self.node_label}) + WHERE n1.id = $subj + RETURN type(r), n2.id; + """ + + with self._driver.session(database=self._database) as session: + data = session.run(query, {"subj": subj}) + return [record.values() for record in data] + + def get_rel_map( + self, subjs: Optional[List[str]] = None, depth: int = 2 + ) -> Dict[str, List[List[str]]]: + """Get flat relation map.""" + rel_map: Dict[Any, List[Any]] = {} + if subjs is None or len(subjs) == 0: + return rel_map + + query = ( + f"""MATCH p=(n1:{self.node_label})-[*1..{depth}]->() """ + f"""{"WHERE n1.id IN $subjs" if subjs else ""} """ + "UNWIND relationships(p) AS rel " + "WITH n1.id AS subj, collect([type(rel), endNode(rel).id]) AS rels " + "RETURN subj, rels" + ) + + data = list(self.query(query, {"subjs": subjs})) + if not data: + return rel_map + + for record in data: + rel_map[record["subj"]] = record["rels"] + + return rel_map + + def upsert_triplet(self, subj: str, rel: str, obj: str) -> None: + """Add triplet.""" + query = f""" + MERGE (n1:`{self.node_label}` {{id:$subj}}) + MERGE (n2:`{self.node_label}` {{id:$obj}}) + MERGE (n1)-[:`{rel.replace(" ", "_").upper()}`]->(n2) + """ + self.query(query, {"subj": subj, "obj": obj}) + + def delete(self, subj: str, rel: str, obj: str) -> None: + """Delete triplet.""" + query = f""" + MATCH (n1:`{self.node_label}`)-[r:`{rel}`]->(n2:`{self.node_label}`) + WHERE n1.id = $subj AND n2.id = $obj + DELETE r + """ + self.query(query, {"subj": subj, "obj": obj}) + + def refresh_schema(self) -> None: + """ + Refreshes the Memgraph graph schema information. + """ + node_properties = self.query(node_properties_query) + relationships_properties = self.query(rel_properties_query) + relationships = self.query(rel_query) + + self.schema = f""" + Node properties are the following: + {node_properties} + Relationship properties are the following: + {relationships_properties} + The relationships are the following: + {relationships} + """ + + def get_schema(self, refresh: bool = False) -> str: + """Get the schema of the MemgraphGraph store.""" + if self.schema and not refresh: + return self.schema + self.refresh_schema() + logger.debug(f"get_schema() schema:\n{self.schema}") + return self.schema diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py new file mode 100644 index 0000000000000..48854a895d478 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py @@ -0,0 +1,962 @@ +from typing import Any, List, Dict, Optional, Tuple +from llama_index.core.graph_stores.prompts import DEFAULT_CYPHER_TEMPALTE +from llama_index.core.graph_stores.types import ( + PropertyGraphStore, + Triplet, + LabelledNode, + Relation, + EntityNode, + ChunkNode, +) +from llama_index.core.graph_stores.utils import ( + clean_string_values, + value_sanitize, + LIST_LIMIT, +) + +from llama_index.core.prompts import PromptTemplate +from llama_index.core.vector_stores.types import VectorStoreQuery +import neo4j + + +def remove_empty_values(input_dict): + """ + Remove entries with empty values from the dictionary. + """ + return {key: value for key, value in input_dict.items() if value} + + +BASE_ENTITY_LABEL = "__Entity__" +BASE_NODE_LABEL = "__Node__" +EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] +EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] +EXHAUSTIVE_SEARCH_LIMIT = 10000 +# Threshold for returning all available prop values in graph schema +DISTINCT_VALUE_LIMIT = 10 +CHUNK_SIZE = 1000 +# Threshold for max number of returned triplets +LIMIT = 100 + +node_properties_query = """ +MATCH (n) +UNWIND labels(n) AS label +WITH label, COUNT(n) AS count +CALL schema.node_type_properties() +YIELD propertyName, nodeLabels, propertyTypes +WITH label, nodeLabels, count, collect({property: propertyName, type: propertyTypes[0]}) AS properties +WHERE label IN nodeLabels +RETURN {labels: label, count: count, properties: properties} AS output +ORDER BY count DESC +""" + +rel_properties_query = """ +CALL schema.rel_type_properties() +YIELD relType AS label, propertyName AS property, propertyTypes AS type +WITH label, collect({property: property, type: type}) AS properties +RETURN {type: label, properties: properties} AS output +""" + +rel_query = """ +MATCH (start_node)-[r]->(end_node) +WITH DISTINCT labels(start_node) AS start_labels, type(r) AS relationship_type, labels(end_node) AS end_labels, keys(r) AS relationship_properties +UNWIND start_labels AS start_label +UNWIND end_labels AS end_label +RETURN DISTINCT {start: start_label, type: relationship_type, end: end_label} AS output +""" + + +class MemgraphPropertyGraphStore(PropertyGraphStore): + r""" + Memgraph Property Graph Store. + + This class implements a Memgraph property graph store. + + Args: + username (str): The username for the Memgraph database. + password (str): The password for the Memgraph database. + url (str): The URL for the Memgraph database. + database (Optional[str]): The name of the database to connect to. Defaults to "memgraph". + + Examples: + ```python + from llama_index.core.indices.property_graph import PropertyGraphIndex + from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore + + # Create a MemgraphPropertyGraphStore instance + graph_store = MemgraphPropertyGraphStore( + username="memgraph", + password="password", + url="bolt://localhost:7687", + database="memgraph" + ) + + # Create the index + index = PropertyGraphIndex.from_documents( + documents, + property_graph_store=graph_store, + ) + + # Close the Memgraph connection explicitly. + graph_store.close() + ``` + """ + + supports_structured_queries: bool = True + text_to_cypher_template: PromptTemplate = DEFAULT_CYPHER_TEMPALTE + + def __init__( + self, + username: str, + password: str, + url: str, + database: Optional[str] = "memgraph", + refresh_schema: bool = True, + sanitize_query_output: bool = True, + enhanced_schema: bool = False, + **neo4j_kwargs: Any, + ) -> None: + self.sanitize_query_output = sanitize_query_output + self.enhanced_schema = enhanced_schema + self._driver = neo4j.GraphDatabase.driver( + url, auth=(username, password), **neo4j_kwargs + ) + self._database = database + self.structured_schema = {} + if refresh_schema: + self.refresh_schema() + + # Create index for faster imports and retrieval + self.structured_query(f"""CREATE INDEX ON :{BASE_NODE_LABEL}(id);""") + self.structured_query(f"""CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);""") + + @property + def client(self): + return self._driver + + def close(self) -> None: + self._driver.close() + + def refresh_schema(self) -> None: + """Refresh the schema.""" + # Leave schema empty if db is empty + if self.structured_query("MATCH (n) RETURN n LIMIT 1") == []: + return + + node_query_results = self.structured_query( + node_properties_query, + param_map={ + "EXCLUDED_LABELS": [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ] + }, + ) + node_properties = {} + for el in node_query_results: + if el["output"]["labels"] in [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ]: + continue + + label = el["output"]["labels"] + properties = el["output"]["properties"] + if label in node_properties: + node_properties[label]["properties"].extend( + prop + for prop in properties + if prop not in node_properties[label]["properties"] + ) + else: + node_properties[label] = {"properties": properties} + + node_properties = [ + {"labels": label, **value} for label, value in node_properties.items() + ] + rels_query_result = self.structured_query( + rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS} + ) + rel_properties = ( + [ + el["output"] + for el in rels_query_result + if any(prop["property"] for prop in el["output"].get("properties", [])) + ] + if rels_query_result + else [] + ) + rel_objs_query_result = self.structured_query( + rel_query, + param_map={ + "EXCLUDED_LABELS": [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ] + }, + ) + relationships = [ + el["output"] + for el in rel_objs_query_result + if rel_objs_query_result + and el["output"]["start"] + not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] + and el["output"]["end"] + not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] + ] + self.structured_schema = { + "node_props": {el["labels"]: el["properties"] for el in node_properties}, + "rel_props": {el["type"]: el["properties"] for el in rel_properties}, + "relationships": relationships, + } + schema_nodes = self.structured_query( + "MATCH (n) UNWIND labels(n) AS label RETURN label AS node, COUNT(n) AS count ORDER BY count DESC" + ) + schema_rels = self.structured_query( + "MATCH ()-[r]->() RETURN TYPE(r) AS relationship_type, COUNT(r) AS count" + ) + schema_counts = [ + { + "nodes": [ + {"name": item["node"], "count": item["count"]} + for item in schema_nodes + ], + "relationships": [ + {"name": item["relationship_type"], "count": item["count"]} + for item in schema_rels + ], + } + ] + # Update node info + for node in schema_counts[0].get("nodes", []): + # Skip bloom labels + if node["name"] in EXCLUDED_LABELS: + continue + node_props = self.structured_schema["node_props"].get(node["name"]) + if not node_props: # The node has no properties + continue + + enhanced_cypher = self._enhanced_schema_cypher( + node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT + ) + output = self.structured_query(enhanced_cypher) + enhanced_info = output[0]["output"] + for prop in node_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + + # Update rel info + for rel in schema_counts[0].get("relationships", []): + if rel["name"] in EXCLUDED_RELS: + continue + rel_props = self.structured_schema["rel_props"].get(f":`{rel['name']}`") + if not rel_props: # The rel has no properties + continue + enhanced_cypher = self._enhanced_schema_cypher( + rel["name"], + rel_props, + rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, + is_relationship=True, + ) + try: + enhanced_info = self.structured_query(enhanced_cypher)[0]["output"] + for prop in rel_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + except neo4j.exceptions.ClientError: + pass + + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + # Lists to hold separated types + entity_dicts: List[dict] = [] + chunk_dicts: List[dict] = [] + + # Sort by type + for item in nodes: + if isinstance(item, EntityNode): + entity_dicts.append({**item.dict(), "id": item.id}) + elif isinstance(item, ChunkNode): + chunk_dicts.append({**item.dict(), "id": item.id}) + else: + pass + if chunk_dicts: + for index in range(0, len(chunk_dicts), CHUNK_SIZE): + chunked_params = chunk_dicts[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ", ".join( + [ + f"{key}: {value!r}" + for key, value in param["properties"].items() + ] + ) + self.structured_query( + f""" + MERGE (c:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + SET c.`text` = '{param["text"]}', c:Chunk + WITH c + SET c += {{{formatted_properties}}} + RETURN count(*) + """ + ) + if entity_dicts: + for index in range(0, len(entity_dicts), CHUNK_SIZE): + chunked_params = entity_dicts[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ", ".join( + [ + f"{key}: {value!r}" + for key, value in param["properties"].items() + ] + ) + self.structured_query( + f""" + MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + SET e += {{{formatted_properties}}} + SET e.name = '{param["name"]}', e:`{BASE_ENTITY_LABEL}` + WITH e + SET e :{param["label"]} + """ + ) + triplet_source_id = param["properties"].get("triplet_source_id") + if triplet_source_id: + self.structured_query( + f""" + MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + MERGE (c:{BASE_NODE_LABEL} {{id: '{triplet_source_id}'}}) + MERGE (e)<-[:MENTIONS]-(c) + """ + ) + + def upsert_relations(self, relations: List[Relation]) -> None: + """Add relations.""" + params = [r.dict() for r in relations] + for index in range(0, len(params), CHUNK_SIZE): + chunked_params = params[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ", ".join( + [f"{key}: {value!r}" for key, value in param["properties"].items()] + ) + + self.structured_query( + f""" + MERGE (source: {BASE_NODE_LABEL} {{id: '{param["source_id"]}'}}) + ON CREATE SET source:Chunk + MERGE (target: {BASE_NODE_LABEL} {{id: '{param["target_id"]}'}}) + ON CREATE SET target:Chunk + WITH source, target + MERGE (source)-[r:{param["label"]}]->(target) + SET r += {{{formatted_properties}}} + RETURN count(*) + """ + ) + + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes.""" + cypher_statement = f"MATCH (e:{BASE_NODE_LABEL}) " + + params = {} + cypher_statement += "WHERE e.id IS NOT NULL " + + if ids: + cypher_statement += "AND e.id IN $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND " + " AND ".join(prop_list) + + return_statement = """ + RETURN + e.id AS name, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS type, + properties(e) AS properties + """ + cypher_statement += return_statement + response = self.structured_query(cypher_statement, param_map=params) + response = response if response else [] + + nodes = [] + for record in response: + if "text" in record["properties"] or record["type"] is None: + text = record["properties"].pop("text", "") + nodes.append( + ChunkNode( + id_=record["name"], + text=text, + properties=remove_empty_values(record["properties"]), + ) + ) + else: + nodes.append( + EntityNode( + name=record["name"], + label=record["type"], + properties=remove_empty_values(record["properties"]), + ) + ) + + return nodes + + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + cypher_statement = f"MATCH (e:`{BASE_ENTITY_LABEL}`)-[r]->(t) " + + params = {} + if entity_names or relation_names or properties or ids: + cypher_statement += "WHERE " + + if entity_names: + cypher_statement += "e.name in $entity_names " + params["entity_names"] = entity_names + + if relation_names and entity_names: + cypher_statement += f"AND " + + if relation_names: + cypher_statement += "type(r) in $relation_names " + params[f"relation_names"] = relation_names + + if ids: + cypher_statement += "e.id in $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND ".join(prop_list) + + if not (entity_names or properties or relation_names or ids): + return_statement = """ + WHERE NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] + END AS target_type, properties(t) AS target_properties LIMIT 100; + """ + else: + return_statement = """ + AND NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] + END AS target_type, properties(t) AS target_properties LIMIT 100; + """ + + cypher_statement += return_statement + data = self.structured_query(cypher_statement, param_map=params) + data = data if data else [] + + triplets = [] + for record in data: + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + properties=remove_empty_values(record["rel_prop"]), + ) + triplets.append([source, rel, target]) + return triplets + + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + triples = [] + + ids = [node.id for node in graph_nodes] + response = self.structured_query( + f""" + WITH $ids AS id_list + UNWIND range(0, size(id_list) - 1) AS idx + MATCH (e:__Node__) + WHERE e.id = id_list[idx] + MATCH p=(e)-[r*1..{depth}]-(other) + WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS') + UNWIND relationships(p) AS rel + WITH DISTINCT rel, idx + WITH startNode(rel) AS source, + type(rel) AS type, + rel{{.*}} AS rel_properties, + endNode(rel) AS endNode, + idx + LIMIT toInteger($limit) + RETURN source.id AS source_id, + CASE + WHEN labels(source)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(source)) > 2 THEN labels(source)[2] + WHEN size(labels(source)) > 1 THEN labels(source)[1] + ELSE NULL + END + ELSE labels(source)[0] + END AS source_type, + properties(source) AS source_properties, + type, + rel_properties, + endNode.id AS target_id, + CASE + WHEN labels(endNode)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(endNode)) > 2 THEN labels(endNode)[2] + WHEN size(labels(endNode)) > 1 THEN labels(endNode)[1] ELSE NULL + END + ELSE labels(endNode)[0] + END AS target_type, + properties(endNode) AS target_properties, + idx + ORDER BY idx + LIMIT toInteger($limit) + """, + param_map={"ids": ids, "limit": limit}, + ) + response = response if response else [] + + ignore_rels = ignore_rels or [] + for record in response: + if record["type"] in ignore_rels: + continue + + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + properties=remove_empty_values(record["rel_properties"]), + ) + triples.append([source, rel, target]) + + return triples + + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + param_map = param_map or {} + + with self._driver.session(database=self._database) as session: + result = session.run(query, param_map) + full_result = [d.data() for d in result] + + if self.sanitize_query_output: + return [value_sanitize(el) for el in full_result] + return full_result + + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + raise NotImplementedError( + "Vector query is not currently implemented for MemgraphPropertyGraphStore." + ) + + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + if entity_names: + self.structured_query( + "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n", + param_map={"entity_names": entity_names}, + ) + if ids: + self.structured_query( + "MATCH (n) WHERE n.id IN $ids DETACH DELETE n", + param_map={"ids": ids}, + ) + if relation_names: + for rel in relation_names: + self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r") + + if properties: + cypher = "MATCH (e) WHERE " + prop_list = [] + params = {} + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher += " AND ".join(prop_list) + self.structured_query(cypher + " DETACH DELETE e", param_map=params) + + def _enhanced_schema_cypher( + self, + label_or_type: str, + properties: List[Dict[str, Any]], + exhaustive: bool, + is_relationship: bool = False, + ) -> str: + if is_relationship: + match_clause = f"MATCH ()-[n:`{label_or_type}`]->()" + else: + match_clause = f"MATCH (n:`{label_or_type}`)" + + with_clauses = [] + return_clauses = [] + output_dict = {} + if exhaustive: + for prop in properties: + if prop["property"]: + prop_name = prop["property"] + else: + prop_name = None + if prop["type"]: + prop_type = prop["type"] + else: + prop_type = None + if prop_type == "String": + with_clauses.append( + f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append( + f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}]," + f" distinct_count: size(`{prop_name}_values`)" + ) + elif prop_type in [ + "Int", + "Double", + "Date", + "LocalTime", + "LocalDateTime", + ]: + with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`") + with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`") + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + elif prop_type in ["List", "List[Any]"]: + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["Bool", "Duration"]: + continue + if return_clauses: + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + else: + output_dict[prop_name] = None + else: + # Just sample 5 random nodes + match_clause += " WITH n LIMIT 5" + for prop in properties: + prop_name = prop["property"] + prop_type = prop["type"] + + # Check if indexed property, we can still do exhaustive + prop_index = [ + el + for el in self.structured_schema["metadata"]["index"] + if el["label"] == label_or_type + and el["properties"] == [prop_name] + and el["type"] == "RANGE" + ] + if prop_type == "String": + if ( + prop_index + and prop_index[0].get("size") > 0 + and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT + ): + distinct_values_query = f""" + MATCH (n:{label_or_type}) + RETURN DISTINCT n.`{prop_name}` AS value + LIMIT {DISTINCT_VALUE_LIMIT} + """ + distinct_values = self.query(distinct_values_query) + + # Extract values from the result set + distinct_values = [ + record["value"] for record in distinct_values + ] + + return_clauses.append( + f"values: {distinct_values}," + f" distinct_count: {len(distinct_values)}" + ) + else: + with_clauses.append( + f"collect(distinct substring(n.`{prop_name}`, 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + elif prop_type in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if not prop_index: + with_clauses.append( + f"collect(distinct toString(n.`{prop_name}`)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + else: + with_clauses.append( + f"min(n.`{prop_name}`) AS `{prop_name}_min`" + ) + with_clauses.append( + f"max(n.`{prop_name}`) AS `{prop_name}_max`" + ) + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + + elif prop_type in ["List", "List[Any]"]: + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["Bool", "Duration"]: + continue + if return_clauses: + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + else: + output_dict[prop_name] = None + + with_clause = "WITH " + ",\n ".join(with_clauses) + return_clause = ( + "RETURN {" + + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items()) + + "} AS output" + ) + # Combine all parts of the Cypher query + return f"{match_clause}\n{with_clause}\n{return_clause}" + + def get_schema(self, refresh: bool = False) -> Any: + if refresh: + self.refresh_schema() + + return self.structured_schema + + def get_schema_str(self, refresh: bool = False) -> str: + schema = self.get_schema(refresh=refresh) + + formatted_node_props = [] + formatted_rel_props = [] + + if self.enhanced_schema: + # Enhanced formatting for nodes + for node_type, properties in schema["node_props"].items(): + formatted_node_props.append(f"- **{node_type}**") + for prop in properties: + example = "" + if prop["type"] == "String" and prop.get("values"): + if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop["values"] + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop["values"] + else "" + ) + + elif prop["type"] in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if prop.get("min") is not None: + example = f'Min: {prop["min"]}, Max: {prop["max"]}' + else: + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] in ["List", "List[Any]"]: + # Skip embeddings + if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT: + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_node_props.append( + f" - `{prop['property']}`: {prop['type']} {example}" + ) + + # Enhanced formatting for relationships + for rel_type, properties in schema["rel_props"].items(): + formatted_rel_props.append(f"- **{rel_type}**") + for prop in properties: + example = "" + if prop["type"] == "STRING": + if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop.get("values") + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop.get("values") + else "" + ) + elif prop["type"] in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if prop.get("min"): # If we have min/max + example = f'Min: {prop["min"]}, Max: {prop["max"]}' + else: # return a single value + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] == "List[Any]": + # Skip embeddings + if prop["min_size"] > LIST_LIMIT: + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_rel_props.append( + f" - `{prop['property']}: {prop['type']}` {example}" + ) + else: + # Format node properties + for label, props in schema["node_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_node_props.append(f"{label} {{{props_str}}}") + + # Format relationship properties using structured_schema + for type, props in schema["rel_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_rel_props.append(f"{type} {{{props_str}}}") + + # Format relationships + formatted_rels = [ + f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" + for el in schema["relationships"] + ] + + return "\n".join( + [ + "Node properties:", + "\n".join(formatted_node_props), + "Relationship properties:", + "\n".join(formatted_rel_props), + "The relationships:", + "\n".join(formatted_rels), + ] + ) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml new file mode 100644 index 0000000000000..c2ebce54d7a65 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = true +import_path = "llama_index.graph_stores.memgraph" + +[tool.llamahub.class_authors] +MemgraphGraphStore = "llama-index" +MemgraphPropertyGraphStore = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Your Name "] +description = "llama-index graph-stores memgraph integration" +license = "MIT" +name = "llama-index-graph-stores-memgraph" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +neo4j = "^5.24.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py new file mode 100644 index 0000000000000..6b82a46196742 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py @@ -0,0 +1,10 @@ +from unittest.mock import MagicMock, patch + +from llama_index.core.graph_stores.types import GraphStore +from llama_index.graph_stores.memgraph import MemgraphGraphStore + + +@patch("llama_index.graph_stores.memgraph.MemgraphGraphStore") +def test_memgraph_graph_store(MockMemgraphGraphStore: MagicMock): + instance: MemgraphGraphStore = MockMemgraphGraphStore.return_value() + assert isinstance(instance, GraphStore) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py new file mode 100644 index 0000000000000..a6260027b1c63 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py @@ -0,0 +1,92 @@ +import os +import pytest +from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore +from llama_index.core.graph_stores.types import ( + EntityNode, + Relation, +) +from llama_index.core.schema import TextNode + +memgraph_user = os.environ.get("MEMGRAPH_TEST_USER") +memgraph_pass = os.environ.get("MEMGRAPH_TEST_PASS") +memgraph_url = os.environ.get("MEMGRAPH_TEST_URL") + +if not memgraph_user or not memgraph_pass or not memgraph_url: + memgraph_available = False +else: + memgraph_available = True + + +@pytest.fixture() +def pg_store() -> MemgraphPropertyGraphStore: + if not memgraph_available: + pytest.skip("No Memgraph credentials provided") + return MemgraphPropertyGraphStore( + username=memgraph_user, password=memgraph_pass, url=memgraph_url + ) + + +def test_memgraph_pg_store(pg_store: MemgraphPropertyGraphStore) -> None: + # Clear the database + pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") + pg_store.structured_query("DROP GRAPH") + pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + + # Test upsert nodes + entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) + entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") + pg_store.upsert_nodes([entity1, entity2]) + + # Assert the nodes are inserted correctly + kg_nodes = pg_store.get(ids=[entity1.id]) + + # Test inserting relations into Memgraph. + relation = Relation( + label="WORKS_FOR", + source_id=entity1.id, + target_id=entity2.id, + properties={"since": 2023}, + ) + + pg_store.upsert_relations([relation]) + + # Assert the relation is inserted correctly by retrieving the relation map + kg_nodes = pg_store.get(ids=[entity1.id]) + paths = pg_store.get_rel_map(kg_nodes, depth=1) + + # Test inserting a source text node and 'MENTIONS' relations. + source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") + + relations = [ + Relation(label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id), + Relation(label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id), + ] + + pg_store.upsert_llama_nodes([source_node]) + pg_store.upsert_relations(relations) + + # Assert the source node and relations are inserted correctly + llama_nodes = pg_store.get_llama_nodes([source_node.node_id]) + + # Test retrieving nodes by properties. + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test executing a structured query in Memgraph. + query = "MATCH (n:`__Entity__`) RETURN n" + result = pg_store.structured_query(query) + + # Test upserting a new node with additional properties. + new_node = EntityNode( + label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} + ) + pg_store.upsert_nodes([new_node]) + + # Assert the node has been updated with the new property + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test deleting nodes from Memgraph. + pg_store.delete(ids=[source_node.node_id]) + pg_store.delete(ids=[entity1.id, entity2.id]) + + # Assert the nodes have been deleted + nodes = pg_store.get(ids=[entity1.id, entity2.id])