From e99ad511fa2448607048c637abf673de1aaef929 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 8 Jun 2024 16:29:27 -0700 Subject: [PATCH] Initial chat memory implementation (#59) --- examples/chess_game.py | 3 ++ examples/orchestrator.py | 5 ++ .../chat/agents/chat_completion_agent.py | 50 ++++++++++--------- src/agnext/chat/memory/__init__.py | 5 ++ src/agnext/chat/memory/_base.py | 15 ++++++ src/agnext/chat/memory/_buffered.py | 29 +++++++++++ src/agnext/chat/memory/_full.py | 24 +++++++++ 7 files changed, 107 insertions(+), 24 deletions(-) create mode 100644 src/agnext/chat/memory/__init__.py create mode 100644 src/agnext/chat/memory/_base.py create mode 100644 src/agnext/chat/memory/_buffered.py create mode 100644 src/agnext/chat/memory/_full.py diff --git a/examples/chess_game.py b/examples/chess_game.py index d21d92653ded..ff63db8b93f4 100644 --- a/examples/chess_game.py +++ b/examples/chess_game.py @@ -11,6 +11,7 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent +from agnext.chat.memory import BufferedChatMemory from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput from agnext.chat.patterns.two_agent_chat import TwoAgentChat from agnext.chat.types import TextMessage @@ -175,6 +176,7 @@ def get_board_text() -> Annotated[str, "The current board state"]: "Think about your strategy and call make_move(thinking, move) to make a move." ), ], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), tools=black_tools, ) @@ -190,6 +192,7 @@ def get_board_text() -> Annotated[str, "The current board state"]: "Think about your strategy and call make_move(thinking, move) to make a move." ), ], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), tools=white_tools, ) diff --git a/examples/orchestrator.py b/examples/orchestrator.py index e38b7cbb3818..1a161ef306ca 100644 --- a/examples/orchestrator.py +++ b/examples/orchestrator.py @@ -11,6 +11,7 @@ ) from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent +from agnext.chat.memory import BufferedChatMemory from agnext.chat.patterns.orchestrator_chat import OrchestratorChat from agnext.chat.types import TextMessage from agnext.components.models import OpenAI, SystemMessage @@ -83,6 +84,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig description="A developer that writes code.", runtime=runtime, system_messages=[SystemMessage("You are a Python developer.")], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), ) @@ -109,6 +111,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig SystemMessage("You are a product manager good at translating customer needs into software specifications."), SystemMessage("You can use the search tool to find information on the web."), ], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), tools=[SearchTool()], ) @@ -118,6 +121,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig description="A planner that organizes and schedules tasks.", runtime=runtime, system_messages=[SystemMessage("You are a planner of complex tasks.")], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), ) @@ -128,6 +132,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig system_messages=[ SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.") ], + memory=BufferedChatMemory(buffer_size=10), model_client=OpenAI(model="gpt-4-turbo"), ) diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index c6a1fcda3354..a6a45e035554 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -2,29 +2,30 @@ import json from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple -from agnext.chat.agents.base import BaseChatAgent -from agnext.chat.types import ( - FunctionCallMessage, - Message, - Reset, - RespondNow, - ResponseFormat, - TextMessage, -) -from agnext.chat.utils import convert_messages_to_llm_messages -from agnext.components import ( +from ...components import ( FunctionCall, TypeRoutedAgent, message_handler, ) -from agnext.components.models import ( +from ...components.models import ( ChatCompletionClient, FunctionExecutionResult, FunctionExecutionResultMessage, SystemMessage, ) -from agnext.components.tools import Tool -from agnext.core import AgentRuntime, CancellationToken +from ...components.tools import Tool +from ...core import AgentRuntime, CancellationToken +from ..memory import ChatMemory +from ..types import ( + FunctionCallMessage, + Message, + Reset, + RespondNow, + ResponseFormat, + TextMessage, +) +from ..utils import convert_messages_to_llm_messages +from .base import BaseChatAgent class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): @@ -34,24 +35,25 @@ def __init__( description: str, runtime: AgentRuntime, system_messages: List[SystemMessage], + memory: ChatMemory, model_client: ChatCompletionClient, tools: Sequence[Tool] = [], ) -> None: super().__init__(name, description, runtime) self._system_messages = system_messages self._client = model_client - self._chat_messages: List[Message] = [] + self._memory = memory self._tools = tools @message_handler() async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # Add a user message. - self._chat_messages.append(message) + self._memory.add_message(message) @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Reset the chat messages. - self._chat_messages = [] + self._memory.clear() @message_handler() async def on_respond_now( @@ -59,7 +61,7 @@ async def on_respond_now( ) -> TextMessage | FunctionCallMessage: # Get a response from the model. response = await self._client.create( - self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name), + self._system_messages + convert_messages_to_llm_messages(self._memory.get_messages(), self.name), tools=self._tools, json_output=message.response_format == ResponseFormat.json_object, ) @@ -80,7 +82,7 @@ async def on_respond_now( ) # Make an assistant message from the response. response = await self._client.create( - self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name), + self._system_messages + convert_messages_to_llm_messages(self._memory.get_messages(), self.name), tools=self._tools, json_output=message.response_format == ResponseFormat.json_object, ) @@ -96,7 +98,7 @@ async def on_respond_now( raise ValueError(f"Unexpected response: {response.content}") # Add the response to the chat messages. - self._chat_messages.append(final_response) + self._memory.add_message(final_response) # Return the response. return final_response @@ -109,7 +111,7 @@ async def on_tool_call_message( raise ValueError("No tools available") # Add a tool call message. - self._chat_messages.append(message) + self._memory.add_message(message) # Execute the tool calls. results: List[FunctionExecutionResult] = [] @@ -146,7 +148,7 @@ async def on_tool_call_message( tool_call_result_msg = FunctionExecutionResultMessage(content=results) # Add tool call result message. - self._chat_messages.append(tool_call_result_msg) + self._memory.add_message(tool_call_result_msg) # Return the results. return tool_call_result_msg @@ -172,11 +174,11 @@ async def execute_function( def save_state(self) -> Mapping[str, Any]: return { "description": self.description, - "chat_messages": self._chat_messages, + "memory": self._memory.save_state(), "system_messages": self._system_messages, } def load_state(self, state: Mapping[str, Any]) -> None: - self._chat_messages = state["chat_messages"] + self._memory.load_state(state["memory"]) self._system_messages = state["system_messages"] self._description = state["description"] diff --git a/src/agnext/chat/memory/__init__.py b/src/agnext/chat/memory/__init__.py new file mode 100644 index 000000000000..0cefbdef1c7b --- /dev/null +++ b/src/agnext/chat/memory/__init__.py @@ -0,0 +1,5 @@ +from ._base import ChatMemory +from ._buffered import BufferedChatMemory +from ._full import FullChatMemory + +__all__ = ["ChatMemory", "FullChatMemory", "BufferedChatMemory"] diff --git a/src/agnext/chat/memory/_base.py b/src/agnext/chat/memory/_base.py new file mode 100644 index 000000000000..f22dc9a19880 --- /dev/null +++ b/src/agnext/chat/memory/_base.py @@ -0,0 +1,15 @@ +from typing import Any, List, Mapping, Protocol + +from ..types import Message + + +class ChatMemory(Protocol): + def add_message(self, message: Message) -> None: ... + + def get_messages(self) -> List[Message]: ... + + def clear(self) -> None: ... + + def save_state(self) -> Mapping[str, Any]: ... + + def load_state(self, state: Mapping[str, Any]) -> None: ... diff --git a/src/agnext/chat/memory/_buffered.py b/src/agnext/chat/memory/_buffered.py new file mode 100644 index 000000000000..2d2110f438a8 --- /dev/null +++ b/src/agnext/chat/memory/_buffered.py @@ -0,0 +1,29 @@ +from typing import Any, List, Mapping + +from ..types import Message +from ._base import ChatMemory + + +class BufferedChatMemory(ChatMemory): + def __init__(self, buffer_size: int) -> None: + self._messages: List[Message] = [] + self._buffer_size = buffer_size + + def add_message(self, message: Message) -> None: + self._messages.append(message) + + def get_messages(self) -> List[Message]: + return self._messages[-self._buffer_size :] + + def clear(self) -> None: + self._messages = [] + + def save_state(self) -> Mapping[str, Any]: + return { + "messages": [message for message in self._messages], + "buffer_size": self._buffer_size, + } + + def load_state(self, state: Mapping[str, Any]) -> None: + self._messages = state["messages"] + self._buffer_size = state["buffer_size"] diff --git a/src/agnext/chat/memory/_full.py b/src/agnext/chat/memory/_full.py new file mode 100644 index 000000000000..166f40c39c32 --- /dev/null +++ b/src/agnext/chat/memory/_full.py @@ -0,0 +1,24 @@ +from typing import Any, List, Mapping + +from ..types import Message +from ._base import ChatMemory + + +class FullChatMemory(ChatMemory): + def __init__(self) -> None: + self._messages: List[Message] = [] + + def add_message(self, message: Message) -> None: + self._messages.append(message) + + def get_messages(self) -> List[Message]: + return self._messages + + def clear(self) -> None: + self._messages = [] + + def save_state(self) -> Mapping[str, Any]: + return {"messages": [message for message in self._messages]} + + def load_state(self, state: Mapping[str, Any]) -> None: + self._messages = state["messages"]