diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index f1fadcca27..131f043038 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1024,6 +1024,20 @@ def _raise_exception_on_async_reply_functions(self) -> None: raise RuntimeError(msg) + def _get_related_agents_for_usage(self) -> List[Agent]: + """Gets all agents related to this agent so they can be used in the usage summary. + + In ConversableAgent, this is only itself, but this can be overridden in other Agent classes to add additional agents, such as for + the agents in a group chat when self is the GroupChatManager.""" + + return [self] + + @staticmethod + def _get_agents_for_usage_summary(sender: "ConversableAgent", recipient: "ConversableAgent") -> List[Agent]: + """Gets all agents for a chat session in order to calculate the usage summary.""" + + return list(set(sender._get_related_agents_for_usage() + recipient._get_related_agents_for_usage())) + def initiate_chat( self, recipient: "ConversableAgent", @@ -1162,7 +1176,7 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d chat_result = ChatResult( chat_history=self.chat_messages[recipient], summary=summary, - cost=gather_usage_summary([self, recipient]), + cost=gather_usage_summary(ConversableAgent._get_agents_for_usage_summary(self, recipient)), human_input=self._human_input, ) return chat_result @@ -1228,7 +1242,7 @@ async def a_initiate_chat( chat_result = ChatResult( chat_history=self.chat_messages[recipient], summary=summary, - cost=gather_usage_summary([self, recipient]), + cost=gather_usage_summary(ConversableAgent._get_agents_for_usage_summary(self, recipient)), human_input=self._human_input, ) return chat_result diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 0e14bf35f8..0d32de74ef 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -1666,3 +1666,12 @@ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :]) return reply_content + + def _get_related_agents_for_usage(self) -> List[Agent]: + """Gets all agents in the groupchat in order to calculate the usage summary. + + This overrides the ConversableAgent method to include groupchat agents. + + TODO: Persist and include the speaker selection agent from the auto speaker selection mode.""" + + return [self] + self._groupchat.agents