-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add faiss storage and graphml visualization (#12)
* 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了faiss作为向量数据库 * 增加了简单的网络可视化 * 修改了faiss的id处理方式
- Loading branch information
1 parent
2038a5a
commit b985264
Showing
4 changed files
with
368 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
import asyncio | ||
import numpy as np | ||
from nano_graphrag.graphrag import GraphRAG, QueryParam | ||
from nano_graphrag._utils import logger | ||
from nano_graphrag.base import BaseVectorStorage | ||
from dataclasses import dataclass | ||
import faiss | ||
import pickle | ||
import logging | ||
import xxhash | ||
logging.getLogger('msal').setLevel(logging.WARNING) | ||
logging.getLogger('azure').setLevel(logging.WARNING) | ||
logging.getLogger("httpx").setLevel(logging.WARNING) | ||
|
||
WORKING_DIR = "./nano_graphrag_cache_faiss_TEST" | ||
|
||
@dataclass | ||
class FAISSStorage(BaseVectorStorage): | ||
|
||
def __post_init__(self): | ||
self._index_file_name = os.path.join( | ||
self.global_config["working_dir"], f"{self.namespace}_faiss.index" | ||
) | ||
self._metadata_file_name = os.path.join( | ||
self.global_config["working_dir"], f"{self.namespace}_metadata.pkl" | ||
) | ||
self._max_batch_size = self.global_config["embedding_batch_num"] | ||
|
||
if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name): | ||
self._index = faiss.read_index(self._index_file_name) | ||
with open(self._metadata_file_name, 'rb') as f: | ||
self._metadata = pickle.load(f) | ||
else: | ||
self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim)) | ||
self._metadata = {} | ||
|
||
async def upsert(self, data: dict[str, dict]): | ||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | ||
|
||
contents = [v["content"] for v in data.values()] | ||
batches = [ | ||
contents[i : i + self._max_batch_size] | ||
for i in range(0, len(contents), self._max_batch_size) | ||
] | ||
embeddings_list = await asyncio.gather( | ||
*[self.embedding_func(batch) for batch in batches] | ||
) | ||
embeddings = np.concatenate(embeddings_list) | ||
|
||
ids = [] | ||
for k, v in data.items(): | ||
id = xxhash.xxh32_intdigest(k.encode()) | ||
metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields} | ||
metadata['id'] = k | ||
self._metadata[id] = metadata | ||
ids.append(id) | ||
|
||
ids = np.array(ids, dtype=np.int64) | ||
self._index.add_with_ids(embeddings, ids) | ||
|
||
|
||
return len(data) | ||
|
||
async def query(self, query, top_k=5): | ||
embedding = await self.embedding_func([query]) | ||
distances, indices = self._index.search(embedding, top_k) | ||
|
||
results = [] | ||
for _, (distance, id) in enumerate(zip(distances[0], indices[0])): | ||
if id != -1: # FAISS returns -1 for empty slots | ||
if id in self._metadata: | ||
metadata = self._metadata[id] | ||
results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance | ||
|
||
return results | ||
|
||
async def index_done_callback(self): | ||
faiss.write_index(self._index, self._index_file_name) | ||
with open(self._metadata_file_name, 'wb') as f: | ||
pickle.dump(self._metadata, f) | ||
|
||
if __name__ == "__main__": | ||
|
||
graph_func = GraphRAG( | ||
working_dir=WORKING_DIR, | ||
enable_llm_cache=True, | ||
vector_db_storage_cls=FAISSStorage, | ||
) | ||
|
||
with open(r"tests/mock_data.txt", encoding='utf-8') as f: | ||
graph_func.insert(f.read()[:30000]) | ||
|
||
# Perform global graphrag search | ||
print(graph_func.query("What are the top themes in this story?")) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
import networkx as nx | ||
import json | ||
import webbrowser | ||
import os | ||
import http.server | ||
import socketserver | ||
import threading | ||
|
||
# 读取GraphML文件并转换为JSON | ||
def graphml_to_json(graphml_file): | ||
G = nx.read_graphml(graphml_file) | ||
data = nx.node_link_data(G) | ||
return json.dumps(data) | ||
|
||
# 创建HTML文件 | ||
def create_html(json_data, html_path): | ||
json_data = json_data.replace('\\"', '') | ||
html_content = ''' | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Graph Visualization</title> | ||
<script src="https://d3js.org/d3.v7.min.js"></script> | ||
<style> | ||
body, html { | ||
margin: 0; | ||
padding: 0; | ||
width: 100%; | ||
height: 100%; | ||
overflow: hidden; | ||
} | ||
svg { | ||
width: 100%; | ||
height: 100%; | ||
} | ||
.links line { | ||
stroke: #999; | ||
stroke-opacity: 0.6; | ||
} | ||
.nodes circle { | ||
stroke: #fff; | ||
stroke-width: 1.5px; | ||
} | ||
.node-label { | ||
font-size: 12px; | ||
pointer-events: none; | ||
} | ||
.link-label { | ||
font-size: 10px; | ||
fill: #666; | ||
pointer-events: none; | ||
opacity: 0; | ||
transition: opacity 0.3s; | ||
} | ||
.link:hover .link-label { | ||
opacity: 1; | ||
} | ||
.tooltip { | ||
position: absolute; | ||
text-align: left; | ||
padding: 10px; | ||
font: 12px sans-serif; | ||
background: lightsteelblue; | ||
border: 0px; | ||
border-radius: 8px; | ||
pointer-events: none; | ||
opacity: 0; | ||
transition: opacity 0.3s; | ||
max-width: 300px; | ||
} | ||
.legend { | ||
position: absolute; | ||
top: 10px; | ||
right: 10px; | ||
background-color: rgba(255, 255, 255, 0.8); | ||
padding: 10px; | ||
border-radius: 5px; | ||
} | ||
.legend-item { | ||
margin: 5px 0; | ||
} | ||
.legend-color { | ||
display: inline-block; | ||
width: 20px; | ||
height: 20px; | ||
margin-right: 5px; | ||
vertical-align: middle; | ||
} | ||
</style> | ||
</head> | ||
<body> | ||
<svg></svg> | ||
<div class="tooltip"></div> | ||
<div class="legend"></div> | ||
<script> | ||
const graphData = JSON.parse('{json_data}'); | ||
const svg = d3.select("svg"), | ||
width = window.innerWidth, | ||
height = window.innerHeight; | ||
svg.attr("viewBox", [0, 0, width, height]); | ||
const g = svg.append("g"); | ||
const entityTypes = [...new Set(graphData.nodes.map(d => d.entity_type))]; | ||
const color = d3.scaleOrdinal(d3.schemeCategory10).domain(entityTypes); | ||
const simulation = d3.forceSimulation(graphData.nodes) | ||
.force("link", d3.forceLink(graphData.links).id(d => d.id).distance(150)) | ||
.force("charge", d3.forceManyBody().strength(-300)) | ||
.force("center", d3.forceCenter(width / 2, height / 2)) | ||
.force("collide", d3.forceCollide().radius(30)); | ||
const linkGroup = g.append("g") | ||
.attr("class", "links") | ||
.selectAll("g") | ||
.data(graphData.links) | ||
.enter().append("g") | ||
.attr("class", "link"); | ||
const link = linkGroup.append("line") | ||
.attr("stroke-width", d => Math.sqrt(d.value)); | ||
const linkLabel = linkGroup.append("text") | ||
.attr("class", "link-label") | ||
.text(d => d.description || ""); | ||
const node = g.append("g") | ||
.attr("class", "nodes") | ||
.selectAll("circle") | ||
.data(graphData.nodes) | ||
.enter().append("circle") | ||
.attr("r", 5) | ||
.attr("fill", d => color(d.entity_type)) | ||
.call(d3.drag() | ||
.on("start", dragstarted) | ||
.on("drag", dragged) | ||
.on("end", dragended)); | ||
const nodeLabel = g.append("g") | ||
.attr("class", "node-labels") | ||
.selectAll("text") | ||
.data(graphData.nodes) | ||
.enter().append("text") | ||
.attr("class", "node-label") | ||
.text(d => d.id); | ||
const tooltip = d3.select(".tooltip"); | ||
node.on("mouseover", function(event, d) { | ||
tooltip.transition() | ||
.duration(200) | ||
.style("opacity", .9); | ||
tooltip.html(`<strong>${d.id}</strong><br>Entity Type: ${d.entity_type}<br>Description: ${d.description || "N/A"}`) | ||
.style("left", (event.pageX + 10) + "px") | ||
.style("top", (event.pageY - 28) + "px"); | ||
}) | ||
.on("mouseout", function(d) { | ||
tooltip.transition() | ||
.duration(500) | ||
.style("opacity", 0); | ||
}); | ||
const legend = d3.select(".legend"); | ||
entityTypes.forEach(type => { | ||
legend.append("div") | ||
.attr("class", "legend-item") | ||
.html(`<span class="legend-color" style="background-color: ${color(type)}"></span>${type}`); | ||
}); | ||
simulation | ||
.nodes(graphData.nodes) | ||
.on("tick", ticked); | ||
simulation.force("link") | ||
.links(graphData.links); | ||
function ticked() { | ||
link | ||
.attr("x1", d => d.source.x) | ||
.attr("y1", d => d.source.y) | ||
.attr("x2", d => d.target.x) | ||
.attr("y2", d => d.target.y); | ||
linkLabel | ||
.attr("x", d => (d.source.x + d.target.x) / 2) | ||
.attr("y", d => (d.source.y + d.target.y) / 2) | ||
.attr("text-anchor", "middle") | ||
.attr("dominant-baseline", "middle"); | ||
node | ||
.attr("cx", d => d.x) | ||
.attr("cy", d => d.y); | ||
nodeLabel | ||
.attr("x", d => d.x + 8) | ||
.attr("y", d => d.y + 3); | ||
} | ||
function dragstarted(event) { | ||
if (!event.active) simulation.alphaTarget(0.3).restart(); | ||
event.subject.fx = event.subject.x; | ||
event.subject.fy = event.subject.y; | ||
} | ||
function dragged(event) { | ||
event.subject.fx = event.x; | ||
event.subject.fy = event.y; | ||
} | ||
function dragended(event) { | ||
if (!event.active) simulation.alphaTarget(0); | ||
event.subject.fx = null; | ||
event.subject.fy = null; | ||
} | ||
const zoom = d3.zoom() | ||
.scaleExtent([0.1, 10]) | ||
.on("zoom", zoomed); | ||
svg.call(zoom); | ||
function zoomed(event) { | ||
g.attr("transform", event.transform); | ||
} | ||
</script> | ||
</body> | ||
</html> | ||
'''.replace("{json_data}", json_data.replace("'", "\\'").replace("\n", "")) | ||
|
||
with open(html_path, 'w', encoding='utf-8') as f: | ||
f.write(html_content) | ||
|
||
# 启动简单的HTTP服务器 | ||
def start_server(): | ||
handler = http.server.SimpleHTTPRequestHandler | ||
with socketserver.TCPServer(("", 8000), handler) as httpd: | ||
print("Server started at http://localhost:8000") | ||
httpd.serve_forever() | ||
|
||
# 主函数 | ||
def visualize_graphml(graphml_file, html_path): | ||
json_data = graphml_to_json(graphml_file) | ||
create_html(json_data, html_path) | ||
|
||
# 在后台启动服务器 | ||
server_thread = threading.Thread(target=start_server) | ||
server_thread.daemon = True | ||
server_thread.start() | ||
|
||
# 打开默认浏览器 | ||
webbrowser.open('http://localhost:8000/graph_visualization.html') | ||
|
||
print("Visualization is ready. Press Ctrl+C to exit.") | ||
try: | ||
# 保持主线程运行 | ||
while True: | ||
pass | ||
except KeyboardInterrupt: | ||
print("Shutting down...") | ||
|
||
# 使用示例 | ||
if __name__ == "__main__": | ||
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # 替换为您的GraphML文件路径 | ||
html_path = "graph_visualization.html" | ||
visualize_graphml(graphml_file, html_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters