From a10bdd53f219913a59fc62f50a93b2cad8d347bb Mon Sep 17 00:00:00 2001 From: matea16 Date: Thu, 3 Oct 2024 11:49:21 +0200 Subject: [PATCH 1/6] Add Memgraph integration --- .../llama-index-graph-stores-memgraph/BUILD | 1 + .../Makefile | 17 + .../README.md | 1 + .../examples/kg_example.py | 61 ++ .../examples/pg_example.py | 80 ++ .../graph_stores/memgraph/__init__.py | 5 + .../graph_stores/memgraph/kg_base.py | 169 ++++ .../graph_stores/memgraph/property_graph.py | 927 ++++++++++++++++++ .../pyproject.toml | 57 ++ .../tests/__init__.py | 0 .../tests/test_graph_stores_memgraph.py | 40 + .../tests/test_pg_stores_memgraph.py | 111 +++ 12 files changed, 1469 insertions(+) create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md new file mode 100644 index 0000000000000..9b972b7a8a899 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md @@ -0,0 +1 @@ +# LlamaIndex Graph-Stores Integration: Memgraph diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py new file mode 100644 index 0000000000000..c5fa02edaf9a1 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py @@ -0,0 +1,61 @@ +import os +import logging +from llama_index.llms.openai import OpenAI +from llama_index.core import Settings +from llama_index.core import KnowledgeGraphIndex, SimpleDirectoryReader, StorageContext +from llama_index.graph_stores.memgraph import MemgraphGraphStore + + +# Step 1: Set up OpenAI API key +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key + +# Step 2: Configure logging +logging.basicConfig(level=logging.INFO) + +# Step 3: Configure OpenAI LLM +llm = OpenAI(temperature=0, model="gpt-3.5-turbo") +Settings.llm = llm +Settings.chunk_size = 512 + +# Step 4: Write documents to text files (Simulating loading documents from disk) +documents = { + "doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.", + "doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.", + "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development." +} + +for filename, content in documents.items(): + with open(filename, "w") as file: + file.write(content) + +# Step 5: Load documents +loaded_documents = SimpleDirectoryReader(".").load_data() + +# Step 6: Set up Memgraph connection +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' +database = "memgraph" # Name of the database, default is 'memgraph' + +graph_store = MemgraphGraphStore( + username=username, + password=password, + url=url, + database=database, +) + +storage_context = StorageContext.from_defaults(graph_store=graph_store) + +# Step 7: Create a Knowledge Graph Index +index = KnowledgeGraphIndex.from_documents( + loaded_documents, + storage_context=storage_context, + max_triplets_per_chunk=3, +) + +# Step 8: Query the Knowledge Graph +query_engine = index.as_query_engine(include_text=False, response_mode="tree_summarize") +response = query_engine.query("Tell me about Python and its uses") + +print("Query Response:") +print(response) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py new file mode 100644 index 0000000000000..32324735d3a49 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py @@ -0,0 +1,80 @@ +import os +import urllib.request +import nest_asyncio +import logging +from llama_index.core import SimpleDirectoryReader, PropertyGraphIndex +from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI +from llama_index.core.indices.property_graph import SchemaLLMPathExtractor + + +# 1. Setup OpenAI API Key (replace this with your actual key) +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key + +# 2. Create the data directory and download the Paul Graham essay +os.makedirs('data/paul_graham/', exist_ok=True) + +url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' +output_path = 'data/paul_graham/paul_graham_essay.txt' +urllib.request.urlretrieve(url, output_path) + +# 3. Ensure nest_asyncio is applied +nest_asyncio.apply() + +# Step 2: Read the file, replace single quotes, and save the modified content +with open(output_path, 'r', encoding='utf-8') as file: + content = file.read() + +# Replace single quotes with escaped single quotes +modified_content = content.replace("'", "\\'") + +# Save the modified content back to the same file +with open(output_path, 'w', encoding='utf-8') as file: + file.write(modified_content) + +# 4. Load the document data +documents = SimpleDirectoryReader("./data/paul_graham/").load_data() + +# 5. Setup Memgraph connection (ensure Memgraph is running) +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' + +graph_store = MemgraphPropertyGraphStore( + username=username, + password=password, + url=url, +) + +# 6. Create the Property Graph Index +index = PropertyGraphIndex.from_documents( + documents, + embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"), + kg_extractors=[ + SchemaLLMPathExtractor( + llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), + ) + ], + property_graph_store=graph_store, + show_progress=True, +) + +# 7. Querying the graph +retriever = index.as_retriever(include_text=False) + +# Example query: "What happened at Interleaf and Viaweb?" +nodes = retriever.retrieve("What happened at Interleaf and Viaweb?") + +# Output results +print("Query Results:") +for node in nodes: + print(node.text) + +# Alternatively, using a query engine +query_engine = index.as_query_engine(include_text=True) + +# Perform a query and print the detailed response +response = query_engine.query("What happened at Interleaf and Viaweb?") +print("\nDetailed Query Response:") +print(str(response)) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py new file mode 100644 index 0000000000000..1678ed26586cf --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py @@ -0,0 +1,5 @@ +from llama_index.graph_stores.memgraph.kg_base import MemgraphGraphStore +from llama_index.graph_stores.memgraph.property_graph import MemgraphPropertyGraphStore + +__all__ = ["MemgraphGraphStore", "MemgraphPropertyGraphStore"] + diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py new file mode 100644 index 0000000000000..dd0d044cd7552 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py @@ -0,0 +1,169 @@ +"""Memgraph graph store index.""" +import logging +from typing import Any, Dict, List, Optional + +from llama_index.core.graph_stores.types import GraphStore + +logger = logging.getLogger(__name__) + +node_properties_query = """ +CALL schema.node_type_properties() +YIELD nodeType AS label, propertyName AS property, propertyTypes AS type +WITH label AS nodeLabels, collect({property: property, type: type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output +""" + +rel_properties_query = """ +CALL schema.rel_type_properties() +YIELD relType AS label, propertyName AS property, propertyTypes AS type +WITH label, collect({property: property, type: type}) AS properties +RETURN {type: label, properties: properties} AS output +""" + +rel_query = """ +MATCH (start_node)-[r]->(end_node) +WITH labels(start_node) AS start, type(r) AS relationship_type, labels(end_node) AS end, keys(r) AS relationship_properties +UNWIND end AS end_label +RETURN DISTINCT {start: start[0], type: relationship_type, end: end_label} AS output +""" + +class MemgraphGraphStore(GraphStore): + def __init__( + self, + username: str, + password: str, + url: str, + database: str = "memgraph", + node_label: str = "Entity", + **kwargs: Any, + ) -> None: + try: + import neo4j + except ImportError: + raise ImportError("Please install neo4j: pip install neo4j") + self.node_label = node_label + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + self.schema = "" + # verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the username and password are correct" + ) + # set schema + self.refresh_schema() + + # create constraint + self.query( + """ + CREATE CONSTRAINT ON (n:%s) ASSERT n.id IS UNIQUE; + """ + % (self.node_label) + ) + + # create index + self.query( + """ + CREATE INDEX ON :%s(id); + """ + % (self.node_label) + ) + + @property + def client(self) -> Any: + return self._driver + + + def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any: + """Execute a Cypher query.""" + with self._driver.session(database=self._database) as session: + result = session.run(query, param_map) + return [record.data() for record in result] + + def get(self, subj: str) -> List[List[str]]: + """Get triplets.""" + query = f""" + MATCH (n1:{self.node_label})-[r]->(n2:{self.node_label}) + WHERE n1.id = $subj + RETURN type(r), n2.id; + """ + + with self._driver.session(database=self._database) as session: + data = session.run(query, {"subj": subj}) + return [record.values() for record in data] + + def get_rel_map( + self, subjs: Optional[List[str]] = None, depth: int = 2 + ) -> Dict[str, List[List[str]]]: + """Get flat relation map.""" + rel_map: Dict[Any, List[Any]] = {} + if subjs is None or len(subjs) == 0: + return rel_map + + query = ( + f"""MATCH p=(n1:{self.node_label})-[*1..{depth}]->() """ + f"""{"WHERE n1.id IN $subjs" if subjs else ""} """ + "UNWIND relationships(p) AS rel " + "WITH n1.id AS subj, collect([type(rel), endNode(rel).id]) AS rels " + "RETURN subj, rels" + ) + + data = list(self.query(query, {"subjs": subjs})) + if not data: + return rel_map + + for record in data: + rel_map[record["subj"]] = record["rels"] + + return rel_map + + def upsert_triplet(self, subj: str, rel: str, obj: str) -> None: + """Add triplet.""" + query = f""" + MERGE (n1:`{self.node_label}` {{id:$subj}}) + MERGE (n2:`{self.node_label}` {{id:$obj}}) + MERGE (n1)-[:`{rel.replace(" ", "_").upper()}`]->(n2) + """ + self.query(query, {"subj": subj, "obj": obj}) + + def delete(self, subj: str, rel: str, obj: str) -> None: + """Delete triplet.""" + query = f""" + MATCH (n1:`{self.node_label}`)-[r:`{rel}`]->(n2:`{self.node_label}`) + WHERE n1.id = $subj AND n2.id = $obj + DELETE r + """ + self.query(query, {"subj": subj, "obj": obj}) + + def refresh_schema(self) -> None: + """ + Refreshes the Memgraph graph schema information. + """ + node_properties = self.query(node_properties_query) + relationships_properties = self.query(rel_properties_query) + relationships = self.query(rel_query) + + self.schema = f""" + Node properties are the following: + {[el for el in node_properties]} + Relationship properties are the following: + {[el for el in relationships_properties]} + The relationships are the following: + {[el for el in relationships]} + """ + + def get_schema(self, refresh: bool = False) -> str: + """Get the schema of the MemgraphGraph store.""" + if self.schema and not refresh: + return self.schema + self.refresh_schema() + logger.debug(f"get_schema() schema:\n{self.schema}") + return self.schema \ No newline at end of file diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py new file mode 100644 index 0000000000000..bab513891e6cd --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py @@ -0,0 +1,927 @@ +from typing import Any, List, Dict, Optional, Tuple +from llama_index.core.graph_stores.prompts import DEFAULT_CYPHER_TEMPALTE +from llama_index.core.graph_stores.types import ( + PropertyGraphStore, + Triplet, + LabelledNode, + Relation, + EntityNode, + ChunkNode, +) +from llama_index.core.graph_stores.utils import ( + clean_string_values, + value_sanitize, + LIST_LIMIT, +) + +from llama_index.core.prompts import PromptTemplate +from llama_index.core.vector_stores.types import VectorStoreQuery +import neo4j + +def remove_empty_values(input_dict): + """ + Remove entries with empty values from the dictionary. + """ + return {key: value for key, value in input_dict.items() if value} + +BASE_ENTITY_LABEL = "__Entity__" +BASE_NODE_LABEL = "__Node__" +EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] +EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] +EXHAUSTIVE_SEARCH_LIMIT = 10000 +# Threshold for returning all available prop values in graph schema +DISTINCT_VALUE_LIMIT = 10 +CHUNK_SIZE = 1000 +# Threshold for max number of returned triplets +LIMIT = 100 + +node_properties_query = """ +MATCH (n) +UNWIND labels(n) AS label +WITH label, COUNT(n) AS count +CALL schema.node_type_properties() +YIELD propertyName, nodeLabels, propertyTypes +WITH label, nodeLabels, count, collect({property: propertyName, type: propertyTypes[0]}) AS properties +WHERE label IN nodeLabels +RETURN {labels: label, count: count, properties: properties} AS output +ORDER BY count DESC +""" + +rel_properties_query = """ +CALL schema.rel_type_properties() +YIELD relType AS label, propertyName AS property, propertyTypes AS type +WITH label, collect({property: property, type: type}) AS properties +RETURN {type: label, properties: properties} AS output +""" + +rel_query = """ +MATCH (start_node)-[r]->(end_node) +WITH DISTINCT labels(start_node) AS start_labels, type(r) AS relationship_type, labels(end_node) AS end_labels, keys(r) AS relationship_properties +UNWIND start_labels AS start_label +UNWIND end_labels AS end_label +RETURN DISTINCT {start: start_label, type: relationship_type, end: end_label} AS output +""" + +class MemgraphPropertyGraphStore(PropertyGraphStore): + r""" + Memgraph Property Graph Store. + + This class implements a Memgraph property graph store. + + Args: + username (str): The username for the Memgraph database. + password (str): The password for the Memgraph database. + url (str): The URL for the Memgraph database. + database (Optional[str]): The name of the database to connect to. Defaults to "memgraph". + + Examples: + ```python + from llama_index.core.indices.property_graph import PropertyGraphIndex + from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore + + # Create a MemgraphPropertyGraphStore instance + graph_store = MemgraphPropertyGraphStore( + username="memgraph", + password="password", + url="bolt://localhost:7687", + database="memgraph" + ) + + # Create the index + index = PropertyGraphIndex.from_documents( + documents, + property_graph_store=graph_store, + ) + + # Close the Memgraph connection explicitly. + graph_store.close() + ``` + """ + + supports_structured_queries: bool = True + text_to_cypher_template: PromptTemplate = DEFAULT_CYPHER_TEMPALTE + + def __init__( + self, + username: str, + password: str, + url: str, + database: Optional[str] = "memgraph", + refresh_schema: bool = True, + sanitize_query_output: bool = True, + enhanced_schema: bool = False, + **neo4j_kwargs: Any, + ) -> None: + self.sanitize_query_output = sanitize_query_output + self.enhanced_schema = enhanced_schema + self._driver = neo4j.GraphDatabase.driver( + url, auth=(username, password), **neo4j_kwargs + ) + self._database = database + self.structured_schema = {} + if refresh_schema: + self.refresh_schema() + + # Create index for faster imports and retrieval + self.structured_query( + f"""CREATE INDEX ON :{BASE_NODE_LABEL}(id);""" + ) + self.structured_query( + f"""CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);""" + ) + + @property + def client(self): + return self._driver + + def close(self) -> None: + self._driver.close() + + def refresh_schema(self) -> None: + """Refresh the schema.""" + # Leave schema empty if db is empty + if self.structured_query("MATCH (n) RETURN n LIMIT 1") == []: + return + + node_query_results = self.structured_query( + node_properties_query, + param_map={ + "EXCLUDED_LABELS": [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ] + }, + ) + node_properties = {} + for el in node_query_results: + if el["output"]["labels"] in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL]: + continue + + label = el["output"]["labels"] + properties = el["output"]["properties"] + if label in node_properties: + node_properties[label]["properties"].extend( + prop for prop in properties if prop not in node_properties[label]["properties"] + ) + else: + node_properties[label] = {"properties": properties} + + node_properties = [{"labels": label, **value} for label, value in node_properties.items()] + rels_query_result = self.structured_query( + rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS} + ) + rel_properties = ( + [el["output"] for el in rels_query_result + if any(prop["property"] for prop in el["output"].get("properties", []))] + if rels_query_result + else [] + ) + rel_objs_query_result = self.structured_query( + rel_query, + param_map={ + "EXCLUDED_LABELS": [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ] + }, + ) + relationships = [ + el["output"] for el in rel_objs_query_result + if rel_objs_query_result and + el["output"]["start"] not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] and + el["output"]["end"] not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] + ] + self.structured_schema = { + "node_props": {el["labels"]: el["properties"] for el in node_properties}, + "rel_props": {el["type"]: el["properties"] for el in rel_properties}, + "relationships": relationships, + } + schema_nodes = self.structured_query( + "MATCH (n) UNWIND labels(n) AS label RETURN label AS node, COUNT(n) AS count ORDER BY count DESC" + ) + schema_rels = self.structured_query( + "MATCH ()-[r]->() RETURN TYPE(r) AS relationship_type, COUNT(r) AS count" + ) + schema_counts = [{ + 'nodes': [{'name': item['node'], 'count': item['count']} for item in schema_nodes], + 'relationships': [{'name': item['relationship_type'], 'count': item['count']} for item in schema_rels] + }] + # Update node info + for node in schema_counts[0].get("nodes", []): + # Skip bloom labels + if node["name"] in EXCLUDED_LABELS: + continue + node_props = self.structured_schema["node_props"].get(node['name']) + if not node_props: # The node has no properties + continue + + enhanced_cypher = self._enhanced_schema_cypher( + node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT + ) + output = self.structured_query(enhanced_cypher) + enhanced_info = output[0]["output"] + for prop in node_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + + # Update rel info + for rel in schema_counts[0].get("relationships", []): + if rel["name"] in EXCLUDED_RELS: + continue + rel_props = self.structured_schema["rel_props"].get(f":`{rel['name']}`") + if not rel_props: # The rel has no properties + continue + enhanced_cypher = self._enhanced_schema_cypher( + rel["name"], + rel_props, + rel["count"] < EXHAUSTIVE_SEARCH_LIMIT, + is_relationship=True, + ) + try: + enhanced_info = self.structured_query(enhanced_cypher)[0]["output"] + for prop in rel_props: + if prop["property"] in enhanced_info: + prop.update(enhanced_info[prop["property"]]) + except neo4j.exceptions.ClientError: + pass + + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + # Lists to hold separated types + entity_dicts: List[dict] = [] + chunk_dicts: List[dict] = [] + + # Sort by type + for item in nodes: + if isinstance(item, EntityNode): + entity_dicts.append({**item.dict(), "id": item.id}) + elif isinstance(item, ChunkNode): + chunk_dicts.append({**item.dict(), "id": item.id}) + else: + pass + if chunk_dicts: + for index in range(0, len(chunk_dicts), CHUNK_SIZE): + chunked_params = chunk_dicts[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + self.structured_query( + f""" + MERGE (c:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + SET c.`text` = '{param["text"]}', c:Chunk + WITH c + SET c += {{{formatted_properties}}} + RETURN count(*) + """ + ) + if entity_dicts: + for index in range(0, len(entity_dicts), CHUNK_SIZE): + chunked_params = entity_dicts[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + self.structured_query( + f""" + MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + SET e += {{{formatted_properties}}} + SET e.name = '{param["name"]}', e:`{BASE_ENTITY_LABEL}` + WITH e + SET e :{param["label"]} + """ + ) + triplet_source_id = param['properties'].get('triplet_source_id') + if triplet_source_id: + self.structured_query( + f""" + MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) + MERGE (c:{BASE_NODE_LABEL} {{id: '{triplet_source_id}'}}) + MERGE (e)<-[:MENTIONS]-(c) + """ + ) + + def upsert_relations(self, relations: List[Relation]) -> None: + """Add relations.""" + params = [r.dict() for r in relations] + for index in range(0, len(params), CHUNK_SIZE): + chunked_params = params[index : index + CHUNK_SIZE] + for param in chunked_params: + formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + + self.structured_query( + f""" + MERGE (source: {BASE_NODE_LABEL} {{id: '{param["source_id"]}'}}) + ON CREATE SET source:Chunk + MERGE (target: {BASE_NODE_LABEL} {{id: '{param["target_id"]}'}}) + ON CREATE SET target:Chunk + WITH source, target + MERGE (source)-[r:{param["label"]}]->(target) + SET r += {{{formatted_properties}}} + RETURN count(*) + """ + ) + + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes.""" + cypher_statement = f"MATCH (e:{BASE_NODE_LABEL}) " + + params = {} + cypher_statement += "WHERE e.id IS NOT NULL " + + if ids: + cypher_statement += "AND e.id IN $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND " + " AND ".join(prop_list) + + return_statement = """ + RETURN + e.id AS name, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS type, + properties(e) AS properties + """ + cypher_statement += return_statement + response = self.structured_query(cypher_statement, param_map=params) + response = response if response else [] + + nodes = [] + for record in response: + if "text" in record["properties"] or record["type"] is None: + text = record["properties"].pop("text", "") + nodes.append( + ChunkNode( + id_=record["name"], + text=text, + properties=remove_empty_values(record["properties"]), + ) + ) + else: + nodes.append( + EntityNode( + name=record["name"], + label=record["type"], + properties=remove_empty_values(record["properties"]), + ) + ) + + return nodes + + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + cypher_statement = f"MATCH (e:`{BASE_ENTITY_LABEL}`)-[r]->(t) " + + params = {} + if entity_names or relation_names or properties or ids: + cypher_statement += "WHERE " + + if entity_names: + cypher_statement += "e.name in $entity_names " + params["entity_names"] = entity_names + + if relation_names and entity_names: + cypher_statement += f"AND " + + if relation_names: + cypher_statement += "type(r) in $relation_names " + params[f"relation_names"] = relation_names + + if ids: + cypher_statement += "e.id in $ids " + params["ids"] = ids + + if properties: + prop_list = [] + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher_statement += " AND ".join(prop_list) + + if not (entity_names or properties or relation_names or ids): + return_statement = """ + WHERE NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] + END AS target_type, properties(t) AS target_properties LIMIT 100; + """ + else: + return_statement = """ + AND NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] + END AS target_type, properties(t) AS target_properties LIMIT 100; + """ + + cypher_statement += return_statement + data = self.structured_query(cypher_statement, param_map=params) + data = data if data else [] + + triplets = [] + for record in data: + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + properties=remove_empty_values(record["rel_prop"]), + ) + triplets.append([source, rel, target]) + return triplets + + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + triples = [] + + ids = [node.id for node in graph_nodes] + response = self.structured_query( + f""" + WITH $ids AS id_list + UNWIND range(0, size(id_list) - 1) AS idx + MATCH (e:__Node__) + WHERE e.id = id_list[idx] + MATCH p=(e)-[r*1..{depth}]-(other) + WHERE ALL(rel in relationships(p) WHERE type(rel) <> 'MENTIONS') + UNWIND relationships(p) AS rel + WITH DISTINCT rel, idx + WITH startNode(rel) AS source, + type(rel) AS type, + rel{{.*}} AS rel_properties, + endNode(rel) AS endNode, + idx + LIMIT toInteger($limit) + RETURN source.id AS source_id, + CASE + WHEN labels(source)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(source)) > 2 THEN labels(source)[2] + WHEN size(labels(source)) > 1 THEN labels(source)[1] + ELSE NULL + END + ELSE labels(source)[0] + END AS source_type, + properties(source) AS source_properties, + type, + rel_properties, + endNode.id AS target_id, + CASE + WHEN labels(endNode)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(endNode)) > 2 THEN labels(endNode)[2] + WHEN size(labels(endNode)) > 1 THEN labels(endNode)[1] ELSE NULL + END + ELSE labels(endNode)[0] + END AS target_type, + properties(endNode) AS target_properties, + idx + ORDER BY idx + LIMIT toInteger($limit) + """, + param_map={"ids": ids, "limit": limit}, + ) + response = response if response else [] + + ignore_rels = ignore_rels or [] + for record in response: + if record["type"] in ignore_rels: + continue + + source = EntityNode( + name=record["source_id"], + label=record["source_type"], + properties=remove_empty_values(record["source_properties"]), + ) + target = EntityNode( + name=record["target_id"], + label=record["target_type"], + properties=remove_empty_values(record["target_properties"]), + ) + rel = Relation( + source_id=record["source_id"], + target_id=record["target_id"], + label=record["type"], + properties=remove_empty_values(record["rel_properties"]), + ) + triples.append([source, rel, target]) + + return triples + + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + param_map = param_map or {} + + with self._driver.session(database=self._database) as session: + result = session.run(query, param_map) + full_result = [d.data() for d in result] + + if self.sanitize_query_output: + return [value_sanitize(el) for el in full_result] + return full_result + + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + raise NotImplementedError( + "Vector query is not currently implemented for MemgraphPropertyGraphStore." + ) + + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + if entity_names: + self.structured_query( + "MATCH (n) WHERE n.name IN $entity_names DETACH DELETE n", + param_map={"entity_names": entity_names}, + ) + if ids: + self.structured_query( + "MATCH (n) WHERE n.id IN $ids DETACH DELETE n", + param_map={"ids": ids}, + ) + if relation_names: + for rel in relation_names: + self.structured_query(f"MATCH ()-[r:`{rel}`]->() DELETE r") + + if properties: + cypher = "MATCH (e) WHERE " + prop_list = [] + params = {} + for i, prop in enumerate(properties): + prop_list.append(f"e.`{prop}` = $property_{i}") + params[f"property_{i}"] = properties[prop] + cypher += " AND ".join(prop_list) + self.structured_query(cypher + " DETACH DELETE e", param_map=params) + + def _enhanced_schema_cypher( + self, + label_or_type: str, + properties: List[Dict[str, Any]], + exhaustive: bool, + is_relationship: bool = False, + ) -> str: + if is_relationship: + match_clause = f"MATCH ()-[n:`{label_or_type}`]->()" + else: + match_clause = f"MATCH (n:`{label_or_type}`)" + + with_clauses = [] + return_clauses = [] + output_dict = {} + if exhaustive: + for prop in properties: + if prop["property"]: + prop_name = prop["property"] + else: + prop_name = None + if prop["type"]: + prop_type = prop["type"] + else: + prop_type = None + if prop_type == "String": + with_clauses.append( + f"collect(distinct substring(toString(n.`{prop_name}`), 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append( + f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}]," + f" distinct_count: size(`{prop_name}_values`)" + ) + elif prop_type in [ + "Int", + "Double", + "Date", + "LocalTime", + "LocalDateTime", + ]: + with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`") + with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`") + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + elif prop_type in ["List", "List[Any]"]: + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["Bool", "Duration"]: + continue + if return_clauses: + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + else: + output_dict[prop_name] = None + else: + # Just sample 5 random nodes + match_clause += " WITH n LIMIT 5" + for prop in properties: + prop_name = prop["property"] + prop_type = prop["type"] + + # Check if indexed property, we can still do exhaustive + prop_index = [ + el + for el in self.structured_schema["metadata"]["index"] + if el["label"] == label_or_type + and el["properties"] == [prop_name] + and el["type"] == "RANGE" + ] + if prop_type == "String": + if ( + prop_index + and prop_index[0].get("size") > 0 + and prop_index[0].get("distinctValues") <= DISTINCT_VALUE_LIMIT + ): + distinct_values_query = f""" + MATCH (n:{label_or_type}) + RETURN DISTINCT n.`{prop_name}` AS value + LIMIT {DISTINCT_VALUE_LIMIT} + """ + distinct_values = self.query(distinct_values_query) + + # Extract values from the result set + distinct_values = [record["value"] for record in distinct_values] + + return_clauses.append( + f"values: {distinct_values}," + f" distinct_count: {len(distinct_values)}" + ) + else: + with_clauses.append( + f"collect(distinct substring(n.`{prop_name}`, 0, 50)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + elif prop_type in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if not prop_index: + with_clauses.append( + f"collect(distinct toString(n.`{prop_name}`)) " + f"AS `{prop_name}_values`" + ) + return_clauses.append(f"values: `{prop_name}_values`") + else: + with_clauses.append( + f"min(n.`{prop_name}`) AS `{prop_name}_min`" + ) + with_clauses.append( + f"max(n.`{prop_name}`) AS `{prop_name}_max`" + ) + with_clauses.append( + f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`" + ) + return_clauses.append( + f"min: toString(`{prop_name}_min`), " + f"max: toString(`{prop_name}_max`), " + f"distinct_count: `{prop_name}_distinct`" + ) + + elif prop_type in ["List", "List[Any]"]: + with_clauses.append( + f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, " + f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`" + ) + return_clauses.append( + f"min_size: `{prop_name}_size_min`, " + f"max_size: `{prop_name}_size_max`" + ) + elif prop_type in ["Bool", "Duration"]: + continue + if return_clauses: + output_dict[prop_name] = "{" + return_clauses.pop() + "}" + else: + output_dict[prop_name] = None + + with_clause = "WITH " + ",\n ".join(with_clauses) + return_clause = ( + "RETURN {" + + ", ".join(f"`{k}`: {v}" for k, v in output_dict.items()) + + "} AS output" + ) + # Combine all parts of the Cypher query + return f"{match_clause}\n{with_clause}\n{return_clause}" + + def get_schema(self, refresh: bool = False) -> Any: + if refresh: + self.refresh_schema() + + return self.structured_schema + + def get_schema_str(self, refresh: bool = False) -> str: + schema = self.get_schema(refresh=refresh) + + formatted_node_props = [] + formatted_rel_props = [] + + if self.enhanced_schema: + # Enhanced formatting for nodes + for node_type, properties in schema["node_props"].items(): + formatted_node_props.append(f"- **{node_type}**") + for prop in properties: + example = "" + if prop["type"] == "String" and prop.get("values"): + if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop["values"] + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop["values"] + else "" + ) + + elif prop["type"] in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if prop.get("min") is not None: + example = f'Min: {prop["min"]}, Max: {prop["max"]}' + else: + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] in ["List", "List[Any]"]: + # Skip embeddings + if not prop.get("min_size") or prop["min_size"] > LIST_LIMIT: + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_node_props.append( + f" - `{prop['property']}`: {prop['type']} {example}" + ) + + # Enhanced formatting for relationships + for rel_type, properties in schema["rel_props"].items(): + formatted_rel_props.append(f"- **{rel_type}**") + for prop in properties: + example = "" + if prop["type"] == "STRING": + if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT: + example = ( + f'Example: "{clean_string_values(prop["values"][0])}"' + if prop.get("values") + else "" + ) + else: # If less than 10 possible values return all + example = ( + ( + "Available options: " + f'{[clean_string_values(el) for el in prop["values"]]}' + ) + if prop.get("values") + else "" + ) + elif prop["type"] in [ + "Int", + "Double", + "Float", + "Date", + "LocalTime", + "LocalDateTime", + ]: + if prop.get("min"): # If we have min/max + example = f'Min: {prop["min"]}, Max: {prop["max"]}' + else: # return a single value + example = ( + f'Example: "{prop["values"][0]}"' + if prop.get("values") + else "" + ) + elif prop["type"] == "List[Any]": + # Skip embeddings + if prop["min_size"] > LIST_LIMIT: + continue + example = f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}' + formatted_rel_props.append( + f" - `{prop['property']}: {prop['type']}` {example}" + ) + else: + # Format node properties + for label, props in schema["node_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_node_props.append(f"{label} {{{props_str}}}") + + # Format relationship properties using structured_schema + for type, props in schema["rel_props"].items(): + props_str = ", ".join( + [f"{prop['property']}: {prop['type']}" for prop in props] + ) + formatted_rel_props.append(f"{type} {{{props_str}}}") + + # Format relationships + formatted_rels = [ + f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" + for el in schema["relationships"] + ] + + return "\n".join( + [ + "Node properties:", + "\n".join(formatted_node_props), + "Relationship properties:", + "\n".join(formatted_rel_props), + "The relationships:", + "\n".join(formatted_rels), + ] + ) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml new file mode 100644 index 0000000000000..7fc88c0539248 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = true +import_path = "llama_index.graph_stores.memgraph" + +[tool.llamahub.class_authors] +MemgraphGraphStore = "llama-index" +MemgraphPropertyGraphStore = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +name = "llama-index-graph-stores-memgraph" +version = "0.1.0" +description = "llama-index graph-stores memgraph integration" +authors = ["Your Name "] +license = "MIT" +readme = "README.md" +packages = [{include = "llama_index/"}] + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py new file mode 100644 index 0000000000000..38986e6789496 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py @@ -0,0 +1,40 @@ +import unittest +from llama_index.core.graph_stores.types import GraphStore +from llama_index.graph_stores.memgraph import MemgraphGraphStore + +class TestMemgraphGraphStore(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.store = MemgraphGraphStore(username="", password="", url="bolt://localhost:7687") + + def test_connection(self): + """Test if connection to Memgraph is working.""" + try: + self.store.client.verify_connectivity() + connected = True + except Exception as e: + connected = False + self.assertTrue(connected, "Could not connect to Memgraph") + + def test_upsert_triplet(self): + """Test inserting a triplet into Memgraph.""" + self.store.upsert_triplet("Alice", "KNOWS", "Bob") + triplets = self.store.get("Alice") + self.assertIn(["KNOWS", "Bob"], triplets) + + def test_delete_triplet(self): + """Test deleting a triplet from Memgraph.""" + self.store.delete("Alice", "KNOWS", "Bob") + triplets = self.store.get("Alice") + self.assertNotIn(["KNOWS", "Bob"], triplets) + + def test_get_rel_map(self): + """Test retrieving relationships.""" + self.store.upsert_triplet("Alice", "KNOWS", "Bob") + rel_map = self.store.get_rel_map(["Alice"], depth=2) + self.assertIn("Alice", rel_map) + self.assertIn(["KNOWS", "Bob"], rel_map["Alice"]) + +if __name__ == '__main__': + unittest.main() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py new file mode 100644 index 0000000000000..c773573719f33 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py @@ -0,0 +1,111 @@ +import unittest +from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore +from llama_index.core.graph_stores.types import ( + EntityNode, + Relation, +) +from llama_index.core.schema import TextNode + +class TestMemgraphGraphStore(unittest.TestCase): + + @classmethod + def setUp(self): + self.pg_store = MemgraphPropertyGraphStore(username="", password="", url="bolt://localhost:7687") + + def test_connection(self): + """Test if connection to Memgraph is working.""" + try: + self.pg_store.client.verify_connectivity() + connected = True + except Exception as e: + connected = False + self.assertTrue(connected, "Could not connect to Memgraph") + + def test_memgraph_pg_store(self): + """Clear the database""" + self.pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") + self.pg_store.structured_query("DROP GRAPH") + self.pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + + """Test upsert nodes""" + entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) + entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") + self.pg_store.upsert_nodes([entity1, entity2]) + + # Assert the nodes are inserted correctly + kg_nodes = self.pg_store.get(ids=[entity1.id]) + + self.assertEqual(len(kg_nodes), 1) + self.assertEqual(kg_nodes[0].name, "Logan") + + """Test inserting relations into Memgraph.""" + relation = Relation( + label="WORKS_FOR", + source_id=entity1.id, + target_id=entity2.id, + properties={'since': 2023} + ) + + self.pg_store.upsert_relations([relation]) + + # Assert the relation is inserted correctly by retrieving the relation map + kg_nodes = self.pg_store.get(ids=[entity1.id]) + paths = self.pg_store.get_rel_map(kg_nodes, depth=1) + self.assertEqual(len(paths), 1) + path = paths[0] + self.assertEqual(path[0].id, entity1.id) + self.assertEqual(path[2].id, entity2.id) + self.assertEqual(path[1].label, "WORKS_FOR") + + """Test inserting a source text node and 'MENTIONS' relations.""" + source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") + + relations = [ + Relation(label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id), + Relation(label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id) + ] + + self.pg_store.upsert_llama_nodes([source_node]) + self.pg_store.upsert_relations(relations) + + # Assert the source node and relations are inserted correctly + llama_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) + self.assertEqual(len(llama_nodes), 1) + self.assertEqual(llama_nodes[0].text, source_node.text) + + """Test retrieving nodes by properties.""" + kg_nodes = self.pg_store.get(properties={"age": 28}) + self.assertEqual(len(kg_nodes), 1) + self.assertEqual(kg_nodes[0].label, "PERSON") + self.assertEqual(kg_nodes[0].name, "Logan") + + """Test executing a structured query in Memgraph.""" + query = "MATCH (n:`__Entity__`) RETURN n" + result = self.pg_store.structured_query(query) + self.assertEqual(len(result), 2) + + """Test upserting a new node with additional properties.""" + new_node = EntityNode( + label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} + ) + self.pg_store.upsert_nodes([new_node]) + + # Assert the node has been updated with the new property + kg_nodes = self.pg_store.get(properties={"age": 28}) + self.assertEqual(len(kg_nodes), 1) + self.assertEqual(kg_nodes[0].label, "PERSON") + self.assertEqual(kg_nodes[0].name, "Logan") + self.assertEqual(kg_nodes[0].properties["location"], "Canada") + + """Test deleting nodes from Memgraph.""" + self.pg_store.delete(ids=[source_node.node_id]) + self.pg_store.delete(ids=[entity1.id, entity2.id]) + + # Assert the nodes have been deleted + nodes = self.pg_store.get(ids=[entity1.id, entity2.id]) + self.assertEqual(len(nodes), 0) + text_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) + self.assertEqual(len(text_nodes), 0) + +if __name__ == '__main__': + unittest.main() From 568be488b68cfb9721adc6fbeaa0c8e2bfbb6c6e Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Thu, 3 Oct 2024 14:48:44 -0600 Subject: [PATCH 2/6] linting + BUILD files --- .../llama-index-graph-stores-memgraph/BUILD | 4 +- .../examples/BUILD | 1 + .../examples/kg_example.py | 12 +- .../examples/pg_example.py | 24 +- .../llama_index/graph_stores/memgraph/BUILD | 1 + .../graph_stores/memgraph/__init__.py | 1 - .../graph_stores/memgraph/kg_base.py | 22 +- .../graph_stores/memgraph/property_graph.py | 243 ++++++++++-------- .../pyproject.toml | 12 +- .../tests/BUILD | 1 + .../tests/test_graph_stores_memgraph.py | 20 +- .../tests/test_pg_stores_memgraph.py | 61 +++-- 12 files changed, 225 insertions(+), 177 deletions(-) create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD create mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD index db46e8d6c978c..0896ca890d8bf 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/BUILD @@ -1 +1,3 @@ -python_sources() +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py index c5fa02edaf9a1..ad97247e3fb04 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py @@ -7,7 +7,7 @@ # Step 1: Set up OpenAI API key -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key # Step 2: Configure logging logging.basicConfig(level=logging.INFO) @@ -21,7 +21,7 @@ documents = { "doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.", "doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.", - "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development." + "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development.", } for filename, content in documents.items(): @@ -32,10 +32,10 @@ loaded_documents = SimpleDirectoryReader(".").load_data() # Step 6: Set up Memgraph connection -username = "" # Enter your Memgraph username (default "") -password = "" # Enter your Memgraph password (default "") -url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' -database = "memgraph" # Name of the database, default is 'memgraph' +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' +database = "memgraph" # Name of the database, default is 'memgraph' graph_store = MemgraphGraphStore( username=username, diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py index 32324735d3a49..ef2164192fe39 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py @@ -10,41 +10,41 @@ # 1. Setup OpenAI API Key (replace this with your actual key) -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key # 2. Create the data directory and download the Paul Graham essay -os.makedirs('data/paul_graham/', exist_ok=True) +os.makedirs("data/paul_graham/", exist_ok=True) -url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -output_path = 'data/paul_graham/paul_graham_essay.txt' +url = "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" +output_path = "data/paul_graham/paul_graham_essay.txt" urllib.request.urlretrieve(url, output_path) # 3. Ensure nest_asyncio is applied nest_asyncio.apply() # Step 2: Read the file, replace single quotes, and save the modified content -with open(output_path, 'r', encoding='utf-8') as file: +with open(output_path, "r", encoding="utf-8") as file: content = file.read() # Replace single quotes with escaped single quotes modified_content = content.replace("'", "\\'") # Save the modified content back to the same file -with open(output_path, 'w', encoding='utf-8') as file: +with open(output_path, "w", encoding="utf-8") as file: file.write(modified_content) # 4. Load the document data documents = SimpleDirectoryReader("./data/paul_graham/").load_data() # 5. Setup Memgraph connection (ensure Memgraph is running) -username = "" # Enter your Memgraph username (default "") -password = "" # Enter your Memgraph password (default "") -url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' graph_store = MemgraphPropertyGraphStore( username=username, - password=password, - url=url, + password=password, + url=url, ) # 6. Create the Property Graph Index @@ -53,7 +53,7 @@ embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"), kg_extractors=[ SchemaLLMPathExtractor( - llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), + llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), ) ], property_graph_store=graph_store, diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py index 1678ed26586cf..83b3a573e428e 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/__init__.py @@ -2,4 +2,3 @@ from llama_index.graph_stores.memgraph.property_graph import MemgraphPropertyGraphStore __all__ = ["MemgraphGraphStore", "MemgraphPropertyGraphStore"] - diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py index dd0d044cd7552..58bd70d57443e 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/kg_base.py @@ -27,6 +27,7 @@ RETURN DISTINCT {start: start[0], type: relationship_type, end: end_label} AS output """ + class MemgraphGraphStore(GraphStore): def __init__( self, @@ -59,7 +60,7 @@ def __init__( "Please ensure that the username and password are correct" ) # set schema - self.refresh_schema() + self.refresh_schema() # create constraint self.query( @@ -81,13 +82,12 @@ def __init__( def client(self) -> Any: return self._driver - def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any: """Execute a Cypher query.""" with self._driver.session(database=self._database) as session: result = session.run(query, param_map) return [record.data() for record in result] - + def get(self, subj: str) -> List[List[str]]: """Get triplets.""" query = f""" @@ -99,7 +99,7 @@ def get(self, subj: str) -> List[List[str]]: with self._driver.session(database=self._database) as session: data = session.run(query, {"subj": subj}) return [record.values() for record in data] - + def get_rel_map( self, subjs: Optional[List[str]] = None, depth: int = 2 ) -> Dict[str, List[List[str]]]: @@ -124,7 +124,7 @@ def get_rel_map( rel_map[record["subj"]] = record["rels"] return rel_map - + def upsert_triplet(self, subj: str, rel: str, obj: str) -> None: """Add triplet.""" query = f""" @@ -133,7 +133,7 @@ def upsert_triplet(self, subj: str, rel: str, obj: str) -> None: MERGE (n1)-[:`{rel.replace(" ", "_").upper()}`]->(n2) """ self.query(query, {"subj": subj, "obj": obj}) - + def delete(self, subj: str, rel: str, obj: str) -> None: """Delete triplet.""" query = f""" @@ -147,17 +147,17 @@ def refresh_schema(self) -> None: """ Refreshes the Memgraph graph schema information. """ - node_properties = self.query(node_properties_query) + node_properties = self.query(node_properties_query) relationships_properties = self.query(rel_properties_query) relationships = self.query(rel_query) self.schema = f""" Node properties are the following: - {[el for el in node_properties]} + {node_properties} Relationship properties are the following: - {[el for el in relationships_properties]} + {relationships_properties} The relationships are the following: - {[el for el in relationships]} + {relationships} """ def get_schema(self, refresh: bool = False) -> str: @@ -166,4 +166,4 @@ def get_schema(self, refresh: bool = False) -> str: return self.schema self.refresh_schema() logger.debug(f"get_schema() schema:\n{self.schema}") - return self.schema \ No newline at end of file + return self.schema diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py index bab513891e6cd..48854a895d478 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/llama_index/graph_stores/memgraph/property_graph.py @@ -18,12 +18,14 @@ from llama_index.core.vector_stores.types import VectorStoreQuery import neo4j + def remove_empty_values(input_dict): """ Remove entries with empty values from the dictionary. """ return {key: value for key, value in input_dict.items() if value} + BASE_ENTITY_LABEL = "__Entity__" BASE_NODE_LABEL = "__Node__" EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] @@ -36,10 +38,10 @@ def remove_empty_values(input_dict): LIMIT = 100 node_properties_query = """ -MATCH (n) -UNWIND labels(n) AS label -WITH label, COUNT(n) AS count -CALL schema.node_type_properties() +MATCH (n) +UNWIND labels(n) AS label +WITH label, COUNT(n) AS count +CALL schema.node_type_properties() YIELD propertyName, nodeLabels, propertyTypes WITH label, nodeLabels, count, collect({property: propertyName, type: propertyTypes[0]}) AS properties WHERE label IN nodeLabels @@ -62,6 +64,7 @@ def remove_empty_values(input_dict): RETURN DISTINCT {start: start_label, type: relationship_type, end: end_label} AS output """ + class MemgraphPropertyGraphStore(PropertyGraphStore): r""" Memgraph Property Graph Store. @@ -123,12 +126,8 @@ def __init__( self.refresh_schema() # Create index for faster imports and retrieval - self.structured_query( - f"""CREATE INDEX ON :{BASE_NODE_LABEL}(id);""" - ) - self.structured_query( - f"""CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);""" - ) + self.structured_query(f"""CREATE INDEX ON :{BASE_NODE_LABEL}(id);""") + self.structured_query(f"""CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);""") @property def client(self): @@ -155,26 +154,37 @@ def refresh_schema(self) -> None: ) node_properties = {} for el in node_query_results: - if el["output"]["labels"] in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL]: + if el["output"]["labels"] in [ + *EXCLUDED_LABELS, + BASE_ENTITY_LABEL, + BASE_NODE_LABEL, + ]: continue label = el["output"]["labels"] properties = el["output"]["properties"] if label in node_properties: node_properties[label]["properties"].extend( - prop for prop in properties if prop not in node_properties[label]["properties"] + prop + for prop in properties + if prop not in node_properties[label]["properties"] ) else: node_properties[label] = {"properties": properties} - node_properties = [{"labels": label, **value} for label, value in node_properties.items()] + node_properties = [ + {"labels": label, **value} for label, value in node_properties.items() + ] rels_query_result = self.structured_query( rel_properties_query, param_map={"EXCLUDED_LABELS": EXCLUDED_RELS} ) rel_properties = ( - [el["output"] for el in rels_query_result - if any(prop["property"] for prop in el["output"].get("properties", []))] - if rels_query_result + [ + el["output"] + for el in rels_query_result + if any(prop["property"] for prop in el["output"].get("properties", [])) + ] + if rels_query_result else [] ) rel_objs_query_result = self.structured_query( @@ -188,10 +198,13 @@ def refresh_schema(self) -> None: }, ) relationships = [ - el["output"] for el in rel_objs_query_result - if rel_objs_query_result and - el["output"]["start"] not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] and - el["output"]["end"] not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] + el["output"] + for el in rel_objs_query_result + if rel_objs_query_result + and el["output"]["start"] + not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] + and el["output"]["end"] + not in [*EXCLUDED_LABELS, BASE_ENTITY_LABEL, BASE_NODE_LABEL] ] self.structured_schema = { "node_props": {el["labels"]: el["properties"] for el in node_properties}, @@ -204,16 +217,24 @@ def refresh_schema(self) -> None: schema_rels = self.structured_query( "MATCH ()-[r]->() RETURN TYPE(r) AS relationship_type, COUNT(r) AS count" ) - schema_counts = [{ - 'nodes': [{'name': item['node'], 'count': item['count']} for item in schema_nodes], - 'relationships': [{'name': item['relationship_type'], 'count': item['count']} for item in schema_rels] - }] + schema_counts = [ + { + "nodes": [ + {"name": item["node"], "count": item["count"]} + for item in schema_nodes + ], + "relationships": [ + {"name": item["relationship_type"], "count": item["count"]} + for item in schema_rels + ], + } + ] # Update node info for node in schema_counts[0].get("nodes", []): # Skip bloom labels if node["name"] in EXCLUDED_LABELS: continue - node_props = self.structured_schema["node_props"].get(node['name']) + node_props = self.structured_schema["node_props"].get(node["name"]) if not node_props: # The node has no properties continue @@ -225,7 +246,7 @@ def refresh_schema(self) -> None: for prop in node_props: if prop["property"] in enhanced_info: prop.update(enhanced_info[prop["property"]]) - + # Update rel info for rel in schema_counts[0].get("relationships", []): if rel["name"] in EXCLUDED_RELS: @@ -251,7 +272,7 @@ def upsert_nodes(self, nodes: List[LabelledNode]) -> None: # Lists to hold separated types entity_dicts: List[dict] = [] chunk_dicts: List[dict] = [] - + # Sort by type for item in nodes: if isinstance(item, EntityNode): @@ -264,7 +285,12 @@ def upsert_nodes(self, nodes: List[LabelledNode]) -> None: for index in range(0, len(chunk_dicts), CHUNK_SIZE): chunked_params = chunk_dicts[index : index + CHUNK_SIZE] for param in chunked_params: - formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + formatted_properties = ", ".join( + [ + f"{key}: {value!r}" + for key, value in param["properties"].items() + ] + ) self.structured_query( f""" MERGE (c:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) @@ -278,7 +304,12 @@ def upsert_nodes(self, nodes: List[LabelledNode]) -> None: for index in range(0, len(entity_dicts), CHUNK_SIZE): chunked_params = entity_dicts[index : index + CHUNK_SIZE] for param in chunked_params: - formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + formatted_properties = ", ".join( + [ + f"{key}: {value!r}" + for key, value in param["properties"].items() + ] + ) self.structured_query( f""" MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) @@ -288,8 +319,8 @@ def upsert_nodes(self, nodes: List[LabelledNode]) -> None: SET e :{param["label"]} """ ) - triplet_source_id = param['properties'].get('triplet_source_id') - if triplet_source_id: + triplet_source_id = param["properties"].get("triplet_source_id") + if triplet_source_id: self.structured_query( f""" MERGE (e:{BASE_NODE_LABEL} {{id: '{param["id"]}'}}) @@ -304,7 +335,9 @@ def upsert_relations(self, relations: List[Relation]) -> None: for index in range(0, len(params), CHUNK_SIZE): chunked_params = params[index : index + CHUNK_SIZE] for param in chunked_params: - formatted_properties = ', '.join([f'{key}: {repr(value)}' for key, value in param["properties"].items()]) + formatted_properties = ", ".join( + [f"{key}: {value!r}" for key, value in param["properties"].items()] + ) self.structured_query( f""" @@ -342,16 +375,16 @@ def get( cypher_statement += " AND " + " AND ".join(prop_list) return_statement = """ - RETURN - e.id AS name, - CASE - WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(e)) > 2 THEN labels(e)[2] - WHEN size(labels(e)) > 1 THEN labels(e)[1] - ELSE NULL + RETURN + e.id AS name, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL END - ELSE labels(e)[0] + ELSE labels(e)[0] END AS type, properties(e) AS properties """ @@ -378,7 +411,7 @@ def get( properties=remove_empty_values(record["properties"]), ) ) - + return nodes def get_triplets( @@ -418,52 +451,52 @@ def get_triplets( if not (entity_names or properties or relation_names or ids): return_statement = """ - WHERE NOT ANY(label IN labels(e) WHERE label = 'Chunk') - RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, - CASE - WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(e)) > 2 THEN labels(e)[2] - WHEN size(labels(e)) > 1 THEN labels(e)[1] - ELSE NULL - END - ELSE labels(e)[0] - END AS source_type, - properties(e) AS source_properties, - t.id as target_id, - CASE - WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(t)) > 2 THEN labels(t)[2] - WHEN size(labels(t)) > 1 THEN labels(t)[1] - ELSE NULL - END - ELSE labels(t)[0] + WHERE NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] END AS target_type, properties(t) AS target_properties LIMIT 100; """ else: return_statement = """ - AND NOT ANY(label IN labels(e) WHERE label = 'Chunk') - RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, - CASE - WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(e)) > 2 THEN labels(e)[2] - WHEN size(labels(e)) > 1 THEN labels(e)[1] - ELSE NULL - END - ELSE labels(e)[0] - END AS source_type, - properties(e) AS source_properties, - t.id as target_id, - CASE - WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(t)) > 2 THEN labels(t)[2] - WHEN size(labels(t)) > 1 THEN labels(t)[1] - ELSE NULL - END - ELSE labels(t)[0] + AND NOT ANY(label IN labels(e) WHERE label = 'Chunk') + RETURN type(r) as type, properties(r) as rel_prop, e.id as source_id, + CASE + WHEN labels(e)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(e)) > 2 THEN labels(e)[2] + WHEN size(labels(e)) > 1 THEN labels(e)[1] + ELSE NULL + END + ELSE labels(e)[0] + END AS source_type, + properties(e) AS source_properties, + t.id as target_id, + CASE + WHEN labels(t)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(t)) > 2 THEN labels(t)[2] + WHEN size(labels(t)) > 1 THEN labels(t)[1] + ELSE NULL + END + ELSE labels(t)[0] END AS target_type, properties(t) AS target_properties LIMIT 100; """ @@ -493,7 +526,7 @@ def get_triplets( return triplets def get_rel_map( - self, + self, graph_nodes: List[LabelledNode], depth: int = 2, limit: int = 30, @@ -519,27 +552,27 @@ def get_rel_map( endNode(rel) AS endNode, idx LIMIT toInteger($limit) - RETURN source.id AS source_id, - CASE - WHEN labels(source)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(source)) > 2 THEN labels(source)[2] - WHEN size(labels(source)) > 1 THEN labels(source)[1] - ELSE NULL + RETURN source.id AS source_id, + CASE + WHEN labels(source)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(source)) > 2 THEN labels(source)[2] + WHEN size(labels(source)) > 1 THEN labels(source)[1] + ELSE NULL END - ELSE labels(source)[0] + ELSE labels(source)[0] END AS source_type, properties(source) AS source_properties, type, rel_properties, - endNode.id AS target_id, - CASE - WHEN labels(endNode)[0] IN ['__Entity__', '__Node__'] THEN - CASE - WHEN size(labels(endNode)) > 2 THEN labels(endNode)[2] - WHEN size(labels(endNode)) > 1 THEN labels(endNode)[1] ELSE NULL + endNode.id AS target_id, + CASE + WHEN labels(endNode)[0] IN ['__Entity__', '__Node__'] THEN + CASE + WHEN size(labels(endNode)) > 2 THEN labels(endNode)[2] + WHEN size(labels(endNode)) > 1 THEN labels(endNode)[1] ELSE NULL END - ELSE labels(endNode)[0] + ELSE labels(endNode)[0] END AS target_type, properties(endNode) AS target_properties, idx @@ -587,7 +620,7 @@ def structured_query( if self.sanitize_query_output: return [value_sanitize(el) for el in full_result] return full_result - + def vector_query( self, query: VectorStoreQuery, **kwargs: Any ) -> Tuple[List[LabelledNode], List[float]]: @@ -722,8 +755,10 @@ def _enhanced_schema_cypher( distinct_values = self.query(distinct_values_query) # Extract values from the result set - distinct_values = [record["value"] for record in distinct_values] - + distinct_values = [ + record["value"] for record in distinct_values + ] + return_clauses.append( f"values: {distinct_values}," f" distinct_count: {len(distinct_values)}" @@ -788,7 +823,7 @@ def _enhanced_schema_cypher( ) # Combine all parts of the Cypher query return f"{match_clause}\n{with_clause}\n{return_clause}" - + def get_schema(self, refresh: bool = False) -> Any: if refresh: self.refresh_schema() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml index 7fc88c0539248..e76ff4e809ec6 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] [tool.codespell] check-filenames = true @@ -25,13 +25,13 @@ ignore_missing_imports = true python_version = "3.8" [tool.poetry] -name = "llama-index-graph-stores-memgraph" -version = "0.1.0" -description = "llama-index graph-stores memgraph integration" authors = ["Your Name "] +description = "llama-index graph-stores memgraph integration" license = "MIT" -readme = "README.md" +name = "llama-index-graph-stores-memgraph" packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" @@ -53,5 +53,5 @@ types-Deprecated = ">=0.1.0" types-PyYAML = "^6.0.12.12" types-protobuf = "^4.24.0.4" types-redis = "4.5.5.0" -types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py index 38986e6789496..9702101e86233 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py @@ -1,13 +1,14 @@ import unittest -from llama_index.core.graph_stores.types import GraphStore -from llama_index.graph_stores.memgraph import MemgraphGraphStore +from llama_index.graph_stores.memgraph import MemgraphGraphStore -class TestMemgraphGraphStore(unittest.TestCase): +class TestMemgraphGraphStore(unittest.TestCase): @classmethod def setUpClass(cls): - cls.store = MemgraphGraphStore(username="", password="", url="bolt://localhost:7687") - + cls.store = MemgraphGraphStore( + username="", password="", url="bolt://localhost:7687" + ) + def test_connection(self): """Test if connection to Memgraph is working.""" try: @@ -16,19 +17,19 @@ def test_connection(self): except Exception as e: connected = False self.assertTrue(connected, "Could not connect to Memgraph") - + def test_upsert_triplet(self): """Test inserting a triplet into Memgraph.""" self.store.upsert_triplet("Alice", "KNOWS", "Bob") triplets = self.store.get("Alice") self.assertIn(["KNOWS", "Bob"], triplets) - + def test_delete_triplet(self): """Test deleting a triplet from Memgraph.""" self.store.delete("Alice", "KNOWS", "Bob") triplets = self.store.get("Alice") self.assertNotIn(["KNOWS", "Bob"], triplets) - + def test_get_rel_map(self): """Test retrieving relationships.""" self.store.upsert_triplet("Alice", "KNOWS", "Bob") @@ -36,5 +37,6 @@ def test_get_rel_map(self): self.assertIn("Alice", rel_map) self.assertIn(["KNOWS", "Bob"], rel_map["Alice"]) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py index c773573719f33..d5fefca3ed05d 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py @@ -1,16 +1,18 @@ import unittest from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore from llama_index.core.graph_stores.types import ( - EntityNode, + EntityNode, Relation, ) from llama_index.core.schema import TextNode - -class TestMemgraphGraphStore(unittest.TestCase): + +class TestMemgraphGraphStore(unittest.TestCase): @classmethod def setUp(self): - self.pg_store = MemgraphPropertyGraphStore(username="", password="", url="bolt://localhost:7687") + self.pg_store = MemgraphPropertyGraphStore( + username="", password="", url="bolt://localhost:7687" + ) def test_connection(self): """Test if connection to Memgraph is working.""" @@ -22,28 +24,28 @@ def test_connection(self): self.assertTrue(connected, "Could not connect to Memgraph") def test_memgraph_pg_store(self): - """Clear the database""" + # Clear the database self.pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") self.pg_store.structured_query("DROP GRAPH") self.pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") - - """Test upsert nodes""" + + # Test upsert nodes entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") self.pg_store.upsert_nodes([entity1, entity2]) # Assert the nodes are inserted correctly kg_nodes = self.pg_store.get(ids=[entity1.id]) - + self.assertEqual(len(kg_nodes), 1) - self.assertEqual(kg_nodes[0].name, "Logan") - - """Test inserting relations into Memgraph.""" + self.assertEqual(kg_nodes[0].name, "Logan") + + # Test inserting relations into Memgraph. relation = Relation( label="WORKS_FOR", source_id=entity1.id, target_id=entity2.id, - properties={'since': 2023} + properties={"since": 2023}, ) self.pg_store.upsert_relations([relation]) @@ -56,13 +58,17 @@ def test_memgraph_pg_store(self): self.assertEqual(path[0].id, entity1.id) self.assertEqual(path[2].id, entity2.id) self.assertEqual(path[1].label, "WORKS_FOR") - - """Test inserting a source text node and 'MENTIONS' relations.""" + + # Test inserting a source text node and 'MENTIONS' relations. source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") relations = [ - Relation(label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id), - Relation(label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id) + Relation( + label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id + ), + Relation( + label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id + ), ] self.pg_store.upsert_llama_nodes([source_node]) @@ -72,19 +78,19 @@ def test_memgraph_pg_store(self): llama_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) self.assertEqual(len(llama_nodes), 1) self.assertEqual(llama_nodes[0].text, source_node.text) - - """Test retrieving nodes by properties.""" + + # Test retrieving nodes by properties. kg_nodes = self.pg_store.get(properties={"age": 28}) self.assertEqual(len(kg_nodes), 1) self.assertEqual(kg_nodes[0].label, "PERSON") - self.assertEqual(kg_nodes[0].name, "Logan") - - """Test executing a structured query in Memgraph.""" + self.assertEqual(kg_nodes[0].name, "Logan") + + # Test executing a structured query in Memgraph. query = "MATCH (n:`__Entity__`) RETURN n" result = self.pg_store.structured_query(query) - self.assertEqual(len(result), 2) + self.assertEqual(len(result), 2) - """Test upserting a new node with additional properties.""" + # Test upserting a new node with additional properties. new_node = EntityNode( label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} ) @@ -96,16 +102,17 @@ def test_memgraph_pg_store(self): self.assertEqual(kg_nodes[0].label, "PERSON") self.assertEqual(kg_nodes[0].name, "Logan") self.assertEqual(kg_nodes[0].properties["location"], "Canada") - - """Test deleting nodes from Memgraph.""" + + # Test deleting nodes from Memgraph. self.pg_store.delete(ids=[source_node.node_id]) self.pg_store.delete(ids=[entity1.id, entity2.id]) - + # Assert the nodes have been deleted nodes = self.pg_store.get(ids=[entity1.id, entity2.id]) self.assertEqual(len(nodes), 0) text_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) self.assertEqual(len(text_nodes), 0) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() From 053c9258f87ddda349e21f0bf1a4437052ecf3ad Mon Sep 17 00:00:00 2001 From: matea16 Date: Wed, 9 Oct 2024 13:32:40 +0200 Subject: [PATCH 3/6] move examples to readme --- .../README.md | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md index 9b972b7a8a899..370627da0760c 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md @@ -1 +1,138 @@ # LlamaIndex Graph-Stores Integration: Memgraph + +Memgraph is an open source graph database built for real-time streaming and fast analysis. + +In this project, we integrated Memgraph as a graph store to store the LlamaIndex graph data and query it. + +- Property Graph Store: `MemgraphPropertyGraphStore` +- Knowledege Graph Store: `MemgraphGraphStore` + + +## Instalation + +```shell +pip install llama-index llama-index-graph-stores-memgraph +``` + +## Usage + +### Property Graph Store + +```python +import os +import urllib.request +import nest_asyncio +from llama_index.core import SimpleDirectoryReader, PropertyGraphIndex +from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI +from llama_index.core.indices.property_graph import SchemaLLMPathExtractor + + +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key + +os.makedirs("data/paul_graham/", exist_ok=True) + +url = "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" +output_path = "data/paul_graham/paul_graham_essay.txt" +urllib.request.urlretrieve(url, output_path) + +nest_asyncio.apply() + +with open(output_path, "r", encoding="utf-8") as file: + content = file.read() + +modified_content = content.replace("'", "\\'") + +with open(output_path, "w", encoding="utf-8") as file: + file.write(modified_content) + +documents = SimpleDirectoryReader("./data/paul_graham/").load_data() + +# Setup Memgraph connection (ensure Memgraph is running) +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' + +graph_store = MemgraphPropertyGraphStore( + username=username, + password=password, + url=url, +) + +index = PropertyGraphIndex.from_documents( + documents, + embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"), + kg_extractors=[ + SchemaLLMPathExtractor( + llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), + ) + ], + property_graph_store=graph_store, + show_progress=True, +) + +query_engine = index.as_query_engine(include_text=True) + +response = query_engine.query("What happened at Interleaf and Viaweb?") +print("\nDetailed Query Response:") +print(str(response)) +``` + +### Knowledge Graph Store + +```python +import os +import logging +from llama_index.llms.openai import OpenAI +from llama_index.core import Settings +from llama_index.core import KnowledgeGraphIndex, SimpleDirectoryReader, StorageContext +from llama_index.graph_stores.memgraph import MemgraphGraphStore + +os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key + +logging.basicConfig(level=logging.INFO) + +llm = OpenAI(temperature=0, model="gpt-3.5-turbo") +Settings.llm = llm +Settings.chunk_size = 512 + +documents = { + "doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.", + "doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.", + "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development.", +} + +for filename, content in documents.items(): + with open(filename, "w") as file: + file.write(content) + +loaded_documents = SimpleDirectoryReader(".").load_data() + +# Setup Memgraph connection (ensure Memgraph is running) +username = "" # Enter your Memgraph username (default "") +password = "" # Enter your Memgraph password (default "") +url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' +database = "memgraph" # Name of the database, default is 'memgraph' + +graph_store = MemgraphGraphStore( + username=username, + password=password, + url=url, + database=database, +) + +storage_context = StorageContext.from_defaults(graph_store=graph_store) + +index = KnowledgeGraphIndex.from_documents( + loaded_documents, + storage_context=storage_context, + max_triplets_per_chunk=3, +) + +query_engine = index.as_query_engine(include_text=False, response_mode="tree_summarize") +response = query_engine.query("Tell me about Python and its uses") + +print("Query Response:") +print(response) +``` From 00d29bbead7ae851639c93001cf28ddc26ff9b29 Mon Sep 17 00:00:00 2001 From: matea16 Date: Wed, 9 Oct 2024 13:35:40 +0200 Subject: [PATCH 4/6] add notebook example to docs --- .../property_graph_memgraph.ipynb | 286 ++++++++++++++++++ .../examples/BUILD | 1 - .../examples/kg_example.py | 61 ---- .../examples/pg_example.py | 80 ----- 4 files changed, 286 insertions(+), 142 deletions(-) create mode 100644 docs/docs/examples/property_graph/property_graph_memgraph.ipynb delete mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD delete mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py delete mode 100644 llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py diff --git a/docs/docs/examples/property_graph/property_graph_memgraph.ipynb b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb new file mode 100644 index 0000000000000..3060140f08a43 --- /dev/null +++ b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Memgraph Property Graph Index\n", + "\n", + "[Memgraph](https://memgraph.com/) is an open source graph database built real-time streaming and fast analysis of your stored data.\n", + "\n", + "Before running Memgraph, ensure you have Docker running in the background. The quickest way to try out [Memgraph Platform](https://memgraph.com/docs/getting-started#install-memgraph-platform) (Memgraph database + MAGE library + Memgraph Lab) for the first time is running the following command:\n", + "\n", + "For Linux/macOS:\n", + "```shell\n", + "curl https://install.memgraph.com | sh\n", + "```\n", + "\n", + "For Windows:\n", + "```shell\n", + "iwr https://windows.memgraph.com | iex\n", + "```\n", + "\n", + "From here, you can check Memgraph's visual tool, Memgraph Lab on the [http://localhost:3000/](http://localhost:3000/) or the [desktop version](https://memgraph.com/download) of the app." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index llama-index-graph-stores-memgraph" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment setup " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\" # Replace with your OpenAI API key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the data directory and download the Paul Graham essay we'll be using as the input data for this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import urllib.request\n", + "\n", + "os.makedirs(\"data/paul_graham/\", exist_ok=True)\n", + "\n", + "url = \"https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\"\n", + "output_path = \"data/paul_graham/paul_graham_essay.txt\"\n", + "urllib.request.urlretrieve(url, output_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Read the file, replace single quotes, save the modified content and load the document data using the `SimpleDirectoryReader`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import SimpleDirectoryReader\n", + "\n", + "with open(output_path, \"r\", encoding=\"utf-8\") as file:\n", + " content = file.read()\n", + "\n", + "modified_content = content.replace(\"'\", \"\\\\'\")\n", + "\n", + "with open(output_path, \"w\", encoding=\"utf-8\") as file:\n", + " file.write(modified_content)\n", + "\n", + "documents = SimpleDirectoryReader(\"./data/paul_graham/\").load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Memgraph connection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up your graph store class by providing the database credentials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore\n", + "\n", + "username = \"\" # Enter your Memgraph username (default \"\")\n", + "password = \"\" # Enter your Memgraph password (default \"\")\n", + "url = \"\" # Specify the connection URL, e.g., 'bolt://localhost:7687'\n", + "\n", + "graph_store = MemgraphPropertyGraphStore(\n", + " username=username,\n", + " password=password,\n", + " url=url,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Index Construction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import PropertyGraphIndex\n", + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "from llama_index.llms.openai import OpenAI\n", + "from llama_index.core.indices.property_graph import SchemaLLMPathExtractor\n", + "\n", + "index = PropertyGraphIndex.from_documents(\n", + " documents,\n", + " embed_model=OpenAIEmbedding(model_name=\"text-embedding-ada-002\"),\n", + " kg_extractors=[\n", + " SchemaLLMPathExtractor(\n", + " llm=OpenAI(model=\"gpt-3.5-turbo\", temperature=0.0)\n", + " )\n", + " ],\n", + " property_graph_store=graph_store,\n", + " show_progress=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the graph is created, we can explore it in the UI by visiting [http://localhost:3000/](http://localhost:3000/).\n", + "\n", + "The easiest way to visualize the entire graph is by running a Cypher command similar to this:\n", + "\n", + "```shell\n", + "MATCH p=()-[]-() RETURN p;\n", + "```\n", + "\n", + "This command matches all of the possible paths in the graph and returns entire graph. \n", + "\n", + "To visualize the schema of the graph, visit the Graph schema tab and generate the new schema based on the newly created graph.\n", + "\n", + "To delete an entire graph, use:\n", + "\n", + "```shell\n", + "MATCH (n) DETACH DELETE n;\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Querying and retrieval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = index.as_retriever(include_text=False)\n", + "\n", + "# Example query: \"What happened at Interleaf and Viaweb?\"\n", + "nodes = retriever.retrieve(\"What happened at Interleaf and Viaweb?\")\n", + "\n", + "# Output results\n", + "print(\"Query Results:\")\n", + "for node in nodes:\n", + " print(node.text)\n", + "\n", + "# Alternatively, using a query engine\n", + "query_engine = index.as_query_engine(include_text=True)\n", + "\n", + "# Perform a query and print the detailed response\n", + "response = query_engine.query(\"What happened at Interleaf and Viaweb?\")\n", + "print(\"\\nDetailed Query Response:\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading from an existing graph" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you have an existing graph (either created with LlamaIndex or otherwise), we can connect to and use it!\n", + "\n", + "**NOTE:** If your graph was created outside of LlamaIndex, the most useful retrievers will be [text to cypher](../../module_guides/indexing/lpg_index_guide.md#texttocypherretriever) or [cypher templates](../../module_guides/indexing/lpg_index_guide.md#cyphertemplateretriever). Other retrievers rely on properties that LlamaIndex inserts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI(model=\"gpt-4\", temperature=0.0)\n", + "kg_extractors = [\n", + " SchemaLLMPathExtractor(llm=llm) \n", + "]\n", + "\n", + "index = PropertyGraphIndex.from_existing(\n", + " property_graph_store=graph_store,\n", + " kg_extractors=kg_extractors,\n", + " embed_model=OpenAIEmbedding(model_name=\"text-embedding-ada-002\"),\n", + " show_progress=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 64-bit (microsoft store)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "289d8ae9ac585fcc15d0d9333c941ae27bdf80d3e799883224b20975f2046730" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD deleted file mode 100644 index db46e8d6c978c..0000000000000 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/BUILD +++ /dev/null @@ -1 +0,0 @@ -python_sources() diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py deleted file mode 100644 index ad97247e3fb04..0000000000000 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/kg_example.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import logging -from llama_index.llms.openai import OpenAI -from llama_index.core import Settings -from llama_index.core import KnowledgeGraphIndex, SimpleDirectoryReader, StorageContext -from llama_index.graph_stores.memgraph import MemgraphGraphStore - - -# Step 1: Set up OpenAI API key -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key - -# Step 2: Configure logging -logging.basicConfig(level=logging.INFO) - -# Step 3: Configure OpenAI LLM -llm = OpenAI(temperature=0, model="gpt-3.5-turbo") -Settings.llm = llm -Settings.chunk_size = 512 - -# Step 4: Write documents to text files (Simulating loading documents from disk) -documents = { - "doc1.txt": "Python is a popular programming language known for its readability and simplicity. It was created by Guido van Rossum and first released in 1991. Python supports multiple programming paradigms, including procedural, object-oriented, and functional programming. It is widely used in web development, data science, artificial intelligence, and scientific computing.", - "doc2.txt": "JavaScript is a high-level programming language primarily used for web development. It was created by Brendan Eich and first appeared in 1995. JavaScript is a core technology of the World Wide Web, alongside HTML and CSS. It enables interactive web pages and is an essential part of web applications. JavaScript is also used in server-side development with environments like Node.js.", - "doc3.txt": "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible. It was developed by James Gosling and first released by Sun Microsystems in 1995. Java is widely used for building enterprise-scale applications, mobile applications, and large systems development.", -} - -for filename, content in documents.items(): - with open(filename, "w") as file: - file.write(content) - -# Step 5: Load documents -loaded_documents = SimpleDirectoryReader(".").load_data() - -# Step 6: Set up Memgraph connection -username = "" # Enter your Memgraph username (default "") -password = "" # Enter your Memgraph password (default "") -url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' -database = "memgraph" # Name of the database, default is 'memgraph' - -graph_store = MemgraphGraphStore( - username=username, - password=password, - url=url, - database=database, -) - -storage_context = StorageContext.from_defaults(graph_store=graph_store) - -# Step 7: Create a Knowledge Graph Index -index = KnowledgeGraphIndex.from_documents( - loaded_documents, - storage_context=storage_context, - max_triplets_per_chunk=3, -) - -# Step 8: Query the Knowledge Graph -query_engine = index.as_query_engine(include_text=False, response_mode="tree_summarize") -response = query_engine.query("Tell me about Python and its uses") - -print("Query Response:") -print(response) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py deleted file mode 100644 index ef2164192fe39..0000000000000 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/examples/pg_example.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import urllib.request -import nest_asyncio -import logging -from llama_index.core import SimpleDirectoryReader, PropertyGraphIndex -from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore -from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.llms.openai import OpenAI -from llama_index.core.indices.property_graph import SchemaLLMPathExtractor - - -# 1. Setup OpenAI API Key (replace this with your actual key) -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key - -# 2. Create the data directory and download the Paul Graham essay -os.makedirs("data/paul_graham/", exist_ok=True) - -url = "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" -output_path = "data/paul_graham/paul_graham_essay.txt" -urllib.request.urlretrieve(url, output_path) - -# 3. Ensure nest_asyncio is applied -nest_asyncio.apply() - -# Step 2: Read the file, replace single quotes, and save the modified content -with open(output_path, "r", encoding="utf-8") as file: - content = file.read() - -# Replace single quotes with escaped single quotes -modified_content = content.replace("'", "\\'") - -# Save the modified content back to the same file -with open(output_path, "w", encoding="utf-8") as file: - file.write(modified_content) - -# 4. Load the document data -documents = SimpleDirectoryReader("./data/paul_graham/").load_data() - -# 5. Setup Memgraph connection (ensure Memgraph is running) -username = "" # Enter your Memgraph username (default "") -password = "" # Enter your Memgraph password (default "") -url = "" # Specify the connection URL, e.g., 'bolt://localhost:7687' - -graph_store = MemgraphPropertyGraphStore( - username=username, - password=password, - url=url, -) - -# 6. Create the Property Graph Index -index = PropertyGraphIndex.from_documents( - documents, - embed_model=OpenAIEmbedding(model_name="text-embedding-ada-002"), - kg_extractors=[ - SchemaLLMPathExtractor( - llm=OpenAI(model="gpt-3.5-turbo", temperature=0.0), - ) - ], - property_graph_store=graph_store, - show_progress=True, -) - -# 7. Querying the graph -retriever = index.as_retriever(include_text=False) - -# Example query: "What happened at Interleaf and Viaweb?" -nodes = retriever.retrieve("What happened at Interleaf and Viaweb?") - -# Output results -print("Query Results:") -for node in nodes: - print(node.text) - -# Alternatively, using a query engine -query_engine = index.as_query_engine(include_text=True) - -# Perform a query and print the detailed response -response = query_engine.query("What happened at Interleaf and Viaweb?") -print("\nDetailed Query Response:") -print(str(response)) From 4c08eb543c1606ce03d9994be43fdcc80a47943d Mon Sep 17 00:00:00 2001 From: matea16 Date: Wed, 9 Oct 2024 13:35:48 +0200 Subject: [PATCH 5/6] update tests --- .../pyproject.toml | 1 + .../tests/test_graph_stores_memgraph.py | 47 +---- .../tests/test_pg_stores_memgraph.py | 199 ++++++++---------- 3 files changed, 98 insertions(+), 149 deletions(-) diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml index e76ff4e809ec6..c2ebce54d7a65 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/pyproject.toml @@ -36,6 +36,7 @@ version = "0.1.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.0" +neo4j = "^5.24.0" [tool.poetry.group.dev.dependencies] black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py index 9702101e86233..f66ca48b66454 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py @@ -1,42 +1,11 @@ -import unittest -from llama_index.graph_stores.memgraph import MemgraphGraphStore - - -class TestMemgraphGraphStore(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.store = MemgraphGraphStore( - username="", password="", url="bolt://localhost:7687" - ) - - def test_connection(self): - """Test if connection to Memgraph is working.""" - try: - self.store.client.verify_connectivity() - connected = True - except Exception as e: - connected = False - self.assertTrue(connected, "Could not connect to Memgraph") +from unittest.mock import MagicMock, patch - def test_upsert_triplet(self): - """Test inserting a triplet into Memgraph.""" - self.store.upsert_triplet("Alice", "KNOWS", "Bob") - triplets = self.store.get("Alice") - self.assertIn(["KNOWS", "Bob"], triplets) - - def test_delete_triplet(self): - """Test deleting a triplet from Memgraph.""" - self.store.delete("Alice", "KNOWS", "Bob") - triplets = self.store.get("Alice") - self.assertNotIn(["KNOWS", "Bob"], triplets) - - def test_get_rel_map(self): - """Test retrieving relationships.""" - self.store.upsert_triplet("Alice", "KNOWS", "Bob") - rel_map = self.store.get_rel_map(["Alice"], depth=2) - self.assertIn("Alice", rel_map) - self.assertIn(["KNOWS", "Bob"], rel_map["Alice"]) +from llama_index.core.graph_stores.types import GraphStore +from llama_index.graph_stores.memgraph import MemgraphGraphStore -if __name__ == "__main__": - unittest.main() +@patch("llama_index.graph_stores.memgraph.MemgraphGraphStore") +def test_memgraph_graph_store(MockMemgraphGraphStore: MagicMock): + instance: MemgraphGraphStore = MockMemgraphGraphStore.return_value() + assert isinstance(instance, GraphStore) + \ No newline at end of file diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py index d5fefca3ed05d..bd1e4cfc48e36 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py @@ -1,4 +1,6 @@ import unittest +import os +import pytest from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore from llama_index.core.graph_stores.types import ( EntityNode, @@ -6,113 +8,90 @@ ) from llama_index.core.schema import TextNode - -class TestMemgraphGraphStore(unittest.TestCase): - @classmethod - def setUp(self): - self.pg_store = MemgraphPropertyGraphStore( - username="", password="", url="bolt://localhost:7687" - ) - - def test_connection(self): - """Test if connection to Memgraph is working.""" - try: - self.pg_store.client.verify_connectivity() - connected = True - except Exception as e: - connected = False - self.assertTrue(connected, "Could not connect to Memgraph") - - def test_memgraph_pg_store(self): - # Clear the database - self.pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") - self.pg_store.structured_query("DROP GRAPH") - self.pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") - - # Test upsert nodes - entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) - entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") - self.pg_store.upsert_nodes([entity1, entity2]) - - # Assert the nodes are inserted correctly - kg_nodes = self.pg_store.get(ids=[entity1.id]) - - self.assertEqual(len(kg_nodes), 1) - self.assertEqual(kg_nodes[0].name, "Logan") - - # Test inserting relations into Memgraph. - relation = Relation( - label="WORKS_FOR", - source_id=entity1.id, - target_id=entity2.id, - properties={"since": 2023}, - ) - - self.pg_store.upsert_relations([relation]) - - # Assert the relation is inserted correctly by retrieving the relation map - kg_nodes = self.pg_store.get(ids=[entity1.id]) - paths = self.pg_store.get_rel_map(kg_nodes, depth=1) - self.assertEqual(len(paths), 1) - path = paths[0] - self.assertEqual(path[0].id, entity1.id) - self.assertEqual(path[2].id, entity2.id) - self.assertEqual(path[1].label, "WORKS_FOR") - - # Test inserting a source text node and 'MENTIONS' relations. - source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") - - relations = [ - Relation( - label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id - ), - Relation( - label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id - ), - ] - - self.pg_store.upsert_llama_nodes([source_node]) - self.pg_store.upsert_relations(relations) - - # Assert the source node and relations are inserted correctly - llama_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) - self.assertEqual(len(llama_nodes), 1) - self.assertEqual(llama_nodes[0].text, source_node.text) - - # Test retrieving nodes by properties. - kg_nodes = self.pg_store.get(properties={"age": 28}) - self.assertEqual(len(kg_nodes), 1) - self.assertEqual(kg_nodes[0].label, "PERSON") - self.assertEqual(kg_nodes[0].name, "Logan") - - # Test executing a structured query in Memgraph. - query = "MATCH (n:`__Entity__`) RETURN n" - result = self.pg_store.structured_query(query) - self.assertEqual(len(result), 2) - - # Test upserting a new node with additional properties. - new_node = EntityNode( - label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} - ) - self.pg_store.upsert_nodes([new_node]) - - # Assert the node has been updated with the new property - kg_nodes = self.pg_store.get(properties={"age": 28}) - self.assertEqual(len(kg_nodes), 1) - self.assertEqual(kg_nodes[0].label, "PERSON") - self.assertEqual(kg_nodes[0].name, "Logan") - self.assertEqual(kg_nodes[0].properties["location"], "Canada") - - # Test deleting nodes from Memgraph. - self.pg_store.delete(ids=[source_node.node_id]) - self.pg_store.delete(ids=[entity1.id, entity2.id]) - - # Assert the nodes have been deleted - nodes = self.pg_store.get(ids=[entity1.id, entity2.id]) - self.assertEqual(len(nodes), 0) - text_nodes = self.pg_store.get_llama_nodes([source_node.node_id]) - self.assertEqual(len(text_nodes), 0) - - -if __name__ == "__main__": - unittest.main() +memgraph_user = os.environ.get("MEMGRAPH_TEST_USER") +memgraph_pass = os.environ.get("MEMGRAPH_TEST_PASS") +memgraph_url = os.environ.get("MEMGRAPH_TEST_URL") + +if not memgraph_user or not memgraph_pass or not memgraph_url: + memgraph_available = False +else: + memgraph_available = True + +@pytest.fixture() +def pg_store() -> MemgraphPropertyGraphStore: + if not memgraph_available: + pytest.skip("No Memgraph credentials provided") + pg_store = MemgraphPropertyGraphStore( + username=memgraph_user, password=memgraph_pass, url=memgraph_url + ) + return pg_store + +def test_memgraph_pg_store(pg_store: MemgraphPropertyGraphStore) -> None: + # Clear the database + pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") + pg_store.structured_query("DROP GRAPH") + pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + + # Test upsert nodes + entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) + entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") + pg_store.upsert_nodes([entity1, entity2]) + + # Assert the nodes are inserted correctly + kg_nodes = pg_store.get(ids=[entity1.id]) + + # Test inserting relations into Memgraph. + relation = Relation( + label="WORKS_FOR", + source_id=entity1.id, + target_id=entity2.id, + properties={"since": 2023}, + ) + + pg_store.upsert_relations([relation]) + + # Assert the relation is inserted correctly by retrieving the relation map + kg_nodes = pg_store.get(ids=[entity1.id]) + paths = pg_store.get_rel_map(kg_nodes, depth=1) + + # Test inserting a source text node and 'MENTIONS' relations. + source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") + + relations = [ + Relation( + label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id + ), + Relation( + label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id + ), + ] + + pg_store.upsert_llama_nodes([source_node]) + pg_store.upsert_relations(relations) + + # Assert the source node and relations are inserted correctly + llama_nodes = pg_store.get_llama_nodes([source_node.node_id]) + + # Test retrieving nodes by properties. + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test executing a structured query in Memgraph. + query = "MATCH (n:`__Entity__`) RETURN n" + result = pg_store.structured_query(query) + + # Test upserting a new node with additional properties. + new_node = EntityNode( + label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} + ) + pg_store.upsert_nodes([new_node]) + + # Assert the node has been updated with the new property + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test deleting nodes from Memgraph. + pg_store.delete(ids=[source_node.node_id]) + pg_store.delete(ids=[entity1.id, entity2.id]) + + # Assert the nodes have been deleted + nodes = pg_store.get(ids=[entity1.id, entity2.id]) + \ No newline at end of file From aa0a2dc5544901e7f903154ef9a3ff14062e41a5 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Mon, 14 Oct 2024 13:13:39 -0600 Subject: [PATCH 6/6] linting --- .../property_graph_memgraph.ipynb | 12 +- .../README.md | 23 ++- .../tests/test_graph_stores_memgraph.py | 1 - .../tests/test_pg_stores_memgraph.py | 147 +++++++++--------- 4 files changed, 92 insertions(+), 91 deletions(-) diff --git a/docs/docs/examples/property_graph/property_graph_memgraph.ipynb b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb index 3060140f08a43..ccb6491e8fada 100644 --- a/docs/docs/examples/property_graph/property_graph_memgraph.ipynb +++ b/docs/docs/examples/property_graph/property_graph_memgraph.ipynb @@ -47,7 +47,9 @@ "source": [ "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\" # Replace with your OpenAI API key" + "os.environ[\n", + " \"OPENAI_API_KEY\"\n", + "] = \"sk-proj-...\" # Replace with your OpenAI API key" ] }, { @@ -251,9 +253,7 @@ "outputs": [], "source": [ "llm = OpenAI(model=\"gpt-4\", temperature=0.0)\n", - "kg_extractors = [\n", - " SchemaLLMPathExtractor(llm=llm) \n", - "]\n", + "kg_extractors = [SchemaLLMPathExtractor(llm=llm)]\n", "\n", "index = PropertyGraphIndex.from_existing(\n", " property_graph_store=graph_store,\n", @@ -271,10 +271,8 @@ "name": "python3" }, "language_info": { - "name": "python", - "version": "3.9.13" + "name": "python" }, - "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "289d8ae9ac585fcc15d0d9333c941ae27bdf80d3e799883224b20975f2046730" diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md index 370627da0760c..04a925ae066e5 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/README.md @@ -2,13 +2,12 @@ Memgraph is an open source graph database built for real-time streaming and fast analysis. -In this project, we integrated Memgraph as a graph store to store the LlamaIndex graph data and query it. +In this project, we integrated Memgraph as a graph store to store the LlamaIndex graph data and query it. - Property Graph Store: `MemgraphPropertyGraphStore` - Knowledege Graph Store: `MemgraphGraphStore` - -## Instalation +## Installation ```shell pip install llama-index llama-index-graph-stores-memgraph @@ -29,7 +28,9 @@ from llama_index.llms.openai import OpenAI from llama_index.core.indices.property_graph import SchemaLLMPathExtractor -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key +os.environ[ + "OPENAI_API_KEY" +] = "" # Replace with your OpenAI API key os.makedirs("data/paul_graham/", exist_ok=True) @@ -86,10 +87,16 @@ import os import logging from llama_index.llms.openai import OpenAI from llama_index.core import Settings -from llama_index.core import KnowledgeGraphIndex, SimpleDirectoryReader, StorageContext +from llama_index.core import ( + KnowledgeGraphIndex, + SimpleDirectoryReader, + StorageContext, +) from llama_index.graph_stores.memgraph import MemgraphGraphStore -os.environ["OPENAI_API_KEY"] = "" # Replace with your OpenAI API key +os.environ[ + "OPENAI_API_KEY" +] = "" # Replace with your OpenAI API key logging.basicConfig(level=logging.INFO) @@ -130,7 +137,9 @@ index = KnowledgeGraphIndex.from_documents( max_triplets_per_chunk=3, ) -query_engine = index.as_query_engine(include_text=False, response_mode="tree_summarize") +query_engine = index.as_query_engine( + include_text=False, response_mode="tree_summarize" +) response = query_engine.query("Tell me about Python and its uses") print("Query Response:") diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py index f66ca48b66454..6b82a46196742 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_graph_stores_memgraph.py @@ -8,4 +8,3 @@ def test_memgraph_graph_store(MockMemgraphGraphStore: MagicMock): instance: MemgraphGraphStore = MockMemgraphGraphStore.return_value() assert isinstance(instance, GraphStore) - \ No newline at end of file diff --git a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py index bd1e4cfc48e36..a6260027b1c63 100644 --- a/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py +++ b/llama-index-integrations/graph_stores/llama-index-graph-stores-memgraph/tests/test_pg_stores_memgraph.py @@ -1,4 +1,3 @@ -import unittest import os import pytest from llama_index.graph_stores.memgraph import MemgraphPropertyGraphStore @@ -15,83 +14,79 @@ if not memgraph_user or not memgraph_pass or not memgraph_url: memgraph_available = False else: - memgraph_available = True + memgraph_available = True + @pytest.fixture() def pg_store() -> MemgraphPropertyGraphStore: - if not memgraph_available: - pytest.skip("No Memgraph credentials provided") - pg_store = MemgraphPropertyGraphStore( - username=memgraph_user, password=memgraph_pass, url=memgraph_url - ) - return pg_store + if not memgraph_available: + pytest.skip("No Memgraph credentials provided") + return MemgraphPropertyGraphStore( + username=memgraph_user, password=memgraph_pass, url=memgraph_url + ) + def test_memgraph_pg_store(pg_store: MemgraphPropertyGraphStore) -> None: - # Clear the database - pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") - pg_store.structured_query("DROP GRAPH") - pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") - - # Test upsert nodes - entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) - entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") - pg_store.upsert_nodes([entity1, entity2]) - - # Assert the nodes are inserted correctly - kg_nodes = pg_store.get(ids=[entity1.id]) - - # Test inserting relations into Memgraph. - relation = Relation( - label="WORKS_FOR", - source_id=entity1.id, - target_id=entity2.id, - properties={"since": 2023}, - ) - - pg_store.upsert_relations([relation]) - - # Assert the relation is inserted correctly by retrieving the relation map - kg_nodes = pg_store.get(ids=[entity1.id]) - paths = pg_store.get_rel_map(kg_nodes, depth=1) - - # Test inserting a source text node and 'MENTIONS' relations. - source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") - - relations = [ - Relation( - label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id - ), - Relation( - label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id - ), - ] - - pg_store.upsert_llama_nodes([source_node]) - pg_store.upsert_relations(relations) - - # Assert the source node and relations are inserted correctly - llama_nodes = pg_store.get_llama_nodes([source_node.node_id]) - - # Test retrieving nodes by properties. - kg_nodes = pg_store.get(properties={"age": 28}) - - # Test executing a structured query in Memgraph. - query = "MATCH (n:`__Entity__`) RETURN n" - result = pg_store.structured_query(query) - - # Test upserting a new node with additional properties. - new_node = EntityNode( - label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} - ) - pg_store.upsert_nodes([new_node]) - - # Assert the node has been updated with the new property - kg_nodes = pg_store.get(properties={"age": 28}) - - # Test deleting nodes from Memgraph. - pg_store.delete(ids=[source_node.node_id]) - pg_store.delete(ids=[entity1.id, entity2.id]) - - # Assert the nodes have been deleted - nodes = pg_store.get(ids=[entity1.id, entity2.id]) - \ No newline at end of file + # Clear the database + pg_store.structured_query("STORAGE MODE IN_MEMORY_ANALYTICAL") + pg_store.structured_query("DROP GRAPH") + pg_store.structured_query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + + # Test upsert nodes + entity1 = EntityNode(label="PERSON", name="Logan", properties={"age": 28}) + entity2 = EntityNode(label="ORGANIZATION", name="LlamaIndex") + pg_store.upsert_nodes([entity1, entity2]) + + # Assert the nodes are inserted correctly + kg_nodes = pg_store.get(ids=[entity1.id]) + + # Test inserting relations into Memgraph. + relation = Relation( + label="WORKS_FOR", + source_id=entity1.id, + target_id=entity2.id, + properties={"since": 2023}, + ) + + pg_store.upsert_relations([relation]) + + # Assert the relation is inserted correctly by retrieving the relation map + kg_nodes = pg_store.get(ids=[entity1.id]) + paths = pg_store.get_rel_map(kg_nodes, depth=1) + + # Test inserting a source text node and 'MENTIONS' relations. + source_node = TextNode(text="Logan (age 28), works for LlamaIndex since 2023.") + + relations = [ + Relation(label="MENTIONS", target_id=entity1.id, source_id=source_node.node_id), + Relation(label="MENTIONS", target_id=entity2.id, source_id=source_node.node_id), + ] + + pg_store.upsert_llama_nodes([source_node]) + pg_store.upsert_relations(relations) + + # Assert the source node and relations are inserted correctly + llama_nodes = pg_store.get_llama_nodes([source_node.node_id]) + + # Test retrieving nodes by properties. + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test executing a structured query in Memgraph. + query = "MATCH (n:`__Entity__`) RETURN n" + result = pg_store.structured_query(query) + + # Test upserting a new node with additional properties. + new_node = EntityNode( + label="PERSON", name="Logan", properties={"age": 28, "location": "Canada"} + ) + pg_store.upsert_nodes([new_node]) + + # Assert the node has been updated with the new property + kg_nodes = pg_store.get(properties={"age": 28}) + + # Test deleting nodes from Memgraph. + pg_store.delete(ids=[source_node.node_id]) + pg_store.delete(ids=[entity1.id, entity2.id]) + + # Assert the nodes have been deleted + nodes = pg_store.get(ids=[entity1.id, entity2.id])