diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py index a7caffe1d9..08dbab2eef 100644 --- a/autogen/_pydantic.py +++ b/autogen/_pydantic.py @@ -4,7 +4,7 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT -from typing import Any, Dict, Optional, Tuple, Type, Union, get_args +from typing import Any, Tuple, TypeVar, Union, get_args from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION @@ -30,7 +30,7 @@ def type2schema(t: Any) -> JsonSchemaValue: """ return TypeAdapter(t).json_schema() - def model_dump(model: BaseModel) -> Dict[str, Any]: + def model_dump(model: BaseModel) -> dict[str, Any]: """Convert a pydantic model to a dict Args: @@ -59,7 +59,7 @@ def model_dump_json(model: BaseModel) -> str: from pydantic import schema_of from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef] - JsonSchemaValue = Dict[str, Any] # type: ignore[misc] + JsonSchemaValue = dict[str, Any] # type: ignore[misc] def type2schema(t: Any) -> JsonSchemaValue: """Convert a type to a JSON schema @@ -92,7 +92,7 @@ def type2schema(t: Any) -> JsonSchemaValue: return d - def model_dump(model: BaseModel) -> Dict[str, Any]: + def model_dump(model: BaseModel) -> dict[str, Any]: """Convert a pydantic model to a dict Args: diff --git a/autogen/agentchat/agent.py b/autogen/agentchat/agent.py index 3f2a494564..655ad388f1 100644 --- a/autogen/agentchat/agent.py +++ b/autogen/agentchat/agent.py @@ -28,7 +28,7 @@ def description(self) -> str: def send( self, - message: Union[Dict[str, Any], str], + message: Union[dict[str, Any], str], recipient: "Agent", request_reply: Optional[bool] = None, ) -> None: @@ -44,7 +44,7 @@ def send( async def a_send( self, - message: Union[Dict[str, Any], str], + message: Union[dict[str, Any], str], recipient: "Agent", request_reply: Optional[bool] = None, ) -> None: @@ -60,7 +60,7 @@ async def a_send( def receive( self, - message: Union[Dict[str, Any], str], + message: Union[dict[str, Any], str], sender: "Agent", request_reply: Optional[bool] = None, ) -> None: @@ -75,7 +75,7 @@ def receive( async def a_receive( self, - message: Union[Dict[str, Any], str], + message: Union[dict[str, Any], str], sender: "Agent", request_reply: Optional[bool] = None, ) -> None: @@ -91,10 +91,10 @@ async def a_receive( def generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, + messages: Optional[list[dict[str, Any]]] = None, sender: Optional["Agent"] = None, **kwargs: Any, - ) -> Union[str, Dict[str, Any], None]: + ) -> Union[str, dict[str, Any], None]: """Generate a reply based on the received messages. Args: @@ -109,10 +109,10 @@ def generate_reply( async def a_generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, + messages: Optional[list[dict[str, Any]]] = None, sender: Optional["Agent"] = None, **kwargs: Any, - ) -> Union[str, Dict[str, Any], None]: + ) -> Union[str, dict[str, Any], None]: """(Async) Generate a reply based on the received messages. Args: diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py index abae2fb9c2..f87fc9dcd9 100644 --- a/autogen/agentchat/assistant_agent.py +++ b/autogen/agentchat/assistant_agent.py @@ -41,8 +41,8 @@ def __init__( self, name: str, system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, - llm_config: Optional[Union[Dict, Literal[False]]] = None, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + llm_config: Optional[Union[dict, Literal[False]]] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py index f105b63a31..5f1e18a511 100644 --- a/autogen/agentchat/chat.py +++ b/autogen/agentchat/chat.py @@ -18,7 +18,7 @@ from .utils import consolidate_chat_info logger = logging.getLogger(__name__) -Prerequisite = Tuple[int, int] +Prerequisite = tuple[int, int] @dataclass @@ -27,21 +27,21 @@ class ChatResult: chat_id: int = None """chat id""" - chat_history: List[Dict[str, Any]] = None + chat_history: list[dict[str, Any]] = None """The chat history.""" summary: str = None """A summary obtained from the chat.""" - cost: Dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" + cost: dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" """The cost of the chat. The value for each usage type is a dictionary containing cost information for that specific type. - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". """ - human_input: List[str] = None + human_input: list[str] = None """A list of human input solicited during the chat.""" -def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None: +def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None: """ Validate recipients exits and warn repetitive recipients. """ @@ -56,7 +56,7 @@ def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None: ) -def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prerequisite]: +def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]: """ Create list of Prerequisite (prerequisite_chat_id, chat_id) """ @@ -73,7 +73,7 @@ def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prere return prerequisites -def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite]) -> List[int]: +def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]: """Find chat order for async execution based on the prerequisite chats args: @@ -122,7 +122,7 @@ def _post_process_carryover_item(carryover_item): return str(carryover_item) -def __post_carryover_processing(chat_info: Dict[str, Any]) -> None: +def __post_carryover_processing(chat_info: dict[str, Any]) -> None: iostream = IOStream.get_default() if "message" not in chat_info: @@ -158,7 +158,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None: iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") -def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: +def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: """Initiate a list of chats. Args: chat_queue (List[Dict]): A list of dictionaries containing the information about the chats. @@ -234,7 +234,7 @@ def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): async def _dependent_chat_future( - chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future] + chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future] ) -> asyncio.Task: """ Create an async Task for each chat. @@ -272,7 +272,7 @@ async def _dependent_chat_future( return chat_res_future -async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: +async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: """(async) Initiate a list of chats. args: diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py index 822e7176dc..1c235e52e7 100644 --- a/autogen/agentchat/contrib/agent_builder.py +++ b/autogen/agentchat/contrib/agent_builder.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -def _config_check(config: Dict): +def _config_check(config: dict): # check config loading assert config.get("coding", None) is not None, 'Missing "coding" in your config.' assert config.get("default_llm_config", None) is not None, 'Missing "default_llm_config" in your config.' @@ -220,11 +220,11 @@ def __init__( self.config_file_location = config_file_location self.building_task: str = None - self.agent_configs: List[Dict] = [] - self.open_ports: List[str] = [] - self.agent_procs: Dict[str, Tuple[sp.Popen, str]] = {} - self.agent_procs_assign: Dict[str, Tuple[autogen.ConversableAgent, str]] = {} - self.cached_configs: Dict = {} + self.agent_configs: list[dict] = [] + self.open_ports: list[str] = [] + self.agent_procs: dict[str, tuple[sp.Popen, str]] = {} + self.agent_procs_assign: dict[str, tuple[autogen.ConversableAgent, str]] = {} + self.cached_configs: dict = {} self.max_agents = max_agents @@ -236,8 +236,8 @@ def set_agent_model(self, model: str): def _create_agent( self, - agent_config: Dict, - member_name: List[str], + agent_config: dict, + member_name: list[str], llm_config: dict, use_oai_assistant: Optional[bool] = False, ) -> autogen.AssistantAgent: @@ -366,14 +366,14 @@ def clear_all_agents(self, recycle_endpoint: Optional[bool] = True): def build( self, building_task: str, - default_llm_config: Dict, + default_llm_config: dict, coding: Optional[bool] = None, - code_execution_config: Optional[Dict] = None, + code_execution_config: Optional[dict] = None, use_oai_assistant: Optional[bool] = False, user_proxy: Optional[autogen.ConversableAgent] = None, max_agents: Optional[int] = None, **kwargs, - ) -> Tuple[List[autogen.ConversableAgent], Dict]: + ) -> tuple[list[autogen.ConversableAgent], dict]: """ Auto build agents based on the building task. @@ -496,15 +496,15 @@ def build_from_library( self, building_task: str, library_path_or_json: str, - default_llm_config: Dict, + default_llm_config: dict, top_k: int = 3, coding: Optional[bool] = None, - code_execution_config: Optional[Dict] = None, + code_execution_config: Optional[dict] = None, use_oai_assistant: Optional[bool] = False, embedding_model: Optional[str] = "all-mpnet-base-v2", user_proxy: Optional[autogen.ConversableAgent] = None, **kwargs, - ) -> Tuple[List[autogen.ConversableAgent], Dict]: + ) -> tuple[list[autogen.ConversableAgent], dict]: """ Build agents from a library. The library is a list of agent configs, which contains the name and system_message for each agent. @@ -551,7 +551,7 @@ def build_from_library( try: agent_library = json.loads(library_path_or_json) except json.decoder.JSONDecodeError: - with open(library_path_or_json, "r") as f: + with open(library_path_or_json) as f: agent_library = json.load(f) except Exception as e: raise e @@ -663,7 +663,7 @@ def build_from_library( def _build_agents( self, use_oai_assistant: Optional[bool] = False, user_proxy: Optional[autogen.ConversableAgent] = None, **kwargs - ) -> Tuple[List[autogen.ConversableAgent], Dict]: + ) -> tuple[list[autogen.ConversableAgent], dict]: """ Build agents with generated configs. @@ -731,7 +731,7 @@ def load( config_json: Optional[str] = None, use_oai_assistant: Optional[bool] = False, **kwargs, - ) -> Tuple[List[autogen.ConversableAgent], Dict]: + ) -> tuple[list[autogen.ConversableAgent], dict]: """ Load building configs and call the build function to complete building without calling online LLMs' api. diff --git a/autogen/agentchat/contrib/agent_eval/agent_eval.py b/autogen/agentchat/contrib/agent_eval/agent_eval.py index 479a58fc9c..d6f3711cbf 100644 --- a/autogen/agentchat/contrib/agent_eval/agent_eval.py +++ b/autogen/agentchat/contrib/agent_eval/agent_eval.py @@ -15,7 +15,7 @@ def generate_criteria( - llm_config: Optional[Union[Dict, Literal[False]]] = None, + llm_config: Optional[Union[dict, Literal[False]]] = None, task: Task = None, additional_instructions: str = "", max_round=2, @@ -67,8 +67,8 @@ def generate_criteria( def quantify_criteria( - llm_config: Optional[Union[Dict, Literal[False]]] = None, - criteria: List[Criterion] = None, + llm_config: Optional[Union[dict, Literal[False]]] = None, + criteria: list[Criterion] = None, task: Task = None, test_case: str = "", ground_truth: str = "", diff --git a/autogen/agentchat/contrib/agent_eval/criterion.py b/autogen/agentchat/contrib/agent_eval/criterion.py index 9d089d08bb..9e682fcc95 100644 --- a/autogen/agentchat/contrib/agent_eval/criterion.py +++ b/autogen/agentchat/contrib/agent_eval/criterion.py @@ -21,8 +21,8 @@ class Criterion(BaseModel): name: str description: str - accepted_values: List[str] - sub_criteria: List[Criterion] = list() + accepted_values: list[str] + sub_criteria: list[Criterion] = list() @staticmethod def parse_json_str(criteria: str): diff --git a/autogen/agentchat/contrib/agent_optimizer.py b/autogen/agentchat/contrib/agent_optimizer.py index 2257cda69f..7291e5e4cd 100644 --- a/autogen/agentchat/contrib/agent_optimizer.py +++ b/autogen/agentchat/contrib/agent_optimizer.py @@ -217,7 +217,7 @@ def __init__( ) self._client = autogen.OpenAIWrapper(**self.llm_config) - def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None): + def record_one_conversation(self, conversation_history: list[dict], is_satisfied: bool = None): """ record one conversation history. Args: @@ -234,10 +234,10 @@ def record_one_conversation(self, conversation_history: List[Dict], is_satisfied ], "The input is invalid. Please input 1 or 0. 1 represents satisfied. 0 represents not satisfied." is_satisfied = True if reply == "1" else False self._trial_conversations_history.append( - {"Conversation {i}".format(i=len(self._trial_conversations_history)): conversation_history} + {f"Conversation {len(self._trial_conversations_history)}": conversation_history} ) self._trial_conversations_performance.append( - {"Conversation {i}".format(i=len(self._trial_conversations_performance)): 1 if is_satisfied else 0} + {f"Conversation {len(self._trial_conversations_performance)}": 1 if is_satisfied else 0} ) def step(self): @@ -290,8 +290,8 @@ def step(self): incumbent_functions = self._update_function_call(incumbent_functions, actions) remove_functions = list( - set([key for dictionary in self._trial_functions for key in dictionary.keys()]) - - set([key for dictionary in incumbent_functions for key in dictionary.keys()]) + {key for dictionary in self._trial_functions for key in dictionary.keys()} + - {key for dictionary in incumbent_functions for key in dictionary.keys()} ) register_for_llm = [] diff --git a/autogen/agentchat/contrib/capabilities/generate_images.py b/autogen/agentchat/contrib/capabilities/generate_images.py index 2dc9f22a2f..429a466945 100644 --- a/autogen/agentchat/contrib/capabilities/generate_images.py +++ b/autogen/agentchat/contrib/capabilities/generate_images.py @@ -73,7 +73,7 @@ class DalleImageGenerator: def __init__( self, - llm_config: Dict, + llm_config: dict, resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024", quality: Literal["standard", "hd"] = "standard", num_images: int = 1, @@ -149,7 +149,7 @@ def __init__( self, image_generator: ImageGenerator, cache: Optional[AbstractCache] = None, - text_analyzer_llm_config: Optional[Dict] = None, + text_analyzer_llm_config: Optional[dict] = None, text_analyzer_instructions: str = PROMPT_INSTRUCTIONS, verbosity: int = 0, register_reply_position: int = 2, @@ -212,10 +212,10 @@ def add_to_agent(self, agent: ConversableAgent): def _image_gen_reply( self, recipient: ConversableAgent, - messages: Optional[List[Dict]], + messages: Optional[list[dict]], sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: if messages is None: return False, None @@ -268,13 +268,13 @@ def _cache_set(self, prompt: str, image: Image): key = self._image_generator.cache_key(prompt) self._cache.set(key, img_utils.pil_to_data_uri(image)) - def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str: - if isinstance(analysis, Dict): + def _extract_analysis(self, analysis: Union[str, dict, None]) -> str: + if isinstance(analysis, dict): return code_utils.content_str(analysis["content"]) else: return code_utils.content_str(analysis) - def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]: + def _generate_content_message(self, prompt: str, image: Image) -> dict[str, Any]: return { "content": [ {"type": "text", "text": f"I generated an image with the prompt: {prompt}"}, diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py index ccbbfedebc..5429b3df03 100644 --- a/autogen/agentchat/contrib/capabilities/teachability.py +++ b/autogen/agentchat/contrib/capabilities/teachability.py @@ -42,7 +42,7 @@ def __init__( path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db", recall_threshold: Optional[float] = 1.5, max_num_retrievals: Optional[int] = 10, - llm_config: Optional[Union[Dict, bool]] = None, + llm_config: Optional[Union[dict, bool]] = None, ): """ Args: @@ -92,7 +92,7 @@ def prepopulate_db(self): """Adds a few arbitrary memos to the DB.""" self.memo_store.prepopulate() - def process_last_received_message(self, text: Union[Dict, str]): + def process_last_received_message(self, text: Union[dict, str]): """ Appends any relevant memos to the message text, and stores any apparent teachings in new memos. Uses TextAnalyzerAgent to make decisions about memo storage and retrieval. @@ -109,7 +109,7 @@ def process_last_received_message(self, text: Union[Dict, str]): # Return the (possibly) expanded message text. return expanded_text - def _consider_memo_storage(self, comment: Union[Dict, str]): + def _consider_memo_storage(self, comment: Union[dict, str]): """Decides whether to store something from one user comment in the DB.""" memo_added = False @@ -167,7 +167,7 @@ def _consider_memo_storage(self, comment: Union[Dict, str]): # Yes. Save them to disk. self.memo_store._save_memos() - def _consider_memo_retrieval(self, comment: Union[Dict, str]): + def _consider_memo_retrieval(self, comment: Union[dict, str]): """Decides whether to retrieve memos from the DB, and add them to the chat context.""" # First, use the comment directly as the lookup key. @@ -231,7 +231,7 @@ def _concatenate_memo_texts(self, memo_list: list) -> str: memo_texts = memo_texts + "\n" + info return memo_texts - def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]): + def _analyze(self, text_to_analyze: Union[dict, str], analysis_instructions: Union[dict, str]): """Asks TextAnalyzerAgent to analyze the given text according to specific instructions.""" self.analyzer.reset() # Clear the analyzer's list of messages. self.teachable_agent.send( @@ -280,7 +280,7 @@ def __init__( self.last_memo_id = 0 if (not reset) and os.path.exists(self.path_to_dict): print(colored("\nLOADING MEMORY FROM DISK", "light_green")) - print(colored(" Location = {}".format(self.path_to_dict), "light_green")) + print(colored(f" Location = {self.path_to_dict}", "light_green")) with open(self.path_to_dict, "rb") as f: self.uid_text_dict = pickle.load(f) self.last_memo_id = len(self.uid_text_dict) @@ -298,7 +298,7 @@ def list_memos(self): input_text, output_text = text print( colored( - " ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text), + f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}", "light_green", ) ) diff --git a/autogen/agentchat/contrib/capabilities/text_compressors.py b/autogen/agentchat/contrib/capabilities/text_compressors.py index 290d9929b6..1e861c1170 100644 --- a/autogen/agentchat/contrib/capabilities/text_compressors.py +++ b/autogen/agentchat/contrib/capabilities/text_compressors.py @@ -19,7 +19,7 @@ class TextCompressor(Protocol): """Defines a protocol for text compression to optimize agent interactions.""" - def compress_text(self, text: str, **compression_params) -> Dict[str, Any]: + def compress_text(self, text: str, **compression_params) -> dict[str, Any]: """This method takes a string as input and returns a dictionary containing the compressed text and other relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary. To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys. @@ -36,7 +36,7 @@ class LLMLingua: def __init__( self, - prompt_compressor_kwargs: Dict = dict( + prompt_compressor_kwargs: dict = dict( model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank", use_llmlingua2=True, device_map="cpu", @@ -68,5 +68,5 @@ def __init__( else self._prompt_compressor.compress_prompt ) - def compress_text(self, text: str, **compression_params) -> Dict[str, Any]: + def compress_text(self, text: str, **compression_params) -> dict[str, Any]: return self._compression_method([text], **compression_params) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 9546433468..78b4478647 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -47,7 +47,7 @@ class TransformMessages: ``` """ - def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True): + def __init__(self, *, transforms: list[MessageTransform] = [], verbose: bool = True): """ Args: transforms: A list of message transformations to apply. @@ -66,7 +66,7 @@ def add_to_agent(self, agent: ConversableAgent): """ agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) - def _transform_messages(self, messages: List[Dict]) -> List[Dict]: + def _transform_messages(self, messages: list[dict]) -> list[dict]: post_transform_messages = copy.deepcopy(messages) system_message = None diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index a5912cb248..740a81c366 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -26,7 +26,7 @@ class MessageTransform(Protocol): that takes a list of messages and returns the transformed list. """ - def apply_transform(self, messages: List[Dict]) -> List[Dict]: + def apply_transform(self, messages: list[dict]) -> list[dict]: """Applies a transformation to a list of messages. Args: @@ -37,7 +37,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... - def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: list[dict], post_transform_messages: list[dict]) -> tuple[str, bool]: """Creates the string including the logs of the transformation Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. @@ -70,7 +70,7 @@ def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool self._max_messages = max_messages self._keep_first_message = keep_first_message - def apply_transform(self, messages: List[Dict]) -> List[Dict]: + def apply_transform(self, messages: list[dict]) -> list[dict]: """Truncates the conversation history to the specified maximum number of messages. This method returns a new list containing the most recent messages up to the specified @@ -110,7 +110,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return truncated_messages - def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: list[dict], post_transform_messages: list[dict]) -> tuple[str, bool]: pre_transform_messages_len = len(pre_transform_messages) post_transform_messages_len = len(post_transform_messages) @@ -161,7 +161,7 @@ def __init__( max_tokens: Optional[int] = None, min_tokens: Optional[int] = None, model: str = "gpt-3.5-turbo-0613", - filter_dict: Optional[Dict] = None, + filter_dict: Optional[dict] = None, exclude_filter: bool = True, ): """ @@ -185,7 +185,7 @@ def __init__( self._filter_dict = filter_dict self._exclude_filter = exclude_filter - def apply_transform(self, messages: List[Dict]) -> List[Dict]: + def apply_transform(self, messages: list[dict]) -> list[dict]: """Applies token truncation to the conversation history. Args: @@ -237,7 +237,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages - def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: list[dict], post_transform_messages: list[dict]) -> tuple[str, bool]: pre_transform_messages_tokens = sum( transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg ) @@ -253,7 +253,7 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: return logs_str, True return "No tokens were truncated.", False - def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: + def _truncate_str_to_tokens(self, contents: Union[str, list], n_tokens: int) -> Union[str, list]: if isinstance(contents, str): return self._truncate_tokens(contents, n_tokens) elif isinstance(contents, list): @@ -261,7 +261,7 @@ def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> else: raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}") - def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]: + def _truncate_multimodal_text(self, contents: list[dict[str, Any]], n_tokens: int) -> list[dict[str, Any]]: """Truncates text content within a list of multimodal elements, preserving the overall structure.""" tmp_contents = [] for content in contents: @@ -324,9 +324,9 @@ def __init__( self, text_compressor: Optional[TextCompressor] = None, min_tokens: Optional[int] = None, - compression_params: Dict = dict(), + compression_params: dict = dict(), cache: Optional[AbstractCache] = None, - filter_dict: Optional[Dict] = None, + filter_dict: Optional[dict] = None, exclude_filter: bool = True, ): """ @@ -364,7 +364,7 @@ def __init__( # Optimizing savings calculations to optimize log generation self._recent_tokens_savings = 0 - def apply_transform(self, messages: List[Dict]) -> List[Dict]: + def apply_transform(self, messages: list[dict]) -> list[dict]: """Applies compression to messages in a conversation history based on the specified configuration. The function processes each message according to the `compression_args` and `min_tokens` settings, applying @@ -414,13 +414,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: self._recent_tokens_savings = total_savings return processed_messages - def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: list[dict], post_transform_messages: list[dict]) -> tuple[str, bool]: if self._recent_tokens_savings > 0: return f"{self._recent_tokens_savings} tokens saved with text compression.", True else: return "No tokens saved with text compression.", False - def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]: + def _compress(self, content: MessageContentType) -> tuple[MessageContentType, int]: """Compresses the given text or multimodal content using the specified compression method.""" if isinstance(content, str): return self._compress_text(content) @@ -429,7 +429,7 @@ def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, in else: return content, 0 - def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]: + def _compress_multimodal(self, content: MessageContentType) -> tuple[MessageContentType, int]: tokens_saved = 0 for item in content: if isinstance(item, dict) and "text" in item: @@ -442,7 +442,7 @@ def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageCont return content, tokens_saved - def _compress_text(self, text: str) -> Tuple[str, int]: + def _compress_text(self, text: str) -> tuple[str, int]: """Compresses the given text using the specified compression method.""" compressed_text = self._text_compressor.compress_text(text, **self._compression_args) @@ -483,7 +483,7 @@ def __init__( position: str = "start", format_string: str = "{name}:\n", deduplicate: bool = True, - filter_dict: Optional[Dict] = None, + filter_dict: Optional[dict] = None, exclude_filter: bool = True, ): """ @@ -510,7 +510,7 @@ def __init__( # Track the number of messages changed for logging self._messages_changed = 0 - def apply_transform(self, messages: List[Dict]) -> List[Dict]: + def apply_transform(self, messages: list[dict]) -> list[dict]: """Applies the name change to the message based on the position and format string. Args: @@ -558,7 +558,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: self._messages_changed = messages_changed return processed_messages - def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: list[dict], post_transform_messages: list[dict]) -> tuple[str, bool]: if self._messages_changed > 0: return f"{self._messages_changed} message(s) changed to incorporate name.", True else: diff --git a/autogen/agentchat/contrib/capabilities/transforms_util.py b/autogen/agentchat/contrib/capabilities/transforms_util.py index 279054f2f1..62decfa091 100644 --- a/autogen/agentchat/contrib/capabilities/transforms_util.py +++ b/autogen/agentchat/contrib/capabilities/transforms_util.py @@ -4,7 +4,8 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT -from typing import Any, Dict, Hashable, List, Optional, Tuple +from collections.abc import Hashable +from typing import Any, Dict, List, Optional, Tuple from autogen import token_count_utils from autogen.cache.abstract_cache_base import AbstractCache @@ -23,7 +24,7 @@ def cache_key(content: MessageContentType, *args: Hashable) -> str: return "".join(str_keys) -def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]: +def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[tuple[MessageContentType, ...]]: """Retrieves cachedd content from the cache. Args: @@ -50,7 +51,7 @@ def cache_content_set(cache: Optional[AbstractCache], key: str, content: Message cache.set(key, cache_value) -def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool: +def min_tokens_reached(messages: list[dict], min_tokens: Optional[int]) -> bool: """Returns True if the total number of tokens in the messages is greater than or equal to the specified value. Args: @@ -106,7 +107,7 @@ def is_content_text_empty(content: MessageContentType) -> bool: return True -def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool: +def should_transform_message(message: dict[str, Any], filter_dict: Optional[dict[str, Any]], exclude: bool) -> bool: """Validates whether the transform should be applied according to the filter dictionary. Args: diff --git a/autogen/agentchat/contrib/capabilities/vision_capability.py b/autogen/agentchat/contrib/capabilities/vision_capability.py index ec227391b6..c0fe7a53eb 100644 --- a/autogen/agentchat/contrib/capabilities/vision_capability.py +++ b/autogen/agentchat/contrib/capabilities/vision_capability.py @@ -49,7 +49,7 @@ class VisionCapability(AgentCapability): def __init__( self, - lmm_config: Dict, + lmm_config: dict, description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT, custom_caption_func: Callable = None, ) -> None: @@ -105,7 +105,7 @@ def add_to_agent(self, agent: ConversableAgent) -> None: # Register a hook for processing the last message. agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) - def process_last_received_message(self, content: Union[str, List[dict]]) -> str: + def process_last_received_message(self, content: Union[str, list[dict]]) -> str: """ Processes the last received message content by normalizing and augmenting it with descriptions of any included images. The function supports input content diff --git a/autogen/agentchat/contrib/captainagent.py b/autogen/agentchat/contrib/captainagent.py index 3e15d576d1..0229db02fa 100644 --- a/autogen/agentchat/contrib/captainagent.py +++ b/autogen/agentchat/contrib/captainagent.py @@ -135,12 +135,12 @@ def __init__( self, name: str, system_message: Optional[str] = None, - llm_config: Optional[Union[Dict, Literal[False]]] = None, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + llm_config: Optional[Union[dict, Literal[False]]] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Optional[str] = "NEVER", - code_execution_config: Optional[Union[Dict, Literal[False]]] = False, - nested_config: Optional[Dict] = None, + code_execution_config: Optional[Union[dict, Literal[False]]] = False, + nested_config: Optional[dict] = None, agent_lib: Optional[str] = None, tool_lib: Optional[str] = None, agent_config_save_path: Optional[str] = None, @@ -220,7 +220,7 @@ def __init__( ) @staticmethod - def _update_config(default_dict: Dict, update_dict: Optional[Dict]) -> Dict: + def _update_config(default_dict: dict, update_dict: Optional[dict]) -> dict: """ Recursively updates the default_dict with values from update_dict. """ @@ -290,15 +290,15 @@ class CaptainUserProxyAgent(ConversableAgent): def __init__( self, name: str, - nested_config: Dict, + nested_config: dict, agent_config_save_path: str = None, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Optional[str] = "NEVER", - code_execution_config: Optional[Union[Dict, Literal[False]]] = None, - default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_AUTO_REPLY, - llm_config: Optional[Union[Dict, Literal[False]]] = False, - system_message: Optional[Union[str, List]] = "", + code_execution_config: Optional[Union[dict, Literal[False]]] = None, + default_auto_reply: Optional[Union[str, dict, None]] = DEFAULT_AUTO_REPLY, + llm_config: Optional[Union[dict, Literal[False]]] = False, + system_message: Optional[Union[str, list]] = "", description: Optional[str] = None, ): """ diff --git a/autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py b/autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py index 24fce8edf1..95f2be6dfb 100644 --- a/autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +++ b/autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py @@ -39,7 +39,7 @@ def image_processing(img): def text_processing(file_path): # Check the file extension if file_path.endswith(".txt"): - with open(file_path, "r") as file: + with open(file_path) as file: content = file.read() else: # if the file is not .txt, then it is a string, directly return the string diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index 2e818c6365..4c2ac731f7 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -32,8 +32,8 @@ def __init__( self, name="GPT Assistant", instructions: Optional[str] = None, - llm_config: Optional[Union[Dict, bool]] = None, - assistant_config: Optional[Dict] = None, + llm_config: Optional[Union[dict, bool]] = None, + assistant_config: Optional[dict] = None, overwrite_instructions: bool = False, overwrite_tools: bool = False, **kwargs, @@ -184,10 +184,10 @@ def __init__( def _invoke_assistant( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """ Invokes the OpenAI assistant to generate a reply based on the given messages. @@ -441,7 +441,7 @@ def pretty_print_thread(self, thread): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @property - def oai_threads(self) -> Dict[Agent, Any]: + def oai_threads(self) -> dict[Agent, Any]: """Return the threads of the agent.""" return self._openai_threads @@ -475,15 +475,15 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools): matching_assistants = [] # Preprocess the required tools for faster comparison - required_tool_types = set( + required_tool_types = { "file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools - ) + } - required_function_names = set( + required_function_names = { tool.get("function", {}).get("name") for tool in tools if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"] - ) + } for assistant in candidate_assistants: # Check if instructions are similar @@ -496,10 +496,10 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools): continue # Preprocess the assistant's tools - assistant_tool_types = set( + assistant_tool_types = { "file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools - ) - assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function")) + } + assistant_function_names = {tool.function.name for tool in assistant.tools if hasattr(tool, "function")} # Check if the tool types, function names match if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names: diff --git a/autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py index d374c9ed46..607a2e3215 100644 --- a/autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py @@ -88,7 +88,7 @@ def connect_db(self): else: raise ValueError(f"Knowledge graph '{self.name}' does not exist") - def init_db(self, input_doc: List[Document]): + def init_db(self, input_doc: list[Document]): """ Build the knowledge graph with input documents. """ @@ -124,7 +124,7 @@ def init_db(self, input_doc: List[Document]): else: raise ValueError("No input documents could be loaded.") - def add_records(self, new_records: List) -> bool: + def add_records(self, new_records: list) -> bool: raise NotImplementedError("This method is not supported by FalkorDB SDK yet.") def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult: @@ -168,12 +168,12 @@ def _save_ontology_to_db(self, ontology: Ontology): Save graph ontology to a separate table with {graph_name}_ontology """ if self.ontology_table_name in self.falkordb.list_graphs(): - raise ValueError("Knowledge graph {} is already created.".format(self.name)) + raise ValueError(f"Knowledge graph {self.name} is already created.") graph = self.__get_ontology_storage_graph() ontology.save_to_graph(graph) def _load_ontology_from_db(self) -> Ontology: if self.ontology_table_name not in self.falkordb.list_graphs(): - raise ValueError("Knowledge graph {} has not been created.".format(self.name)) + raise ValueError(f"Knowledge graph {self.name} has not been created.") graph = self.__get_ontology_storage_graph() return Ontology.from_graph(graph) diff --git a/autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py b/autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py index b5432403c5..fd6eb1a5d4 100644 --- a/autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +++ b/autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py @@ -47,10 +47,10 @@ def add_to_agent(self, agent: UserProxyAgent): def _reply_using_falkordb_query( self, recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """ Query FalkorDB and return the message. Internally, it utilises OpenAI to generate a reply based on the given messages. The history with FalkorDB is also logged and updated. @@ -74,7 +74,7 @@ def _reply_using_falkordb_query( return True, result.answer if result.answer else "I'm sorry, I don't have an answer for that." - def _messages_summary(self, messages: Union[Dict, str], system_message: str) -> str: + def _messages_summary(self, messages: Union[dict, str], system_message: str) -> str: """Summarize the messages in the conversation history. Excluding any message with 'tool_calls' and 'tool_responses' Includes the 'name' (if it exists) and the 'content', with a new line between each one, like: customer: @@ -90,7 +90,7 @@ def _messages_summary(self, messages: Union[Dict, str], system_message: str) -> else: return messages - elif isinstance(messages, List): + elif isinstance(messages, list): summary = "" for message in messages: if "content" in message and "tool_calls" not in message and "tool_responses" not in message: diff --git a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py index b15866f2db..b10562e7ee 100644 --- a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py @@ -29,7 +29,7 @@ class GraphQueryEngine(Protocol): This interface defines the basic methods for graph-based RAG. """ - def init_db(self, input_doc: List[Document] | None = None): + def init_db(self, input_doc: list[Document] | None = None): """ This method initializes graph database with the input documents or records. Usually, it takes the following steps, @@ -43,7 +43,7 @@ def init_db(self, input_doc: List[Document] | None = None): """ pass - def add_records(self, new_records: List) -> bool: + def add_records(self, new_records: list) -> bool: """ Add new records to the underlying database and add to the graph if required. """ diff --git a/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py index 462371930b..a91a3f23e6 100644 --- a/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py @@ -52,7 +52,7 @@ def __init__( embedding: BaseEmbedding = OpenAIEmbedding(model_name="text-embedding-3-small"), entities: Optional[TypeAlias] = None, relations: Optional[TypeAlias] = None, - schema: Optional[Union[Dict[str, str], List[Triple]]] = None, + schema: Optional[Union[dict[str, str], list[Triple]]] = None, strict: Optional[bool] = False, ): """ @@ -85,7 +85,7 @@ def __init__( self.schema = schema self.strict = strict - def init_db(self, input_doc: List[Document] | None = None): + def init_db(self, input_doc: list[Document] | None = None): """ Build the knowledge graph with input documents. """ @@ -133,7 +133,7 @@ def connect_db(self): show_progress=True, ) - def add_records(self, new_records: List) -> bool: + def add_records(self, new_records: list) -> bool: """ Add new records to the knowledge graph. Must be local files. @@ -195,7 +195,7 @@ def _clear(self) -> None: with self.graph_store._driver.session() as session: session.run("MATCH (n) DETACH DELETE n;") - def _load_doc(self, input_doc: List[Document]) -> List[Document]: + def _load_doc(self, input_doc: list[Document]) -> list[Document]: """ Load documents from the input files. """ diff --git a/autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py b/autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py index c4d952437d..fea72719ed 100644 --- a/autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +++ b/autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py @@ -49,10 +49,10 @@ def add_to_agent(self, agent: UserProxyAgent): def _reply_using_neo4j_query( self, recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """ Query neo4j and return the message. Internally, it queries the Property graph and returns the answer from the graph query engine. @@ -73,11 +73,11 @@ def _reply_using_neo4j_query( return True, result.answer - def _get_last_question(self, message: Union[Dict, str]): + def _get_last_question(self, message: Union[dict, str]): """Retrieves the last message from the conversation history.""" if isinstance(message, str): return message - if isinstance(message, Dict): + if isinstance(message, dict): if "content" in message: return message["content"] return None diff --git a/autogen/agentchat/contrib/img_utils.py b/autogen/agentchat/contrib/img_utils.py index 9b9a01b89c..1cf718d53a 100644 --- a/autogen/agentchat/contrib/img_utils.py +++ b/autogen/agentchat/contrib/img_utils.py @@ -107,7 +107,7 @@ def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes: return content -def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str, List[str]]: +def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]: """ Formats the input prompt by replacing image tags and returns the new prompt along with image locations. @@ -189,7 +189,7 @@ def _get_mime_type_from_data_uri(base64_image): return data_uri -def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]: +def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict]]: """ Formats the input prompt by replacing image tags and returns a list of text and images. @@ -274,7 +274,7 @@ def _to_pil(data: str) -> Image.Image: return Image.open(BytesIO(base64.b64decode(data))) -def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]: +def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]: """ Converts the PIL image URLs in the messages to base64 encoded data URIs. diff --git a/autogen/agentchat/contrib/llamaindex_conversable_agent.py b/autogen/agentchat/contrib/llamaindex_conversable_agent.py index c1a51cc491..d563a525dd 100644 --- a/autogen/agentchat/contrib/llamaindex_conversable_agent.py +++ b/autogen/agentchat/contrib/llamaindex_conversable_agent.py @@ -80,10 +80,10 @@ def __init__( def _generate_oai_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply using autogen.oai.""" user_message, history = self._extract_message_and_history(messages=messages, sender=sender) @@ -95,10 +95,10 @@ def _generate_oai_reply( async def _a_generate_oai_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply using autogen.oai.""" user_message, history = self._extract_message_and_history(messages=messages, sender=sender) @@ -111,8 +111,8 @@ async def _a_generate_oai_reply( return (True, extracted_response) def _extract_message_and_history( - self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None - ) -> Tuple[str, List[ChatMessage]]: + self, messages: Optional[list[dict]] = None, sender: Optional[Agent] = None + ) -> tuple[str, list[ChatMessage]]: """Extract the message and history from the messages.""" if not messages: messages = self._oai_messages[sender] @@ -123,7 +123,7 @@ def _extract_message_and_history( message = messages[-1].get("content", "") history = messages[:-1] - history_messages: List[ChatMessage] = [] + history_messages: list[ChatMessage] = [] for history_message in history: content = history_message.get("content", "") role = history_message.get("role", "user") diff --git a/autogen/agentchat/contrib/llava_agent.py b/autogen/agentchat/contrib/llava_agent.py index d5ae5530c1..f2bf77e533 100644 --- a/autogen/agentchat/contrib/llava_agent.py +++ b/autogen/agentchat/contrib/llava_agent.py @@ -30,7 +30,7 @@ class LLaVAAgent(MultimodalConversableAgent): def __init__( self, name: str, - system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG, + system_message: Optional[tuple[str, list]] = DEFAULT_LLAVA_SYS_MSG, *args, **kwargs, ): diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py index 65350371e5..8cdbd1bafc 100644 --- a/autogen/agentchat/contrib/math_user_proxy_agent.py +++ b/autogen/agentchat/contrib/math_user_proxy_agent.py @@ -140,10 +140,10 @@ def __init__( self, name: Optional[str] = "MathChatAgent", # default set to MathChatAgent is_termination_msg: Optional[ - Callable[[Dict], bool] + Callable[[dict], bool] ] = _is_termination_msg_mathchat, # terminate if \boxed{} in message human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", # Fully automated - default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY, + default_auto_reply: Optional[Union[str, dict, None]] = DEFAULT_REPLY, max_invalid_q_per_step=3, # a parameter needed in MathChat **kwargs, ): @@ -292,7 +292,7 @@ def execute_one_wolfram_query(self, query: str): def _generate_math_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, ): @@ -364,7 +364,7 @@ def _generate_math_reply( # THE SOFTWARE. -def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None) -> str: +def get_from_dict_or_env(data: dict[str, Any], key: str, env_key: str, default: Optional[str] = None) -> str: """Get a value from a dictionary or an environment variable.""" if key in data and data[key]: return data[key] @@ -402,7 +402,7 @@ class Config: extra = Extra.forbid @root_validator(skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + def validate_environment(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" wolfram_alpha_appid = get_from_dict_or_env(values, "wolfram_alpha_appid", "WOLFRAM_ALPHA_APPID") values["wolfram_alpha_appid"] = wolfram_alpha_appid @@ -417,7 +417,7 @@ def validate_environment(cls, values: Dict) -> Dict: return values - def run(self, query: str) -> Tuple[str, bool]: + def run(self, query: str) -> tuple[str, bool]: """Run query through WolframAlpha and parse result.""" from urllib.error import HTTPError diff --git a/autogen/agentchat/contrib/multimodal_conversable_agent.py b/autogen/agentchat/contrib/multimodal_conversable_agent.py index a5cbada75c..b4ffcd48dd 100644 --- a/autogen/agentchat/contrib/multimodal_conversable_agent.py +++ b/autogen/agentchat/contrib/multimodal_conversable_agent.py @@ -29,7 +29,7 @@ class MultimodalConversableAgent(ConversableAgent): def __init__( self, name: str, - system_message: Optional[Union[str, List]] = DEFAULT_LMM_SYS_MSG, + system_message: Optional[Union[str, list]] = DEFAULT_LMM_SYS_MSG, is_termination_msg: str = None, *args, **kwargs, @@ -64,7 +64,7 @@ def __init__( MultimodalConversableAgent.a_generate_oai_reply, ) - def update_system_message(self, system_message: Union[Dict, List, str]): + def update_system_message(self, system_message: Union[dict, list, str]): """Update the system message. Args: @@ -74,7 +74,7 @@ def update_system_message(self, system_message: Union[Dict, List, str]): self._oai_system_message[0]["role"] = "system" @staticmethod - def _message_to_dict(message: Union[Dict, List, str]) -> Dict: + def _message_to_dict(message: Union[dict, list, str]) -> dict: """Convert a message to a dictionary. This implementation handles the GPT-4V formatting for easier prompts. @@ -103,10 +103,10 @@ def _message_to_dict(message: Union[Dict, List, str]) -> Dict: def generate_oai_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply using autogen.oai.""" client = self.client if config is None else config if client is None: diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index 9e24629f7e..ab12d2c1a8 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -31,8 +31,8 @@ def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS", - is_termination_msg: Optional[Callable[[Dict], bool]] = None, - retrieve_config: Optional[Dict] = None, # config for the retrieve agent + is_termination_msg: Optional[Callable[[dict], bool]] = None, + retrieve_config: Optional[dict] = None, # config for the retrieve agent **kwargs, ): """ @@ -169,7 +169,7 @@ def create_qdrant_from_dir( must_break_at_empty_line: bool = True, embedding_model: str = "BAAI/bge-small-en-v1.5", custom_text_split_function: Callable = None, - custom_text_types: List[str] = TEXT_FORMATS, + custom_text_types: list[str] = TEXT_FORMATS, recursive: bool = True, extra_docs: bool = False, parallel: int = 0, @@ -177,7 +177,7 @@ def create_qdrant_from_dir( quantization_config: Optional[models.QuantizationConfig] = None, hnsw_config: Optional[models.HnswConfigDiff] = None, payload_indexing: bool = False, - qdrant_client_options: Optional[Dict] = {}, + qdrant_client_options: Optional[dict] = {}, ): """Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a url to a single file. @@ -266,14 +266,14 @@ def create_qdrant_from_dir( def query_qdrant( - query_texts: List[str], + query_texts: list[str], n_results: int = 10, client: QdrantClient = None, collection_name: str = "all-my-documents", search_string: str = "", embedding_model: str = "BAAI/bge-small-en-v1.5", - qdrant_client_options: Optional[Dict] = {}, -) -> List[List[QueryResponse]]: + qdrant_client_options: Optional[dict] = {}, +) -> list[list[QueryResponse]]: """Perform a similarity search with filters on a Qdrant collection Args: diff --git a/autogen/agentchat/contrib/reasoning_agent.py b/autogen/agentchat/contrib/reasoning_agent.py index 1f623592b1..ac43fe3f48 100644 --- a/autogen/agentchat/contrib/reasoning_agent.py +++ b/autogen/agentchat/contrib/reasoning_agent.py @@ -81,7 +81,7 @@ def __init__(self, content: str, parent: Optional["ThinkNode"] = None) -> None: self.parent.children.append(self) @property - def _trajectory_arr(self) -> List[str]: + def _trajectory_arr(self) -> list[str]: """Get the full path from root to this node as a list of strings. Returns: @@ -118,7 +118,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert ThinkNode to dictionary representation. Returns: @@ -135,7 +135,7 @@ def to_dict(self) -> Dict: } @classmethod - def from_dict(cls, data: Dict, parent: Optional["ThinkNode"] = None) -> "ThinkNode": + def from_dict(cls, data: dict, parent: Optional["ThinkNode"] = None) -> "ThinkNode": """Create ThinkNode from dictionary representation. Args: @@ -624,7 +624,7 @@ def _mtcs_reply(self, prompt, ground_truth=""): (child.value / (child.visits + EPSILON)) + # exploration term self._exploration_constant - * math.sqrt((2 * math.log(node.visits + EPSILON) / (child.visits + EPSILON))) + * math.sqrt(2 * math.log(node.visits + EPSILON) / (child.visits + EPSILON)) for child in node.children ] node = node.children[choices_weights.index(max(choices_weights))] @@ -657,7 +657,7 @@ def _mtcs_reply(self, prompt, ground_truth=""): best_ans_node = max(answer_nodes, key=lambda node: node.value) return best_ans_node.content - def _expand(self, node: ThinkNode) -> List: + def _expand(self, node: ThinkNode) -> list: """ Expand the node by generating possible next steps based on the current trajectory. diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py index 8bea9e46c3..e2e6c0a5cf 100644 --- a/autogen/agentchat/contrib/retrieve_assistant_agent.py +++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py @@ -33,10 +33,10 @@ def __init__(self, *args, **kwargs): def _generate_retrieve_assistant_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: if config is None: config = self if messages is None: diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 49a72f3946..bf5e417156 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -100,8 +100,8 @@ def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS", - is_termination_msg: Optional[Callable[[Dict], bool]] = None, - retrieve_config: Optional[Dict] = None, # config for the retrieve agent + is_termination_msg: Optional[Callable[[dict], bool]] = None, + retrieve_config: Optional[dict] = None, # config for the retrieve agent **kwargs, ): r""" @@ -371,12 +371,10 @@ def _init_db(self): logger.info(f"Found {len(chunks)} chunks.") if self._new_docs: - all_docs_ids = set( - [ - doc["id"] - for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name) - ] - ) + all_docs_ids = { + doc["id"] + for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name) + } else: all_docs_ids = set() @@ -525,10 +523,10 @@ def _check_update_context(self, message): def _generate_retrieve_user_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """In this function, we will update the context and reset the conversation based on different conditions. We'll update the context and reset the conversation if update_context is True and either of the following: (1) the last message contains "UPDATE CONTEXT", diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py index e6f2b5f4dd..fbf2f15cc9 100644 --- a/autogen/agentchat/contrib/society_of_mind_agent.py +++ b/autogen/agentchat/contrib/society_of_mind_agent.py @@ -38,13 +38,13 @@ def __init__( name: str, chat_manager: GroupChatManager, response_preparer: Optional[Union[str, Callable]] = None, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", - function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Union[Dict, Literal[False]] = False, - llm_config: Optional[Union[Dict, Literal[False]]] = False, - default_auto_reply: Optional[Union[str, Dict, None]] = "", + function_map: Optional[dict[str, Callable]] = None, + code_execution_config: Union[dict, Literal[False]] = False, + llm_config: Optional[Union[dict, Literal[False]]] = False, + default_auto_reply: Optional[Union[str, dict, None]] = "", **kwargs, ): super().__init__( @@ -162,10 +162,10 @@ def update_chat_manager(self, chat_manager: Union[GroupChatManager, None]): def generate_inner_monologue_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply by running the group chat""" if self.chat_manager is None: return False, None diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 6ad536ae07..f604c13a57 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -66,7 +66,7 @@ class ON_CONDITION: If a string, it will look up the value of the context variable with that name, which should be a bool. """ - target: Union["SwarmAgent", Dict[str, Any]] = None + target: Union["SwarmAgent", dict[str, Any]] = None condition: str = "" available: Optional[Union[Callable, str]] = None @@ -74,7 +74,7 @@ def __post_init__(self): # Ensure valid types if self.target is not None: assert isinstance(self.target, SwarmAgent) or isinstance( - self.target, Dict + self.target, dict ), "'target' must be a SwarmAgent or a Dict" # Ensure they have a condition @@ -118,13 +118,13 @@ def __post_init__(self): def initiate_swarm_chat( initial_agent: "SwarmAgent", - messages: Union[List[Dict[str, Any]], str], - agents: List["SwarmAgent"], + messages: Union[list[dict[str, Any]], str], + agents: list["SwarmAgent"], user_agent: Optional[UserProxyAgent] = None, max_rounds: int = 20, - context_variables: Optional[Dict[str, Any]] = None, + context_variables: Optional[dict[str, Any]] = None, after_work: Optional[Union[AFTER_WORK, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE), -) -> Tuple[ChatResult, Dict[str, Any], "SwarmAgent"]: +) -> tuple[ChatResult, dict[str, Any], "SwarmAgent"]: """Initialize and run a swarm chat Args: @@ -248,10 +248,10 @@ def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat): else: raise ValueError("Invalid After Work condition or return value from callable") - def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]): + def create_nested_chats(agent: SwarmAgent, nested_chat_agents: list[SwarmAgent]): """Create nested chat agents and register nested chats""" for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs): - nested_chats: Dict[str, Any] = nested_chat_handoff["nested_chats"] + nested_chats: dict[str, Any] = nested_chat_handoff["nested_chats"] condition = nested_chat_handoff["condition"] available = nested_chat_handoff["available"] @@ -365,7 +365,7 @@ class SwarmResult(BaseModel): values: str = "" agent: Optional[Union["SwarmAgent", str]] = None - context_variables: Dict[str, Any] = {} + context_variables: dict[str, Any] = {} class Config: # Add this inner class arbitrary_types_allowed = True @@ -388,15 +388,15 @@ def __init__( self, name: str, system_message: Optional[str] = "You are a helpful AI Assistant.", - llm_config: Optional[Union[Dict, Literal[False]]] = None, - functions: Union[List[Callable], Callable] = None, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + llm_config: Optional[Union[dict, Literal[False]]] = None, + functions: Union[list[Callable], Callable] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, update_agent_state_before_reply: Optional[ - Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] + Union[list[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] ] = None, **kwargs, ) -> None: @@ -439,7 +439,7 @@ def __init__( if name != __TOOL_EXECUTOR_NAME__: self.register_hook("update_agent_state", self._update_conditional_functions) - def register_update_agent_state_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): + def register_update_agent_state_before_reply(self, functions: Optional[Union[list[Callable], Callable]]): """ Register functions that will be called when the agent is selected and before it speaks. You can add your own validation or precondition functions here. @@ -464,8 +464,8 @@ def register_update_agent_state_before_reply(self, functions: Optional[Union[Lis # Outer function to create a closure with the update function def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): def update_system_message_wrapper( - agent: ConversableAgent, messages: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + agent: ConversableAgent, messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: if isinstance(update_func.update_function, str): # Templates like "My context variable passport is {passport}" will # use the context_variables for substitution @@ -502,7 +502,7 @@ def __str__(self): def register_hand_off( self, - hand_to: Union[List[Union[ON_CONDITION, AFTER_WORK]], ON_CONDITION, AFTER_WORK], + hand_to: Union[list[Union[ON_CONDITION, AFTER_WORK]], ON_CONDITION, AFTER_WORK], ): """Register a function to hand off to another agent. @@ -555,7 +555,7 @@ def transfer_to_agent() -> "SwarmAgent": # Store function to add/remove later based on it being 'available' self._conditional_functions[func_name] = (transfer_func, transit) - elif isinstance(transit.target, Dict): + elif isinstance(transit.target, dict): # Transition to a nested chat # We will store them here and establish them in the initiate_swarm_chat self._nested_chat_handoffs.append( @@ -566,7 +566,7 @@ def transfer_to_agent() -> "SwarmAgent": raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK") @staticmethod - def _update_conditional_functions(agent: Agent, messages: Optional[List[Dict]] = None) -> None: + def _update_conditional_functions(agent: Agent, messages: Optional[list[dict]] = None) -> None: """Updates the agent's functions based on the ON_CONDITION's available condition.""" for func_name, (func, on_condition) in agent._conditional_functions.items(): is_available = True @@ -588,10 +588,10 @@ def _update_conditional_functions(agent: Agent, messages: Optional[List[Dict]] = def generate_swarm_tool_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, dict]: + ) -> tuple[bool, dict]: """Pre-processes and generates tool call replies. This function: @@ -697,15 +697,15 @@ def add_single_function(self, func: Callable, name=None, description=""): self.update_tool_signature(f_no_context, is_remove=False) self.register_function({func._name: func}) - def add_functions(self, func_list: List[Callable]): + def add_functions(self, func_list: list[Callable]): for func in func_list: self.add_single_function(func) @staticmethod def process_nested_chat_carryover( - chat: Dict[str, Any], + chat: dict[str, Any], recipient: ConversableAgent, - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], sender: ConversableAgent, config: Any, trim_n_messages: int = 0, @@ -730,7 +730,7 @@ def process_nested_chat_carryover( trim_n_messages: The number of latest messages to trim from the messages list """ - def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[str, Any]]]) -> str: + def concat_carryover(chat_message: str, carryover_message: Union[str, list[dict[str, Any]]]) -> str: """Concatenate the carryover message to the chat message.""" prefix = f"{chat_message}\n" if chat_message else "" @@ -799,8 +799,8 @@ def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[ @staticmethod def _summary_from_nested_chats( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, Union[str, None]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> tuple[bool, Union[str, None]]: """Overridden _summary_from_nested_chats method from ConversableAgent. This function initiates one or a sequence of chats between the "recipient" and the agents in the chat_queue. diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py index 79edcf3b7b..914e020851 100644 --- a/autogen/agentchat/contrib/text_analyzer_agent.py +++ b/autogen/agentchat/contrib/text_analyzer_agent.py @@ -23,7 +23,7 @@ def __init__( name="analyzer", system_message: Optional[str] = system_message, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", - llm_config: Optional[Union[Dict, bool]] = None, + llm_config: Optional[Union[dict, bool]] = None, **kwargs, ): """ @@ -48,10 +48,10 @@ def __init__( def _analyze_in_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Analyzes the given text as instructed, and returns the analysis as a message. Assumes exactly two messages containing the text to analyze and the analysis instructions. See Teachability.analyze for an example of how to use this method.""" diff --git a/autogen/agentchat/contrib/tool_retriever.py b/autogen/agentchat/contrib/tool_retriever.py index daa32c4364..4ca2ddcda4 100644 --- a/autogen/agentchat/contrib/tool_retriever.py +++ b/autogen/agentchat/contrib/tool_retriever.py @@ -81,7 +81,7 @@ def get_full_tool_description(py_file): """ Retrieves the function signature for a given Python file. """ - with open(py_file, "r") as f: + with open(py_file) as f: code = f.read() exec(code) function_name = os.path.splitext(os.path.basename(py_file))[0] diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 1454d65318..d2f3e0685d 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -4,14 +4,13 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT +from collections.abc import Mapping, Sequence from typing import ( Any, Callable, List, - Mapping, Optional, Protocol, - Sequence, Tuple, TypedDict, Union, @@ -42,7 +41,7 @@ class Document(TypedDict): A query is a list containing one string while queries is a list containing multiple strings. The response is a list of query results, each query result is a list of tuples containing the document and the distance. """ -QueryResults = List[List[Tuple[Document, float]]] +QueryResults = list[list[tuple[Document, float]]] @runtime_checkable @@ -67,7 +66,7 @@ class VectorDB(Protocol): active_collection: Any = None type: str = "" - embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = ( + embedding_function: Optional[Callable[[list[str]], list[list[float]]]] = ( None # embeddings = embedding_function(sentences) ) @@ -114,7 +113,7 @@ def delete_collection(self, collection_name: str) -> Any: """ ... - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: + def insert_docs(self, docs: list[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: """ Insert documents into the collection of the vector database. @@ -129,7 +128,7 @@ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: """ ... - def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None: + def update_docs(self, docs: list[Document], collection_name: str = None, **kwargs) -> None: """ Update documents in the collection of the vector database. @@ -143,7 +142,7 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg """ ... - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: list[ItemID], collection_name: str = None, **kwargs) -> None: """ Delete documents from the collection of the vector database. @@ -159,7 +158,7 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) def retrieve_docs( self, - queries: List[str], + queries: list[str], collection_name: str = None, n_results: int = 10, distance_threshold: float = -1, @@ -183,8 +182,8 @@ def retrieve_docs( ... def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs - ) -> List[Document]: + self, ids: list[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> list[Document]: """ Retrieve documents from the collection of the vector database based on the ids. diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index c6e082fc22..f01b32b898 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -169,7 +169,7 @@ def _batch_insert( else: collection.add(**collection_kwargs) - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs(self, docs: list[Document], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. @@ -204,7 +204,7 @@ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: metadatas = [doc.get("metadata") for doc in docs] self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) - def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + def update_docs(self, docs: list[Document], collection_name: str = None) -> None: """ Update documents in the collection of the vector database. @@ -217,7 +217,7 @@ def update_docs(self, docs: List[Document], collection_name: str = None) -> None """ self.insert_docs(docs, collection_name, upsert=True) - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: list[ItemID], collection_name: str = None, **kwargs) -> None: """ Delete documents from the collection of the vector database. @@ -234,7 +234,7 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) def retrieve_docs( self, - queries: List[str], + queries: list[str], collection_name: str = None, n_results: int = 10, distance_threshold: float = -1, @@ -269,7 +269,7 @@ def retrieve_docs( return results @staticmethod - def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: + def _chroma_get_results_to_list_documents(data_dict) -> list[Document]: """Converts a dictionary with list values to a list of Document. Args: @@ -305,8 +305,8 @@ def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: return results def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs - ) -> List[Document]: + self, ids: list[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> list[Document]: """ Retrieve documents from the collection of the vector database based on the ids. diff --git a/autogen/agentchat/contrib/vectordb/mongodb.py b/autogen/agentchat/contrib/vectordb/mongodb.py index aef05e35d7..b1a199c495 100644 --- a/autogen/agentchat/contrib/vectordb/mongodb.py +++ b/autogen/agentchat/contrib/vectordb/mongodb.py @@ -4,9 +4,10 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT +from collections.abc import Iterable, Mapping from copy import deepcopy from time import monotonic, sleep -from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Set, Tuple, Union import numpy as np from pymongo import MongoClient, UpdateOne, errors @@ -25,7 +26,7 @@ _DELAY = 0.5 -def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]: +def with_id_rename(docs: Iterable) -> list[dict[str, Any]]: """Utility changes _id field from Collection into id for Document.""" return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs] @@ -271,7 +272,7 @@ def create_vector_search_index( def insert_docs( self, - docs: List[Document], + docs: list[Document], collection_name: str = None, upsert: bool = False, batch_size=DEFAULT_INSERT_BATCH_SIZE, @@ -341,8 +342,8 @@ def insert_docs( self._wait_for_document(collection, self.index_name, docs[-1]) def _insert_batch( - self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID] - ) -> Set[ItemID]: + self, collection: Collection, texts: list[str], metadatas: list[Mapping[str, Any]], ids: list[ItemID] + ) -> set[ItemID]: """Compute embeddings for and insert a batch of Documents into the Collection. For performance reasons, we chose to call self.embedding_function just once, @@ -373,7 +374,7 @@ def _insert_batch( insert_result = collection.insert_many(to_insert) # type: ignore return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs - def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None: + def update_docs(self, docs: list[Document], collection_name: str = None, **kwargs: Any) -> None: """Update documents, including their embeddings, in the Collection. Optionally allow upsert as kwarg. @@ -413,7 +414,7 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg result.upserted_count, ) - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs): + def delete_docs(self, ids: list[ItemID], collection_name: str = None, **kwargs): """ Delete documents from the collection of the vector database. @@ -425,8 +426,8 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs): return collection.delete_many({"_id": {"$in": ids}}) def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs - ) -> List[Document]: + self, ids: list[ItemID] = None, collection_name: str = None, include: list[str] = None, **kwargs + ) -> list[Document]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -457,7 +458,7 @@ def get_docs_by_ids( def retrieve_docs( self, - queries: List[str], + queries: list[str], collection_name: str = None, n_results: int = 10, distance_threshold: float = -1, @@ -509,14 +510,14 @@ def retrieve_docs( def _vector_search( - embedding_vector: List[float], + embedding_vector: list[float], n_results: int, collection: Collection, index_name: str, distance_threshold: float = -1.0, oversampling_factor=10, include_embedding=False, -) -> List[Tuple[Dict, float]]: +) -> list[tuple[dict, float]]: """Core $vectorSearch Aggregation pipeline. Args: diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py index de7d4e5179..431fa9a543 100644 --- a/autogen/agentchat/contrib/vectordb/pgvectordb.py +++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py @@ -88,7 +88,7 @@ def set_collection_name(self, collection_name) -> str: self.name = name return self.name - def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None: + def add(self, ids: list[ItemID], documents: list, embeddings: list = None, metadatas: list = None) -> None: """ Add documents to the collection. @@ -131,7 +131,7 @@ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metad cursor.executemany(sql_string, sql_values) cursor.close() - def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None: + def upsert(self, ids: list[ItemID], documents: list, embeddings: list = None, metadatas: list = None) -> None: """ Upsert documents into the collection. @@ -240,7 +240,7 @@ def get( where: Optional[str] = None, limit: Optional[Union[int, str]] = None, offset: Optional[Union[int, str]] = None, - ) -> List[Document]: + ) -> list[Document]: """ Retrieve documents from the collection. @@ -312,7 +312,7 @@ def get( cursor.close() return retrieved_documents - def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None: + def update(self, ids: list, embeddings: list, metadatas: list, documents: list) -> None: """ Update documents in the collection. @@ -341,7 +341,7 @@ def update(self, ids: List, embeddings: List, metadatas: List, documents: List) cursor.close() @staticmethod - def euclidean_distance(arr1: List[float], arr2: List[float]) -> float: + def euclidean_distance(arr1: list[float], arr2: list[float]) -> float: """ Calculate the Euclidean distance between two vectors. @@ -356,7 +356,7 @@ def euclidean_distance(arr1: List[float], arr2: List[float]) -> float: return dist @staticmethod - def cosine_distance(arr1: List[float], arr2: List[float]) -> float: + def cosine_distance(arr1: list[float], arr2: list[float]) -> float: """ Calculate the cosine distance between two vectors. @@ -371,7 +371,7 @@ def cosine_distance(arr1: List[float], arr2: List[float]) -> float: return dist @staticmethod - def inner_product_distance(arr1: List[float], arr2: List[float]) -> float: + def inner_product_distance(arr1: list[float], arr2: list[float]) -> float: """ Calculate the Euclidean distance between two vectors. @@ -387,7 +387,7 @@ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float: def query( self, - query_texts: List[str], + query_texts: list[str], collection_name: Optional[str] = None, n_results: Optional[int] = 10, distance_type: Optional[str] = "euclidean", @@ -458,7 +458,7 @@ def query( return results @staticmethod - def convert_string_to_array(array_string: str) -> List[float]: + def convert_string_to_array(array_string: str) -> list[float]: """ Convert a string representation of an array to a list of floats. @@ -494,7 +494,7 @@ def modify(self, metadata, collection_name: Optional[str] = None) -> None: ) cursor.close() - def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None: + def delete(self, ids: list[ItemID], collection_name: Optional[str] = None) -> None: """ Delete documents from the collection. @@ -836,7 +836,7 @@ def _batch_insert( else: collection.add(**collection_kwargs) - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs(self, docs: list[Document], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. @@ -874,7 +874,7 @@ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) - def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + def update_docs(self, docs: list[Document], collection_name: str = None) -> None: """ Update documents in the collection of the vector database. @@ -887,7 +887,7 @@ def update_docs(self, docs: List[Document], collection_name: str = None) -> None """ self.insert_docs(docs, collection_name, upsert=True) - def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None: + def delete_docs(self, ids: list[ItemID], collection_name: str = None) -> None: """ Delete documents from the collection of the vector database. @@ -904,7 +904,7 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None: def retrieve_docs( self, - queries: List[str], + queries: list[str], collection_name: str = None, n_results: int = 10, distance_threshold: float = -1, @@ -936,8 +936,8 @@ def retrieve_docs( return results def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs - ) -> List[Document]: + self, ids: list[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> list[Document]: """ Retrieve documents from the collection of the vector database based on the ids. diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py index 65ede2ec55..70564056d8 100644 --- a/autogen/agentchat/contrib/vectordb/qdrant.py +++ b/autogen/agentchat/contrib/vectordb/qdrant.py @@ -7,7 +7,8 @@ import abc import logging import os -from typing import Callable, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Callable, List, Optional, Tuple, Union from .base import Document, ItemID, QueryResults, VectorDB from .utils import get_logger @@ -24,7 +25,7 @@ class EmbeddingFunction(abc.ABC): @abc.abstractmethod - def __call__(self, inputs: List[str]) -> List[Embeddings]: + def __call__(self, inputs: list[str]) -> list[Embeddings]: raise NotImplementedError @@ -67,7 +68,7 @@ def __init__( self._parallel = parallel self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs) - def __call__(self, inputs: List[str]) -> List[Embeddings]: + def __call__(self, inputs: list[str]) -> list[Embeddings]: embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel) return [embedding.tolist() for embedding in embeddings] @@ -161,7 +162,7 @@ def delete_collection(self, collection_name: str) -> None: """ return self.client.delete_collection(collection_name) - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs(self, docs: list[Document], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. @@ -186,7 +187,7 @@ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: self.client.upsert(collection_name, points=self._documents_to_points(docs)) - def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + def update_docs(self, docs: list[Document], collection_name: str = None) -> None: if not docs: return if any(doc.get("id") is None for doc in docs): @@ -198,7 +199,7 @@ def update_docs(self, docs: List[Document], collection_name: str = None) -> None raise ValueError("Some IDs do not exist. Skipping update") - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: list[ItemID], collection_name: str = None, **kwargs) -> None: """ Delete documents from the collection of the vector database. @@ -214,7 +215,7 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) def retrieve_docs( self, - queries: List[str], + queries: list[str], collection_name: str = None, n_results: int = 10, distance_threshold: float = 0, @@ -251,8 +252,8 @@ def retrieve_docs( return [self._scored_points_to_documents(results) for results in batch_results] def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs - ) -> List[Document]: + self, ids: list[ItemID] = None, collection_name: str = None, include=True, **kwargs + ) -> list[Document]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -280,13 +281,13 @@ def _point_to_document(self, point) -> Document: "embedding": point.vector, } - def _points_to_documents(self, points) -> List[Document]: + def _points_to_documents(self, points) -> list[Document]: return [self._point_to_document(point) for point in points] - def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]: + def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> tuple[Document, float]: return self._point_to_document(scored_point), scored_point.score - def _documents_to_points(self, documents: List[Document]): + def _documents_to_points(self, documents: list[Document]): contents = [document["content"] for document in documents] embeddings = self.embedding_function(contents) points = [ @@ -302,10 +303,10 @@ def _documents_to_points(self, documents: List[Document]): ] return points - def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]: + def _scored_points_to_documents(self, scored_points: list[models.ScoredPoint]) -> list[tuple[Document, float]]: return [self._scored_point_to_document(scored_point) for scored_point in scored_points] - def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool: + def _validate_update_ids(self, collection_name: str, ids: list[str]) -> bool: """ Validates all the IDs exist in the collection """ @@ -319,7 +320,7 @@ def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool: return True - def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool: + def _validate_upsert_ids(self, collection_name: str, ids: list[str]) -> bool: """ Validate none of the IDs exist in the collection """ diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index f0e5f00bce..6fb8cbacdf 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -64,7 +64,7 @@ def filter_results_by_distance(results: QueryResults, distance_threshold: float return results -def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults: +def chroma_results_to_query_results(data_dict: dict[str, list[list[Any]]], special_key="distances") -> QueryResults: """Converts a dictionary with list-of-list values to a list of tuples. Args: diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py index b6dd58162f..e8c5da2c21 100644 --- a/autogen/agentchat/contrib/web_surfer.py +++ b/autogen/agentchat/contrib/web_surfer.py @@ -10,9 +10,7 @@ import re from dataclasses import dataclass from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - -from typing_extensions import Annotated +from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Tuple, Union from ... import Agent, AssistantAgent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper, UserProxyAgent from ...browser_utils import SimpleTextBrowser @@ -36,17 +34,17 @@ class WebSurferAgent(ConversableAgent): def __init__( self, name: str, - system_message: Optional[Union[str, List[str]]] = DEFAULT_PROMPT, + system_message: Optional[Union[str, list[str]]] = DEFAULT_PROMPT, description: Optional[str] = DEFAULT_DESCRIPTION, - is_termination_msg: Optional[Callable[[Dict[str, Any]], bool]] = None, + is_termination_msg: Optional[Callable[[dict[str, Any]], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", - function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Union[Dict, Literal[False]] = False, - llm_config: Optional[Union[Dict, Literal[False]]] = None, - summarizer_llm_config: Optional[Union[Dict, Literal[False]]] = None, - default_auto_reply: Optional[Union[str, Dict, None]] = "", - browser_config: Optional[Union[Dict, None]] = None, + function_map: Optional[dict[str, Callable]] = None, + code_execution_config: Union[dict, Literal[False]] = False, + llm_config: Optional[Union[dict, Literal[False]]] = None, + summarizer_llm_config: Optional[Union[dict, Literal[False]]] = None, + default_auto_reply: Optional[Union[str, dict, None]] = "", + browser_config: Optional[Union[dict, None]] = None, **kwargs, ): super().__init__( @@ -94,7 +92,7 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) - def _create_summarizer_client(self, summarizer_llm_config: Dict[str, Any], llm_config: Dict[str, Any]) -> None: + def _create_summarizer_client(self, summarizer_llm_config: dict[str, Any], llm_config: dict[str, Any]) -> None: # If the summarizer_llm_config is None, we copy it from the llm_config if summarizer_llm_config is None: if llm_config is None: # Nothing to copy @@ -127,7 +125,7 @@ def _register_functions(self) -> None: """Register the functions for the inner assistant and user proxy.""" # Helper functions - def _browser_state() -> Tuple[str, str]: + def _browser_state() -> tuple[str, str]: header = f"Address: {self.browser.address}\n" if self.browser.page_title is not None: header += f"Title: {self.browser.page_title}\n" @@ -266,10 +264,10 @@ def _summarize_page( def generate_surfer_reply( self, - messages: Optional[List[Dict[str, str]]] = None, + messages: Optional[list[dict[str, str]]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Optional[Union[str, Dict[str, str]]]]: + ) -> tuple[bool, Optional[Union[str, dict[str, str]]]]: """Generate a reply using autogen.oai.""" if messages is None: messages = self._oai_messages[sender] diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b2f22ce9c5..747990a7cb 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -69,23 +69,23 @@ class ConversableAgent(LLMAgent): DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." DEFAULT_SUMMARY_METHOD = "last_msg" - llm_config: Union[Dict, Literal[False]] + llm_config: Union[dict, Literal[False]] def __init__( self, name: str, - system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.", - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + system_message: Optional[Union[str, list]] = "You are a helpful AI Assistant.", + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", - function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Union[Dict, Literal[False]] = False, - llm_config: Optional[Union[Dict, Literal[False]]] = None, - default_auto_reply: Union[str, Dict] = "", + function_map: Optional[dict[str, Callable]] = None, + code_execution_config: Union[dict, Literal[False]] = False, + llm_config: Optional[Union[dict, Literal[False]]] = None, + default_auto_reply: Union[str, dict] = "", description: Optional[str] = None, - chat_messages: Optional[Dict[Agent, List[Dict]]] = None, + chat_messages: Optional[dict[Agent, list[dict]]] = None, silent: Optional[bool] = None, - context_variables: Optional[Dict[str, Any]] = None, + context_variables: Optional[dict[str, Any]] = None, ): """ Args: @@ -260,7 +260,7 @@ def __init__( # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. # New hookable methods should be added to this list as required to support new agent capabilities. - self.hook_lists: Dict[str, List[Callable]] = { + self.hook_lists: dict[str, list[Callable]] = { "process_last_received_message": [], "process_all_messages_before_reply": [], "process_message_before_send": [], @@ -309,7 +309,7 @@ def code_executor(self) -> Optional[CodeExecutor]: def register_reply( self, - trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], + trigger: Union[type[Agent], str, Agent, Callable[[Agent], bool], list], reply_func: Callable, position: int = 0, config: Optional[Any] = None, @@ -392,8 +392,8 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable) @staticmethod def _get_chats_to_run( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> List[Dict[str, Any]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> list[dict[str, Any]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -424,8 +424,8 @@ def _get_chats_to_run( @staticmethod def _summary_from_nested_chats( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, Union[str, None]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> tuple[bool, Union[str, None]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -443,8 +443,8 @@ def _summary_from_nested_chats( @staticmethod async def _a_summary_from_nested_chats( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, Union[str, None]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> tuple[bool, Union[str, None]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -463,8 +463,8 @@ async def _a_summary_from_nested_chats( def register_nested_chats( self, - chat_queue: List[Dict[str, Any]], - trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], + chat_queue: list[dict[str, Any]], + trigger: Union[type[Agent], str, Agent, Callable[[Agent], bool], list], reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", position: int = 2, use_async: Union[bool, None] = None, @@ -548,7 +548,7 @@ def set_context(self, key: str, value: Any) -> None: """ self._context_variables[key] = value - def update_context(self, context_variables: Dict[str, Any]) -> None: + def update_context(self, context_variables: dict[str, Any]) -> None: """ Update multiple context variables at once. Args: @@ -599,15 +599,15 @@ def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int: return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] @property - def chat_messages(self) -> Dict[Agent, List[Dict]]: + def chat_messages(self) -> dict[Agent, list[dict]]: """A dictionary of conversations from agent to list of messages.""" return self._oai_messages - def chat_messages_for_summary(self, agent: Agent) -> List[Dict]: + def chat_messages_for_summary(self, agent: Agent) -> list[dict]: """A list of messages as a conversation to summarize.""" return self._oai_messages[agent] - def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]: + def last_message(self, agent: Optional[Agent] = None) -> Optional[dict]: """The last message exchanged with the agent. Args: @@ -640,7 +640,7 @@ def use_docker(self) -> Union[bool, str, None]: return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") @staticmethod - def _message_to_dict(message: Union[Dict, str]) -> Dict: + def _message_to_dict(message: Union[dict, str]) -> dict: """Convert a message to a dictionary. The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. @@ -674,7 +674,7 @@ def _assert_valid_name(name): raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") return name - def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent, is_sending: bool) -> bool: + def _append_oai_message(self, message: Union[dict, str], role, conversation_id: Agent, is_sending: bool) -> bool: """Append a message to the ChatCompletion conversation. If the message received is a string, it will be put in the "content" field of the new dictionary. @@ -731,8 +731,8 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: return True def _process_message_before_send( - self, message: Union[Dict, str], recipient: Agent, silent: bool - ) -> Union[Dict, str]: + self, message: Union[dict, str], recipient: Agent, silent: bool + ) -> Union[dict, str]: """Process the message before sending it to the recipient.""" hook_list = self.hook_lists["process_message_before_send"] for hook in hook_list: @@ -743,7 +743,7 @@ def _process_message_before_send( def send( self, - message: Union[Dict, str], + message: Union[dict, str], recipient: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, @@ -793,7 +793,7 @@ def send( async def a_send( self, - message: Union[Dict, str], + message: Union[dict, str], recipient: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, @@ -841,7 +841,7 @@ async def a_send( "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." ) - def _print_received_message(self, message: Union[Dict, str], sender: Agent, skip_head: bool = False): + def _print_received_message(self, message: Union[dict, str], sender: Agent, skip_head: bool = False): iostream = IOStream.get_default() # print the message received if not skip_head: @@ -903,7 +903,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent, skip iostream.print("\n", "-" * 80, flush=True, sep="") - def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool): + def _process_received_message(self, message: Union[dict, str], sender: Agent, silent: bool): # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) valid = self._append_oai_message(message, "user", sender, is_sending=False) if logging_enabled(): @@ -919,7 +919,7 @@ def _process_received_message(self, message: Union[Dict, str], sender: Agent, si def receive( self, - message: Union[Dict, str], + message: Union[dict, str], sender: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, @@ -956,7 +956,7 @@ def receive( async def a_receive( self, - message: Union[Dict, str], + message: Union[dict, str], sender: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, @@ -1034,7 +1034,7 @@ def initiate_chat( max_turns: Optional[int] = None, summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, summary_args: Optional[dict] = {}, - message: Optional[Union[Dict, str, Callable]] = None, + message: Optional[Union[dict, str, Callable]] = None, **kwargs, ) -> ChatResult: """Initiate a chat with the recipient agent. @@ -1355,7 +1355,7 @@ def _reflection_with_llm( response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) return response - def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Check the chat queue and add the "sender" key if it's missing. @@ -1372,7 +1372,7 @@ def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List chat_queue_with_sender.append(chat_info) return chat_queue_with_sender - def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: + def initiate_chats(self, chat_queue: list[dict[str, Any]]) -> list[ChatResult]: """(Experimental) Initiate chats with multiple agents. Args: @@ -1385,12 +1385,12 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: self._finished_chats = initiate_chats(_chat_queue) return self._finished_chats - async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: + async def a_initiate_chats(self, chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = await a_initiate_chats(_chat_queue) return self._finished_chats - def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]: + def get_chat_results(self, chat_index: Optional[int] = None) -> Union[list[ChatResult], ChatResult]: """A summary from the finished chats of particular agents.""" if chat_index is not None: return self._finished_chats[chat_index] @@ -1462,10 +1462,10 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser def generate_oai_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply using autogen.oai.""" client = self.client if config is None else config if client is None: @@ -1477,7 +1477,7 @@ def generate_oai_reply( ) return (False, None) if extracted_response is None else (True, extracted_response) - def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[str, Dict, None]: + def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[str, dict, None]: # unroll tool_responses all_messages = [] for message in messages: @@ -1522,16 +1522,16 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ async def a_generate_oai_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Generate a reply using autogen.oai asynchronously.""" iostream = IOStream.get_default() def _generate_oai_reply( self, iostream: IOStream, *args: Any, **kwargs: Any - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: with IOStream.set_default(iostream): return self.generate_oai_reply(*args, **kwargs) @@ -1544,9 +1544,9 @@ def _generate_oai_reply( def _generate_code_execution_reply_using_executor( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, - config: Optional[Union[Dict, Literal[False]]] = None, + config: Optional[Union[dict, Literal[False]]] = None, ): """Generate a reply using code executor.""" iostream = IOStream.get_default() @@ -1613,9 +1613,9 @@ def _generate_code_execution_reply_using_executor( def generate_code_execution_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, - config: Optional[Union[Dict, Literal[False]]] = None, + config: Optional[Union[dict, Literal[False]]] = None, ): """Generate a reply using code execution.""" code_execution_config = config if config is not None else self._code_execution_config @@ -1665,10 +1665,10 @@ def generate_code_execution_reply( def generate_function_call_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + ) -> tuple[bool, Union[dict, None]]: """ Generate a reply using function call. @@ -1703,10 +1703,10 @@ def generate_function_call_reply( async def a_generate_function_call_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + ) -> tuple[bool, Union[dict, None]]: """ Generate a reply using async function call. @@ -1735,10 +1735,10 @@ def _str_for_tool_response(self, tool_response): def generate_tool_calls_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + ) -> tuple[bool, Union[dict, None]]: """Generate a reply using tool call.""" if config is None: config = self @@ -1802,10 +1802,10 @@ async def _a_execute_tool_call(self, tool_call): async def a_generate_tool_calls_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + ) -> tuple[bool, Union[dict, None]]: """Generate a reply using async function call.""" if config is None: config = self @@ -1827,10 +1827,10 @@ async def a_generate_tool_calls_reply( def check_termination_and_human_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, None]]: + ) -> tuple[bool, Union[str, None]]: """Check if the conversation should be terminated, and if human reply is provided. This method checks for conditions that require the conversation to be terminated, such as reaching @@ -1940,10 +1940,10 @@ def check_termination_and_human_reply( async def a_check_termination_and_human_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, None]]: + ) -> tuple[bool, Union[str, None]]: """(async) Check if the conversation should be terminated, and if human reply is provided. This method checks for conditions that require the conversation to be terminated, such as reaching @@ -2053,10 +2053,10 @@ async def a_check_termination_and_human_reply( def generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, + messages: Optional[list[dict[str, Any]]] = None, sender: Optional["Agent"] = None, **kwargs: Any, - ) -> Union[str, Dict, None]: + ) -> Union[str, dict, None]: """Reply based on the conversation history and the sender. Either messages or sender must be provided. @@ -2126,10 +2126,10 @@ def generate_reply( async def a_generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, + messages: Optional[list[dict[str, Any]]] = None, sender: Optional["Agent"] = None, **kwargs: Any, - ) -> Union[str, Dict[str, Any], None]: + ) -> Union[str, dict[str, Any], None]: """(async) Reply based on the conversation history and the sender. Either messages or sender must be provided. @@ -2192,7 +2192,7 @@ async def a_generate_reply( return reply return self._default_auto_reply - def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool: + def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, list], sender: Optional[Agent]) -> bool: """Check if the sender matches the trigger. Args: @@ -2348,7 +2348,7 @@ def _format_json_str(jstr): result.append(char) return "".join(result) - def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, Any]]: + def execute_function(self, func_call, verbose: bool = False) -> tuple[bool, dict[str, Any]]: """Execute a function call and return the result. Override this function to modify the way to execute function and tool calls. @@ -2460,7 +2460,7 @@ async def a_execute_function(self, func_call): "content": content, } - def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + def generate_init_message(self, message: Union[dict, str, None], **kwargs) -> Union[str, dict]: """Generate the initial message for the agent. If message is None, input() will be called to get the initial message. @@ -2478,7 +2478,7 @@ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Un return self._handle_carryover(message, kwargs) - def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]: + def _handle_carryover(self, message: Union[str, dict], kwargs: dict) -> Union[str, dict]: if not kwargs.get("carryover"): return message @@ -2515,7 +2515,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str: ) return content - def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]: + def _process_multimodal_carryover(self, content: list[dict], kwargs: dict) -> list[dict]: """Prepends the context to a multimodal message.""" # Makes sure there's a carryover if not kwargs.get("carryover"): @@ -2523,7 +2523,7 @@ def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> Li return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content - async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + async def a_generate_init_message(self, message: Union[dict, str, None], **kwargs) -> Union[str, dict]: """Generate the initial message for the agent. If message is None, input() will be called to get the initial message. @@ -2538,7 +2538,7 @@ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwarg return self._handle_carryover(message, kwargs) - def register_function(self, function_map: Dict[str, Union[Callable, None]]): + def register_function(self, function_map: dict[str, Union[Callable, None]]): """Register functions to the agent. Args: @@ -2553,7 +2553,7 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]): self._function_map.update(function_map) self._function_map = {k: v for k, v in self._function_map.items() if v is not None} - def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None): + def update_function_signature(self, func_sig: Union[str, dict], is_remove: None): """update a function_signature in the LLM configuration for function_call. Args: @@ -2571,7 +2571,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) if is_remove: if "functions" not in self.llm_config.keys(): - error_msg = "The agent config doesn't have function {name}.".format(name=func_sig) + error_msg = f"The agent config doesn't have function {func_sig}." logger.error(error_msg) raise AssertionError(error_msg) else: @@ -2600,7 +2600,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) self.client = OpenAIWrapper(**self.llm_config) - def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool): + def update_tool_signature(self, tool_sig: Union[str, dict], is_remove: bool): """update a tool_signature in the LLM configuration for tool_call. Args: @@ -2615,7 +2615,7 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool): if is_remove: if "tools" not in self.llm_config.keys(): - error_msg = "The agent config doesn't have tool {name}.".format(name=tool_sig) + error_msg = f"The agent config doesn't have tool {tool_sig}." logger.error(error_msg) raise AssertionError(error_msg) else: @@ -2644,13 +2644,13 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool): self.client = OpenAIWrapper(**self.llm_config) - def can_execute_function(self, name: Union[List[str], str]) -> bool: + def can_execute_function(self, name: Union[list[str], str]) -> bool: """Whether the agent can execute the function.""" names = name if isinstance(name, list) else [name] return all([n in self._function_map for n in names]) @property - def function_map(self) -> Dict[str, Callable]: + def function_map(self) -> dict[str, Callable]: """Return the function map.""" return self._function_map @@ -2854,7 +2854,7 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) - def update_agent_state_before_reply(self, messages: List[Dict]) -> None: + def update_agent_state_before_reply(self, messages: list[dict]) -> None: """ Calls any registered capability hooks to update the agent's state. Primarily used to update context variables. @@ -2866,7 +2866,7 @@ def update_agent_state_before_reply(self, messages: List[Dict]) -> None: for hook in hook_list: hook(self, messages) - def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: + def process_all_messages_before_reply(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. """ @@ -2881,7 +2881,7 @@ def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: processed_messages = hook(processed_messages) return processed_messages - def process_last_received_message(self, messages: List[Dict]) -> List[Dict]: + def process_last_received_message(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to use and potentially modify the text of the last message, as long as the last message is not a function call or exit command. @@ -2924,7 +2924,7 @@ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]: messages[-1]["content"] = processed_user_content return messages - def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: + def print_usage_summary(self, mode: Union[str, list[str]] = ["actual", "total"]) -> None: """Print the usage summary.""" iostream = IOStream.get_default() @@ -2934,14 +2934,14 @@ def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) iostream.print(f"Agent '{self.name}':") self.client.print_usage_summary(mode) - def get_actual_usage(self) -> Union[None, Dict[str, int]]: + def get_actual_usage(self) -> Union[None, dict[str, int]]: """Get the actual usage summary.""" if self.client is None: return None else: return self.client.actual_usage_summary - def get_total_usage(self) -> Union[None, Dict[str, int]]: + def get_total_usage(self) -> Union[None, dict[str, int]]: """Get the total usage summary.""" if self.client is None: return None diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 0e14bf35f8..4a2bc18241 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -111,15 +111,15 @@ def custom_speaker_selection_func( - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system') """ - agents: List[Agent] - messages: List[Dict] + agents: list[Agent] + messages: list[dict] max_round: int = 10 admin_name: str = "Admin" func_call_filter: bool = True speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto" max_retries_for_selecting_speaker: int = 2 - allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None - allowed_or_disallowed_speaker_transitions: Optional[Dict] = None + allow_repeat_speaker: Optional[Union[bool, list[Agent]]] = None + allowed_or_disallowed_speaker_transitions: Optional[dict] = None speaker_transitions_type: Literal["allowed", "disallowed", None] = None enable_clear_history: bool = False send_introductions: bool = False @@ -145,8 +145,8 @@ def custom_speaker_selection_func( Respond with ONLY the name of the speaker and DO NOT provide a reason.""" select_speaker_transform_messages: Optional[transform_messages.TransformMessages] = None select_speaker_auto_verbose: Optional[bool] = False - select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None - select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None + select_speaker_auto_model_client_cls: Optional[Union[ModelClient, list[ModelClient]]] = None + select_speaker_auto_llm_config: Optional[Union[dict, Literal[False]]] = None role_for_select_speaker_messages: Optional[str] = "system" _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] @@ -157,7 +157,7 @@ def custom_speaker_selection_func( "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:" ) - allowed_speaker_transitions_dict: Dict = field(init=False) + allowed_speaker_transitions_dict: dict = field(init=False) def __post_init__(self): # Post init steers clears of the automatically generated __init__ method from dataclass @@ -277,7 +277,7 @@ def __post_init__(self): raise ValueError("select_speaker_auto_verbose cannot be None or non-bool") @property - def agent_names(self) -> List[str]: + def agent_names(self) -> list[str]: """Return the names of the agents in the group chat.""" return [agent.name for agent in self.agents] @@ -285,7 +285,7 @@ def reset(self): """Reset the group chat.""" self.messages.clear() - def append(self, message: Dict, speaker: Agent): + def append(self, message: dict, speaker: Agent): """Append a message to the group chat. We cast the content to str here so that it can be managed by text-based model. @@ -311,7 +311,7 @@ def agent_by_name( return filtered_agents[0] if filtered_agents else None - def nested_agents(self) -> List[Agent]: + def nested_agents(self) -> list[Agent]: """Returns all agents in the group chat manager.""" agents = self.agents.copy() for agent in agents: @@ -320,7 +320,7 @@ def nested_agents(self) -> List[Agent]: agents.extend(agent.groupchat.nested_agents()) return agents - def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent: + def next_agent(self, agent: Agent, agents: Optional[list[Agent]] = None) -> Agent: """Return the next agent in the list.""" if agents is None: agents = self.agents @@ -344,7 +344,7 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen # Explicitly handle cases where no valid next agent exists in the provided subset. raise UndefinedNextAgent() - def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str: + def select_speaker_msg(self, agents: Optional[list[Agent]] = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" if agents is None: agents = self.agents @@ -355,7 +355,7 @@ def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str: return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist) return return_msg - def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str: + def select_speaker_prompt(self, agents: Optional[list[Agent]] = None) -> str: """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context. Will return None if the select_speaker_prompt_template is None.""" @@ -371,7 +371,7 @@ def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str: return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist) return return_prompt - def introductions_msg(self, agents: Optional[List[Agent]] = None) -> str: + def introductions_msg(self, agents: Optional[list[Agent]] = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" if agents is None: agents = self.agents @@ -382,7 +382,7 @@ def introductions_msg(self, agents: Optional[List[Agent]] = None) -> str: return f"{intro_msg}\n\n{participant_roles}" - def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: + def manual_select_speaker(self, agents: Optional[list[Agent]] = None) -> Union[Agent, None]: """Manually select the next speaker.""" iostream = IOStream.get_default() @@ -415,7 +415,7 @@ def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") return None - def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: + def random_select_speaker(self, agents: Optional[list[Agent]] = None) -> Union[Agent, None]: """Randomly select the next speaker.""" if agents is None: agents = self.agents @@ -424,7 +424,7 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A def _prepare_and_select_agents( self, last_speaker: Agent, - ) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]: + ) -> tuple[Optional[Agent], list[Agent], Optional[list[dict]]]: # If self.speaker_selection_method is a callable, call it to get the next speaker. # If self.speaker_selection_method is a string, return it. speaker_selection_method = self.speaker_selection_method @@ -575,7 +575,7 @@ async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent # auto speaker selection with 2-agent chat return await self.a_auto_select_speaker(last_speaker, selector, messages, agents) - def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[List[Agent]]) -> Agent: + def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[list[Agent]]) -> Agent: if not final: # the LLM client is None, thus no reply is generated. Use round robin instead. return self.next_agent(last_speaker, agents) @@ -593,7 +593,7 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: agent = self.agent_by_name(name) return agent if agent else self.next_agent(last_speaker, agents) - def _register_client_from_config(self, agent: Agent, config: Dict): + def _register_client_from_config(self, agent: Agent, config: dict): model_client_cls_to_match = config.get("model_client_cls") if model_client_cls_to_match: if not self.select_speaker_auto_model_client_cls: @@ -670,8 +670,8 @@ def _auto_select_speaker( self, last_speaker: Agent, selector: ConversableAgent, - messages: Optional[List[Dict]], - agents: Optional[List[Agent]], + messages: Optional[list[dict]], + agents: Optional[list[Agent]], ) -> Agent: """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. @@ -706,7 +706,7 @@ def _auto_select_speaker( attempt = 0 # Registered reply function for checking_agent, checks the result of the response for agent names - def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]: + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, Union[str, dict, None]]: # The number of retries left, starting at max_retries_for_selecting_speaker nonlocal attempts_left nonlocal attempt @@ -754,8 +754,8 @@ async def a_auto_select_speaker( self, last_speaker: Agent, selector: ConversableAgent, - messages: Optional[List[Dict]], - agents: Optional[List[Agent]], + messages: Optional[list[dict]], + agents: Optional[list[Agent]], ) -> Agent: """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. @@ -789,7 +789,7 @@ async def a_auto_select_speaker( attempt = 0 # Registered reply function for checking_agent, checks the result of the response for agent names - def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]: + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, Union[str, dict, None]]: # The number of retries left, starting at max_retries_for_selecting_speaker nonlocal attempts_left nonlocal attempt @@ -834,7 +834,7 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un def _validate_speaker_name( self, recipient, messages, sender, config, attempts_left, attempt, agents - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, Union[str, dict, None]]: """Validates the speaker response for each round in the internal 2-agent chat within the auto select speaker method. @@ -928,7 +928,7 @@ def _validate_speaker_name( return True, None - def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[List[Agent]]): + def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[list[Agent]]): """Checks the result of the auto_select_speaker function, returning the agent to speak. @@ -948,7 +948,7 @@ def _process_speaker_selection_result(self, result, last_speaker: ConversableAge # No agent, return the failed reason return next_agent - def _participant_roles(self, agents: List[Agent] = None) -> str: + def _participant_roles(self, agents: list[Agent] = None) -> str: # Default to all agents registered if agents is None: agents = self.agents @@ -962,7 +962,7 @@ def _participant_roles(self, agents: List[Agent] = None) -> str: roles.append(f"{agent.name}: {agent.description}".strip()) return "\n".join(roles) - def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[List[Agent]]) -> Dict: + def _mentioned_agents(self, message_content: Union[str, list], agents: Optional[list[Agent]]) -> dict: """Counts the number of times each agent is mentioned in the provided message content. Agent names will match under any of the following conditions (all case-sensitive): - Exact name match @@ -1013,7 +1013,7 @@ def __init__( # unlimited consecutive auto reply by default max_consecutive_auto_reply: Optional[int] = sys.maxsize, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", - system_message: Optional[Union[str, List]] = "Group chat manager.", + system_message: Optional[Union[str, list]] = "Group chat manager.", silent: bool = False, **kwargs, ): @@ -1058,7 +1058,7 @@ def groupchat(self) -> GroupChat: """Returns the group chat managed by the group chat manager.""" return self._groupchat - def chat_messages_for_summary(self, agent: Agent) -> List[Dict]: + def chat_messages_for_summary(self, agent: Agent) -> list[dict]: """The list of messages in the group chat as a conversation to summarize. The agent is ignored. """ @@ -1129,10 +1129,10 @@ def print_messages(recipient, messages, sender, config): def run_chat( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[GroupChat] = None, - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, Optional[str]]: """Run a group chat.""" if messages is None: messages = self._oai_messages[sender] @@ -1209,7 +1209,7 @@ def run_chat( async def a_run_chat( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[GroupChat] = None, ): @@ -1275,10 +1275,10 @@ async def a_run_chat( def resume( self, - messages: Union[List[Dict], str], + messages: Union[list[dict], str], remove_termination_string: Optional[Union[str, Callable[[str], str]]] = None, silent: Optional[bool] = False, - ) -> Tuple[ConversableAgent, Dict]: + ) -> tuple[ConversableAgent, dict]: """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established as per the original group chat. @@ -1383,10 +1383,10 @@ def resume( async def a_resume( self, - messages: Union[List[Dict], str], + messages: Union[list[dict], str], remove_termination_string: Optional[Union[str, Callable[[str], str]]] = None, silent: Optional[bool] = False, - ) -> Tuple[ConversableAgent, Dict]: + ) -> tuple[ConversableAgent, dict]: """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established as per the original group chat. @@ -1489,7 +1489,7 @@ async def a_resume( return previous_last_agent, last_message - def _valid_resume_messages(self, messages: List[Dict]): + def _valid_resume_messages(self, messages: list[dict]): """Validates the messages used for resuming args: @@ -1515,7 +1515,7 @@ def _valid_resume_messages(self, messages: List[Dict]): raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}") def _process_resume_termination( - self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict] + self, remove_termination_string: Union[str, Callable[[str], str]], messages: list[dict] ): """Removes termination string, if required, and checks if termination may occur. @@ -1548,7 +1548,7 @@ def _remove_termination_string(content: str) -> str: if self._is_termination_msg(last_message): logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.") - def messages_from_string(self, message_string: str) -> List[Dict]: + def messages_from_string(self, message_string: str) -> list[dict]: """Reads the saved state of messages in Json format for resume and returns as a messages list args: @@ -1564,7 +1564,7 @@ def messages_from_string(self, message_string: str) -> List[Dict]: return state - def messages_to_string(self, messages: List[Dict]) -> str: + def messages_to_string(self, messages: list[dict]) -> str: """Converts the provided messages into a Json string that can be used for resuming the chat. The state is made up of a list of messages diff --git a/autogen/agentchat/realtime_agent/realtime_agent.py b/autogen/agentchat/realtime_agent/realtime_agent.py index 09dadab27e..aadbc1f283 100644 --- a/autogen/agentchat/realtime_agent/realtime_agent.py +++ b/autogen/agentchat/realtime_agent/realtime_agent.py @@ -9,7 +9,8 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, TypeVar, Union +from collections.abc import Generator +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union import anyio import websockets @@ -51,8 +52,8 @@ def __init__( *, name: str, audio_adapter: RealtimeObserver, - system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.", - llm_config: Optional[Union[Dict, Literal[False]]] = None, + system_message: Optional[Union[str, list]] = "You are a helpful AI Assistant.", + llm_config: Optional[Union[dict, Literal[False]]] = None, voice: str = "alloy", ): """(Experimental) Agent for interacting with the Realtime Clients. @@ -102,7 +103,7 @@ def register_swarm( self, *, initial_agent: SwarmAgent, - agents: List[SwarmAgent], + agents: list[SwarmAgent], system_message: Optional[str] = None, ) -> None: """Register a swarm of agents with the Realtime Agent. @@ -207,10 +208,10 @@ async def _check_event_set(timeout: int = question_timeout) -> None: def check_termination_and_human_reply( self, - messages: Optional[List[Dict]] = None, + messages: Optional[list[dict]] = None, sender: Optional[Agent] = None, config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, None]]: + ) -> tuple[bool, Union[str, None]]: """Check if the conversation should be terminated and if the agent should reply. Called when its agents turn in the chat conversation. diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py index 6602f232b8..75f6e354b7 100644 --- a/autogen/agentchat/user_proxy_agent.py +++ b/autogen/agentchat/user_proxy_agent.py @@ -32,14 +32,14 @@ class UserProxyAgent(ConversableAgent): def __init__( self, name: str, - is_termination_msg: Optional[Callable[[Dict], bool]] = None, + is_termination_msg: Optional[Callable[[dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS", - function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Union[Dict, Literal[False]] = {}, - default_auto_reply: Optional[Union[str, Dict, None]] = "", - llm_config: Optional[Union[Dict, Literal[False]]] = False, - system_message: Optional[Union[str, List]] = "", + function_map: Optional[dict[str, Callable]] = None, + code_execution_config: Union[dict, Literal[False]] = {}, + default_auto_reply: Optional[Union[str, dict, None]] = "", + llm_config: Optional[Union[dict, Literal[False]]] = False, + system_message: Optional[Union[str, list]] = "", description: Optional[str] = None, **kwargs, ): diff --git a/autogen/agentchat/utils.py b/autogen/agentchat/utils.py index 490bd46f18..94b2d7c5e5 100644 --- a/autogen/agentchat/utils.py +++ b/autogen/agentchat/utils.py @@ -32,7 +32,7 @@ def consolidate_chat_info(chat_info, uniform_sender=None) -> None: ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." -def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, Dict]]: +def gather_usage_summary(agents: list[Agent]) -> dict[dict[str, dict], dict[str, dict]]: r"""Gather usage summary from all agents. Args: @@ -74,7 +74,7 @@ def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`. """ - def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None: + def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None: if agent_summary is None: return usage_summary["total_cost"] += agent_summary.get("total_cost", 0) @@ -102,7 +102,7 @@ def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, An } -def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]: +def parse_tags_from_content(tag: str, content: Union[str, list[dict[str, Any]]]) -> list[dict[str, dict[str, str]]]: """Parses HTML style tags from message contents. The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is @@ -142,7 +142,7 @@ def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) return results -def _parse_tags_from_text(tag: str, text: str) -> List[Dict[str, str]]: +def _parse_tags_from_text(tag: str, text: str) -> list[dict[str, str]]: pattern = re.compile(f"<{tag} (.*?)>") results = [] @@ -180,7 +180,7 @@ def _append_src_value(content, value): return content -def _reconstruct_attributes(attrs: List[str]) -> List[str]: +def _reconstruct_attributes(attrs: list[str]) -> list[str]: """Reconstructs attributes from a list of strings where some attributes may be split across multiple elements.""" def is_attr(attr: str) -> bool: diff --git a/autogen/browser_utils.py b/autogen/browser_utils.py index dd8d9ca2b7..3251a5c9ab 100644 --- a/autogen/browser_utils.py +++ b/autogen/browser_utils.py @@ -44,15 +44,15 @@ def __init__( downloads_folder: Optional[Union[str, None]] = None, bing_base_url: str = "https://api.bing.microsoft.com/v7.0/search", bing_api_key: Optional[Union[str, None]] = None, - request_kwargs: Optional[Union[Dict[str, Any], None]] = None, + request_kwargs: Union[dict[str, Any], None] = None, ): self.start_page: str = start_page if start_page else "about:blank" self.viewport_size = viewport_size # Applies only to the standard uri types self.downloads_folder = downloads_folder - self.history: List[str] = list() + self.history: list[str] = list() self.page_title: Optional[str] = None self.viewport_current_page = 0 - self.viewport_pages: List[Tuple[int, int]] = list() + self.viewport_pages: list[tuple[int, int]] = list() self.set_address(self.start_page) self.bing_base_url = bing_base_url self.bing_api_key = bing_api_key @@ -132,7 +132,7 @@ def _split_pages(self) -> None: self.viewport_pages.append((start_idx, end_idx)) start_idx = end_idx - def _bing_api_call(self, query: str) -> Dict[str, Dict[str, List[Dict[str, Union[str, Dict[str, str]]]]]]: + def _bing_api_call(self, query: str) -> dict[str, dict[str, list[dict[str, Union[str, dict[str, str]]]]]]: # Make sure the key was set if self.bing_api_key is None: raise ValueError("Missing Bing API key.") @@ -162,7 +162,7 @@ def _bing_api_call(self, query: str) -> Dict[str, Dict[str, List[Dict[str, Union def _bing_search(self, query: str) -> None: results = self._bing_api_call(query) - web_snippets: List[str] = list() + web_snippets: list[str] = list() idx = 0 for page in results["webPages"]["value"]: idx += 1 diff --git a/autogen/cache/abstract_cache_base.py b/autogen/cache/abstract_cache_base.py index 6a5b88823d..d7d864508f 100644 --- a/autogen/cache/abstract_cache_base.py +++ b/autogen/cache/abstract_cache_base.py @@ -63,7 +63,7 @@ def __enter__(self) -> Self: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: diff --git a/autogen/cache/cache.py b/autogen/cache/cache.py index 1e64e5d5f0..1b2bc26f6b 100644 --- a/autogen/cache/cache.py +++ b/autogen/cache/cache.py @@ -40,7 +40,7 @@ class Cache(AbstractCache): ] @staticmethod - def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache": + def redis(cache_seed: str | int = 42, redis_url: str = "redis://localhost:6379/0") -> Cache: """ Create a Redis cache instance. @@ -54,7 +54,7 @@ def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost: return Cache({"cache_seed": cache_seed, "redis_url": redis_url}) @staticmethod - def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> "Cache": + def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> Cache: """ Create a Disk cache instance. @@ -69,11 +69,11 @@ def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> " @staticmethod def cosmos_db( - connection_string: Optional[str] = None, - container_id: Optional[str] = None, - cache_seed: Union[str, int] = 42, - client: Optional[any] = None, - ) -> "Cache": + connection_string: str | None = None, + container_id: str | None = None, + cache_seed: str | int = 42, + client: any | None = None, + ) -> Cache: """ Create a Cosmos DB cache instance with 'autogen_cache' as database ID. @@ -93,7 +93,7 @@ def cosmos_db( } return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config}) - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: dict[str, Any]): """ Initialize the Cache with the given configuration. @@ -121,7 +121,7 @@ def __init__(self, config: Dict[str, Any]): cosmosdb_config=self.config.get("cosmos_db_config"), ) - def __enter__(self) -> "Cache": + def __enter__(self) -> Cache: """ Enter the runtime context related to the cache object. @@ -132,9 +132,9 @@ def __enter__(self) -> "Cache": def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: """ Exit the runtime context related to the cache object. @@ -149,7 +149,7 @@ def __exit__( """ return self.cache.__exit__(exc_type, exc_value, traceback) - def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + def get(self, key: str, default: Any | None = None) -> Any | None: """ Retrieve an item from the cache. diff --git a/autogen/cache/cache_factory.py b/autogen/cache/cache_factory.py index 8401c189d7..b64328cfe7 100644 --- a/autogen/cache/cache_factory.py +++ b/autogen/cache/cache_factory.py @@ -18,7 +18,7 @@ def cache_factory( seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache", - cosmosdb_config: Optional[Dict[str, Any]] = None, + cosmosdb_config: Optional[dict[str, Any]] = None, ) -> AbstractCache: """ Factory function for creating cache instances. diff --git a/autogen/cache/disk_cache.py b/autogen/cache/disk_cache.py index adb0720287..2838447cb6 100644 --- a/autogen/cache/disk_cache.py +++ b/autogen/cache/disk_cache.py @@ -92,7 +92,7 @@ def __enter__(self) -> Self: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: diff --git a/autogen/cache/in_memory_cache.py b/autogen/cache/in_memory_cache.py index 8469554d04..f080530e56 100644 --- a/autogen/cache/in_memory_cache.py +++ b/autogen/cache/in_memory_cache.py @@ -20,7 +20,7 @@ class InMemoryCache(AbstractCache): def __init__(self, seed: Union[str, int] = ""): self._seed = str(seed) - self._cache: Dict[str, Any] = {} + self._cache: dict[str, Any] = {} def _prefixed_key(self, key: str) -> str: separator = "_" if self._seed else "" @@ -48,7 +48,7 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: """ Exit the runtime context related to the object. diff --git a/autogen/cache/redis_cache.py b/autogen/cache/redis_cache.py index b1c4c10557..b87863f083 100644 --- a/autogen/cache/redis_cache.py +++ b/autogen/cache/redis_cache.py @@ -113,7 +113,7 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: """ Exit the runtime context related to the object. diff --git a/autogen/code_utils.py b/autogen/code_utils.py index 0587e87728..9e9f1b91c2 100644 --- a/autogen/code_utils.py +++ b/autogen/code_utils.py @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) -def content_str(content: Union[str, List[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str: +def content_str(content: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str: """Converts the `content` field of an OpenAI message into a string format. This function processes content that may be a string, a list of mixed text and image URLs, or None, @@ -108,8 +108,8 @@ def infer_lang(code: str) -> str: # TODO: In the future move, to better support https://spec.commonmark.org/0.30/#fenced-code-blocks # perhaps by using a full Markdown parser. def extract_code( - text: Union[str, List], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False -) -> List[Tuple[str, str]]: + text: Union[str, list], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False +) -> list[tuple[str, str]]: """Extract code from a text. Args: @@ -146,7 +146,7 @@ def extract_code( return extracted -def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> Tuple[str, float]: +def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> tuple[str, float]: """(openai<1) Generate code. Args: @@ -175,7 +175,7 @@ def improve_function(file_name, func_name, objective, **config): """(openai<1) Improve the function to achieve the objective.""" params = {**_IMPROVE_FUNCTION_CONFIG, **config} # read the entire file into a str - with open(file_name, "r") as f: + with open(file_name) as f: file_string = f.read() response = oai.Completion.create( {"func_name": func_name, "objective": objective, "file_string": file_string}, **params @@ -208,7 +208,7 @@ def improve_code(files, objective, suggest_only=True, **config): code = "" for file_name in files: # read the entire file into a string - with open(file_name, "r") as f: + with open(file_name) as f: file_string = f.read() code += f"""{file_name}: {file_string} @@ -358,9 +358,9 @@ def execute_code( timeout: Optional[int] = None, filename: Optional[str] = None, work_dir: Optional[str] = None, - use_docker: Union[List[str], str, bool] = SENTINEL, + use_docker: Union[list[str], str, bool] = SENTINEL, lang: Optional[str] = "python", -) -> Tuple[int, str, Optional[str]]: +) -> tuple[int, Optional[str]]: """Execute code in a docker container. This function is not tested on MacOS. @@ -552,7 +552,7 @@ def execute_code( } -def generate_assertions(definition: str, **config) -> Tuple[str, float]: +def generate_assertions(definition: str, **config) -> tuple[str, float]: """(openai<1) Generate assertions for a function. Args: @@ -582,14 +582,14 @@ def _remove_check(response): def eval_function_completions( - responses: List[str], + responses: list[str], definition: str, test: Optional[str] = None, entry_point: Optional[str] = None, - assertions: Optional[Union[str, Callable[[str], Tuple[str, float]]]] = None, + assertions: Optional[Union[str, Callable[[str], tuple[str, float]]]] = None, timeout: Optional[float] = 3, use_docker: Optional[bool] = True, -) -> Dict: +) -> dict[str, Any]: """(openai<1) Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test. Args: @@ -692,9 +692,9 @@ def pass_assertions(self, context, response, **_): def implement( definition: str, - configs: Optional[List[Dict]] = None, - assertions: Optional[Union[str, Callable[[str], Tuple[str, float]]]] = generate_assertions, -) -> Tuple[str, float]: + configs: Optional[list[dict]] = None, + assertions: Optional[Union[str, Callable[[str], tuple[str, float]]]] = generate_assertions, +) -> tuple[str, float]: """(openai<1) Implement a function from a definition. Args: diff --git a/autogen/coding/base.py b/autogen/coding/base.py index 57af08ac5c..e17fa22a60 100644 --- a/autogen/coding/base.py +++ b/autogen/coding/base.py @@ -6,7 +6,8 @@ # SPDX-License-Identifier: MIT from __future__ import annotations -from typing import Any, List, Literal, Mapping, Optional, Protocol, TypedDict, Union, runtime_checkable +from collections.abc import Mapping +from typing import Any, List, Literal, Optional, Protocol, TypedDict, Union, runtime_checkable from pydantic import BaseModel, Field @@ -35,8 +36,8 @@ class CodeExtractor(Protocol): """(Experimental) A code extractor class that extracts code blocks from a message.""" def extract_code_blocks( - self, message: Union[str, List[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None] - ) -> List[CodeBlock]: + self, message: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None + ) -> list[CodeBlock]: """(Experimental) Extract code blocks from a message. Args: @@ -57,7 +58,7 @@ def code_extractor(self) -> CodeExtractor: """(Experimental) The code extractor used by this code executor.""" ... # pragma: no cover - def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult: + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CodeResult: """(Experimental) Execute code blocks and return the result. This method should be implemented by the code executor. @@ -83,7 +84,7 @@ def restart(self) -> None: class IPythonCodeResult(CodeResult): """(Experimental) A code result class for IPython code executor.""" - output_files: List[str] = Field( + output_files: list[str] = Field( default_factory=list, description="The list of files that the executed code blocks generated.", ) @@ -95,7 +96,7 @@ class IPythonCodeResult(CodeResult): "executor": Union[Literal["ipython-embedded", "commandline-local"], CodeExecutor], "last_n_messages": Union[int, Literal["auto"]], "timeout": int, - "use_docker": Union[bool, str, List[str]], + "use_docker": Union[bool, str, list[str]], "work_dir": str, "ipython-embedded": Mapping[str, Any], "commandline-local": Mapping[str, Any], @@ -107,7 +108,7 @@ class IPythonCodeResult(CodeResult): class CommandLineCodeResult(CodeResult): """(Experimental) A code result class for command line code executor.""" - code_file: Optional[str] = Field( + code_file: str | None = Field( default=None, description="The file that the executed code block was saved to.", ) diff --git a/autogen/coding/docker_commandline_code_executor.py b/autogen/coding/docker_commandline_code_executor.py index 2576f28ed7..395f61da27 100644 --- a/autogen/coding/docker_commandline_code_executor.py +++ b/autogen/coding/docker_commandline_code_executor.py @@ -45,7 +45,7 @@ def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) - class DockerCommandLineCodeExecutor(CodeExecutor): - DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = { + DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = { "bash": True, "shell": True, "sh": True, @@ -57,18 +57,18 @@ class DockerCommandLineCodeExecutor(CodeExecutor): "html": False, "css": False, } - LANGUAGE_ALIASES: ClassVar[Dict[str, str]] = {"py": "python", "js": "javascript"} + LANGUAGE_ALIASES: ClassVar[dict[str, str]] = {"py": "python", "js": "javascript"} def __init__( self, image: str = "python:3-slim", - container_name: Optional[str] = None, + container_name: str | None = None, timeout: int = 60, - work_dir: Union[Path, str] = Path("."), - bind_dir: Optional[Union[Path, str]] = None, + work_dir: Path | str = Path("."), + bind_dir: Path | str | None = None, auto_remove: bool = True, stop_container: bool = True, - execution_policies: Optional[Dict[str, bool]] = None, + execution_policies: dict[str, bool] | None = None, ): """(Experimental) A code executor class that executes code through a command line environment in a Docker container. @@ -183,7 +183,7 @@ def code_extractor(self) -> CodeExtractor: """(Experimental) Export a code extractor that can be used by an agent.""" return MarkdownCodeExtractor() - def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult: + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: """(Experimental) Execute the code blocks and return the result. Args: @@ -257,6 +257,6 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self.stop() diff --git a/autogen/coding/func_with_reqs.py b/autogen/coding/func_with_reqs.py index 7a755ffdce..1f842c9193 100644 --- a/autogen/coding/func_with_reqs.py +++ b/autogen/coding/func_with_reqs.py @@ -20,7 +20,7 @@ P = ParamSpec("P") -def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: +def _to_code(func: FunctionWithRequirements[T, P] | Callable[P, T] | FunctionWithRequirementsStr) -> str: if isinstance(func, FunctionWithRequirementsStr): return func.func @@ -40,7 +40,7 @@ class Alias: @dataclass class ImportFromModule: module: str - imports: List[Union[str, Alias]] + imports: list[str | Alias] Import = Union[str, ImportFromModule, Alias] @@ -53,7 +53,7 @@ def _import_to_str(im: Import) -> str: return f"import {im.name} as {im.alias}" else: - def to_str(i: Union[str, Alias]) -> str: + def to_str(i: str | Alias) -> str: if isinstance(i, str): return i else: @@ -82,10 +82,10 @@ class FunctionWithRequirementsStr: func: str _compiled_func: Callable[..., Any] _func_name: str - python_packages: List[str] = field(default_factory=list) - global_imports: List[Import] = field(default_factory=list) + python_packages: list[str] = field(default_factory=list) + global_imports: list[Import] = field(default_factory=list) - def __init__(self, func: str, python_packages: List[str] = [], global_imports: List[Import] = []): + def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []): self.func = func self.python_packages = python_packages self.global_imports = global_imports @@ -117,18 +117,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: @dataclass class FunctionWithRequirements(Generic[T, P]): func: Callable[P, T] - python_packages: List[str] = field(default_factory=list) - global_imports: List[Import] = field(default_factory=list) + python_packages: list[str] = field(default_factory=list) + global_imports: list[Import] = field(default_factory=list) @classmethod def from_callable( - cls, func: Callable[P, T], python_packages: List[str] = [], global_imports: List[Import] = [] + cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = [] ) -> FunctionWithRequirements[T, P]: return cls(python_packages=python_packages, global_imports=global_imports, func=func) @staticmethod def from_str( - func: str, python_packages: List[str] = [], global_imports: List[Import] = [] + func: str, python_packages: list[str] = [], global_imports: list[Import] = [] ) -> FunctionWithRequirementsStr: return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports) @@ -138,7 +138,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: def with_requirements( - python_packages: List[str] = [], global_imports: List[Import] = [] + python_packages: list[str] = [], global_imports: list[Import] = [] ) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: """Decorate a function with package and import requirements @@ -162,10 +162,10 @@ def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]: def _build_python_functions_file( - funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]] + funcs: list[FunctionWithRequirements[Any, P] | Callable[..., Any] | FunctionWithRequirementsStr] ) -> str: # First collect all global imports - global_imports: Set[str] = set() + global_imports: set[str] = set() for func in funcs: if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): global_imports.update(map(_import_to_str, func.global_imports)) @@ -178,7 +178,7 @@ def _build_python_functions_file( return content -def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str: +def to_stub(func: Callable[..., Any] | FunctionWithRequirementsStr) -> str: """Generate a stub for a function as a string Args: diff --git a/autogen/coding/jupyter/docker_jupyter_server.py b/autogen/coding/jupyter/docker_jupyter_server.py index 2a73a7b307..f90090dee6 100644 --- a/autogen/coding/jupyter/docker_jupyter_server.py +++ b/autogen/coding/jupyter/docker_jupyter_server.py @@ -59,12 +59,12 @@ class GenerateToken: def __init__( self, *, - custom_image_name: Optional[str] = None, - container_name: Optional[str] = None, + custom_image_name: str | None = None, + container_name: str | None = None, auto_remove: bool = True, stop_container: bool = True, - docker_env: Dict[str, str] = {}, - token: Union[str, GenerateToken] = GenerateToken(), + docker_env: dict[str, str] = {}, + token: str | GenerateToken = GenerateToken(), ): """Start a Jupyter kernel gateway server in a Docker container. @@ -159,6 +159,6 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self.stop() diff --git a/autogen/coding/jupyter/embedded_ipython_code_executor.py b/autogen/coding/jupyter/embedded_ipython_code_executor.py index 231dca0ffd..4e0a8d828c 100644 --- a/autogen/coding/jupyter/embedded_ipython_code_executor.py +++ b/autogen/coding/jupyter/embedded_ipython_code_executor.py @@ -78,7 +78,7 @@ def code_extractor(self) -> CodeExtractor: """(Experimental) Export a code extractor that can be used by an agent.""" return MarkdownCodeExtractor() - def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> IPythonCodeResult: + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult: """(Experimental) Execute a list of code blocks and return the result. This method executes a list of code blocks as cells in an IPython kernel diff --git a/autogen/coding/jupyter/jupyter_client.py b/autogen/coding/jupyter/jupyter_client.py index e1df947969..5482c8537f 100644 --- a/autogen/coding/jupyter/jupyter_client.py +++ b/autogen/coding/jupyter/jupyter_client.py @@ -40,7 +40,7 @@ def __init__(self, connection_info: JupyterConnectionInfo): retries = Retry(total=5, backoff_factor=0.1) self._session.mount("http://", HTTPAdapter(max_retries=retries)) - def _get_headers(self) -> Dict[str, str]: + def _get_headers(self) -> dict[str, str]: if self._connection_info.token is None: return {} return {"Authorization": f"token {self._connection_info.token}"} @@ -54,13 +54,13 @@ def _get_ws_base_url(self) -> str: port = f":{self._connection_info.port}" if self._connection_info.port else "" return f"ws://{self._connection_info.host}{port}" - def list_kernel_specs(self) -> Dict[str, Dict[str, str]]: + def list_kernel_specs(self) -> dict[str, dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) - return cast(Dict[str, Dict[str, str]], response.json()) + return cast(dict[str, dict[str, str]], response.json()) - def list_kernels(self) -> List[Dict[str, str]]: + def list_kernels(self) -> list[dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) - return cast(List[Dict[str, str]], response.json()) + return cast(list[dict[str, str]], response.json()) def start_kernel(self, kernel_spec_name: str) -> str: """Start a new kernel. @@ -109,7 +109,7 @@ class DataItem: is_ok: bool output: str - data_items: List[DataItem] + data_items: list[DataItem] def __init__(self, websocket: WebSocket): self._session_id: str = uuid.uuid4().hex @@ -119,14 +119,14 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self.stop() def stop(self) -> None: self._websocket.close() - def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str: + def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str: timestamp = datetime.datetime.now().isoformat() message_id = uuid.uuid4().hex message = { @@ -147,17 +147,17 @@ def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: self._websocket.send_text(json.dumps(message)) return message_id - def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]: + def _receive_message(self, timeout_seconds: float | None) -> dict[str, Any] | None: self._websocket.settimeout(timeout_seconds) try: data = self._websocket.recv() if isinstance(data, bytes): data = data.decode("utf-8") - return cast(Dict[str, Any], json.loads(data)) + return cast(dict[str, Any], json.loads(data)) except websocket.WebSocketTimeoutException: return None - def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: + def wait_for_ready(self, timeout_seconds: float | None = None) -> bool: message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request") while True: message = self._receive_message(timeout_seconds) @@ -170,7 +170,7 @@ def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: ): return True - def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult: + def execute(self, code: str, timeout_seconds: float | None = None) -> ExecutionResult: message_id = self._send_message( content={ "code": code, diff --git a/autogen/coding/jupyter/jupyter_code_executor.py b/autogen/coding/jupyter/jupyter_code_executor.py index 862885d797..afee16963d 100644 --- a/autogen/coding/jupyter/jupyter_code_executor.py +++ b/autogen/coding/jupyter/jupyter_code_executor.py @@ -82,7 +82,7 @@ def code_extractor(self) -> CodeExtractor: """(Experimental) Export a code extractor that can be used by an agent.""" return MarkdownCodeExtractor() - def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> IPythonCodeResult: + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult: """(Experimental) Execute a list of code blocks and return the result. This method executes a list of code blocks as cells in the Jupyter kernel. @@ -156,6 +156,6 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: self.stop() diff --git a/autogen/coding/jupyter/local_jupyter_server.py b/autogen/coding/jupyter/local_jupyter_server.py index 6aa1378313..6ea55c1f0a 100644 --- a/autogen/coding/jupyter/local_jupyter_server.py +++ b/autogen/coding/jupyter/local_jupyter_server.py @@ -32,8 +32,8 @@ class GenerateToken: def __init__( self, ip: str = "127.0.0.1", - port: Optional[int] = None, - token: Union[str, GenerateToken] = GenerateToken(), + port: int | None = None, + token: str | GenerateToken = GenerateToken(), log_file: str = "jupyter_gateway.log", log_level: str = "INFO", log_max_bytes: int = 1048576, @@ -59,8 +59,7 @@ def __init__( subprocess.run( [sys.executable, "-m", "jupyter", "kernelgateway", "--version"], check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, text=True, ) except subprocess.CalledProcessError: @@ -163,6 +162,6 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self.stop() diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py index 889182bfb1..bcbab20ef2 100644 --- a/autogen/coding/local_commandline_code_executor.py +++ b/autogen/coding/local_commandline_code_executor.py @@ -36,7 +36,7 @@ class LocalCommandLineCodeExecutor(CodeExecutor): - SUPPORTED_LANGUAGES: ClassVar[List[str]] = [ + SUPPORTED_LANGUAGES: ClassVar[list[str]] = [ "bash", "shell", "sh", @@ -48,7 +48,7 @@ class LocalCommandLineCodeExecutor(CodeExecutor): "html", "css", ] - DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = { + DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = { "bash": True, "shell": True, "sh": True, @@ -74,9 +74,9 @@ def __init__( timeout: int = 60, virtual_env_context: Optional[SimpleNamespace] = None, work_dir: Union[Path, str] = Path("."), - functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [], + functions: list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [], functions_module: str = "functions", - execution_policies: Optional[Dict[str, bool]] = None, + execution_policies: Optional[dict[str, bool]] = None, ): """(Experimental) A code executor class that executes or saves LLM generated code a local command line environment. @@ -168,7 +168,7 @@ def functions_module(self) -> str: @property def functions( self, - ) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]: + ) -> list[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]: """(Experimental) The functions that are available to the code executor.""" return self._functions @@ -244,7 +244,7 @@ def _setup_functions(self) -> None: raise ValueError(f"Functions failed to load: {exec_result.output}") self._setup_functions_complete = True - def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult: + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: """(Experimental) Execute the code blocks and return the result. Args: @@ -256,7 +256,7 @@ def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeRe self._setup_functions() return self._execute_code_dont_check_setup(code_blocks) - def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult: + def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: logs_all = "" file_names = [] for code_block in code_blocks: diff --git a/autogen/coding/markdown_code_extractor.py b/autogen/coding/markdown_code_extractor.py index 01dda0df52..8342ea2f92 100644 --- a/autogen/coding/markdown_code_extractor.py +++ b/autogen/coding/markdown_code_extractor.py @@ -18,8 +18,8 @@ class MarkdownCodeExtractor(CodeExtractor): """(Experimental) A class that extracts code blocks from a message using Markdown syntax.""" def extract_code_blocks( - self, message: Union[str, List[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None] - ) -> List[CodeBlock]: + self, message: Union[str, list[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None] + ) -> list[CodeBlock]: """(Experimental) Extract code blocks from a message. If no code blocks are found, return an empty list. diff --git a/autogen/formatting_utils.py b/autogen/formatting_utils.py index fefc7f05ad..9c6ccb1676 100644 --- a/autogen/formatting_utils.py +++ b/autogen/formatting_utils.py @@ -6,7 +6,8 @@ # SPDX-License-Identifier: MIT from __future__ import annotations -from typing import Iterable, Literal +from collections.abc import Iterable +from typing import Literal try: from termcolor import colored diff --git a/autogen/function_utils.py b/autogen/function_utils.py index f4a6531fe5..8553e3e8e2 100644 --- a/autogen/function_utils.py +++ b/autogen/function_utils.py @@ -8,10 +8,10 @@ import inspect import json from logging import getLogger -from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Annotated, Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union from pydantic import BaseModel, Field -from typing_extensions import Annotated, Literal, get_args, get_origin +from typing_extensions import Literal, get_args, get_origin from ._pydantic import JsonSchemaValue, evaluate_forwardref, model_dump, model_dump_json, type2schema @@ -20,7 +20,7 @@ T = TypeVar("T") -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: +def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: """Get the type annotation of a parameter. Args: @@ -79,7 +79,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) -def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: +def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Union[Annotated[type[Any], str], type[Any]]]: """Get the type annotations of the parameters of a function Args: @@ -97,8 +97,8 @@ class Parameters(BaseModel): """Parameters of a function as defined by the OpenAI API""" type: Literal["object"] = "object" - properties: Dict[str, JsonSchemaValue] - required: List[str] + properties: dict[str, JsonSchemaValue] + required: list[str] class Function(BaseModel): @@ -116,7 +116,7 @@ class ToolFunction(BaseModel): function: Annotated[Function, Field(description="Function under tool")] -def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> JsonSchemaValue: +def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue: """Get a JSON schema for a parameter as defined by the OpenAI API Args: @@ -128,7 +128,7 @@ def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> A Pydanitc model for the parameter """ - def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str: + def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str: # handles Annotated if hasattr(v, "__metadata__"): retval = v.__metadata__[0] @@ -149,7 +149,7 @@ def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> return schema -def get_required_params(typed_signature: inspect.Signature) -> List[str]: +def get_required_params(typed_signature: inspect.Signature) -> list[str]: """Get the required parameters of a function Args: @@ -161,7 +161,7 @@ def get_required_params(typed_signature: inspect.Signature) -> List[str]: return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] -def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]: +def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]: """Get default values of parameters of a function Args: @@ -174,9 +174,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]: def get_parameters( - required: List[str], - param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]], - default_values: Dict[str, Any], + required: list[str], + param_annotations: dict[str, Union[Annotated[type[Any], str], type[Any]]], + default_values: dict[str, Any], ) -> Parameters: """Get the parameters of a function as defined by the OpenAI API @@ -197,7 +197,7 @@ def get_parameters( ) -def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]: +def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]: """Get the missing annotations of a function Ignores the parameters with default values as they are not required to be annotated, but logs a warning. @@ -214,7 +214,7 @@ def get_missing_annotations(typed_signature: inspect.Signature, required: List[s return missing, unannotated_with_default -def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]: +def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> dict[str, Any]: """Get a JSON schema for a function as defined by the OpenAI API Args: @@ -289,7 +289,7 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet return model_dump(function) -def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]: +def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[dict[str, Any], type[BaseModel]], BaseModel]]: """Get a function to load a parameter if it is a Pydantic model Args: @@ -302,7 +302,7 @@ def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[Dict[str, An if get_origin(t) is Annotated: return get_load_param_if_needed_function(get_args(t)[0]) - def load_base_model(v: Dict[str, Any], t: Type[BaseModel]) -> BaseModel: + def load_base_model(v: dict[str, Any], t: type[BaseModel]) -> BaseModel: return t(**v) return load_base_model if isinstance(t, type) and issubclass(t, BaseModel) else None diff --git a/autogen/graph_utils.py b/autogen/graph_utils.py index 82e8bf4ae9..a495fb4131 100644 --- a/autogen/graph_utils.py +++ b/autogen/graph_utils.py @@ -10,7 +10,7 @@ from autogen.agentchat import Agent -def has_self_loops(allowed_speaker_transitions: Dict) -> bool: +def has_self_loops(allowed_speaker_transitions: dict) -> bool: """ Returns True if there are self loops in the allowed_speaker_transitions_Dict. """ @@ -18,8 +18,8 @@ def has_self_loops(allowed_speaker_transitions: Dict) -> bool: def check_graph_validity( - allowed_speaker_transitions_dict: Dict, - agents: List[Agent], + allowed_speaker_transitions_dict: dict, + agents: list[Agent], ): """ allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to. @@ -100,7 +100,7 @@ def check_graph_validity( ) -def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agents: List[Agent]) -> dict: +def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agents: list[Agent]) -> dict: """ Start with a fully connected allowed_speaker_transitions_dict of all agents. Remove edges from the fully connected allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict to form the allowed_speaker_transitions_dict. """ @@ -117,7 +117,7 @@ def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agen def visualize_speaker_transitions_dict( - speaker_transitions_dict: dict, agents: List[Agent], export_path: Optional[str] = None + speaker_transitions_dict: dict, agents: list[Agent], export_path: Optional[str] = None ): """ Visualize the speaker_transitions_dict using networkx. diff --git a/autogen/interop/interoperability.py b/autogen/interop/interoperability.py index b86285d6a6..067571de5c 100644 --- a/autogen/interop/interoperability.py +++ b/autogen/interop/interoperability.py @@ -40,7 +40,7 @@ def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool: return interop.convert_tool(tool, **kwargs) @classmethod - def get_interoperability_class(cls, type: str) -> Type[Interoperable]: + def get_interoperability_class(cls, type: str) -> type[Interoperable]: """ Retrieves the interoperability class corresponding to the specified type. @@ -63,7 +63,7 @@ def get_interoperability_class(cls, type: str) -> Type[Interoperable]: return cls.registry.get_class(type) @classmethod - def get_supported_types(cls) -> List[str]: + def get_supported_types(cls) -> list[str]: """ Returns a sorted list of all supported interoperability types. diff --git a/autogen/interop/pydantic_ai/pydantic_ai_tool.py b/autogen/interop/pydantic_ai/pydantic_ai_tool.py index 629f65e7ad..7ff50181ba 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai_tool.py +++ b/autogen/interop/pydantic_ai/pydantic_ai_tool.py @@ -25,7 +25,7 @@ class PydanticAITool(Tool): """ def __init__( - self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any] + self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: dict[str, Any] ) -> None: """ Initializes a PydanticAITool object with the provided name, description, diff --git a/autogen/interop/registry.py b/autogen/interop/registry.py index 443dcb5beb..cb10b701ac 100644 --- a/autogen/interop/registry.py +++ b/autogen/interop/registry.py @@ -8,12 +8,12 @@ __all__ = ["register_interoperable_class", "InteroperableRegistry"] -InteroperableClass = TypeVar("InteroperableClass", bound=Type[Interoperable]) +InteroperableClass = TypeVar("InteroperableClass", bound=type[Interoperable]) class InteroperableRegistry: def __init__(self) -> None: - self._registry: Dict[str, Type[Interoperable]] = {} + self._registry: dict[str, type[Interoperable]] = {} def register(self, short_name: str, cls: InteroperableClass) -> InteroperableClass: if short_name in self._registry: @@ -23,15 +23,15 @@ def register(self, short_name: str, cls: InteroperableClass) -> InteroperableCla return cls - def get_short_names(self) -> List[str]: + def get_short_names(self) -> list[str]: return sorted(self._registry.keys()) - def get_supported_types(self) -> List[str]: + def get_supported_types(self) -> list[str]: short_names = self.get_short_names() supported_types = [name for name in short_names if self._registry[name].get_unsupported_reason() is None] return supported_types - def get_class(self, short_name: str) -> Type[Interoperable]: + def get_class(self, short_name: str) -> type[Interoperable]: return self._registry[short_name] @classmethod diff --git a/autogen/io/base.py b/autogen/io/base.py index 5c79832751..39b857f416 100644 --- a/autogen/io/base.py +++ b/autogen/io/base.py @@ -5,9 +5,10 @@ # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT import logging +from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Iterator, Optional, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, runtime_checkable __all__ = ("OutputStream", "InputStream", "IOStream") diff --git a/autogen/io/websockets.py b/autogen/io/websockets.py index d30bcc69c5..2135727c8c 100644 --- a/autogen/io/websockets.py +++ b/autogen/io/websockets.py @@ -7,10 +7,11 @@ import logging import ssl import threading +from collections.abc import Iterable, Iterator from contextlib import contextmanager from functools import partial from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, Optional, Protocol, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, Union from .base import IOStream @@ -134,7 +135,7 @@ def run_server_in_thread( Yields: str: The URI of the websocket server. """ - server_dict: Dict[str, WebSocketServer] = {} + server_dict: dict[str, WebSocketServer] = {} def _run_server() -> None: if _import_error is not None: diff --git a/autogen/logger/base_logger.py b/autogen/logger/base_logger.py index 93a7c617eb..b01c112da7 100644 --- a/autogen/logger/base_logger.py +++ b/autogen/logger/base_logger.py @@ -18,8 +18,8 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper F = TypeVar("F", bound=Callable[..., Any]) -ConfigItem = Dict[str, Union[str, List[str]]] -LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]] +ConfigItem = dict[str, Union[str, list[str]]] +LLMConfig = dict[str, Union[None, float, int, ConfigItem, list[ConfigItem]]] class BaseLogger(ABC): @@ -39,9 +39,9 @@ def log_chat_completion( invocation_id: uuid.UUID, client_id: int, wrapper_id: int, - source: Union[str, Agent], - request: Dict[str, Union[float, str, List[Dict[str, str]]]], - response: Union[str, ChatCompletion], + source: str | Agent, + request: dict[str, float | str | list[dict[str, str]]], + response: str | ChatCompletion, is_cached: int, cost: float, start_time: str, @@ -67,7 +67,7 @@ def log_chat_completion( ... @abstractmethod - def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None: + def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any]) -> None: """ Log the birth of a new agent. @@ -78,7 +78,7 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N ... @abstractmethod - def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: + def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None: """ Log an event for an agent. @@ -90,7 +90,7 @@ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, An ... @abstractmethod - def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None: + def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None: """ Log the birth of a new OpenAIWrapper. @@ -101,9 +101,7 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM ... @abstractmethod - def log_new_client( - self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any] - ) -> None: + def log_new_client(self, client: AzureOpenAI | OpenAI, wrapper: OpenAIWrapper, init_args: dict[str, Any]) -> None: """ Log the birth of a new OpenAIWrapper. @@ -114,7 +112,7 @@ def log_new_client( ... @abstractmethod - def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None: + def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None: """ Log the use of a registered function (could be a tool) @@ -133,7 +131,7 @@ def stop(self) -> None: ... @abstractmethod - def get_connection(self) -> Union[None, sqlite3.Connection]: + def get_connection(self) -> None | sqlite3.Connection: """ Return a connection to the logging database. """ diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 625a96892c..249dd11015 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -51,7 +51,7 @@ def default(o: Any) -> str: class FileLogger(BaseLogger): - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: dict[str, Any]): self.config = config self.session_id = str(uuid.uuid4()) @@ -85,9 +85,9 @@ def log_chat_completion( invocation_id: uuid.UUID, client_id: int, wrapper_id: int, - source: Union[str, Agent], - request: Dict[str, Union[float, str, List[Dict[str, str]]]], - response: Union[str, ChatCompletion], + source: str | Agent, + request: dict[str, float | str | list[dict[str, str]]], + response: str | ChatCompletion, is_cached: int, cost: float, start_time: str, @@ -122,7 +122,7 @@ def log_chat_completion( except Exception as e: self.logger.error(f"[file_logger] Failed to log chat completion: {e}") - def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) -> None: + def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any] = {}) -> None: """ Log a new agent instance. """ @@ -147,7 +147,7 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) except Exception as e: self.logger.error(f"[file_logger] Failed to log new agent: {e}") - def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: + def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None: """ Log an event from an agent or a string source. """ @@ -191,9 +191,7 @@ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, An except Exception as e: self.logger.error(f"[file_logger] Failed to log event {e}") - def log_new_wrapper( - self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {} - ) -> None: + def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]] = {}) -> None: """ Log a new wrapper instance. """ @@ -229,7 +227,7 @@ def log_new_client( | BedrockClient ), wrapper: OpenAIWrapper, - init_args: Dict[str, Any], + init_args: dict[str, Any], ) -> None: """ Log a new client instance. @@ -252,7 +250,7 @@ def log_new_client( except Exception as e: self.logger.error(f"[file_logger] Failed to log event {e}") - def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None: + def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None: """ Log a registered function(can be a tool) use from an agent or a string source. """ diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py index f25cd1d6af..c3bab860a9 100644 --- a/autogen/logger/logger_factory.py +++ b/autogen/logger/logger_factory.py @@ -16,7 +16,7 @@ class LoggerFactory: @staticmethod def get_logger( - logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[Dict[str, Any]] = None + logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[dict[str, Any]] = None ) -> BaseLogger: if config is None: config = {} diff --git a/autogen/logger/logger_utils.py b/autogen/logger/logger_utils.py index 5c226d3d3a..f80f1eb426 100644 --- a/autogen/logger/logger_utils.py +++ b/autogen/logger/logger_utils.py @@ -17,9 +17,9 @@ def get_current_ts() -> str: def to_dict( - obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any], - exclude: Tuple[str, ...] = (), - no_recursive: Tuple[Any, ...] = (), + obj: Union[int, float, str, bool, dict[Any, Any], list[Any], tuple[Any, ...], Any], + exclude: tuple[str, ...] = (), + no_recursive: tuple[Any, ...] = (), ) -> Any: if isinstance(obj, (int, float, str, bool)): return obj diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 31510ccfec..24bd7447e3 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -55,7 +55,7 @@ def default(o: Any) -> str: class SqliteLogger(BaseLogger): schema_version = 1 - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: dict[str, Any]): self.config = config try: @@ -169,7 +169,7 @@ class TEXT, -- type or class name of cli finally: return self.session_id - def _get_current_db_version(self) -> Union[None, int]: + def _get_current_db_version(self) -> None | int: self.cur.execute("SELECT version_number FROM version ORDER BY id DESC LIMIT 1") result = self.cur.fetchone() return result[0] if result is not None else None @@ -188,7 +188,7 @@ def _apply_migration(self, migrations_dir: str = "./migrations") -> None: migrations_to_apply = [m for m in migrations if int(m.split("_")[0]) > current_version] for script in migrations_to_apply: - with open(script, "r") as f: + with open(script) as f: migration_sql = f.read() self._run_query_script(script=migration_sql) @@ -197,7 +197,7 @@ def _apply_migration(self, migrations_dir: str = "./migrations") -> None: args = (latest_version,) self._run_query(query=query, args=args) - def _run_query(self, query: str, args: Tuple[Any, ...] = ()) -> None: + def _run_query(self, query: str, args: tuple[Any, ...] = ()) -> None: """ Executes a given SQL query. @@ -231,9 +231,9 @@ def log_chat_completion( invocation_id: uuid.UUID, client_id: int, wrapper_id: int, - source: Union[str, Agent], - request: Dict[str, Union[float, str, List[Dict[str, str]]]], - response: Union[str, ChatCompletion], + source: str | Agent, + request: dict[str, float | str | list[dict[str, str]]], + response: str | ChatCompletion, is_cached: int, cost: float, start_time: str, @@ -275,7 +275,7 @@ def log_chat_completion( self._run_query(query=query, args=args) - def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> None: + def log_new_agent(self, agent: ConversableAgent, init_args: dict[str, Any]) -> None: from autogen import Agent if self.con is None: @@ -317,7 +317,7 @@ class = excluded.class, ) self._run_query(query=query, args=args) - def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: + def log_event(self, source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None: from autogen import Agent if self.con is None: @@ -352,7 +352,7 @@ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, An ) self._run_query(query=query, args=args_str_based) - def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None: + def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None: if self.con is None: return @@ -382,7 +382,7 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM ) self._run_query(query=query, args=args) - def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None: + def log_function_use(self, source: str | Agent, function: F, args: dict[str, Any], returns: Any) -> None: if self.con is None: return @@ -390,7 +390,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st query = """ INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?) """ - query_args: Tuple[Any, ...] = ( + query_args: tuple[Any, ...] = ( id(source), source.name if hasattr(source, "name") else source, function.__name__, @@ -402,21 +402,21 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[ - AzureOpenAI, - OpenAI, - CerebrasClient, - GeminiClient, - AnthropicClient, - MistralAIClient, - TogetherClient, - GroqClient, - CohereClient, - OllamaClient, - BedrockClient, - ], + client: ( + AzureOpenAI + | OpenAI + | CerebrasClient + | GeminiClient + | AnthropicClient + | MistralAIClient + | TogetherClient + | GroqClient + | CohereClient + | OllamaClient + | BedrockClient + ), wrapper: OpenAIWrapper, - init_args: Dict[str, Any], + init_args: dict[str, Any], ) -> None: if self.con is None: return @@ -453,7 +453,7 @@ def stop(self) -> None: if self.con: self.con.close() - def get_connection(self) -> Union[None, sqlite3.Connection]: + def get_connection(self) -> None | sqlite3.Connection: if self.con: return self.con return None diff --git a/autogen/math_utils.py b/autogen/math_utils.py index 0ef12d8f2d..069747f7c7 100644 --- a/autogen/math_utils.py +++ b/autogen/math_utils.py @@ -138,7 +138,7 @@ def _fix_a_slash_b(string: str) -> str: try: a = int(a_str) b = int(b_str) - if not string == "{}/{}".format(a, b): + if not string == f"{a}/{b}": raise AssertionError new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index 44dc7bd60d..a3a67baf85 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -56,7 +56,7 @@ import os import time import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Optional, Tuple, Union from anthropic import Anthropic, AnthropicBedrock from anthropic import __version__ as anthropic_version @@ -65,7 +65,6 @@ from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel -from typing_extensions import Annotated from autogen.oai.client_utils import validate_parameter @@ -134,7 +133,7 @@ def __init__(self, **kwargs: Any): self._last_tooluse_status = {} - def load_config(self, params: Dict[str, Any]): + def load_config(self, params: dict[str, Any]): """Load the configuration for the Anthropic API client.""" anthropic_params = {} @@ -183,7 +182,7 @@ def aws_session_token(self): def aws_region(self): return self._aws_region - def create(self, params: Dict[str, Any]) -> ChatCompletion: + def create(self, params: dict[str, Any]) -> ChatCompletion: if "tools" in params: converted_functions = self.convert_tools_to_functions(params["tools"]) params["functions"] = params.get("functions", []) + converted_functions @@ -270,7 +269,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: return response_oai - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -286,7 +285,7 @@ def openai_func_to_anthropic(openai_func: dict) -> dict: return res @staticmethod - def get_usage(response: ChatCompletion) -> Dict: + def get_usage(response: ChatCompletion) -> dict: """Get the usage of tokens and their cost information.""" return { "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, @@ -297,7 +296,7 @@ def get_usage(response: ChatCompletion) -> Dict: } @staticmethod - def convert_tools_to_functions(tools: List) -> List: + def convert_tools_to_functions(tools: list) -> list: functions = [] for tool in tools: if tool.get("type") == "function" and "function" in tool: @@ -306,7 +305,7 @@ def convert_tools_to_functions(tools: List) -> List: return functions -def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str, Any]]: +def oai_messages_to_anthropic_messages(params: dict[str, Any]) -> list[dict[str, Any]]: """Convert messages from OAI format to Anthropic format. We correct for any specific role orders and types, etc. """ diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py index 5d8f34fe51..b624cc9125 100644 --- a/autogen/oai/bedrock.py +++ b/autogen/oai/bedrock.py @@ -120,7 +120,7 @@ def message_retrieval(self, response): """Retrieve the messages from the response.""" return [choice.message for choice in response.choices] - def parse_custom_params(self, params: Dict[str, Any]): + def parse_custom_params(self, params: dict[str, Any]): """ Parses custom parameters for logic in this client class """ @@ -129,7 +129,7 @@ def parse_custom_params(self, params: Dict[str, Any]): # This is required because not all models support a system prompt (e.g. Mistral Instruct). self._supports_system_prompts = params.get("supports_system_prompts", True) - def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]: + def parse_params(self, params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: """ Loads the valid parameters required to invoke Bedrock Converse Returns a tuple of (base_params, additional_params) @@ -273,7 +273,7 @@ def cost(self, response: ChatCompletion) -> float: return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model) @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Get the usage of tokens and their cost information.""" return { "prompt_tokens": response.usage.prompt_tokens, @@ -284,7 +284,7 @@ def get_usage(response) -> Dict: } -def extract_system_messages(messages: List[dict]) -> List: +def extract_system_messages(messages: list[dict]) -> list: """Extract the system messages from the list of messages. Args: @@ -309,8 +309,8 @@ def extract_system_messages(messages: List[dict]) -> List: def oai_messages_to_bedrock_messages( - messages: List[Dict[str, Any]], has_tools: bool, supports_system_prompts: bool -) -> List[Dict]: + messages: list[dict[str, Any]], has_tools: bool, supports_system_prompts: bool +) -> list[dict]: """ Convert messages from OAI format to Bedrock format. We correct for any specific role orders and types, etc. @@ -453,9 +453,9 @@ def oai_messages_to_bedrock_messages( def parse_content_parts( - message: Dict[str, Any], -) -> List[dict]: - content: str | List[Dict[str, Any]] = message.get("content") + message: dict[str, Any], +) -> list[dict]: + content: str | list[dict[str, Any]] = message.get("content") if isinstance(content, str): return [ { @@ -487,7 +487,7 @@ def parse_content_parts( return content_parts -def parse_image(image_url: str) -> Tuple[bytes, str]: +def parse_image(image_url: str) -> tuple[bytes, str]: """Try to get the raw data from an image url. Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html @@ -516,7 +516,7 @@ def parse_image(image_url: str) -> Tuple[bytes, str]: raise RuntimeError("Unable to access the image url") -def format_tools(tools: List[Dict[str, Any]]) -> Dict[Literal["tools"], List[Dict[str, Any]]]: +def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dict[str, Any]]]: converted_schema = {"tools": []} for tool in tools: diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index 201fb2ee55..cce38f1ca2 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -67,7 +67,7 @@ def __init__(self, api_key=None, **kwargs): if "response_format" in kwargs and kwargs["response_format"] is not None: warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning) - def message_retrieval(self, response: ChatCompletion) -> List: + def message_retrieval(self, response: ChatCompletion) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -81,7 +81,7 @@ def cost(self, response: ChatCompletion) -> float: return response.cost @staticmethod - def get_usage(response: ChatCompletion) -> Dict: + def get_usage(response: ChatCompletion) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -92,7 +92,7 @@ def get_usage(response: ChatCompletion) -> Dict: "model": response.model, } - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" cerebras_params = {} @@ -115,7 +115,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cerebras_params - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Cerebras messages @@ -243,7 +243,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai -def oai_messages_to_cerebras_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: +def oai_messages_to_cerebras_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Cerebras's format. We correct for any specific role orders and types. """ diff --git a/autogen/oai/client.py b/autogen/oai/client.py index de83c7c1b0..5481741b34 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -26,7 +26,7 @@ try: import openai except ImportError: - ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") + ERROR: ImportError | None = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") OpenAI = object AzureOpenAI = object else: @@ -59,7 +59,7 @@ from autogen.oai.cerebras import CerebrasClient - cerebras_import_exception: Optional[ImportError] = None + cerebras_import_exception: ImportError | None = None except ImportError as e: cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception cerebras_import_exception = e @@ -72,7 +72,7 @@ from autogen.oai.gemini import GeminiClient - gemini_import_exception: Optional[ImportError] = None + gemini_import_exception: ImportError | None = None except ImportError as e: gemini_InternalServerError = gemini_ResourceExhausted = Exception gemini_import_exception = e @@ -85,7 +85,7 @@ from autogen.oai.anthropic import AnthropicClient - anthropic_import_exception: Optional[ImportError] = None + anthropic_import_exception: ImportError | None = None except ImportError as e: anthorpic_InternalServerError = anthorpic_RateLimitError = Exception anthropic_import_exception = e @@ -98,7 +98,7 @@ from autogen.oai.mistral import MistralAIClient - mistral_import_exception: Optional[ImportError] = None + mistral_import_exception: ImportError | None = None except ImportError as e: mistral_SDKError = mistral_HTTPValidationError = Exception mistral_import_exception = e @@ -108,7 +108,7 @@ from autogen.oai.together import TogetherClient - together_import_exception: Optional[ImportError] = None + together_import_exception: ImportError | None = None except ImportError as e: together_TogetherException = Exception together_import_exception = e @@ -122,7 +122,7 @@ from autogen.oai.groq import GroqClient - groq_import_exception: Optional[ImportError] = None + groq_import_exception: ImportError | None = None except ImportError as e: groq_InternalServerError = groq_RateLimitError = groq_APIConnectionError = Exception groq_import_exception = e @@ -136,7 +136,7 @@ from autogen.oai.cohere import CohereClient - cohere_import_exception: Optional[ImportError] = None + cohere_import_exception: ImportError | None = None except ImportError as e: cohere_InternalServerError = cohere_TooManyRequestsError = cohere_ServiceUnavailableError = Exception cohere_import_exception = e @@ -149,7 +149,7 @@ from autogen.oai.ollama import OllamaClient - ollama_import_exception: Optional[ImportError] = None + ollama_import_exception: ImportError | None = None except ImportError as e: ollama_RequestError = ollama_ResponseError = Exception ollama_import_exception = e @@ -162,7 +162,7 @@ from autogen.oai.bedrock import BedrockClient - bedrock_import_exception: Optional[ImportError] = None + bedrock_import_exception: ImportError | None = None except ImportError as e: bedrock_BotoCoreError = bedrock_ClientError = Exception bedrock_import_exception = e @@ -201,18 +201,18 @@ class ModelClient(Protocol): class ModelClientResponseProtocol(Protocol): class Choice(Protocol): class Message(Protocol): - content: Optional[str] + content: str | None message: Message - choices: List[Choice] + choices: list[Choice] model: str - def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover + def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol - ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: + ) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -224,7 +224,7 @@ def message_retrieval( def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover @staticmethod - def get_usage(response: ModelClientResponseProtocol) -> Dict: + def get_usage(response: ModelClientResponseProtocol) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... # pragma: no cover @@ -237,7 +237,7 @@ def __init__(self, config): class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" - def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None): + def __init__(self, client: OpenAI | AzureOpenAI, response_format: BaseModel | None = None): self._oai_client = client self.response_format = response_format if ( @@ -249,9 +249,7 @@ def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional "The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model." ) - def message_retrieval( - self, response: Union[ChatCompletion, Completion] - ) -> Union[List[str], List[ChatCompletionMessage]]: + def message_retrieval(self, response: ChatCompletion | Completion) -> list[str] | list[ChatCompletionMessage]: """Retrieve the messages from the response.""" choices = response.choices if isinstance(response, Completion): @@ -279,7 +277,7 @@ def _format_content(content: str) -> str: for choice in choices ] - def create(self, params: Dict[str, Any]) -> ChatCompletion: + def create(self, params: dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. Args: @@ -314,8 +312,8 @@ def _create_or_parse(*args, **kwargs): iostream.print("\033[32m", end="") # Prepare for potential function call - full_function_call: Optional[Dict[str, Any]] = None - full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None + full_function_call: dict[str, Any] | None = None + full_tool_calls: list[dict[str, Any] | None] | None = None # Send the chat completion request to OpenAI's API and process the response in chunks for chunk in create_or_parse(**params): @@ -424,7 +422,7 @@ def _create_or_parse(*args, **kwargs): return response - def cost(self, response: Union[ChatCompletion, Completion]) -> float: + def cost(self, response: ChatCompletion | Completion) -> float: """Calculate the cost of the response.""" model = response.model if model not in OAI_PRICE1K: @@ -445,7 +443,7 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] @staticmethod - def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: + def get_usage(response: ChatCompletion | Completion) -> dict: return { "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, @@ -479,13 +477,13 @@ class OpenAIWrapper: openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs) openai_kwargs = openai_kwargs | aopenai_kwargs - total_usage_summary: Optional[Dict[str, Any]] = None - actual_usage_summary: Optional[Dict[str, Any]] = None + total_usage_summary: dict[str, Any] | None = None + actual_usage_summary: dict[str, Any] | None = None def __init__( self, *, - config_list: Optional[List[Dict[str, Any]]] = None, + config_list: list[dict[str, Any]] | None = None, **base_config: Any, ): """ @@ -526,8 +524,8 @@ def __init__( # It's OK if "model" is not provided in base_config or config_list # Because one can provide "model" at `create` time. - self._clients: List[ModelClient] = [] - self._config_list: List[Dict[str, Any]] = [] + self._clients: list[ModelClient] = [] + self._config_list: list[dict[str, Any]] = [] if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying @@ -541,19 +539,19 @@ def __init__( self._config_list = [extra_kwargs] self.wrapper_id = id(self) - def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _separate_openai_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: """Separate the config into openai_config and extra_kwargs.""" openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} return openai_config, extra_kwargs - def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _separate_create_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: """Separate the config into create_config and extra_kwargs.""" create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs} extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) if openai_config["azure_deployment"] is not None: openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") @@ -567,7 +565,7 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) - def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + def _configure_openai_config_for_bedrock(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: """Update openai_config with AWS credentials from config.""" required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] optional_keys = ["aws_session_token", "aws_profile_name"] @@ -578,7 +576,7 @@ def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_co if key in config: openai_config[key] = config[key] - def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -690,21 +688,21 @@ def register_model_client(self, model_client_cls: ModelClient, **kwargs): @classmethod def instantiate( cls, - template: Optional[Union[str, Callable[[Dict[str, Any]], str]]], - context: Optional[Dict[str, Any]] = None, - allow_format_str_template: Optional[bool] = False, - ) -> Optional[str]: + template: str | Callable[[dict[str, Any]], str] | None, + context: dict[str, Any] | None = None, + allow_format_str_template: bool | None = False, + ) -> str | None: if not context or template is None: return template # type: ignore [return-value] if isinstance(template, str): return template.format(**context) if allow_format_str_template else template return template(context) - def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]: + def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]: """Prime the create_config with additional_kwargs.""" # Validate the config - prompt: Optional[str] = create_config.get("prompt") - messages: Optional[List[Dict[str, Any]]] = create_config.get("messages") + prompt: str | None = create_config.get("prompt") + messages: list[dict[str, Any]] | None = create_config.get("messages") if (prompt is None) == (messages is None): raise ValueError("Either prompt or messages should be in create config but not both.") context = extra_kwargs.get("context") @@ -961,7 +959,7 @@ def yes_or_no_filter(context, response): @staticmethod def _cost_with_customized_price( - response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float] + response: ModelClient.ModelClientResponseProtocol, price_1k: tuple[float, float] ) -> None: """If a customized cost is passed, overwrite the cost in the response.""" n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] @@ -971,7 +969,7 @@ def _cost_with_customized_price( return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000 @staticmethod - def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int: + def _update_dict_from_chunk(chunk: BaseModel, d: dict[str, Any], field: str) -> int: """Update the dict from the chunk. Reads `chunk.field` and if present updates `d[field]` accordingly. @@ -1006,10 +1004,10 @@ def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> @staticmethod def _update_function_call_from_chunk( - function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall], - full_function_call: Optional[Dict[str, Any]], + function_call_chunk: ChoiceDeltaToolCallFunction | ChoiceDeltaFunctionCall, + full_function_call: dict[str, Any] | None, completion_tokens: int, - ) -> Tuple[Dict[str, Any], int]: + ) -> tuple[dict[str, Any], int]: """Update the function call from the chunk. Args: @@ -1038,9 +1036,9 @@ def _update_function_call_from_chunk( @staticmethod def _update_tool_calls_from_chunk( tool_calls_chunk: ChoiceDeltaToolCall, - full_tool_call: Optional[Dict[str, Any]], + full_tool_call: dict[str, Any] | None, completion_tokens: int, - ) -> Tuple[Dict[str, Any], int]: + ) -> tuple[dict[str, Any], int]: """Update the tool call from the chunk. Args: @@ -1113,11 +1111,11 @@ def update_usage(usage_summary, response_usage): if actual_usage is not None: self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage) - def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: + def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None: """Print the usage summary.""" iostream = IOStream.get_default() - def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None: + def print_usage(usage_summary: dict[str, Any] | None, usage_type: str = "total") -> None: word_from_type = "including" if usage_type == "total" else "excluding" if usage_summary is None: iostream.print("No actual cost incurred (all completions are using cache).", flush=True) @@ -1174,7 +1172,7 @@ def clear_usage_summary(self) -> None: @classmethod def extract_text_or_completion_object( cls, response: ModelClient.ModelClientResponseProtocol - ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: + ) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]: """Extract the text or ChatCompletion objects from a completion or chat response. Args: diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py index fac01286cc..6f417c90ba 100644 --- a/autogen/oai/client_utils.py +++ b/autogen/oai/client_utils.py @@ -12,12 +12,12 @@ def validate_parameter( - params: Dict[str, Any], + params: dict[str, Any], param_name: str, - allowed_types: Tuple, + allowed_types: tuple, allow_None: bool, default_value: Any, - numerical_bound: Tuple, + numerical_bound: tuple, allowed_values: list, ) -> Any: """ @@ -106,7 +106,7 @@ def validate_parameter( return param_value -def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], hide_tools_param: str) -> bool: +def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool: """ Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't. Parameters: diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index b7d411454d..e725d8e156 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -83,7 +83,7 @@ def __init__(self, **kwargs): if "response_format" in kwargs and kwargs["response_format"] is not None: warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning) - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -96,7 +96,7 @@ def cost(self, response) -> float: return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -107,7 +107,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" cohere_params = {} @@ -151,7 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cohere_params - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) client_name = params.get("client_name") or "autogen-cohere" # Parse parameters to the Cohere API's parameters @@ -263,7 +263,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai -def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]: +def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> list[dict[str, Any]]: temp_tool_results = [] for tool_call in all_tool_calls: @@ -281,7 +281,7 @@ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_t def oai_messages_to_cohere_messages( - messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any] + messages: list[dict[str, Any]], params: dict[str, Any], cohere_params: dict[str, Any] ) -> tuple[list[dict[str, Any]], str, str]: """Convert messages from OAI format to Cohere's format. We correct for any specific role orders and types. diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py index 72886dd857..300abfda53 100644 --- a/autogen/oai/completion.py +++ b/autogen/oai/completion.py @@ -172,7 +172,7 @@ def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] cache.clear() @classmethod - def _book_keeping(cls, config: Dict, response): + def _book_keeping(cls, config: dict, response): """Book keeping for the created completions.""" if response != -1 and "cost" not in response: response["cost"] = cls.cost(response) @@ -212,7 +212,7 @@ def _book_keeping(cls, config: Dict, response): cls._count_create += 1 @classmethod - def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_cache=True): + def _get_response(cls, config: dict, raise_on_ratelimit_or_timeout=False, use_cache=True): """Get the response from the openai api call. Try cache first. If not found, call the openai api. If the api call fails, retry after retry_wait_time. @@ -335,7 +335,7 @@ def _pop_subspace(cls, config, always_copy=True): return config.copy() if always_copy else config @classmethod - def _get_params_for_create(cls, config: Dict) -> Dict: + def _get_params_for_create(cls, config: dict) -> dict: """Get the params for the openai api call from a config in the search space.""" params = cls._pop_subspace(config) if cls._prompts: @@ -526,7 +526,7 @@ def _eval(cls, config: dict, prune=True, eval_only=False): @classmethod def tune( cls, - data: List[Dict], + data: list[dict], metric: str, mode: str, eval_func: Callable, @@ -726,10 +726,10 @@ def eval_func(responses, **data): @classmethod def create( cls, - context: Optional[Dict] = None, + context: Optional[dict] = None, use_cache: Optional[bool] = True, - config_list: Optional[List[Dict]] = None, - filter_func: Optional[Callable[[Dict, Dict], bool]] = None, + config_list: Optional[list[dict]] = None, + filter_func: Optional[Callable[[dict, dict], bool]] = None, raise_on_ratelimit_or_timeout: Optional[bool] = True, allow_format_str_template: Optional[bool] = False, **config, @@ -861,7 +861,7 @@ def yes_or_no_filter(context, config, response): def instantiate( cls, template: Union[str, None], - context: Optional[Dict] = None, + context: Optional[dict] = None, allow_format_str_template: Optional[bool] = False, ): if not context or template is None: @@ -1069,7 +1069,7 @@ def cost(cls, response: dict): return price1K * (n_input_tokens + n_output_tokens) / 1000 @classmethod - def extract_text(cls, response: dict) -> List[str]: + def extract_text(cls, response: dict) -> list[str]: """Extract the text from a completion or chat response. Args: @@ -1084,7 +1084,7 @@ def extract_text(cls, response: dict) -> List[str]: return [choice["message"].get("content", "") for choice in choices] @classmethod - def extract_text_or_function_call(cls, response: dict) -> List[str]: + def extract_text_or_function_call(cls, response: dict) -> list[str]: """Extract the text or function calls from a completion or chat response. Args: @@ -1103,12 +1103,12 @@ def extract_text_or_function_call(cls, response: dict) -> List[str]: @classmethod @property - def logged_history(cls) -> Dict: + def logged_history(cls) -> dict: """Return the book keeping dictionary.""" return cls._history_dict @classmethod - def print_usage_summary(cls) -> Dict: + def print_usage_summary(cls) -> dict: """Return the usage summary.""" if cls._history_dict is None: print("No usage summary available.", flush=True) @@ -1147,7 +1147,7 @@ def print_usage_summary(cls) -> Dict: @classmethod def start_logging( - cls, history_dict: Optional[Dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True + cls, history_dict: Optional[dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True ): """Start book keeping. diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index f89e40cf84..02fa5df54f 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -46,8 +46,9 @@ import re import time import warnings +from collections.abc import Mapping from io import BytesIO -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import google.generativeai as genai import PIL @@ -151,7 +152,7 @@ def __init__(self, **kwargs): if "response_format" in kwargs and kwargs["response_format"] is not None: warnings.warn("response_format is not supported for Gemini. It will be ignored.", UserWarning) - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -164,7 +165,7 @@ def cost(self, response) -> float: return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -175,7 +176,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: if self.use_vertexai: self._initialize_vertexai(**params) @@ -230,7 +231,7 @@ def create(self, params: Dict) -> ChatCompletion: autogen_tool_calls = [] # Maps the function call ids to function names so we can inject it into FunctionResponse messages - self.tool_call_function_map: Dict[str, str] = {} + self.tool_call_function_map: dict[str, str] = {} # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) @@ -325,7 +326,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai - def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> Tuple[List, str]: + def _oai_content_to_gemini_content(self, message: dict[str, Any]) -> tuple[list, str]: """Convert AutoGen content to Gemini parts, catering for text and tool calls""" rst = [] @@ -420,7 +421,7 @@ def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> Tuple[List, else: raise Exception("Unable to convert content to Gemini format.") - def _concat_parts(self, parts: List[Part]) -> List: + def _concat_parts(self, parts: list[Part]) -> list: """Concatenate parts with the same type. If two adjacent parts both have the "text" attribute, then it will be joined into one part. """ @@ -449,7 +450,7 @@ def _concat_parts(self, parts: List[Part]) -> List: return concatenated_parts - def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + def _oai_messages_to_gemini_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Gemini format. Make sure the "user" role and "model" role are interleaved. Also, make sure the last item is from the "user" role. @@ -522,7 +523,7 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li return rst - def _tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]: + def _tools_to_gemini_tools(self, tools: list[dict[str, Any]]) -> list[Tool]: """Create Gemini tools (as typically requires Callables)""" functions = [] @@ -543,7 +544,7 @@ def _tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]: return [Tool(function_declarations=functions)] @staticmethod - def _create_gemini_function_declaration(tool: Dict) -> FunctionDeclaration: + def _create_gemini_function_declaration(tool: dict) -> FunctionDeclaration: function_declaration = FunctionDeclaration() function_declaration.name = tool["function"]["name"] function_declaration.description = tool["function"]["description"] @@ -657,7 +658,7 @@ def _to_vertexai_safety_settings(safety_settings): return safety_settings @staticmethod - def _to_json_or_str(data: str) -> Union[Dict, str]: + def _to_json_or_str(data: str) -> dict | str: try: json_data = json.loads(data) return json_data diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index e3112619fa..65cc29ec97 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -70,7 +70,7 @@ def __init__(self, **kwargs): warnings.warn("response_format is not supported for Groq API, it will be ignored.", UserWarning) self.base_url = kwargs.get("base_url", None) - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -83,7 +83,7 @@ def cost(self, response) -> float: return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -94,7 +94,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" groq_params = {} @@ -130,7 +130,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return groq_params - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Groq messages @@ -255,7 +255,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai -def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: +def oai_messages_to_groq_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Groq's format. We correct for any specific role orders and types. """ diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py index 022210c9aa..4904583b56 100644 --- a/autogen/oai/mistral.py +++ b/autogen/oai/mistral.py @@ -76,7 +76,7 @@ def __init__(self, **kwargs): self._client = Mistral(api_key=self.api_key) - def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]: + def message_retrieval(self, response: ChatCompletion) -> Union[list[str], list[ChatCompletionMessage]]: """Retrieve the messages from the response.""" return [choice.message for choice in response.choices] @@ -84,7 +84,7 @@ def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[C def cost(self, response) -> float: return response.cost - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" mistral_params = {} @@ -173,7 +173,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return mistral_params - def create(self, params: Dict[str, Any]) -> ChatCompletion: + def create(self, params: dict[str, Any]) -> ChatCompletion: # 1. Parse parameters to Mistral.AI API's parameters mistral_params = self.parse_params(params) @@ -224,7 +224,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: return response_oai @staticmethod - def get_usage(response: ChatCompletion) -> Dict: + def get_usage(response: ChatCompletion) -> dict: return { "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, @@ -236,7 +236,7 @@ def get_usage(response: ChatCompletion) -> Dict: } -def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def tool_def_to_mistral(tool_definitions: list[dict[str, Any]]) -> list[dict[str, Any]]: """Converts AutoGen tool definition to a mistral tool format""" mistral_tools = [] diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 1b7c3ced79..6c431619c6 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -88,7 +88,7 @@ def __init__(self, **kwargs): if "response_format" in kwargs and kwargs["response_format"] is not None: warnings.warn("response_format is not supported for Ollama, it will be ignored.", UserWarning) - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -101,7 +101,7 @@ def cost(self, response) -> float: return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -112,7 +112,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Ollama API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" ollama_params = {} @@ -180,7 +180,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return ollama_params - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) # Are tools involved in this conversation? @@ -289,7 +289,7 @@ def create(self, params: Dict) -> ChatCompletion: for tool_call in response["message"]["tool_calls"]: tool_calls.append( ChatCompletionMessageToolCall( - id="ollama_func_{}".format(random_id), + id=f"ollama_func_{random_id}", function={ "name": tool_call["function"]["name"], "arguments": json.dumps(tool_call["function"]["arguments"]), @@ -314,7 +314,7 @@ def create(self, params: Dict) -> ChatCompletion: for json_function in response_toolcalls: tool_calls.append( ChatCompletionMessageToolCall( - id="ollama_manual_func_{}".format(random_id), + id=f"ollama_manual_func_{random_id}", function={ "name": json_function["name"], "arguments": ( @@ -360,7 +360,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai - def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: list) -> list[dict[str, Any]]: + def oai_messages_to_ollama_messages(self, messages: list[dict[str, Any]], tools: list) -> list[dict[str, Any]]: """Convert messages from OAI format to Ollama's format. We correct for any specific role orders and types, and convert tools to messages (as Ollama can't use tool messages) """ @@ -526,7 +526,7 @@ def response_to_tool_call(response_string: str) -> Any: return None -def _object_to_tool_call(data_object: Any) -> List[Dict]: +def _object_to_tool_call(data_object: Any) -> list[dict]: """Attempts to convert an object to a valid tool call object List[Dict] and returns it, if it can, otherwise None""" # If it's a dictionary and not a list then wrap in a list diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index e26096199f..77da5a6279 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -84,7 +84,7 @@ } -def get_key(config: Dict[str, Any]) -> str: +def get_key(config: dict[str, Any]) -> str: """Get a unique identifier of a configuration. Args: @@ -122,11 +122,11 @@ def is_valid_api_key(api_key: str) -> bool: def get_config_list( - api_keys: List[str], - base_urls: Optional[List[str]] = None, + api_keys: list[str], + base_urls: Optional[list[str]] = None, api_type: Optional[str] = None, api_version: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Get a list of configs for OpenAI API client. Args: @@ -179,7 +179,7 @@ def config_list_openai_aoai( openai_api_base_file: Optional[str] = "base_openai.txt", aoai_api_base_file: Optional[str] = "base_aoai.txt", exclude: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API). This function constructs configurations by reading API keys and base URLs from environment variables or text files. @@ -307,8 +307,8 @@ def config_list_from_models( aoai_api_key_file: Optional[str] = "key_aoai.txt", aoai_api_base_file: Optional[str] = "base_aoai.txt", exclude: Optional[str] = None, - model_list: Optional[List[str]] = None, -) -> List[Dict[str, Any]]: + model_list: Optional[list[str]] = None, +) -> list[dict[str, Any]]: """ Get a list of configs for API calls with models specified in the model list. @@ -374,7 +374,7 @@ def config_list_gpt4_gpt35( aoai_api_key_file: Optional[str] = "key_aoai.txt", aoai_api_base_file: Optional[str] = "base_aoai.txt", exclude: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls. Args: @@ -398,10 +398,10 @@ def config_list_gpt4_gpt35( def filter_config( - config_list: List[Dict[str, Any]], - filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]], + config_list: list[dict[str, Any]], + filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]], exclude: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """This function filters `config_list` by checking each configuration dictionary against the criteria specified in `filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below. @@ -479,8 +479,8 @@ def _satisfies_criteria(value: Any, criteria_values: Any) -> bool: def config_list_from_json( env_or_file: str, file_location: Optional[str] = "", - filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None, -) -> List[Dict[str, Any]]: + filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None, +) -> list[dict[str, Any]]: """ Retrieves a list of API configurations from a JSON stored in an environment variable or a file. @@ -523,7 +523,7 @@ def config_list_from_json( # The environment variable exists. We should use information from it. if os.path.exists(env_str): # It is a file location, and we need to load the json from the file. - with open(env_str, "r") as file: + with open(env_str) as file: json_str = file.read() else: # Else, it should be a JSON string by itself. @@ -547,7 +547,7 @@ def get_config( base_url: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Constructs a configuration dictionary for a single model with the provided API configurations. @@ -587,9 +587,9 @@ def get_config( def config_list_from_dotenv( dotenv_file_path: Optional[str] = None, - model_api_key_map: Optional[Dict[str, Any]] = None, - filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None, -) -> List[Dict[str, Union[str, Set[str]]]]: + model_api_key_map: Optional[dict[str, Any]] = None, + filter_dict: Optional[dict[str, Union[list[Union[str, None]], set[Union[str, None]]]]] = None, +) -> list[dict[str, Union[str, set[str]]]]: """ Load API configurations from a specified .env file or environment variables and construct a list of configurations. @@ -688,7 +688,7 @@ def config_list_from_dotenv( return config_list -def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]: +def retrieve_assistants_by_name(client: OpenAI, name: str) -> list[Assistant]: """ Return the assistants with the given name from OAI assistant API """ @@ -709,7 +709,7 @@ def detect_gpt_assistant_api_version() -> str: return "v2" -def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any: +def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: list[str]) -> Any: """Create a openai vector store for gpt assistant""" try: @@ -732,7 +732,7 @@ def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> A def create_gpt_assistant( - client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any] + client: OpenAI, name: str, instructions: str, model: str, assistant_config: dict[str, Any] ) -> Assistant: """Create a openai gpt assistant""" @@ -782,7 +782,7 @@ def create_gpt_assistant( return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs) -def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant: +def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: dict[str, Any]) -> Assistant: """Update openai gpt assistant""" gpt_assistant_api_version = detect_gpt_assistant_api_version() diff --git a/autogen/oai/together.py b/autogen/oai/together.py index a823155dd7..b98d64d310 100644 --- a/autogen/oai/together.py +++ b/autogen/oai/together.py @@ -32,8 +32,9 @@ import re import time import warnings +from collections.abc import Mapping from io import BytesIO -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import requests from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall @@ -67,7 +68,7 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable." - def message_retrieval(self, response) -> List: + def message_retrieval(self, response) -> list: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -80,7 +81,7 @@ def cost(self, response) -> float: return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -91,7 +92,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def parse_params(self, params: dict[str, Any]) -> dict[str, Any]: """Loads the parameters for Together.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" together_params = {} @@ -133,7 +134,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return together_params - def create(self, params: Dict) -> ChatCompletion: + def create(self, params: dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Together.AI messages @@ -218,7 +219,7 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai -def oai_messages_to_together_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: +def oai_messages_to_together_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert messages from OAI format to Together.AI format. We correct for any specific role orders and types. """ diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 774e5e57d3..b6daed232c 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -164,7 +164,7 @@ def split_files_to_chunks( chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, custom_text_split_function: Callable = None, -) -> Tuple[List[str], List[dict]]: +) -> tuple[list[str], list[dict]]: """Split a list of files into chunks of max_tokens.""" chunks = [] @@ -185,7 +185,7 @@ def split_files_to_chunks( elif file_extension == ".pdf": text = extract_text_from_pdf(file) else: # For non-PDF text-based files - with open(file, "r", encoding="utf-8", errors="ignore") as f: + with open(file, encoding="utf-8", errors="ignore") as f: text = f.read() if not text.strip(): # Debugging line to check if text is empty after reading @@ -202,7 +202,7 @@ def split_files_to_chunks( return chunks, sources -def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True): +def get_files_from_dir(dir_path: Union[str, list[str]], types: list = TEXT_FORMATS, recursive: bool = True): """Return a list of all the files in a given directory, a url, a file path or a list of them.""" if len(types) == 0: raise ValueError("types cannot be empty.") @@ -292,7 +292,7 @@ def _generate_file_name_from_url(url: str, max_length=255) -> str: return file_name -def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]: +def get_file_from_url(url: str, save_path: str = None) -> tuple[str, str]: """Download a file from a URL.""" if save_path is None: save_path = "tmp/chromadb" @@ -339,7 +339,7 @@ def is_url(string: str): def create_vector_db_from_dir( - dir_path: Union[str, List[str]], + dir_path: Union[str, list[str]], max_tokens: int = 4000, client: API = None, db_path: str = "tmp/chromadb.db", @@ -350,7 +350,7 @@ def create_vector_db_from_dir( embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, custom_text_split_function: Callable = None, - custom_text_types: List[str] = TEXT_FORMATS, + custom_text_types: list[str] = TEXT_FORMATS, recursive: bool = True, extra_docs: bool = False, ) -> API: @@ -432,7 +432,7 @@ def create_vector_db_from_dir( def query_vector_db( - query_texts: List[str], + query_texts: list[str], n_results: int = 10, client: API = None, db_path: str = "tmp/chromadb.db", diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index a4430a4f91..02abf2e80c 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -38,9 +38,9 @@ def start( - logger: Optional[BaseLogger] = None, + logger: BaseLogger | None = None, logger_type: Literal["sqlite", "file"] = "sqlite", - config: Optional[Dict[str, Any]] = None, + config: dict[str, Any] | None = None, ) -> str: """ Start logging for the runtime. @@ -72,9 +72,9 @@ def log_chat_completion( invocation_id: uuid.UUID, client_id: int, wrapper_id: int, - agent: Union[str, Agent], - request: Dict[str, Union[float, str, List[Dict[str, str]]]], - response: Union[str, ChatCompletion], + agent: str | Agent, + request: dict[str, float | str | list[dict[str, str]]], + response: str | ChatCompletion, is_cached: int, cost: float, start_time: str, @@ -88,7 +88,7 @@ def log_chat_completion( ) -def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None: +def log_new_agent(agent: ConversableAgent, init_args: dict[str, Any]) -> None: if autogen_logger is None: logger.error("[runtime logging] log_new_agent: autogen logger is None") return @@ -96,7 +96,7 @@ def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None: autogen_logger.log_new_agent(agent, init_args) -def log_event(source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: +def log_event(source: str | Agent, name: str, **kwargs: dict[str, Any]) -> None: if autogen_logger is None: logger.error("[runtime logging] log_event: autogen logger is None") return @@ -104,7 +104,7 @@ def log_event(source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> autogen_logger.log_event(source, name, **kwargs) -def log_function_use(agent: Union[str, Agent], function: F, args: Dict[str, Any], returns: any): +def log_function_use(agent: str | Agent, function: F, args: dict[str, Any], returns: any): if autogen_logger is None: logger.error("[runtime logging] log_function_use: autogen logger is None") return @@ -112,7 +112,7 @@ def log_function_use(agent: Union[str, Agent], function: F, args: Dict[str, Any] autogen_logger.log_function_use(agent, function, args, returns) -def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None: +def log_new_wrapper(wrapper: OpenAIWrapper, init_args: dict[str, LLMConfig | list[LLMConfig]]) -> None: if autogen_logger is None: logger.error("[runtime logging] log_new_wrapper: autogen logger is None") return @@ -121,21 +121,21 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[ - AzureOpenAI, - OpenAI, - CerebrasClient, - GeminiClient, - AnthropicClient, - MistralAIClient, - TogetherClient, - GroqClient, - CohereClient, - OllamaClient, - BedrockClient, - ], + client: ( + AzureOpenAI + | OpenAI + | CerebrasClient + | GeminiClient + | AnthropicClient + | MistralAIClient + | TogetherClient + | GroqClient + | CohereClient + | OllamaClient + | BedrockClient + ), wrapper: OpenAIWrapper, - init_args: Dict[str, Any], + init_args: dict[str, Any], ) -> None: if autogen_logger is None: logger.error("[runtime logging] log_new_client: autogen logger is None") @@ -151,7 +151,7 @@ def stop() -> None: is_logging = False -def get_connection() -> Union[None, sqlite3.Connection]: +def get_connection() -> None | sqlite3.Connection: if autogen_logger is None: logger.error("[runtime logging] get_connection: autogen logger is None") return None diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index 56975a279b..defb163674 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -67,7 +67,7 @@ def percentile_used(input, model="gpt-3.5-turbo-0613"): return count_token(input) / get_max_token_limit(model) -def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int: +def token_left(input: Union[str, list, dict], model="gpt-3.5-turbo-0613") -> int: """Count number of tokens left for an OpenAI model. Args: @@ -80,7 +80,7 @@ def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int return get_max_token_limit(model) - count_token(input, model=model) -def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613") -> int: +def count_token(input: Union[str, list, dict], model: str = "gpt-3.5-turbo-0613") -> int: """Count number of tokens used by an OpenAI model. Args: input: (str, list, dict): Input to the model. @@ -107,7 +107,7 @@ def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"): return len(encoding.encode(text)) -def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0613"): +def _num_token_from_messages(messages: Union[list, dict], model="gpt-3.5-turbo-0613"): """Return the number of tokens used by a list of messages. retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/ diff --git a/autogen/types.py b/autogen/types.py index 99546672cd..be865907ab 100644 --- a/autogen/types.py +++ b/autogen/types.py @@ -6,7 +6,7 @@ # SPDX-License-Identifier: MIT from typing import Dict, List, Literal, TypedDict, Union -MessageContentType = Union[str, List[Union[Dict, str]], None] +MessageContentType = Union[str, list[Union[dict, str]], None] class UserMessageTextContentPart(TypedDict): @@ -17,4 +17,4 @@ class UserMessageTextContentPart(TypedDict): class UserMessageImageContentPart(TypedDict): type: Literal["image_url"] # Ignoring the other "detail param for now" - image_url: Dict[Literal["url"], str] + image_url: dict[Literal["url"], str] diff --git a/test/agentchat/contrib/agent_eval/test_agent_eval.py b/test/agentchat/contrib/agent_eval/test_agent_eval.py index 65e03af36e..d5f7306528 100644 --- a/test/agentchat/contrib/agent_eval/test_agent_eval.py +++ b/test/agentchat/contrib/agent_eval/test_agent_eval.py @@ -54,9 +54,9 @@ def remove_ground_truth(test_case: str): filter_dict={"api_type": ["azure"]}, ) - success_str = open("test/test_files/agenteval-in-out/samples/sample_math_response_successful.txt", "r").read() + success_str = open("test/test_files/agenteval-in-out/samples/sample_math_response_successful.txt").read() response_successful = remove_ground_truth(success_str)[0] - failed_str = open("test/test_files/agenteval-in-out/samples/sample_math_response_failed.txt", "r").read() + failed_str = open("test/test_files/agenteval-in-out/samples/sample_math_response_failed.txt").read() response_failed = remove_ground_truth(failed_str)[0] task = Task( **{ @@ -87,10 +87,10 @@ def test_generate_criteria(): ) def test_quantify_criteria(): criteria_file = "test/test_files/agenteval-in-out/samples/sample_math_criteria.json" - criteria = open(criteria_file, "r").read() + criteria = open(criteria_file).read() criteria = Criterion.parse_json_str(criteria) - test_case = open("test/test_files/agenteval-in-out/samples/sample_test_case.json", "r").read() + test_case = open("test/test_files/agenteval-in-out/samples/sample_test_case.json").read() test_case, ground_truth = remove_ground_truth(test_case) quantified = quantify_criteria( diff --git a/test/agentchat/contrib/agent_eval/test_criterion.py b/test/agentchat/contrib/agent_eval/test_criterion.py index f36ccdfd24..af0600cd3c 100644 --- a/test/agentchat/contrib/agent_eval/test_criterion.py +++ b/test/agentchat/contrib/agent_eval/test_criterion.py @@ -11,7 +11,7 @@ def test_parse_json_str(): criteria_file = "test/test_files/agenteval-in-out/samples/sample_math_criteria.json" - criteria = open(criteria_file, "r").read() + criteria = open(criteria_file).read() criteria = Criterion.parse_json_str(criteria) assert criteria assert len(criteria) == 6 diff --git a/test/agentchat/contrib/capabilities/test_image_generation_capability.py b/test/agentchat/contrib/capabilities/test_image_generation_capability.py index c050a775af..39f5e4daf9 100644 --- a/test/agentchat/contrib/capabilities/test_image_generation_capability.py +++ b/test/agentchat/contrib/capabilities/test_image_generation_capability.py @@ -57,7 +57,7 @@ def create_test_agent(name: str = "test_agent", default_auto_reply: str = "") -> return ConversableAgent(name=name, llm_config=False, default_auto_reply=default_auto_reply) -def dalle_image_generator(dalle_config: Dict[str, Any], resolution: str, quality: str): +def dalle_image_generator(dalle_config: dict[str, Any], resolution: str, quality: str): return generate_images.DalleImageGenerator(dalle_config, resolution=resolution, quality=quality, num_images=1) @@ -66,7 +66,7 @@ def api_key(): @pytest.fixture -def dalle_config() -> Dict[str, Any]: +def dalle_config() -> dict[str, Any]: config_list = openai_utils.config_list_from_models(model_list=["dall-e-3"], exclude="aoai") if not config_list: config_list = [{"model": "dall-e-3", "api_key": api_key()}] @@ -74,7 +74,7 @@ def dalle_config() -> Dict[str, Any]: @pytest.fixture -def gpt4_config() -> Dict[str, Any]: +def gpt4_config() -> dict[str, Any]: config_list = [ { "model": "gpt-4o-mini", @@ -96,7 +96,7 @@ def image_gen_capability(): @pytest.mark.skipif(skip_openai, reason="Requested to skip.") @pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.") -def test_dalle_image_generator(dalle_config: Dict[str, Any]): +def test_dalle_image_generator(dalle_config: dict[str, Any]): """Tests DalleImageGenerator capability to generate images by calling the OpenAI API.""" dalle_generator = dalle_image_generator(dalle_config, RESOLUTIONS[0], QUALITIES[0]) image = dalle_generator.generate_image(PROMPTS[0]) @@ -109,7 +109,7 @@ def test_dalle_image_generator(dalle_config: Dict[str, Any]): @pytest.mark.parametrize("gen_config_2", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS)) @pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.") def test_dalle_image_generator_cache_key( - dalle_config: Dict[str, Any], gen_config_1: Tuple[str, str, str], gen_config_2: Tuple[str, str, str] + dalle_config: dict[str, Any], gen_config_1: tuple[str, str, str], gen_config_2: tuple[str, str, str] ): """Tests if DalleImageGenerator creates unique cache keys. diff --git a/test/agentchat/contrib/capabilities/test_teachable_agent.py b/test/agentchat/contrib/capabilities/test_teachable_agent.py index 82252f07f6..b5ed584e3e 100755 --- a/test/agentchat/contrib/capabilities/test_teachable_agent.py +++ b/test/agentchat/contrib/capabilities/test_teachable_agent.py @@ -202,7 +202,7 @@ def test_teachability_accuracy(): return # All trials failed. - assert False, "test_teachability_accuracy() failed on all {} trials.".format(num_trials) + assert False, f"test_teachability_accuracy() failed on all {num_trials} trials." if __name__ == "__main__": diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 744727f65e..2038991bfc 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -21,11 +21,11 @@ class _MockTextCompressor: - def compress_text(self, text: str, **compression_params) -> Dict[str, Any]: + def compress_text(self, text: str, **compression_params) -> dict[str, Any]: return {"compressed_prompt": ""} -def get_long_messages() -> List[Dict]: +def get_long_messages() -> list[dict]: return [ {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, {"role": "user", "content": "very very very very very very long string"}, @@ -35,7 +35,7 @@ def get_long_messages() -> List[Dict]: ] -def get_short_messages() -> List[Dict]: +def get_short_messages() -> list[dict]: return [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, @@ -43,11 +43,11 @@ def get_short_messages() -> List[Dict]: ] -def get_no_content_messages() -> List[Dict]: +def get_no_content_messages() -> list[dict]: return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] -def get_tool_messages() -> List[Dict]: +def get_tool_messages() -> list[dict]: return [ {"role": "user", "content": "hello"}, {"role": "tool_calls", "content": "calling_tool"}, @@ -57,7 +57,7 @@ def get_tool_messages() -> List[Dict]: ] -def get_tool_messages_kept() -> List[Dict]: +def get_tool_messages_kept() -> list[dict]: return [ {"role": "user", "content": "hello"}, {"role": "tool_calls", "content": "calling_tool"}, @@ -67,7 +67,7 @@ def get_tool_messages_kept() -> List[Dict]: ] -def get_messages_with_names() -> List[Dict]: +def get_messages_with_names() -> list[dict]: return [ {"role": "system", "content": "I am the system."}, {"role": "user", "name": "charlie", "content": "I think the sky is blue."}, @@ -76,7 +76,7 @@ def get_messages_with_names() -> List[Dict]: ] -def get_messages_with_names_post_start() -> List[Dict]: +def get_messages_with_names_post_start() -> list[dict]: return [ {"role": "system", "content": "I am the system."}, {"role": "user", "name": "charlie", "content": "'charlie' said:\nI think the sky is blue."}, @@ -85,7 +85,7 @@ def get_messages_with_names_post_start() -> List[Dict]: ] -def get_messages_with_names_post_end() -> List[Dict]: +def get_messages_with_names_post_end() -> list[dict]: return [ {"role": "system", "content": "I am the system."}, {"role": "user", "name": "charlie", "content": "I think the sky is blue.\n(said 'charlie')"}, @@ -94,7 +94,7 @@ def get_messages_with_names_post_end() -> List[Dict]: ] -def get_messages_with_names_post_filtered() -> List[Dict]: +def get_messages_with_names_post_filtered() -> list[dict]: return [ {"role": "system", "content": "I am the system."}, {"role": "user", "name": "charlie", "content": "I think the sky is blue."}, @@ -103,8 +103,8 @@ def get_messages_with_names_post_filtered() -> List[Dict]: ] -def get_text_compressors() -> List[TextCompressor]: - compressors: List[TextCompressor] = [_MockTextCompressor()] +def get_text_compressors() -> list[TextCompressor]: + compressors: list[TextCompressor] = [_MockTextCompressor()] try: from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua @@ -136,7 +136,7 @@ def message_token_limiter_with_threshold() -> MessageTokenLimiter: def _filter_dict_test( - post_transformed_message: Dict, pre_transformed_messages: Dict, roles: List[str], exclude_filter: bool + post_transformed_message: dict, pre_transformed_messages: dict, roles: list[str], exclude_filter: bool ) -> bool: is_role = post_transformed_message["role"] in roles if exclude_filter: diff --git a/test/agentchat/contrib/capabilities/test_transforms_util.py b/test/agentchat/contrib/capabilities/test_transforms_util.py index 31c5ac223e..6647226f0b 100644 --- a/test/agentchat/contrib/capabilities/test_transforms_util.py +++ b/test/agentchat/contrib/capabilities/test_transforms_util.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("message", MESSAGES.values()) -def test_cache_content(message: Dict[str, MessageContentType]) -> None: +def test_cache_content(message: dict[str, MessageContentType]) -> None: with tempfile.TemporaryDirectory() as tmpdirname: cache = Cache.disk(tmpdirname) cache_key_1 = "test_string" @@ -51,7 +51,7 @@ def test_cache_content(message: Dict[str, MessageContentType]) -> None: @pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values())) -def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None: +def test_cache_key(messages: tuple[dict[str, MessageContentType], dict[str, MessageContentType]]) -> None: message_1, message_2 = messages cache_1 = transforms_util.cache_key(message_1["content"], 10) cache_2 = transforms_util.cache_key(message_2["content"], 10) @@ -62,17 +62,17 @@ def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, Mess @pytest.mark.parametrize("message", MESSAGES.values()) -def test_min_tokens_reached(message: Dict[str, MessageContentType]): +def test_min_tokens_reached(message: dict[str, MessageContentType]): assert transforms_util.min_tokens_reached([message], None) assert transforms_util.min_tokens_reached([message], 0) assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1) @pytest.mark.parametrize("message", MESSAGES.values()) -def test_count_text_tokens(message: Dict[str, MessageContentType]): +def test_count_text_tokens(message: dict[str, MessageContentType]): assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"] @pytest.mark.parametrize("message", MESSAGES.values()) -def test_is_content_text_empty(message: Dict[str, MessageContentType]): +def test_is_content_text_empty(message: dict[str, MessageContentType]): assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0) diff --git a/test/agentchat/contrib/test_agent_builder.py b/test/agentchat/contrib/test_agent_builder.py index cab4a051b5..e16468ad62 100755 --- a/test/agentchat/contrib/test_agent_builder.py +++ b/test/agentchat/contrib/test_agent_builder.py @@ -180,7 +180,7 @@ def test_load(): ) config_save_path = f"{here}/example_test_agent_builder_config.json" - json.load(open(config_save_path, "r")) + json.load(open(config_save_path)) agent_list, loaded_agent_configs = builder.load( config_save_path, diff --git a/test/agentchat/contrib/test_society_of_mind_agent.py b/test/agentchat/contrib/test_society_of_mind_agent.py index 376bddfd70..1e70ebf47f 100755 --- a/test/agentchat/contrib/test_society_of_mind_agent.py +++ b/test/agentchat/contrib/test_society_of_mind_agent.py @@ -8,9 +8,9 @@ import os import sys +from typing import Annotated import pytest -from typing_extensions import Annotated import autogen from autogen.agentchat.contrib.society_of_mind_agent import SocietyOfMindAgent diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 55fae1826d..85c24110a0 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -306,12 +306,12 @@ def test_context_variables_updating_multi_tools(): test_context_variables = {"my_key": 0} # Increment the context variable - def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: + def test_func_1(context_variables: dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) # Increment the context variable - def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: + def test_func_2(context_variables: dict[str, Any], param2: str) -> str: context_variables["my_key"] += 100 return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) @@ -367,7 +367,7 @@ def test_function_transfer(): test_context_variables = {"my_key": 0} # Increment the context variable - def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: + def test_func_1(context_variables: dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) @@ -474,7 +474,7 @@ def __init__(self): message_container = MessageContainer() # 1. Test with a callable function - def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str: + def custom_update_function(agent: ConversableAgent, messages: list[dict]) -> str: return f"System message with {agent.get_context('test_var')} and {len(messages)} messages" # 2. Test with a string template @@ -537,7 +537,7 @@ def invalid_return_function(context_variables, messages) -> dict: SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) # Test multiple update functions - def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: + def another_update_function(context_variables: dict[str, Any], messages: list[dict]) -> str: return "Another update" agent6 = SwarmAgent( @@ -673,17 +673,17 @@ def test_after_work_callable(): agent3 = SwarmAgent("agent3", llm_config=testing_llm_config) def return_agent( - last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat ) -> Union[AfterWorkOption, SwarmAgent, str]: return agent2 def return_agent_str( - last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat ) -> Union[AfterWorkOption, SwarmAgent, str]: return "agent3" def return_after_work_option( - last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat ) -> Union[AfterWorkOption, SwarmAgent, str]: return AfterWorkOption.TERMINATE diff --git a/test/agentchat/contrib/vectordb/test_mongodb.py b/test/agentchat/contrib/vectordb/test_mongodb.py index 536da417fc..055ec22b5f 100644 --- a/test/agentchat/contrib/vectordb/test_mongodb.py +++ b/test/agentchat/contrib/vectordb/test_mongodb.py @@ -103,7 +103,7 @@ def db(): @pytest.fixture -def example_documents() -> List[Document]: +def example_documents() -> list[Document]: """Note mix of integers and strings as ids""" return [ Document(id=1, content="Dogs are tough.", metadata={"a": 1}), diff --git a/test/agentchat/contrib/vectordb/test_pgvectordb.py b/test/agentchat/contrib/vectordb/test_pgvectordb.py index e158cf1678..15f96809b8 100644 --- a/test/agentchat/contrib/vectordb/test_pgvectordb.py +++ b/test/agentchat/contrib/vectordb/test_pgvectordb.py @@ -133,7 +133,7 @@ def test_pgvector(): res = db.get_docs_by_ids(["1", "2"], collection_name) assert [r["id"] for r in res] == ["2"] # "1" has been deleted res = db.get_docs_by_ids(collection_name=collection_name) - assert set([r["id"] for r in res]) == set(["2", "3"]) # All Docs returned + assert {r["id"] for r in res} == {"2", "3"} # All Docs returned if __name__ == "__main__": diff --git a/test/agentchat/test_agent_file_logging.py b/test/agentchat/test_agent_file_logging.py index d68c5dea9c..78755ab270 100644 --- a/test/agentchat/test_agent_file_logging.py +++ b/test/agentchat/test_agent_file_logging.py @@ -74,7 +74,7 @@ def test_log_chat_completion(logger: FileLogger): source=agent, ) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() assert len(lines) == 1 log_data = json.loads(lines[0]) @@ -98,7 +98,7 @@ def test_log_function_use(logger: FileLogger): logger.log_function_use(source=source, function=func, args=args, returns=returns) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() assert len(lines) == 1 log_data = json.loads(lines[0]) @@ -118,7 +118,7 @@ def test_log_new_agent(logger: FileLogger): agent = autogen.UserProxyAgent(name="user_proxy", code_execution_config=False) logger.log_new_agent(agent) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() log_data = json.loads(lines[0]) # the first line is the session id assert log_data["agent_name"] == "user_proxy" @@ -131,7 +131,7 @@ def test_log_event(logger: FileLogger): kwargs = {"key": "value"} logger.log_event(source, name, **kwargs) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() log_data = json.loads(lines[0]) assert log_data["source_name"] == "TestAgent" @@ -145,7 +145,7 @@ def test_log_new_wrapper(logger: FileLogger): wrapper = TestWrapper(init_args={"foo": "bar"}) logger.log_new_wrapper(wrapper, wrapper.init_args) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() log_data = json.loads(lines[0]) assert log_data["wrapper_id"] == id(wrapper) @@ -160,7 +160,7 @@ def test_log_new_client(logger: FileLogger): init_args = {"foo": "bar"} logger.log_new_client(client, wrapper, init_args) - with open(logger.log_file, "r") as f: + with open(logger.log_file) as f: lines = f.readlines() log_data = json.loads(lines[0]) assert log_data["client_id"] == id(client) diff --git a/test/agentchat/test_agentchat_utils.py b/test/agentchat/test_agentchat_utils.py index 805411f9c2..8c38a6855e 100644 --- a/test/agentchat/test_agentchat_utils.py +++ b/test/agentchat/test_agentchat_utils.py @@ -48,13 +48,13 @@ ] -def _delete_unused_keys(d: Dict) -> None: +def _delete_unused_keys(d: dict) -> None: if "match" in d: del d["match"] @pytest.mark.parametrize("test_case", TAG_PARSING_TESTS) -def test_tag_parsing(test_case: Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]) -> None: +def test_tag_parsing(test_case: dict[str, Union[str, list[dict[str, Union[str, dict[str, str]]]]]]) -> None: """Test the tag_parsing function.""" message = test_case["message"] expected = test_case["expected"] diff --git a/test/agentchat/test_assistant_agent.py b/test/agentchat/test_assistant_agent.py index ee7f5b88bd..3e96ecee14 100755 --- a/test/agentchat/test_assistant_agent.py +++ b/test/agentchat/test_assistant_agent.py @@ -181,7 +181,7 @@ def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=2): def tsp_message(sender, recipient, context): filename = context.get("prompt_filename", "") - with open(filename, "r") as f: + with open(filename) as f: prompt = f.read() question = context.get("question", "") return prompt.format(question=question) diff --git a/test/agentchat/test_chats.py b/test/agentchat/test_chats.py index a39162debf..6f3504bbdb 100755 --- a/test/agentchat/test_chats.py +++ b/test/agentchat/test_chats.py @@ -8,11 +8,10 @@ import os import sys -from typing import Literal +from typing import Annotated, Literal import pytest from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST -from typing_extensions import Annotated import autogen from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats @@ -568,7 +567,7 @@ def my_writing_task(sender, recipient, context): try: filename = context.get("work_dir", "") + "/stock_prices.md" - with open(filename, "r") as file: + with open(filename) as file: data = file.read() except Exception as e: data = f"An error occurred while reading the file: {e}" diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 93866c81a0..d43f2dba3f 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -13,13 +13,12 @@ import sys import time import unittest -from typing import Any, Callable, Dict, Literal +from typing import Annotated, Any, Callable, Dict, Literal from unittest.mock import MagicMock import pytest from pydantic import BaseModel, Field from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST -from typing_extensions import Annotated import autogen from autogen.agentchat import ConversableAgent, UserProxyAgent @@ -660,7 +659,7 @@ async def currency_calculator( assert inspect.iscoroutinefunction(currency_calculator) -def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]: +def get_origin(d: dict[str, Callable[..., Any]]) -> dict[str, Callable[..., Any]]: return {k: v._origin for k, v in d.items()} diff --git a/test/agentchat/test_function_and_tool_calling.py b/test/agentchat/test_function_and_tool_calling.py index eaaea6a8a9..d3776ce206 100644 --- a/test/agentchat/test_function_and_tool_calling.py +++ b/test/agentchat/test_function_and_tool_calling.py @@ -203,7 +203,7 @@ async def _a_tool_func_error(arg1: str, arg2: str) -> str: _text_message = {"content": "Hi!", "role": "user"} -def _get_function_map(is_function_async: bool, drop_tool_2: bool = False) -> Dict[str, Callable[..., Any]]: +def _get_function_map(is_function_async: bool, drop_tool_2: bool = False) -> dict[str, Callable[..., Any]]: if is_function_async: return ( { @@ -230,7 +230,7 @@ def _get_function_map(is_function_async: bool, drop_tool_2: bool = False) -> Dic def _get_error_function_map( is_function_async: bool, error_on_tool_func_2: bool = True -) -> Dict[str, Callable[..., Any]]: +) -> dict[str, Callable[..., Any]]: if is_function_async: return { "_tool_func_1": _a_tool_func_1 if error_on_tool_func_2 else _a_tool_func_error, @@ -280,7 +280,7 @@ def test_generate_function_call_reply_on_function_call_message(is_function_async assert (finished, retval) == (False, None) # text message - messages: List[Dict[str, str]] = [_text_message] + messages: list[dict[str, str]] = [_text_message] finished, retval = agent.generate_function_call_reply(messages) assert (finished, retval) == (False, None) @@ -329,7 +329,7 @@ async def test_a_generate_function_call_reply_on_function_call_message(is_functi assert (finished, retval) == (False, None) # text message - messages: List[Dict[str, str]] = [_text_message] + messages: list[dict[str, str]] = [_text_message] finished, retval = await agent.a_generate_function_call_reply(messages) assert (finished, retval) == (False, None) @@ -377,7 +377,7 @@ def test_generate_tool_calls_reply_on_function_call_message(is_function_async: b assert (finished, retval) == (False, None) # text message - messages: List[Dict[str, str]] = [_text_message] + messages: list[dict[str, str]] = [_text_message] finished, retval = agent.generate_tool_calls_reply(messages) assert (finished, retval) == (False, None) @@ -426,7 +426,7 @@ async def test_a_generate_tool_calls_reply_on_function_call_message(is_function_ assert (finished, retval) == (False, None) # text message - messages: List[Dict[str, str]] = [_text_message] + messages: list[dict[str, str]] = [_text_message] finished, retval = await agent.a_generate_tool_calls_reply(messages) assert (finished, retval) == (False, None) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index de5f8bab11..a3063011a9 100755 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -602,9 +602,7 @@ def test_init_default_parameters(): agents = [autogen.ConversableAgent(name=f"Agent{i}", llm_config=False) for i in range(3)] group_chat = GroupChat(agents=agents, messages=[], max_round=3) for agent in agents: - assert set([a.name for a in group_chat.allowed_speaker_transitions_dict[agent]]) == set( - [a.name for a in agents] - ) + assert {a.name for a in group_chat.allowed_speaker_transitions_dict[agent]} == {a.name for a in agents} def test_graph_parameters(): @@ -889,7 +887,7 @@ def agent(name: str) -> autogen.ConversableAgent: llm_config=False, ) - def team(members: List[autogen.Agent], name: str) -> autogen.Agent: + def team(members: list[autogen.Agent], name: str) -> autogen.Agent: gc = autogen.GroupChat(agents=members, messages=[]) return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False) @@ -963,7 +961,7 @@ def test_nested_teams_chat(): team1_msg = {"content": "Hello from team 1"} team2_msg = {"content": "Hello from team 2"} - def agent(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent: + def agent(name: str, auto_reply: Optional[dict[str, Any]] = None) -> autogen.ConversableAgent: return autogen.ConversableAgent( name=name, max_consecutive_auto_reply=10, @@ -972,7 +970,7 @@ def agent(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.Con default_auto_reply=auto_reply, ) - def team(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent: + def team(name: str, auto_reply: Optional[dict[str, Any]] = None) -> autogen.ConversableAgent: member1 = agent(f"member1_{name}", auto_reply=auto_reply) member2 = agent(f"member2_{name}", auto_reply=auto_reply) diff --git a/test/agentchat/test_nested.py b/test/agentchat/test_nested.py index 24c86a7fed..f47371628f 100755 --- a/test/agentchat/test_nested.py +++ b/test/agentchat/test_nested.py @@ -22,7 +22,7 @@ class MockAgentReplies(AgentCapability): - def __init__(self, mock_messages: List[str]): + def __init__(self, mock_messages: list[str]): self.mock_messages = mock_messages self.mock_message_index = 0 diff --git a/test/agentchat/test_structured_output.py b/test/agentchat/test_structured_output.py index d99b2a63d6..4c1c671cb7 100644 --- a/test/agentchat/test_structured_output.py +++ b/test/agentchat/test_structured_output.py @@ -73,7 +73,7 @@ class Step(BaseModel): class MathReasoning(BaseModel): - steps: List[Step] + steps: list[Step] final_answer: str def format(self) -> str: diff --git a/test/coding/test_embedded_ipython_code_executor.py b/test/coding/test_embedded_ipython_code_executor.py index e009981779..df0d315161 100644 --- a/test/coding/test_embedded_ipython_code_executor.py +++ b/test/coding/test_embedded_ipython_code_executor.py @@ -61,7 +61,7 @@ def test_is_code_executor(cls) -> None: @pytest.mark.skipif(skip, reason=skip_reason) def test_create_dict() -> None: - config: Dict[str, Union[str, CodeExecutor]] = {"executor": "ipython-embedded"} + config: dict[str, Union[str, CodeExecutor]] = {"executor": "ipython-embedded"} executor = CodeExecutorFactory.create(config) assert isinstance(executor, EmbeddedIPythonCodeExecutor) @@ -190,7 +190,7 @@ def test_save_image(cls) -> None: @pytest.mark.skipif(skip, reason=skip_reason) @pytest.mark.parametrize("cls", classes_to_test) -def test_timeout_preserves_kernel_state(cls: Type[CodeExecutor]) -> None: +def test_timeout_preserves_kernel_state(cls: type[CodeExecutor]) -> None: executor = cls(timeout=1) code_blocks = [CodeBlock(code="x = 123", language="python")] code_result = executor.execute_code_blocks(code_blocks) diff --git a/test/interop/langchain/test_langchain.py b/test/interop/langchain/test_langchain.py index be0a2f6bfc..fe129a6e1c 100644 --- a/test/interop/langchain/test_langchain.py +++ b/test/interop/langchain/test_langchain.py @@ -8,16 +8,11 @@ import pytest from conftest import reason, skip_openai +from langchain.tools import tool as langchain_tool from pydantic import BaseModel, Field from autogen import AssistantAgent, UserProxyAgent from autogen.interop import Interoperable - -if sys.version_info >= (3, 9): - from langchain.tools import tool as langchain_tool -else: - langchain_tool = unittest.mock.MagicMock() - from autogen.interop.langchain import LangChainInteroperability diff --git a/test/interop/pydantic_ai/test_pydantic_ai.py b/test/interop/pydantic_ai/test_pydantic_ai.py index 2840cdbc9a..605e40923b 100644 --- a/test/interop/pydantic_ai/test_pydantic_ai.py +++ b/test/interop/pydantic_ai/test_pydantic_ai.py @@ -12,18 +12,11 @@ import pytest from conftest import reason, skip_openai from pydantic import BaseModel +from pydantic_ai import RunContext +from pydantic_ai.tools import Tool as PydanticAITool from autogen import AssistantAgent, UserProxyAgent from autogen.interop import Interoperable - -if sys.version_info >= (3, 9): - from pydantic_ai import RunContext - from pydantic_ai.tools import Tool as PydanticAITool - -else: - RunContext = unittest.mock.MagicMock() - PydanticAITool = unittest.mock.MagicMock() - from autogen.interop.pydantic_ai import PydanticAIInteroperability @@ -104,7 +97,7 @@ def f( tool=pydantic_ai_tool, ) assert list(signature(g).parameters.keys()) == ["city", "date"] - kwargs: Dict[str, Any] = {"city": "Zagreb", "date": "2021-01-01"} + kwargs: dict[str, Any] = {"city": "Zagreb", "date": "2021-01-01"} assert g(**kwargs) == "Zagreb 2021-01-01 123" def test_dependency_injection_with_retry(self) -> None: diff --git a/test/interop/pydantic_ai/test_pydantic_ai_tool.py b/test/interop/pydantic_ai/test_pydantic_ai_tool.py index f1ae38389e..0f4eb92577 100644 --- a/test/interop/pydantic_ai/test_pydantic_ai_tool.py +++ b/test/interop/pydantic_ai/test_pydantic_ai_tool.py @@ -6,15 +6,9 @@ import unittest import pytest +from pydantic_ai.tools import Tool as PydanticAITool from autogen import AssistantAgent - -if sys.version_info >= (3, 9): - from pydantic_ai.tools import Tool as PydanticAITool - -else: - PydanticAITool = unittest.mock.MagicMock() - from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool diff --git a/test/io/test_base.py b/test/io/test_base.py index 8083f0d811..c4c77d8f66 100644 --- a/test/io/test_base.py +++ b/test/io/test_base.py @@ -35,9 +35,9 @@ def input(self, prompt: str = "", *, password: bool = False) -> str: assert isinstance(IOStream.get_default(), IOConsole) def test_get_default_on_new_thread(self) -> None: - exceptions: List[Exception] = [] + exceptions: list[Exception] = [] - def on_new_thread(exceptions: List[Exception] = exceptions) -> None: + def on_new_thread(exceptions: list[Exception] = exceptions) -> None: try: assert isinstance(IOStream.get_default(), IOConsole) except Exception as e: diff --git a/test/io/test_websockets.py b/test/io/test_websockets.py index 6c4b4662e3..1c84eebc79 100644 --- a/test/io/test_websockets.py +++ b/test/io/test_websockets.py @@ -92,7 +92,7 @@ def test_chat(self) -> None: success_dict = {"success": False} - def on_connect(iostream: IOWebsockets, success_dict: Dict[str, bool] = success_dict) -> None: + def on_connect(iostream: IOWebsockets, success_dict: dict[str, bool] = success_dict) -> None: print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True) print(" - on_connect(): Receiving message from client.", flush=True) diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index abb7e18c72..1ce72e0e77 100755 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -68,13 +68,13 @@ def test_chat_completion_stream() -> None: def test__update_dict_from_chunk() -> None: # dictionaries and lists are not supported mock = MagicMock() - empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []] + empty_collections: list[Union[list[Any], dict[str, Any]]] = [{}, []] for c in empty_collections: mock.c = c with pytest.raises(NotImplementedError): OpenAIWrapper._update_dict_from_chunk(mock, {}, "c") - org_d: Dict[str, Any] = {} + org_d: dict[str, Any] = {} for i, v in enumerate([0, 1, False, True, 0.0, 1.0]): field = "abcedfghijklmnopqrstuvwxyz"[i] setattr(mock, field, v) @@ -186,7 +186,7 @@ def test__update_tool_calls_from_chunk() -> None: ), ] - full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None] + full_tool_calls: list[Optional[dict[str, Any]]] = [None, None] completion_tokens = 0 for tool_calls_chunk in tool_calls_chunks: index = tool_calls_chunk.index diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 5976b7a46f..0e7fea224d 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -28,7 +28,7 @@ def test_custom_model_client(): TEST_MAX_LENGTH = 1000 class CustomModel: - def __init__(self, config: Dict, test_hook): + def __init__(self, config: dict, test_hook): self.test_hook = test_hook self.device = config["device"] self.model = config["model"] @@ -63,7 +63,7 @@ def cost(self, response) -> float: return TEST_COST @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: return {} config_list = [ @@ -96,7 +96,7 @@ def get_usage(response) -> Dict: def test_registering_with_wrong_class_name_raises_error(): class CustomModel: - def __init__(self, config: Dict): + def __init__(self, config: dict): pass def create(self, params): @@ -109,7 +109,7 @@ def cost(self, response) -> float: return 0 @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: return {} config_list = [ @@ -126,7 +126,7 @@ def get_usage(response) -> Dict: def test_not_all_clients_registered_raises_error(): class CustomModel: - def __init__(self, config: Dict): + def __init__(self, config: dict): pass def create(self, params): @@ -139,7 +139,7 @@ def cost(self, response) -> float: return 0 @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: return {} config_list = [ @@ -173,7 +173,7 @@ def get_usage(response) -> Dict: def test_registering_with_extra_config_args(): class CustomModel: - def __init__(self, config: Dict, test_hook): + def __init__(self, config: dict, test_hook): self.test_hook = test_hook self.test_hook["called"] = True @@ -192,7 +192,7 @@ def cost(self, response) -> float: return 0 @staticmethod - def get_usage(response) -> Dict: + def get_usage(response) -> dict: return {} config_list = [ diff --git a/test/oai/test_utils.py b/test/oai/test_utils.py index 599254b47d..1f8d1f4855 100755 --- a/test/oai/test_utils.py +++ b/test/oai/test_utils.py @@ -96,7 +96,7 @@ ] -def _compare_lists_of_dicts(list1: List[Dict], list2: List[Dict]) -> bool: +def _compare_lists_of_dicts(list1: list[dict], list2: list[dict]) -> bool: dump1 = sorted(json.dumps(d, sort_keys=True) for d in list1) dump2 = sorted(json.dumps(d, sort_keys=True) for d in list2) return dump1 == dump2 diff --git a/test/test_function_utils.py b/test/test_function_utils.py index fce7e819b8..7563979d57 100644 --- a/test/test_function_utils.py +++ b/test/test_function_utils.py @@ -7,11 +7,10 @@ import asyncio import inspect import unittest.mock -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple import pytest from pydantic import BaseModel, Field -from typing_extensions import Annotated from autogen._pydantic import PYDANTIC_V1, model_dump from autogen.function_utils import ( @@ -40,7 +39,7 @@ def g( # type: ignore[empty-body] b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, - d: Dict[str, Tuple[Optional[int], List[float]]], + d: dict[str, tuple[Optional[int], list[float]]], ) -> str: pass @@ -50,7 +49,7 @@ async def a_g( # type: ignore[empty-body] b: int = 2, c: Annotated[float, "Parameter c"] = 0.1, *, - d: Dict[str, Tuple[Optional[int], List[float]]], + d: dict[str, tuple[Optional[int], list[float]]], ) -> str: pass @@ -89,7 +88,7 @@ class B(BaseModel): b: float c: str - expected: Dict[str, Any] = { + expected: dict[str, Any] = { "description": "b", "properties": {"b": {"title": "B", "type": "number"}, "c": {"title": "C", "type": "string"}}, "required": ["b", "c"], @@ -367,7 +366,7 @@ def test_load_basemodels_if_needed_sync() -> None: def f( base: Annotated[Currency, "Base currency"], quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR", - ) -> Tuple[Currency, CurrencySymbol]: + ) -> tuple[Currency, CurrencySymbol]: return base, quote_currency assert not inspect.iscoroutinefunction(f) @@ -385,7 +384,7 @@ async def test_load_basemodels_if_needed_async() -> None: async def f( base: Annotated[Currency, "Base currency"], quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR", - ) -> Tuple[Currency, CurrencySymbol]: + ) -> tuple[Currency, CurrencySymbol]: return base, quote_currency assert inspect.iscoroutinefunction(f) diff --git a/test/test_logging.py b/test/test_logging.py index ca2db497ee..f055850953 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -264,7 +264,7 @@ def __init__(self): self.extra_key = "remove this key" self.path = Path("/to/something") - class Bar(object): + class Bar: def init(self): pass diff --git a/test/test_pydantic.py b/test/test_pydantic.py index 256b30e335..0006a605b0 100644 --- a/test/test_pydantic.py +++ b/test/test_pydantic.py @@ -4,10 +4,9 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT -from typing import Dict, List, Optional, Tuple, Union +from typing import Annotated, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field -from typing_extensions import Annotated from autogen._pydantic import model_dump, model_dump_json, type2schema @@ -19,14 +18,14 @@ def test_type2schema() -> None: assert type2schema(bool) == {"type": "boolean"} assert type2schema(None) == {"type": "null"} assert type2schema(Optional[int]) == {"anyOf": [{"type": "integer"}, {"type": "null"}]} - assert type2schema(List[int]) == {"items": {"type": "integer"}, "type": "array"} - assert type2schema(Tuple[int, float, str]) == { + assert type2schema(list[int]) == {"items": {"type": "integer"}, "type": "array"} + assert type2schema(tuple[int, float, str]) == { "maxItems": 3, "minItems": 3, "prefixItems": [{"type": "integer"}, {"type": "number"}, {"type": "string"}], "type": "array", } - assert type2schema(Dict[str, int]) == {"additionalProperties": {"type": "integer"}, "type": "object"} + assert type2schema(dict[str, int]) == {"additionalProperties": {"type": "integer"}, "type": "object"} assert type2schema(Annotated[str, "some text"]) == {"type": "string"} assert type2schema(Union[int, float]) == {"anyOf": [{"type": "integer"}, {"type": "number"}]} diff --git a/website/process_api_reference.py b/website/process_api_reference.py index 9652b8d7fe..3405f360fc 100644 --- a/website/process_api_reference.py +++ b/website/process_api_reference.py @@ -55,7 +55,7 @@ def read_file_content(file_path: str) -> str: Returns: str: Content of the file """ - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: return f.read() @@ -100,18 +100,18 @@ def convert_md_to_mdx(input_dir: Path) -> None: print(f"Converted: {md_file} -> {mdx_file}") -def get_mdx_files(directory: Path) -> List[str]: +def get_mdx_files(directory: Path) -> list[str]: """Get all MDX files in directory and subdirectories.""" return [f"{str(p.relative_to(directory).with_suffix(''))}".replace("\\", "/") for p in directory.rglob("*.mdx")] -def add_prefix(path: str, parent_groups: List[str] = None) -> str: +def add_prefix(path: str, parent_groups: list[str] = None) -> str: """Create full path with prefix and parent groups.""" groups = parent_groups or [] return f"docs/reference/{'/'.join(groups + [path])}" -def create_nav_structure(paths: List[str], parent_groups: List[str] = None) -> List[Any]: +def create_nav_structure(paths: list[str], parent_groups: list[str] = None) -> list[Any]: """Convert list of file paths into nested navigation structure.""" groups = {} pages = [] @@ -142,7 +142,7 @@ def create_nav_structure(paths: List[str], parent_groups: List[str] = None) -> L return sorted_groups + sorted_pages -def update_nav(mint_json_path: Path, new_nav_pages: List[Any]) -> None: +def update_nav(mint_json_path: Path, new_nav_pages: list[Any]) -> None: """ Update the 'API Reference' section in mint.json navigation with new pages. @@ -152,7 +152,7 @@ def update_nav(mint_json_path: Path, new_nav_pages: List[Any]) -> None: """ try: # Read the current mint.json - with open(mint_json_path, "r") as f: + with open(mint_json_path) as f: mint_config = json.load(f) # Find and update the API Reference section diff --git a/website/process_notebooks.py b/website/process_notebooks.py index 0e5b903f77..3b276beabd 100755 --- a/website/process_notebooks.py +++ b/website/process_notebooks.py @@ -20,7 +20,6 @@ import tempfile import threading import time -import typing from dataclasses import dataclass from multiprocessing import current_process from pathlib import Path @@ -81,12 +80,12 @@ def notebooks_target_dir(website_directory: Path) -> Path: return website_directory / "notebooks" -def load_metadata(notebook: Path) -> typing.Dict: +def load_metadata(notebook: Path) -> dict: content = json.load(notebook.open(encoding="utf-8")) return content["metadata"] -def skip_reason_or_none_if_ok(notebook: Path) -> typing.Optional[str]: +def skip_reason_or_none_if_ok(notebook: Path) -> str | None: """Return a reason to skip the notebook, or None if it should not be skipped.""" if notebook.suffix != ".ipynb": @@ -99,7 +98,7 @@ def skip_reason_or_none_if_ok(notebook: Path) -> typing.Optional[str]: if "notebook" not in notebook.parts: return None - with open(notebook, "r", encoding="utf-8") as f: + with open(notebook, encoding="utf-8") as f: content = f.read() # Load the json and get the first cell @@ -139,9 +138,9 @@ def skip_reason_or_none_if_ok(notebook: Path) -> typing.Optional[str]: return None -def extract_title(notebook: Path) -> Optional[str]: +def extract_title(notebook: Path) -> str | None: """Extract the title of the notebook.""" - with open(notebook, "r", encoding="utf-8") as f: + with open(notebook, encoding="utf-8") as f: content = f.read() # Load the json and get the first cell @@ -202,9 +201,7 @@ def process_notebook(src_notebook: Path, website_dir: Path, notebook_dir: Path, shutil.copy(src_notebook.parent / file, dest_dir / file) # Capture output - result = subprocess.run( - [quarto_bin, "render", intermediate_notebook], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) + result = subprocess.run([quarto_bin, "render", intermediate_notebook], capture_output=True, text=True) if result.returncode != 0: return fmt_error( src_notebook, f"Failed to render {src_notebook}\n\nstderr:\n{result.stderr}\nstdout:\n{result.stdout}" @@ -223,9 +220,7 @@ def process_notebook(src_notebook: Path, website_dir: Path, notebook_dir: Path, if dry_run: return colored(f"Would process {src_notebook.name}", "green") - result = subprocess.run( - [quarto_bin, "render", src_notebook], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) + result = subprocess.run([quarto_bin, "render", src_notebook], capture_output=True, text=True) if result.returncode != 0: return fmt_error( src_notebook, f"Failed to render {src_notebook}\n\nstderr:\n{result.stderr}\nstdout:\n{result.stdout}" @@ -240,7 +235,7 @@ def process_notebook(src_notebook: Path, website_dir: Path, notebook_dir: Path, @dataclass class NotebookError: error_name: str - error_value: Optional[str] + error_value: str | None traceback: str cell_source: str @@ -253,7 +248,7 @@ class NotebookSkip: NB_VERSION = 4 -def test_notebook(notebook_path: Path, timeout: int = 300) -> Tuple[Path, Optional[Union[NotebookError, NotebookSkip]]]: +def test_notebook(notebook_path: Path, timeout: int = 300) -> tuple[Path, NotebookError | NotebookSkip | None]: nb = nbformat.read(str(notebook_path), NB_VERSION) if "skip_test" in nb.metadata: @@ -285,7 +280,7 @@ def test_notebook(notebook_path: Path, timeout: int = 300) -> Tuple[Path, Option # Find the first code cell which did not complete. def get_timeout_info( nb: NotebookNode, -) -> Optional[NotebookError]: +) -> NotebookError | None: for i, cell in enumerate(nb.cells): if cell.cell_type != "code": continue @@ -300,7 +295,7 @@ def get_timeout_info( return None -def get_error_info(nb: NotebookNode) -> Optional[NotebookError]: +def get_error_info(nb: NotebookNode) -> NotebookError | None: for cell in nb["cells"]: # get LAST error if cell["cell_type"] != "code": continue @@ -318,13 +313,13 @@ def get_error_info(nb: NotebookNode) -> Optional[NotebookError]: def add_front_matter_to_metadata_mdx( - front_matter: Dict[str, Union[str, List[str]]], website_dir: Path, rendered_mdx: Path + front_matter: dict[str, str | list[str]], website_dir: Path, rendered_mdx: Path ) -> None: metadata_mdx = website_dir / "snippets" / "data" / "NotebooksMetadata.mdx" metadata = [] if metadata_mdx.exists(): - with open(metadata_mdx, "r", encoding="utf-8") as f: + with open(metadata_mdx, encoding="utf-8") as f: content = f.read() if content: start = content.find("export const notebooksMetadata = [") @@ -384,8 +379,8 @@ def resolve_path(match): # rendered_notebook is the final mdx file -def post_process_mdx(rendered_mdx: Path, source_notebooks: Path, front_matter: Dict, website_dir: Path) -> None: - with open(rendered_mdx, "r", encoding="utf-8") as f: +def post_process_mdx(rendered_mdx: Path, source_notebooks: Path, front_matter: dict, website_dir: Path) -> None: + with open(rendered_mdx, encoding="utf-8") as f: content = f.read() # If there is front matter in the mdx file, we need to remove it @@ -465,7 +460,7 @@ def path(path_str: str) -> Path: return Path(path_str) -def collect_notebooks(notebook_directory: Path, website_directory: Path) -> typing.List[Path]: +def collect_notebooks(notebook_directory: Path, website_directory: Path) -> list[Path]: notebooks = list(notebook_directory.glob("*.ipynb")) notebooks.extend(list(website_directory.glob("docs/**/*.ipynb"))) return notebooks @@ -479,7 +474,7 @@ def fmt_ok(notebook: Path) -> str: return f"{colored('[OK]', 'green')} {colored(notebook.name, 'blue')} ✅" -def fmt_error(notebook: Path, error: Union[NotebookError, str]) -> str: +def fmt_error(notebook: Path, error: NotebookError | str) -> str: if isinstance(error, str): return f"{colored('[Error]', 'red')} {colored(notebook.name, 'blue')}: {error}" elif isinstance(error, NotebookError): @@ -538,11 +533,11 @@ def update_navigation_with_notebooks(website_dir: Path) -> None: return # Read mint.json - with open(mint_json_path, "r", encoding="utf-8") as f: + with open(mint_json_path, encoding="utf-8") as f: mint_config = json.load(f) # Read NotebooksMetadata.mdx and extract metadata links - with open(metadata_path, "r", encoding="utf-8") as f: + with open(metadata_path, encoding="utf-8") as f: content = f.read() # Extract the array between the brackets start = content.find("export const notebooksMetadata = [") @@ -622,7 +617,7 @@ def fix_internal_references_in_mdx_files(website_dir: Path) -> None: """Process all MDX files in directory to fix internal references.""" for file_path in website_dir.glob("**/*.mdx"): try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: content = f.read() fixed_content = fix_internal_references(content, website_dir, file_path)