Skip to content

Commit

Permalink
feat(agent):Fix agent bug (eosphoros-ai#1953)
Browse files Browse the repository at this point in the history
Co-authored-by: aries_ckt <[email protected]>
  • Loading branch information
yhjun1026 and Aries-ckt authored Sep 4, 2024
1 parent d72bfb2 commit b951b50
Show file tree
Hide file tree
Showing 18 changed files with 67 additions and 46 deletions.
4 changes: 2 additions & 2 deletions dbgpt/agent/core/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def to_dict(self) -> Dict[str, Any]:
class Action(ABC, Generic[T]):
"""Base Action class for defining agent actions."""

def __init__(self):
def __init__(self, language: str = "en"):
"""Create an action."""
self.resource: Optional[Resource] = None
self.language: str = "en"
self.language: str = language

def init_resource(self, resource: Optional[Resource]):
"""Initialize the resource."""
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/core/action/blank_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
class BlankAction(Action):
"""Blank action class."""

def __init__(self):
"""Create a blank action."""
super().__init__()
def __init__(self, **kwargs):
"""Blank action init."""
super().__init__(**kwargs)

@property
def ai_out_schema(self) -> Optional[str]:
Expand Down
4 changes: 0 additions & 4 deletions dbgpt/agent/core/agent_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,13 @@ def all_agents(self) -> Dict[str, str]:
def list_agents(self):
"""Return a list of all registered agents and their descriptions."""
result = []
from datetime import datetime

logger.info(f"List Agent Begin:{datetime.now()}")
for name, value in self._agents.items():
result.append(
{
"name": value[1].role,
"desc": value[1].goal,
}
)
logger.info(f"List Agent End:{datetime.now()}")
return result


Expand Down
2 changes: 1 addition & 1 deletion dbgpt/agent/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def _init_actions(self, actions: List[Type[Action]]):
self.actions = []
for idx, action in enumerate(actions):
if issubclass(action, Action):
self.actions.append(action())
self.actions.append(action(language=self.language))

async def _a_append_message(
self, message: AgentMessage, role, sender: Agent
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/agent/core/plan/plan_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PlanAction(Action[List[PlanInput]]):

def __init__(self, **kwargs):
"""Create a plan action."""
super().__init__()
super().__init__(**kwargs)
self._render_protocol = VisAgentPlans()

@property
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/expand/actions/chart_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class SqlInput(BaseModel):
class ChartAction(Action[SqlInput]):
"""Chart action class."""

def __init__(self):
"""Create a chart action."""
super().__init__()
def __init__(self, **kwargs):
"""Chart action init."""
super().__init__(**kwargs)
self._render_protocol = VisChart()

@property
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/expand/actions/code_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
class CodeAction(Action[None]):
"""Code Action Module."""

def __init__(self):
"""Create a code action."""
super().__init__()
def __init__(self, **kwargs):
"""Code action init."""
super().__init__(**kwargs)
self._render_protocol = VisCode()
self._code_execution_config = {}

Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/expand/actions/dashboard_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def to_dict(self):
class DashboardAction(Action[List[ChartItem]]):
"""Dashboard action class."""

def __init__(self):
"""Create a dashboard action."""
super().__init__()
def __init__(self, **kwargs):
"""Dashboard action init."""
super().__init__(**kwargs)
self._render_protocol = VisDashboard()

@property
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/expand/actions/indicator_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class IndicatorInput(BaseModel):
class IndicatorAction(Action[IndicatorInput]):
"""Indicator Action."""

def __init__(self):
"""Init Indicator Action."""
super().__init__()
def __init__(self, **kwargs):
"""Init indicator action."""
super().__init__(**kwargs)
self._render_protocol = VisApiResponse()

@property
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/expand/actions/tool_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class ToolInput(BaseModel):
class ToolAction(Action[ToolInput]):
"""Tool action class."""

def __init__(self):
"""Create a plugin action."""
super().__init__()
def __init__(self, **kwargs):
"""Tool action init."""
super().__init__(**kwargs)
self._render_protocol = VisPlugin()

@property
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/agent/resource/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ async def get_prompt(
prompt_template = f"\nResources-{self.name}:\n {content}"
prompt_template_zh = f"\n资源-{self.name}:\n {content}"
if lang == "en":
return prompt_template.format(content=content), self._get_references(chunks)
return prompt_template_zh.format(content=content), self._get_references(chunks)
return prompt_template, self._get_references(chunks)
return prompt_template_zh, self._get_references(chunks)

async def get_resources(
self,
Expand Down
11 changes: 11 additions & 0 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,17 @@ async def chat_completions(
headers=headers,
media_type="text/plain",
)
except Exception as e:
logger.exception(f"Chat Exception!{dialogue}", e)

async def error_text(err_msg):
yield f"data:{err_msg}\n\n"

return StreamingResponse(
error_text(str(e)),
headers=headers,
media_type="text/plain",
)
finally:
# write to recent usage app.
if dialogue.user_name is not None and dialogue.app_code is not None:
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, chat_param: Dict):
"""

self.select_param = chat_param["select_param"]
if not self.select_param:
raise ValueError("Please upload the Excel document you want to talk to!")
self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel
self.chat_param = chat_param
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def is_chinese(text):


class ExcelReader:
def __init__(self, conv_uid, file_param):
def __init__(self, conv_uid: str, file_param: str):
self.conv_uid = conv_uid
self.file_param = file_param
if isinstance(file_param, str) and os.path.isabs(file_param):
Expand Down
9 changes: 7 additions & 2 deletions dbgpt/datasource/manages/connect_config_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,15 @@ def get_db_config(self, db_name):
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
"""Get db list."""
session = self.get_raw_session()
if db_name:
if db_name and user_id:
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
else:
elif user_id:
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
elif db_name:
sql = f"SELECT * FROM connect_config where db_name='{db_name}'" # noqa
else:
sql = f"SELECT * FROM connect_config" # noqa

result = session.execute(text(sql))
fields = [field[0] for field in result.cursor.description] # type: ignore
data = []
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/serve/agent/agents/expand/actions/app_start_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def run(
**kwargs,
) -> ActionOutput:
conv_id = kwargs.get("conv_id")
user_input = kwargs.get("user_input")
paren_agent = kwargs.get("paren_agent")
init_message_rounds = kwargs.get("init_message_rounds")

Expand Down Expand Up @@ -83,7 +84,7 @@ async def run(
from dbgpt.serve.agent.agents.controller import multi_agents

await multi_agents.agent_team_chat_new(
new_user_input,
new_user_input if new_user_input else user_input,
conv_id,
gpts_app,
paren_agent.memory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,28 @@ def out_model_type(self):

@property
def ai_out_schema(self) -> Optional[str]:
out_put_schema = {
"intent": "[The recognized intent is placed here]",
"app_code": "[App code in selected intent]",
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
"ask_user": "If you want the user to supplement slot data, ask the user a question",
"user_input": "[Complete instructions generated based on intent and slot]",
}
if self.language == "en":
out_put_schema = {
"intent": "[The recognized intent is placed here]",
"app_code": "[App code in selected intent]",
"slots": {
"Slot attribute 1 in intent definition": "value",
"Slot attribute 2 in intent definition": "value",
},
"ask_user": "[If you want the user to supplement slot data, ask the user a question]",
"user_input": "[Complete instructions generated based on intent and slot]",
}
return f"""Please reply in the following json format:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the output is only json and can be parsed by Python json.loads.""" # noqa: E501
else:
out_put_schema = {
"intent": "选择的意图放在这里",
"app_code": "选择意图对应的Appcode值",
"slots": {"意图定义中槽位属性1": "具体值", "意图定义中槽位属性2": "具体值"},
"ask_user": "如果需要用户补充槽位属性的具体值,请向用户进行提问",
"user_input": "根据意图和槽位生成完整指令问题",
}
return f"""请按如下JSON格式输出:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
确保输出只有json,且可以被python json.loads加载."""
Expand Down
10 changes: 3 additions & 7 deletions dbgpt/storage/knowledge_graph/community/community_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,17 @@ def __init__(

async def build_communities(self):
"""Discover communities."""
community_ids = await (self._community_store_adapter.discover_communities())
community_ids = await self._community_store_adapter.discover_communities()

# summarize communities
communities = []
for community_id in community_ids:
community = await (
self._community_store_adapter.get_community(community_id)
)
community = await self._community_store_adapter.get_community(community_id)
graph = community.data.format()
if not graph:
break

community.summary = await (
self._community_summarizer.summarize(graph=graph)
)
community.summary = await self._community_summarizer.summarize(graph=graph)
communities.append(community)
logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
Expand Down

0 comments on commit b951b50

Please sign in to comment.