Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Dec 25, 2024
1 parent 0435414 commit 31ca1cf
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 37 deletions.
88 changes: 51 additions & 37 deletions autogen/messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Literal, Optional, TypeVar, Union

from pydantic import BaseModel
Expand All @@ -7,7 +11,7 @@
from .code_utils import content_str
from .oai.client import OpenAIWrapper

MessageRole = TypeVar("MessageRole", bound=Literal["assistant", "function", "tool"])
MessageRole = Literal["assistant", "function", "tool"]


class BaseMessage(BaseModel):
Expand Down Expand Up @@ -88,24 +92,25 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:


class FunctionCallMessage(BaseMessage):
content: Optional[str] = None
content: Optional[str] = None # type: ignore [assignment]
function_call: FunctionCall
# ToDo: Does function call has context?
context: Optional[dict[str, Any]] = None
llm_config: Union[dict, Literal[False]]
llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
super().print(f)

if self.content is not None:
content = self.content
if self.context is not None:
content = OpenAIWrapper.instantiate(
content,
self.context,
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
allow_format_str_template = (
self.llm_config.get("allow_format_str_template", False) if self.llm_config else False
)
content = OpenAIWrapper.instantiate(
self.content,
self.context,
allow_format_str_template,
)
f(content_str(content), flush=True)

self.function_call.print(f)
Expand All @@ -115,7 +120,7 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:

class ToolCall(BaseModel):
id: Optional[str] = None
function: Optional[FunctionCall] = None
function: FunctionCall
type: str

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
Expand All @@ -138,28 +143,29 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:


class ToolCallMessage(BaseMessage):
content: Optional[str] = None
content: Optional[str] = None # type: ignore [assignment]
refusal: Optional[str] = None
role: MessageRole
audio: Optional[str] = None
function_call: Optional[FunctionCall] = None
tool_calls: list[ToolCall]
# ToDo: Does tool calls has context?
context: Optional[dict[str, Any]] = None
llm_config: Union[dict, Literal[False]]
llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
super().print(f)

if self.content is not None:
content = self.content
if self.context is not None:
content = OpenAIWrapper.instantiate(
content,
self.context,
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
allow_format_str_template = (
self.llm_config.get("allow_format_str_template", False) if self.llm_config else False
)
content = OpenAIWrapper.instantiate(
self.content,
self.context,
allow_format_str_template,
)
f(content_str(content), flush=True)

for tool_call in self.tool_calls:
Expand All @@ -169,32 +175,31 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:


class ContentMessage(BaseMessage):
content: Optional[Union[str, Callable[..., Any]]] = None
content: Optional[Union[str, Callable[..., Any]]] = None # type: ignore [assignment]
context: Optional[dict[str, Any]] = None
llm_config: Union[dict, Literal[False]]
llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
super().print(f)

if self.content is not None:
content = self.content
if self.context is not None:
content = OpenAIWrapper.instantiate(
content,
self.context,
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
allow_format_str_template = (
self.llm_config.get("allow_format_str_template", False) if self.llm_config else False
)
content = OpenAIWrapper.instantiate(
self.content,
self.context,
allow_format_str_template,
)
f(content_str(content), flush=True)

f("\n", "-" * 80, flush=True, sep="")


def create_message_model(message: Union[dict[str, Any], str], sender: Agent, receiver: Agent) -> BaseMessage:
print(f"{message=}")
print(f"{sender=}")
if isinstance(message, str):
return
def create_message_model(message: dict[str, Any], sender: Agent, receiver: Agent) -> BaseMessage:
# print(f"{message=}")
# print(f"{sender=}")

role = message.get("role")
if role == "function":
Expand All @@ -206,16 +211,25 @@ def create_message_model(message: Union[dict[str, Any], str], sender: Agent, rec

if "function_call" in message and message["function_call"]:
return FunctionCallMessage(
**message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config
**message,
sender_name=sender.name,
receiver_name=receiver.name,
llm_config=receiver.llm_config, # type: ignore [attr-defined]
)

if "tool_calls" in message and message["tool_calls"]:
return ToolCallMessage(
**message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config
**message,
sender_name=sender.name,
receiver_name=receiver.name,
llm_config=receiver.llm_config, # type: ignore [attr-defined]
)

# Now message is a simple content message

return ContentMessage(
**message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config
**message,
sender_name=sender.name,
receiver_name=receiver.name,
llm_config=receiver.llm_config, # type: ignore [attr-defined]
)
4 changes: 4 additions & 0 deletions test/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock

import pytest
Expand Down

0 comments on commit 31ca1cf

Please sign in to comment.