-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support text2gql search for GraphRAG
- Loading branch information
1 parent
a801a97
commit d4113a3
Showing
18 changed files
with
13,768 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""AWELIntentInterpreter class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AWELIntentInterpreter(TranslatorBase): | ||
"""AWELIntentInterpreter class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""IntentInterpreter class.""" | ||
import logging, re, json | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
|
||
INTENT_INTERPRET_PT = ( | ||
"A question is provided below. Given the question, analyze and classify it into one of the following categories:\n" | ||
"1. Single Entity Search: search for the detail of the given entity.\n" | ||
"2. One Hop Entity Search: given one entity and one relation, " | ||
"search for all entities that have the relation with the given entity.\n" | ||
"3. One Hop Relation Search: given two entities, serach for the relation between them.\n" | ||
"4. Two Hop Entity Search: given one entity and one relation, break that relation into two consecutive relation, " | ||
"then search all entities that have the two hop relation with the given entity.\n" | ||
"5. Freestyle Question: questions that are not in above four categories. " | ||
"Search all related entities and two-hop subgraphs centered on them.\n" | ||
"After classfied the given question, rewrite the question in a graph query language style, " | ||
"return the category of the given question, the rewrited question in json format." | ||
"Also return entities and relations that might be used for query generation in json format." | ||
"Here are some examples to guide your classification:\n" | ||
"---------------------\n" | ||
"Example:\n" | ||
"Question: Introduce TuGraph.\n" | ||
"Return:\n{{\"category\": \"Single Entity Search\", \"rewrited_question\": \"Query the entity named TuGraph then return the entity.\", " | ||
"\"entities\": [\"TuGraph\"], \"relations\": []}}\n" | ||
"Question: Who commits code to TuGraph.\n" | ||
"Return:\n{{\"category\": \"One Hop Entity Search\", \"rewrited_question\": \"Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\", " | ||
"\"entities\": [\"TuGraph\"], \"relations\": [\"commit\"]}}\n" | ||
"Question: What is the relation between Alex and TuGraph?\n" | ||
"Return:\n{{\"category\": \"One Hop Relation Search\", \"rewrited_question\": \"Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\", " | ||
"\"entities\": [\"Alex\", \"TuGraph\"], \"relations\": []}}\n" | ||
"Question: Who is the colleague of Bob?\n" | ||
"Return:\n{{\"category\": \"Two Hop Entity Search\", \"rewrited_question\": \"Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\", " | ||
"\"entities\": [\"Bob\"], \"relations\": [\"work for\"]}}\n" | ||
"Question: Introduce TuGraph and DBGPT seperately.\n" | ||
"Return:\n{{\"category\": \"Freestyle Question\", \"rewrited_question\": \"Query the entity named TuGraph and the entity named DBGPT, then return two-hop subgraphs centered on them.\", " | ||
"\"entities\": [\"TuGraph\", \"DBGPT\"], \"relations\": []}}\n" | ||
"---------------------\n" | ||
"Text: {text}\n" | ||
"Keywords:\n" | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class IntentInterpreter(LLMTranslator): | ||
"""IntentInterpreter class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str): | ||
"""Initialize the IntentInterpreter.""" | ||
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT) | ||
|
||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None, type: Optional[str] = "PROMPT" | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
""" | ||
The returned diction should contain the following content. | ||
{ | ||
"category": "Type of the given question.", | ||
"original_question: "The original question provided by user.", | ||
"rewrited_question": "Question that has been rewritten in graph query language style." | ||
"entities": ["entities", "that", "might", "be", "used", "in", "query"], | ||
"relations" ["relations", "that", "might", "be", "used", "in", "query"] | ||
} | ||
""" | ||
|
||
# interprete intent with single prompt only. | ||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages = ( | ||
template.format_messages(text=text, history=history) | ||
if history is not None | ||
else template.format_messages(text=text) | ||
) | ||
|
||
# use default model if needed | ||
if not self._model_name: | ||
models = await self._llm_client.models() | ||
if not models: | ||
raise Exception("No models available") | ||
self._model_name = models[0].model | ||
logger.info(f"Using model {self._model_name} to extract") | ||
|
||
model_messages = ModelMessage.from_base_messages(messages) | ||
request = ModelRequest(model=self._model_name, messages=model_messages) | ||
response = await self._llm_client.generate(request=request) | ||
|
||
if not response.success: | ||
code = str(response.error_code) | ||
reason = response.text | ||
logger.error(f"request llm failed ({code}) {reason}") | ||
return [] | ||
|
||
if limit and limit < 1: | ||
ValueError("optional argument limit >= 1") | ||
return self._parse_response(response.text, limit) | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" | ||
intention = text | ||
|
||
code_block_pattern = re.compile(r'```json(.*?)```', re.S) | ||
json_pattern = re.compile(r'{.*?}', re.S) | ||
|
||
result = re.findall(code_block_pattern, intention) | ||
if result: | ||
intention = result[0] | ||
result = re.findall(json_pattern, intention) | ||
if result: | ||
intention = result[0] | ||
else: | ||
intention = "" | ||
|
||
return json.loads(intention) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""LLMTranslator class.""" | ||
|
||
import asyncio | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMTranslator(TranslatorBase, ABC): | ||
"""LLMTranslator class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str): | ||
"""Initialize the LLMExtractor.""" | ||
self._llm_client = llm_client | ||
self._model_name = model_name | ||
self._prompt_template = prompt_template | ||
|
||
async def translate(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Translate by LLM.""" | ||
return await self._translate(text, None, limit) | ||
|
||
@abstractmethod | ||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
@abstractmethod | ||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""MASIntentInterpreter class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MASIntentInterpreter(TranslatorBase): | ||
"""MASIntentInterpreter class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,137 @@ | ||
"""Text2Cypher class.""" | ||
import logging | ||
import logging, re, json | ||
from typing import Dict, Optional | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
from dbgpt.rag.transformer.intent_interpreter import IntentInterpreter | ||
|
||
TEXT_TO_CYPHER_PT = ( | ||
"A question written in graph query language style is provided below. " | ||
"The category of this question, " | ||
"entities and relations that might be used in the cypher query are also provided. " | ||
"Given the question, translate the question into a cypher query that " | ||
"can be executed on the given knowledge graph. " | ||
"Make sure the syntax of the translated cypher query is correct.\n" | ||
"To help query generation, the schema of the knowledge graph is:\n" | ||
"{schema}\n" | ||
"---------------------\n" | ||
"Example:\n" | ||
"Question: Query the entity named TuGraph then return the entity.\n" | ||
"Category: Single Entity Search\n" | ||
"entities: [\"TuGraph\"]\n" | ||
"relations: []\n" | ||
"Query:\nMatch (n) WHERE n.id=\"TuGraph\" RETURN n\n" | ||
"Question: Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\n" | ||
"Category: One Hop Entity Search\n" | ||
"entities: [\"Alex\", \"TuGraph\"]\n" | ||
"relations: []\n" | ||
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"Alex\" AND m.id=\"TuGraph\" RETURN p \n" | ||
"Question: Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\n" | ||
"Category: One Hop Relation Search\n" | ||
"entities: [\"TuGraph\"]\n" | ||
"relations: [\"commit\"]\n" | ||
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"TuGraph\" AND r.id=\"commit\" RETURN p \n" | ||
"Question: Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\n" | ||
"Category: Two Hop Entity Search\n" | ||
"entities: [\"Bob\"]\n" | ||
"relations: [\"work for\"]\n" | ||
"Query:\nMATCH p=(n)-[r1]-(m)-[r2]-(l) WHERE n.id=\"Bob\" AND r1.id=\"work for\" AND r2.id=\"work for\" RETURN p \n" | ||
"Question: Introduce TuGraph and DBGPT seperately.\n" | ||
"Category: Freestyle Question\n" | ||
"entities: [\"TuGraph\", \"DBGPT\"]\n" | ||
"relations: []\n" | ||
"Query:\nMATCH p=(n)-[r:relation*2]-(m) WHERE n.id IN [\"TuGraph\", \"DB-GPT\"] RETURN p\n" | ||
"---------------------\n" | ||
"Question: {question}\n" | ||
"Category: {category}\n" | ||
"entities: {entities}\n" | ||
"relations: {relations}\n" | ||
"Query:\n" | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Text2Cypher(TranslatorBase): | ||
class Text2Cypher(LLMTranslator): | ||
"""Text2Cypher class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str, schema: str): | ||
"""Initialize the Text2Cypher.""" | ||
super().__init__(llm_client, model_name, TEXT_TO_CYPHER_PT) | ||
self._schema = json.dumps(json.loads(schema), indent=4) | ||
self._intent_interpreter = IntentInterpreter(llm_client, model_name) | ||
|
||
async def _translate( | ||
self, text: str, history: str = None, limit: Optional[int] = None | ||
) -> Dict: | ||
"""Inner translate by LLM.""" | ||
|
||
"""Interprete the intent of the question.""" | ||
intention = await self._intent_interpreter.translate(text) | ||
question = intention["rewrited_question"] | ||
category = intention["category"] | ||
entities = intention["entities"] | ||
relations = intention["relations"] | ||
|
||
"""Translate query with intention.""" | ||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages = ( | ||
template.format_messages( | ||
schema=self._schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations, | ||
history=history | ||
) | ||
if history is not None | ||
else template.format_messages( | ||
schema=self._schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations | ||
) | ||
) | ||
|
||
# use default model if needed | ||
if not self._model_name: | ||
models = await self._llm_client.models() | ||
if not models: | ||
raise Exception("No models available") | ||
self._model_name = models[0].model | ||
logger.info(f"Using model {self._model_name} to extract") | ||
|
||
model_messages = ModelMessage.from_base_messages(messages) | ||
request = ModelRequest(model=self._model_name, messages=model_messages) | ||
response = await self._llm_client.generate(request=request) | ||
|
||
if not response.success: | ||
code = str(response.error_code) | ||
reason = response.text | ||
logger.error(f"request llm failed ({code}) {reason}") | ||
return [] | ||
|
||
if limit and limit < 1: | ||
ValueError("optional argument limit >= 1") | ||
return self._parse_response(response.text, limit) | ||
|
||
|
||
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict: | ||
"""Parse llm response.""" | ||
interaction = {} | ||
query = "" | ||
|
||
code_block_pattern = re.compile(r'```cypher(.*?)```', re.S) | ||
|
||
result = re.findall(code_block_pattern, text) | ||
if result: | ||
query = result[0] | ||
else: | ||
query = text | ||
|
||
interaction["query"] = query.strip() | ||
|
||
return interaction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.