Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LMM: use _generate_oai_reply_from_client #58

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions autogen/agentchat/contrib/multimodal_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
)
from autogen.code_utils import content_str

from ..._pydantic import model_dump

DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
DEFAULT_MODEL = "gpt-4-vision-preview"
DEFAULT_MODEL = "gpt-4-turbo"


class MultimodalConversableAgent(ConversableAgent):
Expand Down Expand Up @@ -116,13 +114,5 @@ def generate_oai_reply(

messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)

# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=messages_with_b64_img, agent=self.name
)

# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = model_dump(extracted_response)
return True, extracted_response
extracted_response = self._generate_oai_reply_from_client(client, messages_with_b64_img, self.client_cache)
return (False, None) if extracted_response is None else (True, extracted_response)
46 changes: 44 additions & 2 deletions test/agentchat/contrib/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from unittest.mock import MagicMock

import pytest
from annotated_types import Annotated
from conftest import MOCK_OPEN_AI_API_KEY

import autogen
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen import AssistantAgent, ConversableAgent, UserProxyAgent

try:
from autogen.agentchat.contrib.img_utils import get_pil_image
Expand All @@ -23,6 +24,7 @@
else:
skip = False

VISION_MODEL_NAME = "gpt-4-turbo"

base64_encoded_image = (
""
Expand All @@ -44,7 +46,7 @@ def setUp(self):
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": MOCK_OPEN_AI_API_KEY}],
"config_list": [{"model": VISION_MODEL_NAME, "api_key": MOCK_OPEN_AI_API_KEY}],
},
)

Expand Down Expand Up @@ -144,5 +146,45 @@ def test_group_chat_with_lmm():
assert all(len(arr) <= max_round for arr in user_proxy._oai_messages.values()), "User proxy exceeded max rounds"


@pytest.mark.skipif(skip, reason="Dependency not installed")
def test_func_call_with_lmm():
assistant = MultimodalConversableAgent(
name="Assistant",
system_message="Describe all the colors in the image.",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": VISION_MODEL_NAME, "api_key": MOCK_OPEN_AI_API_KEY}],
},
)

coder = AssistantAgent(
name="Coder",
system_message="YOU MUST USE THE FUNCTION PROVIDED.",
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": VISION_MODEL_NAME, "api_key": MOCK_OPEN_AI_API_KEY}],
},
human_input_mode="NEVER",
code_execution_config=False,
max_consecutive_auto_reply=2,
)

def count_colors(colors: list) -> int:
return len(colors)

coder.register_for_llm(name="count_colors", description="Count colors.")(count_colors)
assistant.register_for_execution(name="count_colors")(count_colors)

coder.initiate_chat(
assistant, clear_history=True, message=f"""How many colors here: <img {base64_encoded_image}>"""
)

assert len(coder._oai_messages[assistant]) > 1, "Function call did not happen"


if __name__ == "__main__":
unittest.main()
Loading