From 31ca1cfd7db35c3589bdbb8529c40ac17ea97597 Mon Sep 17 00:00:00 2001 From: Kumaran Rajendhiran Date: Wed, 25 Dec 2024 12:58:05 +0530 Subject: [PATCH] Fix mypy issues --- autogen/messages.py | 88 +++++++++++++++++++++++++------------------ test/test_messages.py | 4 ++ 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/autogen/messages.py b/autogen/messages.py index 1fe9570814..3b93ee14e1 100644 --- a/autogen/messages.py +++ b/autogen/messages.py @@ -1,3 +1,7 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Callable, Literal, Optional, TypeVar, Union from pydantic import BaseModel @@ -7,7 +11,7 @@ from .code_utils import content_str from .oai.client import OpenAIWrapper -MessageRole = TypeVar("MessageRole", bound=Literal["assistant", "function", "tool"]) +MessageRole = Literal["assistant", "function", "tool"] class BaseMessage(BaseModel): @@ -88,24 +92,25 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None: class FunctionCallMessage(BaseMessage): - content: Optional[str] = None + content: Optional[str] = None # type: ignore [assignment] function_call: FunctionCall # ToDo: Does function call has context? context: Optional[dict[str, Any]] = None - llm_config: Union[dict, Literal[False]] + llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None def print(self, f: Optional[Callable[..., Any]] = None) -> None: f = f or print super().print(f) if self.content is not None: - content = self.content - if self.context is not None: - content = OpenAIWrapper.instantiate( - content, - self.context, - self.llm_config and self.llm_config.get("allow_format_str_template", False), - ) + allow_format_str_template = ( + self.llm_config.get("allow_format_str_template", False) if self.llm_config else False + ) + content = OpenAIWrapper.instantiate( + self.content, + self.context, + allow_format_str_template, + ) f(content_str(content), flush=True) self.function_call.print(f) @@ -115,7 +120,7 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None: class ToolCall(BaseModel): id: Optional[str] = None - function: Optional[FunctionCall] = None + function: FunctionCall type: str def print(self, f: Optional[Callable[..., Any]] = None) -> None: @@ -138,7 +143,7 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None: class ToolCallMessage(BaseMessage): - content: Optional[str] = None + content: Optional[str] = None # type: ignore [assignment] refusal: Optional[str] = None role: MessageRole audio: Optional[str] = None @@ -146,20 +151,21 @@ class ToolCallMessage(BaseMessage): tool_calls: list[ToolCall] # ToDo: Does tool calls has context? context: Optional[dict[str, Any]] = None - llm_config: Union[dict, Literal[False]] + llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None def print(self, f: Optional[Callable[..., Any]] = None) -> None: f = f or print super().print(f) if self.content is not None: - content = self.content - if self.context is not None: - content = OpenAIWrapper.instantiate( - content, - self.context, - self.llm_config and self.llm_config.get("allow_format_str_template", False), - ) + allow_format_str_template = ( + self.llm_config.get("allow_format_str_template", False) if self.llm_config else False + ) + content = OpenAIWrapper.instantiate( + self.content, + self.context, + allow_format_str_template, + ) f(content_str(content), flush=True) for tool_call in self.tool_calls: @@ -169,32 +175,31 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None: class ContentMessage(BaseMessage): - content: Optional[Union[str, Callable[..., Any]]] = None + content: Optional[Union[str, Callable[..., Any]]] = None # type: ignore [assignment] context: Optional[dict[str, Any]] = None - llm_config: Union[dict, Literal[False]] + llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None def print(self, f: Optional[Callable[..., Any]] = None) -> None: f = f or print super().print(f) if self.content is not None: - content = self.content - if self.context is not None: - content = OpenAIWrapper.instantiate( - content, - self.context, - self.llm_config and self.llm_config.get("allow_format_str_template", False), - ) + allow_format_str_template = ( + self.llm_config.get("allow_format_str_template", False) if self.llm_config else False + ) + content = OpenAIWrapper.instantiate( + self.content, + self.context, + allow_format_str_template, + ) f(content_str(content), flush=True) f("\n", "-" * 80, flush=True, sep="") -def create_message_model(message: Union[dict[str, Any], str], sender: Agent, receiver: Agent) -> BaseMessage: - print(f"{message=}") - print(f"{sender=}") - if isinstance(message, str): - return +def create_message_model(message: dict[str, Any], sender: Agent, receiver: Agent) -> BaseMessage: + # print(f"{message=}") + # print(f"{sender=}") role = message.get("role") if role == "function": @@ -206,16 +211,25 @@ def create_message_model(message: Union[dict[str, Any], str], sender: Agent, rec if "function_call" in message and message["function_call"]: return FunctionCallMessage( - **message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config + **message, + sender_name=sender.name, + receiver_name=receiver.name, + llm_config=receiver.llm_config, # type: ignore [attr-defined] ) if "tool_calls" in message and message["tool_calls"]: return ToolCallMessage( - **message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config + **message, + sender_name=sender.name, + receiver_name=receiver.name, + llm_config=receiver.llm_config, # type: ignore [attr-defined] ) # Now message is a simple content message return ContentMessage( - **message, sender_name=sender.name, receiver_name=receiver.name, llm_config=receiver.llm_config + **message, + sender_name=sender.name, + receiver_name=receiver.name, + llm_config=receiver.llm_config, # type: ignore [attr-defined] ) diff --git a/test/test_messages.py b/test/test_messages.py index 837056e2bc..24c68d6337 100644 --- a/test/test_messages.py +++ b/test/test_messages.py @@ -1,3 +1,7 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + from unittest.mock import MagicMock import pytest