Skip to content

Commit

Permalink
support text2gql search for GraphRAG
Browse files Browse the repository at this point in the history
  • Loading branch information
SonglinLyu committed Dec 20, 2024
1 parent a801a97 commit d4113a3
Show file tree
Hide file tree
Showing 18 changed files with 13,768 additions and 54 deletions.
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/awel_intent_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""AWELIntentInterpreter class."""
import logging

from dbgpt.rag.transformer.base import TranslatorBase

logger = logging.getLogger(__name__)


class AWELIntentInterpreter(TranslatorBase):
"""AWELIntentInterpreter class."""
4 changes: 4 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ async def batch_extract(

class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""

@abstractmethod
async def translate(self, text: str, limit: Optional[int] = None) -> Dict:
"""Translate results from text."""
123 changes: 123 additions & 0 deletions dbgpt/rag/transformer/intent_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""IntentInterpreter class."""
import logging, re, json
from typing import Dict, Optional

from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.llm_translator import LLMTranslator

INTENT_INTERPRET_PT = (
"A question is provided below. Given the question, analyze and classify it into one of the following categories:\n"
"1. Single Entity Search: search for the detail of the given entity.\n"
"2. One Hop Entity Search: given one entity and one relation, "
"search for all entities that have the relation with the given entity.\n"
"3. One Hop Relation Search: given two entities, serach for the relation between them.\n"
"4. Two Hop Entity Search: given one entity and one relation, break that relation into two consecutive relation, "
"then search all entities that have the two hop relation with the given entity.\n"
"5. Freestyle Question: questions that are not in above four categories. "
"Search all related entities and two-hop subgraphs centered on them.\n"
"After classfied the given question, rewrite the question in a graph query language style, "
"return the category of the given question, the rewrited question in json format."
"Also return entities and relations that might be used for query generation in json format."
"Here are some examples to guide your classification:\n"
"---------------------\n"
"Example:\n"
"Question: Introduce TuGraph.\n"
"Return:\n{{\"category\": \"Single Entity Search\", \"rewrited_question\": \"Query the entity named TuGraph then return the entity.\", "
"\"entities\": [\"TuGraph\"], \"relations\": []}}\n"
"Question: Who commits code to TuGraph.\n"
"Return:\n{{\"category\": \"One Hop Entity Search\", \"rewrited_question\": \"Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\", "
"\"entities\": [\"TuGraph\"], \"relations\": [\"commit\"]}}\n"
"Question: What is the relation between Alex and TuGraph?\n"
"Return:\n{{\"category\": \"One Hop Relation Search\", \"rewrited_question\": \"Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\", "
"\"entities\": [\"Alex\", \"TuGraph\"], \"relations\": []}}\n"
"Question: Who is the colleague of Bob?\n"
"Return:\n{{\"category\": \"Two Hop Entity Search\", \"rewrited_question\": \"Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\", "
"\"entities\": [\"Bob\"], \"relations\": [\"work for\"]}}\n"
"Question: Introduce TuGraph and DBGPT seperately.\n"
"Return:\n{{\"category\": \"Freestyle Question\", \"rewrited_question\": \"Query the entity named TuGraph and the entity named DBGPT, then return two-hop subgraphs centered on them.\", "
"\"entities\": [\"TuGraph\", \"DBGPT\"], \"relations\": []}}\n"
"---------------------\n"
"Text: {text}\n"
"Keywords:\n"
)

logger = logging.getLogger(__name__)


class IntentInterpreter(LLMTranslator):
"""IntentInterpreter class."""

def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the IntentInterpreter."""
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT)

async def _translate(
self, text: str, history: str = None, limit: Optional[int] = None, type: Optional[str] = "PROMPT"
) -> Dict:
"""Inner translate by LLM."""

"""
The returned diction should contain the following content.
{
"category": "Type of the given question.",
"original_question: "The original question provided by user.",
"rewrited_question": "Question that has been rewritten in graph query language style."
"entities": ["entities", "that", "might", "be", "used", "in", "query"],
"relations" ["relations", "that", "might", "be", "used", "in", "query"]
}
"""

# interprete intent with single prompt only.
template = HumanPromptTemplate.from_template(self._prompt_template)

messages = (
template.format_messages(text=text, history=history)
if history is not None
else template.format_messages(text=text)
)

# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model
logger.info(f"Using model {self._model_name} to extract")

model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)

if not response.success:
code = str(response.error_code)
reason = response.text
logger.error(f"request llm failed ({code}) {reason}")
return []

if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)

def truncate(self):
"""Do nothing by default."""

def drop(self):
"""Do nothing by default."""

def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict:
"""Parse llm response."""
intention = text

code_block_pattern = re.compile(r'```json(.*?)```', re.S)
json_pattern = re.compile(r'{.*?}', re.S)

result = re.findall(code_block_pattern, intention)
if result:
intention = result[0]
result = re.findall(json_pattern, intention)
if result:
intention = result[0]
else:
intention = ""

return json.loads(intention)
41 changes: 41 additions & 0 deletions dbgpt/rag/transformer/llm_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""LLMTranslator class."""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Dict, Optional

from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import TranslatorBase

logger = logging.getLogger(__name__)


class LLMTranslator(TranslatorBase, ABC):
"""LLMTranslator class."""

def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
"""Initialize the LLMExtractor."""
self._llm_client = llm_client
self._model_name = model_name
self._prompt_template = prompt_template

