From 86a928762e20d99c917f8ac8df28ed5257e7ddce Mon Sep 17 00:00:00 2001 From: Gustavo Ye Date: Thu, 5 Sep 2024 15:53:48 +0800 Subject: [PATCH] feat: add naive RAG for chunks (#25) * feat: add naive RAG * docs: add naive rag --- ROADMAP.md | 2 +- nano_graphrag/_op.py | 39 +++++++++++++++++++++++++++++++++++---- nano_graphrag/base.py | 4 +++- nano_graphrag/graphrag.py | 34 +++++++++++++++++++++++++++++++++- nano_graphrag/prompt.py | 14 ++++++++++++++ readme.md | 17 +++++++++++++++++ 6 files changed, 103 insertions(+), 7 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 17c6b58..cd46b02 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,7 +1,7 @@ ## Next Version TODO - [ ] Add `eval` method for `GraphRAG`, add at least one benchmark dataset -- [ ] Add `before_query`, `before_insert` to separate component loadings. +- [x] Add naive RAG pipeline ## Interesting directions diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index bd671bb..1e346d3 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -4,9 +4,6 @@ from typing import Union from collections import Counter, defaultdict -from openai import AsyncOpenAI - -from ._llm import gpt_4o_complete from ._utils import ( logger, clean_str, @@ -992,4 +989,38 @@ async def global_query( report_data=points_context, response_type=query_param.response_type ), ) - return response \ No newline at end of file + return response + + +async def naive_query( + query, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, +): + use_model_func = global_config["best_model_func"] + results = await chunks_vdb.query(query, top_k=query_param.top_k) + if not len(results): + return PROMPTS["fail_response"] + chunks_ids = [r["id"] for r in results] + chunks = await text_chunks_db.get_by_ids(chunks_ids) + + maybe_trun_chunks = truncate_list_by_token_size( + chunks, + key=lambda x: x["content"], + max_token_size=query_param.naive_max_token_for_text_unit, + ) + logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") + section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + if query_param.only_need_context: + return section + sys_prompt_temp = PROMPTS["naive_rag_response"] + sys_prompt = sys_prompt_temp.format( + content_data=section, response_type=query_param.response_type + ) + response = await use_model_func( + query, + system_prompt=sys_prompt, + ) + return response diff --git a/nano_graphrag/base.py b/nano_graphrag/base.py index b6fed83..ea9bb21 100644 --- a/nano_graphrag/base.py +++ b/nano_graphrag/base.py @@ -8,11 +8,13 @@ @dataclass class QueryParam: - mode: Literal["local", "global"] = "global" + mode: Literal["local", "global", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" level: int = 2 top_k: int = 20 + # naive search + naive_max_token_for_text_unit = 12000 # local search local_max_token_for_text_unit: int = 4000 # 12000 * 0.33 local_max_token_for_local_context: int = 4800 # 12000 * 0.4 diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 7dd2fdb..3f663db 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -13,6 +13,7 @@ generate_community_report, local_query, global_query, + naive_query, ) from ._storage import ( JsonKVStorage, @@ -54,6 +55,7 @@ class GraphRAG: ) # graph mode enable_local: bool = True + enable_naive_rag: bool = False # text chunking chunk_token_size: int = 1200 @@ -151,11 +153,20 @@ def __post_init__(self): namespace="entities", global_config=asdict(self), embedding_func=self.embedding_func, - meta_fields={"entity_name"} + meta_fields={"entity_name"}, ) if self.enable_local else None ) + self.chunks_vdb = ( + self.vector_db_storage_cls( + namespace="chunks", + global_config=asdict(self), + embedding_func=self.embedding_func, + ) + if self.enable_naive_rag + else None + ) self.best_model_func = limit_async_func_call(self.best_model_max_async)( partial(self.best_model_func, hashing_kv=self.llm_response_cache) @@ -172,9 +183,15 @@ def query(self, query: str, param: QueryParam = QueryParam()): loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param)) + def eval(self, querys: list[str], contexts: list[str], answers: list[str]): + loop = always_get_an_event_loop() + return loop.run_until_complete(self.aeval(querys, contexts, answers)) + async def aquery(self, query: str, param: QueryParam = QueryParam()): if param.mode == "local" and not self.enable_local: raise ValueError("enable_local is False, cannot query in local mode") + if param.mode == "naive" and not self.enable_naive_rag: + raise ValueError("enable_naive_rag is False, cannot query in local mode") if param.mode == "local": response = await local_query( query, @@ -195,6 +212,14 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): param, asdict(self), ) + elif param.mode == "naive": + response = await naive_query( + query, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + ) else: raise ValueError(f"Unknown mode {param.mode}") await self._query_done() @@ -242,6 +267,9 @@ async def ainsert(self, string_or_strings): logger.warning(f"All chunks are already in the storage") return logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") + if self.enable_naive_rag: + logger.info("Insert chunks for naive RAG") + await self.chunks_vdb.upsert(inserting_chunks) # TODO: no incremental update for communities now, so just drop all await self.community_reports.drop() @@ -273,6 +301,9 @@ async def ainsert(self, string_or_strings): finally: await self._insert_done() + async def aeval(self, querys: list[str], contexts: list[str], answers: list[str]): + pass + async def _insert_done(self): tasks = [] for storage_inst in [ @@ -281,6 +312,7 @@ async def _insert_done(self): self.llm_response_cache, self.community_reports, self.entities_vdb, + self.chunks_vdb, self.chunk_entity_relation_graph, ]: if storage_inst is None: diff --git a/nano_graphrag/prompt.py b/nano_graphrag/prompt.py index 30bc466..d40cf7f 100644 --- a/nano_graphrag/prompt.py +++ b/nano_graphrag/prompt.py @@ -473,6 +473,20 @@ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ +PROMPTS[ + "naive_rag_response" +] = """You're a helpful assistant +Below are the knowledge you know: +{content_data} +--- +If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up. +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +If you don't know the answer, just say so. Do not make anything up. +Do not include information where the supporting evidence for it is not provided. +---Target response length and format--- +{response_type} +""" + PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] diff --git a/readme.md b/readme.md index 5d37626..cdc38a6 100644 --- a/readme.md +++ b/readme.md @@ -103,6 +103,23 @@ with open("./book.txt") as f: +
+ Naive RAG + +`nano-graphrag` supports naive RAG insert and query as well: + +```python +graph_func = GraphRAG(working_dir="./dickens", enable_naive_rag=True) +... +# Query +print(rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="naive") +) +``` +
+ + ### Async For each method `NAME(...)` , there is a corresponding async method `aNAME(...)`