From 52973fa8fd8ea8c02cc5a93f4199c73725a536aa Mon Sep 17 00:00:00 2001 From: Sternakt Date: Wed, 4 Dec 2024 21:14:12 +0100 Subject: [PATCH] Fix caching issues with format, fix function registration issues when using response_format Signed-off-by: Sternakt --- autogen/agentchat/conversable_agent.py | 11 +- autogen/oai/anthropic.py | 8 +- autogen/oai/bedrock.py | 9 +- autogen/oai/cerebras.py | 14 ++- autogen/oai/client.py | 106 ++++++++------------ autogen/oai/cohere.py | 8 +- autogen/oai/gemini.py | 8 +- autogen/oai/groq.py | 8 +- autogen/oai/mistral.py | 8 +- autogen/oai/ollama.py | 7 +- autogen/oai/together.py | 8 +- notebook/agentchat_structured_outputs.ipynb | 35 ++++++- test/agentchat/test_structured_output.py | 12 +-- test/oai/test_custom_client.py | 4 +- 14 files changed, 129 insertions(+), 117 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 5d8bbdca8d..98b0d22d8d 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -274,7 +274,11 @@ def _validate_llm_config(self, llm_config): raise ValueError( "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." ) - self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) + self.client = ( + None + if self.llm_config is False + else OpenAIWrapper(response_format=self._response_format, **self.llm_config) + ) @staticmethod def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool: @@ -1452,7 +1456,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ messages=all_messages, cache=cache, agent=self, - response_format=self._response_format, ) extracted_response = llm_client.extract_text_or_completion_object(response)[0] @@ -2549,7 +2552,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) if len(self.llm_config["functions"]) == 0: del self.llm_config["functions"] - self.client = OpenAIWrapper(**self.llm_config) + self.client = OpenAIWrapper(**self.llm_config, response_format=self._response_format) def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): """update a tool_signature in the LLM configuration for tool_call. @@ -2593,7 +2596,7 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): if len(self.llm_config["tools"]) == 0: del self.llm_config["tools"] - self.client = OpenAIWrapper(**self.llm_config) + self.client = OpenAIWrapper(**self.llm_config, response_format=self._response_format) def can_execute_function(self, name: Union[List[str], str]) -> bool: """Whether the agent can execute the function.""" diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index 5e367ef31a..3bfaa75309 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -114,6 +114,9 @@ def __init__(self, **kwargs: Any): ): raise ValueError("API key or AWS credentials are required to use the Anthropic API.") + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning) + if self._api_key is not None: self._client = Anthropic(api_key=self._api_key) else: @@ -175,10 +178,7 @@ def aws_session_token(self): def aws_region(self): return self._aws_region - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Anthropic API.") - + def create(self, params: Dict[str, Any]) -> ChatCompletion: if "tools" in params: converted_functions = self.convert_tools_to_functions(params["tools"]) params["functions"] = params.get("functions", []) + converted_functions diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py index 9c04b8c9a4..21253b92d1 100644 --- a/autogen/oai/bedrock.py +++ b/autogen/oai/bedrock.py @@ -93,6 +93,9 @@ def __init__(self, **kwargs: Any): profile_name=self._aws_profile_name, ) + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning) + self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config) def message_retrieval(self, response): @@ -179,12 +182,8 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str return base_params, additional_params - def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion: + def create(self, params) -> ChatCompletion: """Run Amazon Bedrock inference and return AutoGen response""" - - if response_format is not None: - raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.") - # Set custom client class settings self.parse_custom_params(params) diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index de3739767a..20e43ef59b 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -49,13 +49,17 @@ class CerebrasClient: """Client for Cerebras's API.""" - def __init__(self, api_key=None, **kwargs): + def __init__(self, api_key=None, response_format: Optional[BaseModel] = None, **kwargs): """Requires api_key or environment variable to be set Args: api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set) """ # Ensure we have the api_key upon instantiation + + if response_format is not None: + raise NotImplementedError("Response format is not supported by Cerebras' API.") + self.api_key = api_key if not self.api_key: self.api_key = os.getenv("CEREBRAS_API_KEY") @@ -64,6 +68,9 @@ def __init__(self, api_key=None, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning) + def message_retrieval(self, response: ChatCompletion) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -112,10 +119,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cerebras_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Cerebras' API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Cerebras messages diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 13bb1b8c70..4b25ddff4c 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -33,6 +33,7 @@ # raises exception if openai>=1 is installed and something is wrong with imports from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI from openai import __version__ as OPENAIVERSION + from openai.lib._parsing._completions import type_to_response_format_param from openai.resources import Completions from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined] @@ -207,9 +208,7 @@ class Message(Protocol): choices: List[Choice] model: str - def create( - self, params: Dict[str, Any], response_format: Optional[BaseModel] = None - ) -> ModelClientResponseProtocol: ... # pragma: no cover + def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol @@ -238,42 +237,9 @@ def __init__(self, config): class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" - @staticmethod - def _convert_to_chat_completion(parsed: ParsedChatCompletion, response_format: BaseModel) -> ChatCompletion: - # Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage - def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage: - return ChatCompletionMessage( - role=parsed_message.role, - content=( - response_format.model_validate_json(parsed_message.content).format() - if isinstance(response_format, FormatterProtocol) - else parsed_message.content - ), - function_call=parsed_message.function_call, - ) - - # Convert ParsedChatCompletion to ChatCompletion - return ChatCompletion( - id=parsed.id, - choices=[ - Choice( - finish_reason=choice.finish_reason, - index=choice.index, - logprobs=choice.logprobs, - message=convert_message(choice.message), # Parse the message - ) - for choice in parsed.choices - ], - created=parsed.created, - model=parsed.model, - object=parsed.object, - service_tier=parsed.service_tier, - system_fingerprint=parsed.system_fingerprint, - usage=parsed.usage, - ) - - def __init__(self, client: Union[OpenAI, AzureOpenAI]): + def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None): self._oai_client = client + self.response_format = response_format if ( not isinstance(client, openai.AzureOpenAI) and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX) @@ -291,22 +257,29 @@ def message_retrieval( if isinstance(response, Completion): return [choice.text for choice in choices] # type: ignore [union-attr] + def _format_content(content: str) -> str: + return ( + self.response_format.model_validate_json(content).format() + if isinstance(self.response_format, FormatterProtocol) + else content + ) + if TOOL_ENABLED: return [ # type: ignore [return-value] ( choice.message # type: ignore [union-attr] if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content + else _format_content(choice.message.content) ) # type: ignore [union-attr] for choice in choices ] else: return [ # type: ignore [return-value] - choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] + choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr] for choice in choices ] - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: + def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. Args: @@ -318,15 +291,13 @@ def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = """ iostream = IOStream.get_default() - if response_format is not None: + if self.response_format is not None: def _create_or_parse(*args, **kwargs): if "stream" in kwargs: kwargs.pop("stream") - kwargs["response_format"] = response_format - return OpenAIClient._convert_to_chat_completion( - self._oai_client.beta.chat.completions.parse(*args, **kwargs), response_format - ) + kwargs["response_format"] = type_to_response_format_param(self.response_format) + return self._oai_client.chat.completions.create(*args, **kwargs) create_or_parse = _create_or_parse else: @@ -451,6 +422,8 @@ def _create_or_parse(*args, **kwargs): params["stream"] = False response = create_or_parse(**params) + print("!" * 100) + print(response) return response def cost(self, response: Union[ChatCompletion, Completion]) -> float: @@ -511,7 +484,13 @@ class OpenAIWrapper: total_usage_summary: Optional[Dict[str, Any]] = None actual_usage_summary: Optional[Dict[str, Any]] = None - def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any): + def __init__( + self, + *, + config_list: Optional[List[Dict[str, Any]]] = None, + response_format: Optional[BaseModel] = None, + **base_config: Any, + ): """ Args: config_list: a list of config dicts to override the base_config. @@ -552,6 +531,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base self._clients: List[ModelClient] = [] self._config_list: List[Dict[str, Any]] = [] + self.response_format = response_format if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying @@ -626,58 +606,58 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s if api_type is not None and api_type.startswith("azure"): self._configure_azure_openai(config, openai_config) client = AzureOpenAI(**openai_config) - self._clients.append(OpenAIClient(client)) + self._clients.append(OpenAIClient(client, response_format=self.response_format)) elif api_type is not None and api_type.startswith("cerebras"): if cerebras_import_exception: raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") - client = CerebrasClient(**openai_config) + client = CerebrasClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("google"): if gemini_import_exception: raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.") - client = GeminiClient(**openai_config) + client = GeminiClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("anthropic"): if "api_key" not in config: self._configure_openai_config_for_bedrock(config, openai_config) if anthropic_import_exception: raise ImportError("Please install `anthropic` to use Anthropic API.") - client = AnthropicClient(**openai_config) + client = AnthropicClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("mistral"): if mistral_import_exception: raise ImportError("Please install `mistralai` to use the Mistral.AI API.") - client = MistralAIClient(**openai_config) + client = MistralAIClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("together"): if together_import_exception: raise ImportError("Please install `together` to use the Together.AI API.") - client = TogetherClient(**openai_config) + client = TogetherClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("groq"): if groq_import_exception: raise ImportError("Please install `groq` to use the Groq API.") - client = GroqClient(**openai_config) + client = GroqClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("cohere"): if cohere_import_exception: raise ImportError("Please install `cohere` to use the Cohere API.") - client = CohereClient(**openai_config) + client = CohereClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("ollama"): if ollama_import_exception: raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.") - client = OllamaClient(**openai_config) + client = OllamaClient(response_format=self.response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("bedrock"): self._configure_openai_config_for_bedrock(config, openai_config) if bedrock_import_exception: raise ImportError("Please install `boto3` to use the Amazon Bedrock API.") - client = BedrockClient(**openai_config) + client = BedrockClient(response_format=self.response_format, **openai_config) self._clients.append(client) else: client = OpenAI(**openai_config) - self._clients.append(OpenAIClient(client)) + self._clients.append(OpenAIClient(client, self.response_format)) if logging_enabled(): log_new_client(client, self, openai_config) @@ -756,9 +736,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def create( - self, response_format: Optional[BaseModel] = None, **config: Any - ) -> ModelClient.ModelClientResponseProtocol: + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config. @@ -847,8 +825,8 @@ def yes_or_no_filter(context, response): with cache_client as cache: # Try to get the response from cache key = get_key( - {**params, **{"response_format": schema_json_of(response_format)}} - if response_format + {**params, **{"response_format": schema_json_of(self.response_format)}} + if self.response_format else params ) request_ts = get_current_ts() @@ -891,7 +869,7 @@ def yes_or_no_filter(context, response): continue # filter is not passed; try the next config try: request_ts = get_current_ts() - response = client.create(params, response_format=response_format) + response = client.create(params) except APITimeoutError as err: logger.debug(f"config {i} timed out", exc_info=True) if i == last: diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 6334288f40..b7d411454d 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -80,6 +80,9 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -148,10 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cohere_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Cohere's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) client_name = params.get("client_name") or "autogen-cohere" # Parse parameters to the Cohere API's parameters diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 00f8662ec3..074d8587cf 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -136,6 +136,9 @@ def __init__(self, **kwargs): "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Gemini. It will be ignored.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -160,10 +163,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Gemini's API.") - + def create(self, params: Dict) -> ChatCompletion: if self.use_vertexai: self._initialize_vertexai(**params) else: diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index 39f1af1c68..72972a13e8 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -66,6 +66,9 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Groq API. Ignoring.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -126,10 +129,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return groq_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Groq's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Groq messages diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py index efdf420ef1..c095522410 100644 --- a/autogen/oai/mistral.py +++ b/autogen/oai/mistral.py @@ -71,6 +71,9 @@ def __init__(self, **kwargs): self.api_key ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Mistral.AI, will be ignored.", UserWarning) + self._client = Mistral(api_key=self.api_key) def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]: @@ -170,10 +173,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return mistral_params - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Mistral's API.") - + def create(self, params: Dict[str, Any]) -> ChatCompletion: # 1. Parse parameters to Mistral.AI API's parameters mistral_params = self.parse_params(params) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 2fc065d701..1b7c3ced79 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -85,6 +85,8 @@ def __init__(self, **kwargs): Args: None """ + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Ollama, it will be ignored.", UserWarning) def message_retrieval(self, response) -> List: """ @@ -178,10 +180,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return ollama_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Ollama's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Are tools involved in this conversation? diff --git a/autogen/oai/together.py b/autogen/oai/together.py index 3f6d0289b3..2a228450e0 100644 --- a/autogen/oai/together.py +++ b/autogen/oai/together.py @@ -60,6 +60,9 @@ def __init__(self, **kwargs): if not self.api_key: self.api_key = os.getenv("TOGETHER_API_KEY") + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Together.AI.", UserWarning) + assert ( self.api_key ), "Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable." @@ -130,10 +133,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return together_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Together AI's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Together.AI messages diff --git a/notebook/agentchat_structured_outputs.ipynb b/notebook/agentchat_structured_outputs.ipynb index 3f5fc464b2..3b023cecdd 100644 --- a/notebook/agentchat_structured_outputs.ipynb +++ b/notebook/agentchat_structured_outputs.ipynb @@ -210,8 +210,6 @@ "metadata": {}, "outputs": [], "source": [ - "llm_config = {\"config_list\": config_list, \"cache_seed\": 42}\n", - "\n", "user_proxy = autogen.UserProxyAgent(\n", " name=\"User_proxy\",\n", " system_message=\"A human admin.\",\n", @@ -224,8 +222,39 @@ " response_format=MathReasoning,\n", ")\n", "\n", - "user_proxy.initiate_chat(assistant, message=\"how can I solve 4x + 8 = -40\", max_turns=1, summary_method=\"last_msg\")" + "user_proxy.initiate_chat(assistant, message=\"how can I solve 8x + 7 = -23\", max_turns=1, summary_method=\"last_msg\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function calling still works alongside structured output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@assistant.register_for_execution()\n", + "@assistant.register_for_llm(description=\"You can use this function call to solve addition\")\n", + "def add(x: int, y: int) -> int:\n", + " return x + y\n", + "\n", + "\n", + "user_proxy.initiate_chat(\n", + " assistant, message=\"solve 3 + 4 by calling appropriate function\", max_turns=1, summary_method=\"last_msg\"\n", + ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/test/agentchat/test_structured_output.py b/test/agentchat/test_structured_output.py index 614cb8f604..4b0eca3e37 100644 --- a/test/agentchat/test_structured_output.py +++ b/test/agentchat/test_structured_output.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock import pytest -from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice +from openai.types.chat.parsed_chat_completion import ChatCompletion, ChatCompletionMessage, Choice from pydantic import BaseModel, ValidationError from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST @@ -37,7 +37,7 @@ class ResponseModel(BaseModel): reasoning: str difficulty: float - llm_config = {"config_list": config_list, "cache_seed": 41} + llm_config = {"config_list": config_list, "cache_seed": 43} user_proxy = autogen.UserProxyAgent( name="User_proxy", @@ -85,7 +85,7 @@ def format(self) -> str: def mock_assistant(): """Set up a mocked AssistantAgent with a predefined response format.""" config_list = [{"model": "gpt-4o", "api_key": MOCK_OPEN_AI_API_KEY}] - llm_config = {"config_list": config_list, "cache_seed": 41} + llm_config = {"config_list": config_list, "cache_seed": 43} assistant = autogen.AssistantAgent( name="Assistant", @@ -94,16 +94,16 @@ def mock_assistant(): ) oai_client_mock = MagicMock() - oai_client_mock.beta.chat.completions.parse.return_value = ParsedChatCompletion[MathReasoning]( + oai_client_mock.chat.completions.create.return_value = ChatCompletion( id="some-id", created=1733302346, model="gpt-4o", object="chat.completion", choices=[ - ParsedChoice[MathReasoning]( + Choice( finish_reason="stop", index=0, - message=ParsedChatCompletionMessage[MathReasoning]( + message=ChatCompletionMessage( content='{"steps":[{"explanation":"some explanation","output":"some output"}],"final_answer":"final answer"}', role="assistant", ), diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 9e05b8a606..5976b7a46f 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -41,7 +41,7 @@ def __init__(self, config: Dict, test_hook): self.test_hook["other_params"] = self.other_params self.test_hook["max_length"] = self.max_length - def create(self, params, response_format): + def create(self, params): from types import SimpleNamespace response = SimpleNamespace() @@ -177,7 +177,7 @@ def __init__(self, config: Dict, test_hook): self.test_hook = test_hook self.test_hook["called"] = True - def create(self, params, response_format): + def create(self, params): from types import SimpleNamespace response = SimpleNamespace()