async def translate(self, text: str, limit: Optional[int] = None) -> Dict:
"""Translate by LLM."""
return await self._translate(text, None, limit)

@abstractmethod
async def _translate(
self, text: str, history: str = None, limit: Optional[int] = None
) -> Dict:
"""Inner translate by LLM."""

def truncate(self):
"""Do nothing by default."""

def drop(self):
"""Do nothing by default."""

@abstractmethod
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict:
"""Parse llm response."""
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/mas_intent_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""MASIntentInterpreter class."""
import logging

from dbgpt.rag.transformer.base import TranslatorBase

logger = logging.getLogger(__name__)


class MASIntentInterpreter(TranslatorBase):
"""MASIntentInterpreter class."""
133 changes: 130 additions & 3 deletions dbgpt/rag/transformer/text2cypher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,137 @@
"""Text2Cypher class."""
import logging
import logging, re, json
from typing import Dict, Optional

from dbgpt.rag.transformer.base import TranslatorBase
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.llm_translator import LLMTranslator
from dbgpt.rag.transformer.intent_interpreter import IntentInterpreter

TEXT_TO_CYPHER_PT = (
"A question written in graph query language style is provided below. "
"The category of this question, "
"entities and relations that might be used in the cypher query are also provided. "
"Given the question, translate the question into a cypher query that "
"can be executed on the given knowledge graph. "
"Make sure the syntax of the translated cypher query is correct.\n"
"To help query generation, the schema of the knowledge graph is:\n"
"{schema}\n"
"---------------------\n"
"Example:\n"
"Question: Query the entity named TuGraph then return the entity.\n"
"Category: Single Entity Search\n"
"entities: [\"TuGraph\"]\n"
"relations: []\n"
"Query:\nMatch (n) WHERE n.id=\"TuGraph\" RETURN n\n"
"Question: Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.\n"
"Category: One Hop Entity Search\n"
"entities: [\"Alex\", \"TuGraph\"]\n"
"relations: []\n"
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"Alex\" AND m.id=\"TuGraph\" RETURN p \n"
"Question: Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.\n"
"Category: One Hop Relation Search\n"
"entities: [\"TuGraph\"]\n"
"relations: [\"commit\"]\n"
"Query:\nMATCH p=(n)-[r]-(m) WHERE n.id=\"TuGraph\" AND r.id=\"commit\" RETURN p \n"
"Question: Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.\n"
"Category: Two Hop Entity Search\n"
"entities: [\"Bob\"]\n"
"relations: [\"work for\"]\n"
"Query:\nMATCH p=(n)-[r1]-(m)-[r2]-(l) WHERE n.id=\"Bob\" AND r1.id=\"work for\" AND r2.id=\"work for\" RETURN p \n"
"Question: Introduce TuGraph and DBGPT seperately.\n"
"Category: Freestyle Question\n"
"entities: [\"TuGraph\", \"DBGPT\"]\n"
"relations: []\n"
"Query:\nMATCH p=(n)-[r:relation*2]-(m) WHERE n.id IN [\"TuGraph\", \"DB-GPT\"] RETURN p\n"
"---------------------\n"
"Question: {question}\n"
"Category: {category}\n"
"entities: {entities}\n"
"relations: {relations}\n"
"Query:\n"
)

logger = logging.getLogger(__name__)


class Text2Cypher(TranslatorBase):
class Text2Cypher(LLMTranslator):
"""Text2Cypher class."""

def __init__(self, llm_client: LLMClient, model_name: str, schema: str):
"""Initialize the Text2Cypher."""
super().__init__(llm_client, model_name, TEXT_TO_CYPHER_PT)
self._schema = json.dumps(json.loads(schema), indent=4)
self._intent_interpreter = IntentInterpreter(llm_client, model_name)

async def _translate(
self, text: str, history: str = None, limit: Optional[int] = None
) -> Dict:
"""Inner translate by LLM."""

"""Interprete the intent of the question."""
intention = await self._intent_interpreter.translate(text)
question = intention["rewrited_question"]
category = intention["category"]
entities = intention["entities"]
relations = intention["relations"]

"""Translate query with intention."""
template = HumanPromptTemplate.from_template(self._prompt_template)

messages = (
template.format_messages(
schema=self._schema,
question=question,
category=category,
entities=entities,
relations=relations,
history=history
)
if history is not None
else template.format_messages(
schema=self._schema,
question=question,
category=category,
entities=entities,
relations=relations
)
)

# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model
logger.info(f"Using model {self._model_name} to extract")

model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)

if not response.success:
code = str(response.error_code)
reason = response.text
logger.error(f"request llm failed ({code}) {reason}")
return []

if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)


def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict:
"""Parse llm response."""
interaction = {}
query = ""

code_block_pattern = re.compile(r'```cypher(.*?)```', re.S)

result = re.findall(code_block_pattern, text)
if result:
query = result[0]
else:
query = text

interaction["query"] = query.strip()

return interaction
7 changes: 7 additions & 0 deletions dbgpt/storage/knowledge_graph/community/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,10 @@ async def truncate(self):
@abstractmethod
def drop(self):
"""Drop community metastore."""

class GraphSyntaxValidator(ABC):
"""Community Syntax Validator."""

@abstractmethod
def validate(self, query: str) -> bool:
"""Validate query syntax."""
Loading

0 comments on commit d4113a3

Please sign in to comment.