Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agent): optimize of base agent #2276

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions dbgpt/agent/core/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class ActionOutput(BaseModel):
ask_user: Optional[bool] = False
# 如果当前agent能确定下个发言者,需要在这里指定
next_speakers: Optional[List[str]] = None
# 强制重试,不受重试次数影响限制
force_retry: Optional[bool] = False

@model_validator(mode="before")
@classmethod
Expand Down
40 changes: 39 additions & 1 deletion dbgpt/agent/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ConversableAgent(Role, Agent):
stream_out: bool = True
# 确认当前Agent是否需要进行参考资源展示
show_reference: bool = False
name_prefix: Optional[str] = None

executor: Executor = Field(
default_factory=lambda: ThreadPoolExecutor(max_workers=1),
Expand All @@ -58,6 +59,13 @@ def __init__(self, **kwargs):
Role.__init__(self, **kwargs)
Agent.__init__(self)

@property
def name(self) -> str:
"""Return the name of the agent."""
if self.name_prefix is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is different to understand. Maybe it's better to use . or - such as:
Report.ChatReport or Report-ChatReport?

return f"{self.current_profile.get_name()}[{self.name_prefix}]"
return self.current_profile.get_name()

def check_available(self) -> None:
"""Check if the agent is available.

Expand Down Expand Up @@ -347,7 +355,22 @@ async def generate_reply(
fail_reason = None
current_retry_counter = 0
is_success = True
while current_retry_counter < self.max_retry_count:
force_retry = False
while force_retry or current_retry_counter < self.max_retry_count:
# Action force_retry 强制重试的处理: Action中明确指定需要进行重试,重试的消息按以下规则重新生成
# - 重新生成消息,保留上一轮的Action,并增加Rounds
# - 将上一轮的Action的Content作为当前输入消息
if force_retry:
if reply_message.action_report is None:
raise ValueError("action output is None when force_retry")
received_message.content = reply_message.action_report.content
received_message.rounds = reply_message.rounds + 1
reply_message = self._init_reply_message(
received_message=received_message,
rely_messages=rely_messages,
)

# 普通重试的处理
if current_retry_counter > 0:
retry_message = self._init_reply_message(
received_message=received_message,
Expand Down Expand Up @@ -460,6 +483,21 @@ async def generate_reply(

question: str = received_message.content or ""
ai_message: str = llm_reply or ""

# force_retry means this reply do not complete
# should reentry and do more things
force_retry = False
if act_out is not None and act_out.force_retry:
await self.write_memories(
question=question,
ai_message=ai_message,
action_output=act_out,
check_pass=check_pass,
check_fail_reason=fail_reason,
)
force_retry = True
continue

# 5.Optimize wrong answers myself
if not check_pass:
if not act_out.have_retry:
Expand Down
5 changes: 4 additions & 1 deletion dbgpt/agent/core/memory/gpts/default_gpts_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,7 @@ def get_by_conv_id(self, conv_id: str) -> List[GptsMessage]:

def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
"""Get the last message in the conversation."""
return None
messages: List[GptsMessage] = self.get_by_conv_id(conv_id)
if messages is None or len(messages) == 0:
return None
return messages[-1]
5 changes: 3 additions & 2 deletions dbgpt/agent/core/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,13 @@ async def write_memories(
raise ValueError("Action output is required to save to memory.")

mem_thoughts = action_output.thoughts or ai_message
observation = action_output.observations
observation = check_fail_reason or action_output.observations
action = action_output.action

memory_map = {
"question": question,
"thought": mem_thoughts,
"action": check_fail_reason,
"action": action,
"observation": observation,
}
write_memory_template = self.write_memory_template
Expand Down
48 changes: 42 additions & 6 deletions dbgpt/agent/expand/actions/tool_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ class ToolInput(BaseModel):
tool_name: str = Field(
...,
description="The name of a tool that can be used to answer the current question"
" or solve the current task.",
" or solve the current task. "
"If no suitable tool is selected, leave this blank.",
)
args: dict = Field(
default={"arg name1": "", "arg name2": ""},
description="The tool selected for the current target, the parameter "
"information required for execution",
"information required for execution, "
"If no suitable tool is selected, leave this blank.",
)
thought: str = Field(..., description="Summary of thoughts to the user")

Expand Down Expand Up @@ -68,9 +70,9 @@ def ai_out_schema(self) -> Optional[str]:
}

return f"""Please response in the following json format:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the response is correct json and can be parsed by Python json.loads.
"""
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the response is correct json and can be parsed by Python json.loads.
and do not write the comment in json,only write the json content."""

async def run(
self,
Expand All @@ -91,6 +93,15 @@ async def run(
need_vis_render (bool, optional): Whether need visualization rendering.
Defaults to True.
"""
success, error = parse_json_safe(ai_message)
if not success:
return ActionOutput(
is_exe_success=False,
content=f"Tool Action execute failed! llm reply {ai_message} "
f"is not a valid json format, json error: {error}. "
f"You need to strictly return the raw JSON format. ",
)

try:
param: ToolInput = self._input_convert(ai_message, ToolInput)
except Exception as e:
Expand All @@ -100,6 +111,16 @@ async def run(
content="The requested correctly structured answer could not be found.",
)

if param.tool_name is None or param.tool_name == "":
# can not choice tools, it must be some reason
return ActionOutput(
is_exe_success=False,
# content= param.thought,
content=f"There are no suitable tools available "
f"to achieve the user's goal: '{param.thought}'",
have_retry=False,
)

try:
tool_packs = ToolPack.from_resource(self.resource)
if not tool_packs:
Expand Down Expand Up @@ -137,10 +158,25 @@ async def run(
is_exe_success=response_success,
content=str(tool_result),
view=view,
thoughts=param.thought,
action=str({"tool_name": param.tool_name, "args": param.args}),
observations=str(tool_result),
)
except Exception as e:
logger.exception("Tool Action Run Failed!")
return ActionOutput(
is_exe_success=False, content=f"Tool action run failed!{str(e)}"
is_exe_success=False,
content=f"Tool action run failed!{str(e)}",
action=str({"tool_name": param.tool_name, "args": param.args}),
)


def parse_json_safe(json_str):
"""Try to parse json."""
try:
# try to parse json
data = json.loads(json_str)
return True, data
except json.JSONDecodeError as e:
# 捕捉JSON解析错误并返回详细信息
return False, e.msg
45 changes: 44 additions & 1 deletion dbgpt/agent/expand/tool_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Plugin Assistant Agent."""

import logging
from typing import List, Optional

from .. import Resource, ResourceType
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..resource import BaseTool
from .actions.tool_action import ToolAction

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -37,7 +40,10 @@ class ToolAssistantAgent(ConversableAgent):
"goal.",
"Please output the selected tool name and specific parameter "
"information in json format according to the following required format."
" If there is an example, please refer to the sample format output.",
"If there is an example, please refer to the sample format output.",
"It is not necessarily required to select a tool for execution. "
"If the tool to be used or its parameters cannot be clearly "
"determined based on the user's input, you can choose not to execute.",
],
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_constraints",
Expand All @@ -54,3 +60,40 @@ def __init__(self, **kwargs):
"""Create a new instance of ToolAssistantAgent."""
super().__init__(**kwargs)
self._init_actions([ToolAction])

@property
def desc(self) -> Optional[str]:
"""Return desc of this agent."""
tools = _get_tools_by_resource(self.resource)
if tools is None or len(tools) == 0:
return "Has no tools to use"

tools_desc_list = []
for i in range(len(tools)):
tool = tools[i]
s = f"{i + 1}. tool {tool.name}, can {tool.description}."
tools_desc_list.append(s)

return (
"Can use the following tools to complete the task objectives, "
"tool information: "
f"{' '.join(tools_desc_list)}"
)


def _get_tools_by_resource(resource: Optional[Resource]) -> Optional[List[BaseTool]]:
tools: List[BaseTool] = []

if resource is None:
return tools

if resource.type() == ResourceType.Tool and isinstance(resource, BaseTool):
tools.append(resource)
elif resource.type() == ResourceType.Pack:
for sub_res in resource.sub_resources:
res_list = _get_tools_by_resource(sub_res)
if res_list is not None and len(res_list) > 0:
for res in res_list:
tools.append(res)

return tools
Loading