Skip to content

Commit

Permalink
Custom Model Client class support for 'auto' speaker selection in Gro…
Browse files Browse the repository at this point in the history
…up Chat (#65)

* Changes to support custom model client class for auto speaker selection group chats

* Documentation added

---------

Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
marklysze and qingyun-wu authored Oct 20, 2024
1 parent a81a6a7 commit 1111adc
Show file tree
Hide file tree
Showing 3 changed files with 590 additions and 42 deletions.
123 changes: 82 additions & 41 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ..formatting_utils import colored
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..io.base import IOStream
from ..oai.client import ModelClient
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
from .chat import ChatResult
from .contrib.capabilities import transform_messages
from .conversable_agent import ConversableAgent

Expand Down Expand Up @@ -106,6 +106,8 @@ def custom_speaker_selection_func(
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
- send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
- select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional)
- select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional)
- role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""

Expand Down Expand Up @@ -143,6 +145,8 @@ def custom_speaker_selection_func(
Respond with ONLY the name of the speaker and DO NOT provide a reason."""
select_speaker_transform_messages: Optional[transform_messages.TransformMessages] = None
select_speaker_auto_verbose: Optional[bool] = False
select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None
select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None
role_for_select_speaker_messages: Optional[str] = "system"

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
Expand Down Expand Up @@ -587,6 +591,79 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)

def _register_client_from_config(self, agent: Agent, config: Dict):
model_client_cls_to_match = config.get("model_client_cls")
if model_client_cls_to_match:
if not self.select_speaker_auto_model_client_cls:
raise ValueError(
"A custom model was detected in the config but no 'model_client_cls' "
"was supplied for registration in GroupChat."
)

if isinstance(self.select_speaker_auto_model_client_cls, list):
# Register the first custom model client class matching the name specified in the config
matching_model_cls = [
client_cls
for client_cls in self.select_speaker_auto_model_client_cls
if client_cls.__name__ == model_client_cls_to_match
]
if len(set(matching_model_cls)) > 1:
raise RuntimeError(
f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'."
)
if not matching_model_cls:
raise ValueError(
"No model's __name__ matches the model client class "
f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config."
)
select_speaker_auto_model_client_cls = matching_model_cls[0]
else:
# Register the only custom model client
select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls

agent.register_model_client(select_speaker_auto_model_client_cls)

def _register_custom_model_clients(self, agent: ConversableAgent):
if not self.select_speaker_auto_llm_config:
return

config_format_is_list = "config_list" in self.select_speaker_auto_llm_config.keys()
if config_format_is_list:
for config in self.select_speaker_auto_llm_config["config_list"]:
self._register_client_from_config(agent, config)
elif not config_format_is_list:
self._register_client_from_config(agent, self.select_speaker_auto_llm_config)

def _create_internal_agents(
self, agents, max_attempts, messages, validate_speaker_name, selector: Optional[ConversableAgent] = None
):
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# Override the selector's config if one was passed as a parameter to this class
speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=speaker_selection_llm_config,
human_input_mode="NEVER",
# Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
)

# Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent
self._register_custom_model_clients(speaker_selection_agent)

return checking_agent, speaker_selection_agent

def _auto_select_speaker(
self,
last_speaker: Agent,
Expand Down Expand Up @@ -640,28 +717,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages=(
{checking_agent: messages}
if self.select_speaker_prompt_template is not None
else {checking_agent: messages[:-1]}
),
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down Expand Up @@ -743,24 +800,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down
58 changes: 57 additions & 1 deletion test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io
import json
import logging
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
from unittest import TestCase, mock

Expand Down Expand Up @@ -2068,6 +2069,60 @@ def test_manager_resume_messages():
return_agent, return_message = manager.resume(messages="Let's get this conversation started.")


def test_custom_model_client():
class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

response = SimpleNamespace()
response.choices = []
response.model = "test_model_name"

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

llm_config = {"config_list": [{"model": "test_model_name", "model_client_cls": "CustomModelClient"}]}

group_chat = autogen.GroupChat(
agents=[],
messages=[],
max_round=3,
select_speaker_auto_llm_config=llm_config,
select_speaker_auto_model_client_cls=CustomModelClient,
)

checking_agent, speaker_selection_agent = group_chat._create_internal_agents(
agents=[], messages=[], max_attempts=3, validate_speaker_name=(True, "test")
)

# Check that the custom model client is assigned to the speaker selection agent
assert isinstance(speaker_selection_agent.client._clients[0], CustomModelClient)

# Check that the LLM Config is assigned
assert speaker_selection_agent.client._config_list == llm_config["config_list"]


def test_select_speaker_transform_messages():
"""Tests adding transform messages to a GroupChat for speaker selection when in 'auto' mode"""

Expand Down Expand Up @@ -2182,5 +2237,6 @@ def test_manager_resume_message_assignment():
# test_manager_resume_returns()
# test_manager_resume_messages()
# test_select_speaker_transform_messages()
test_manager_resume_message_assignment()
# test_manager_resume_message_assignment()
test_custom_model_client()
pass
Loading

0 comments on commit 1111adc

Please sign in to comment.