Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Model Client class support for 'auto' speaker selection in Group Chat #65

Merged
merged 3 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 82 additions & 41 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ..formatting_utils import colored
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..io.base import IOStream
from ..oai.client import ModelClient
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
from .chat import ChatResult
from .contrib.capabilities import transform_messages
from .conversable_agent import ConversableAgent

Expand Down Expand Up @@ -106,6 +106,8 @@ def custom_speaker_selection_func(
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
- send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
- select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional)
- select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional)
- role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""

Expand Down Expand Up @@ -143,6 +145,8 @@ def custom_speaker_selection_func(
Respond with ONLY the name of the speaker and DO NOT provide a reason."""
select_speaker_transform_messages: Optional[transform_messages.TransformMessages] = None
select_speaker_auto_verbose: Optional[bool] = False
select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None
select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None
role_for_select_speaker_messages: Optional[str] = "system"

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
Expand Down Expand Up @@ -587,6 +591,79 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)

def _register_client_from_config(self, agent: Agent, config: Dict):
model_client_cls_to_match = config.get("model_client_cls")
if model_client_cls_to_match:
if not self.select_speaker_auto_model_client_cls:
raise ValueError(
"A custom model was detected in the config but no 'model_client_cls' "
"was supplied for registration in GroupChat."
)

if isinstance(self.select_speaker_auto_model_client_cls, list):
# Register the first custom model client class matching the name specified in the config
matching_model_cls = [
client_cls
for client_cls in self.select_speaker_auto_model_client_cls
if client_cls.__name__ == model_client_cls_to_match
]
if len(set(matching_model_cls)) > 1:
raise RuntimeError(
f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'."
)
if not matching_model_cls:
raise ValueError(
"No model's __name__ matches the model client class "
f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config."
)
select_speaker_auto_model_client_cls = matching_model_cls[0]
else:
# Register the only custom model client
select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls

agent.register_model_client(select_speaker_auto_model_client_cls)

def _register_custom_model_clients(self, agent: ConversableAgent):
if not self.select_speaker_auto_llm_config:
return

config_format_is_list = "config_list" in self.select_speaker_auto_llm_config.keys()
if config_format_is_list:
for config in self.select_speaker_auto_llm_config["config_list"]:
self._register_client_from_config(agent, config)
elif not config_format_is_list:
self._register_client_from_config(agent, self.select_speaker_auto_llm_config)

def _create_internal_agents(
self, agents, max_attempts, messages, validate_speaker_name, selector: Optional[ConversableAgent] = None
):
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# Override the selector's config if one was passed as a parameter to this class
speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=speaker_selection_llm_config,
human_input_mode="NEVER",
# Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
)

# Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent
self._register_custom_model_clients(speaker_selection_agent)

return checking_agent, speaker_selection_agent

def _auto_select_speaker(
self,
last_speaker: Agent,
Expand Down Expand Up @@ -640,28 +717,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages=(
{checking_agent: messages}
if self.select_speaker_prompt_template is not None
else {checking_agent: messages[:-1]}
),
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down Expand Up @@ -743,24 +800,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down
58 changes: 57 additions & 1 deletion test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io
import json
import logging
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
from unittest import TestCase, mock

Expand Down Expand Up @@ -2068,6 +2069,60 @@ def test_manager_resume_messages():
return_agent, return_message = manager.resume(messages="Let's get this conversation started.")


def test_custom_model_client():
class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

response = SimpleNamespace()
response.choices = []
response.model = "test_model_name"

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

llm_config = {"config_list": [{"model": "test_model_name", "model_client_cls": "CustomModelClient"}]}

group_chat = autogen.GroupChat(
agents=[],
messages=[],
max_round=3,
select_speaker_auto_llm_config=llm_config,
select_speaker_auto_model_client_cls=CustomModelClient,
)

checking_agent, speaker_selection_agent = group_chat._create_internal_agents(
agents=[], messages=[], max_attempts=3, validate_speaker_name=(True, "test")
)

# Check that the custom model client is assigned to the speaker selection agent
assert isinstance(speaker_selection_agent.client._clients[0], CustomModelClient)

# Check that the LLM Config is assigned
assert speaker_selection_agent.client._config_list == llm_config["config_list"]


def test_select_speaker_transform_messages():
"""Tests adding transform messages to a GroupChat for speaker selection when in 'auto' mode"""

Expand Down Expand Up @@ -2182,5 +2237,6 @@ def test_manager_resume_message_assignment():
# test_manager_resume_returns()
# test_manager_resume_messages()
# test_select_speaker_transform_messages()
test_manager_resume_message_assignment()
# test_manager_resume_message_assignment()
test_custom_model_client()
pass
Loading
Loading