From b9852642b27ddd31fd0baef39ff95624e31c12bc Mon Sep 17 00:00:00 2001 From: caoyu <33222683+handsomecaoyu@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:34:45 +0800 Subject: [PATCH] feat: add faiss storage and graphml visualization (#12) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了简单的网络可视化 * 修改了faiss的id处理方式 --- examples/using_faiss_as_vextorDB.py | 97 ++++++++++ examples/visualize.py | 270 ++++++++++++++++++++++++++++ nano_graphrag/_storage.py | 1 - nano_graphrag/_utils.py | 2 +- 4 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 examples/using_faiss_as_vextorDB.py create mode 100644 examples/visualize.py diff --git a/examples/using_faiss_as_vextorDB.py b/examples/using_faiss_as_vextorDB.py new file mode 100644 index 0000000..543d6fb --- /dev/null +++ b/examples/using_faiss_as_vextorDB.py @@ -0,0 +1,97 @@ +import os +import asyncio +import numpy as np +from nano_graphrag.graphrag import GraphRAG, QueryParam +from nano_graphrag._utils import logger +from nano_graphrag.base import BaseVectorStorage +from dataclasses import dataclass +import faiss +import pickle +import logging +import xxhash +logging.getLogger('msal').setLevel(logging.WARNING) +logging.getLogger('azure').setLevel(logging.WARNING) +logging.getLogger("httpx").setLevel(logging.WARNING) + +WORKING_DIR = "./nano_graphrag_cache_faiss_TEST" + +@dataclass +class FAISSStorage(BaseVectorStorage): + + def __post_init__(self): + self._index_file_name = os.path.join( + self.global_config["working_dir"], f"{self.namespace}_faiss.index" + ) + self._metadata_file_name = os.path.join( + self.global_config["working_dir"], f"{self.namespace}_metadata.pkl" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + + if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name): + self._index = faiss.read_index(self._index_file_name) + with open(self._metadata_file_name, 'rb') as f: + self._metadata = pickle.load(f) + else: + self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim)) + self._metadata = {} + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + + ids = [] + for k, v in data.items(): + id = xxhash.xxh32_intdigest(k.encode()) + metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields} + metadata['id'] = k + self._metadata[id] = metadata + ids.append(id) + + ids = np.array(ids, dtype=np.int64) + self._index.add_with_ids(embeddings, ids) + + + return len(data) + + async def query(self, query, top_k=5): + embedding = await self.embedding_func([query]) + distances, indices = self._index.search(embedding, top_k) + + results = [] + for _, (distance, id) in enumerate(zip(distances[0], indices[0])): + if id != -1: # FAISS returns -1 for empty slots + if id in self._metadata: + metadata = self._metadata[id] + results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance + + return results + + async def index_done_callback(self): + faiss.write_index(self._index, self._index_file_name) + with open(self._metadata_file_name, 'wb') as f: + pickle.dump(self._metadata, f) + +if __name__ == "__main__": + + graph_func = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=FAISSStorage, + ) + + with open(r"tests/mock_data.txt", encoding='utf-8') as f: + graph_func.insert(f.read()[:30000]) + + # Perform global graphrag search + print(graph_func.query("What are the top themes in this story?")) + + \ No newline at end of file diff --git a/examples/visualize.py b/examples/visualize.py new file mode 100644 index 0000000..4baf1e4 --- /dev/null +++ b/examples/visualize.py @@ -0,0 +1,270 @@ +import networkx as nx +import json +import webbrowser +import os +import http.server +import socketserver +import threading + +# 读取GraphML文件并转换为JSON +def graphml_to_json(graphml_file): + G = nx.read_graphml(graphml_file) + data = nx.node_link_data(G) + return json.dumps(data) + +# 创建HTML文件 +def create_html(json_data, html_path): + json_data = json_data.replace('\\"', '') + html_content = ''' + + + + + + Graph Visualization + + + + + +
+
+ + + + '''.replace("{json_data}", json_data.replace("'", "\\'").replace("\n", "")) + + with open(html_path, 'w', encoding='utf-8') as f: + f.write(html_content) + +# 启动简单的HTTP服务器 +def start_server(): + handler = http.server.SimpleHTTPRequestHandler + with socketserver.TCPServer(("", 8000), handler) as httpd: + print("Server started at http://localhost:8000") + httpd.serve_forever() + +# 主函数 +def visualize_graphml(graphml_file, html_path): + json_data = graphml_to_json(graphml_file) + create_html(json_data, html_path) + + # 在后台启动服务器 + server_thread = threading.Thread(target=start_server) + server_thread.daemon = True + server_thread.start() + + # 打开默认浏览器 + webbrowser.open('http://localhost:8000/graph_visualization.html') + + print("Visualization is ready. Press Ctrl+C to exit.") + try: + # 保持主线程运行 + while True: + pass + except KeyboardInterrupt: + print("Shutting down...") + +# 使用示例 +if __name__ == "__main__": + graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # 替换为您的GraphML文件路径 + html_path = "graph_visualization.html" + visualize_graphml(graphml_file, html_path) \ No newline at end of file diff --git a/nano_graphrag/_storage.py b/nano_graphrag/_storage.py index 785f969..aefbf60 100644 --- a/nano_graphrag/_storage.py +++ b/nano_graphrag/_storage.py @@ -21,7 +21,6 @@ ) from .prompt import GRAPH_FIELD_SEP - @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index 0c063dc..e3a2db3 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -70,7 +70,7 @@ def compute_mdhash_id(content, prefix: str = ""): def write_json(json_obj, file_name): - with open(file_name, "w") as f: + with open(file_name, "w", encoding='utf-8') as f: json.dump(json_obj, f, indent=2, ensure_ascii=False)