diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b738e6821d..f1fadcca27 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -86,7 +86,6 @@ def __init__( chat_messages: Optional[Dict[Agent, List[Dict]]] = None, silent: Optional[bool] = None, context_variables: Optional[Dict[str, Any]] = None, - response_format: Optional[BaseModel] = None, ): """ Args: @@ -141,7 +140,6 @@ def __init__( Note: Will maintain a reference to the passed in context variables (enabling a shared context) Only used in Swarms at this stage: https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent - response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client. """ # we change code_execution_config below and we have to make sure we don't change the input # in case of UserProxyAgent, without this we could even change the default value {} @@ -164,7 +162,6 @@ def __init__( else (lambda x: content_str(x.get("content")) == "TERMINATE") ) self.silent = silent - self._response_format = response_format # Take a copy to avoid modifying the given dict if isinstance(llm_config, dict): try: @@ -1498,7 +1495,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ messages=all_messages, cache=cache, agent=self, - response_format=self._response_format, ) extracted_response = llm_client.extract_text_or_completion_object(response)[0] diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index 8f10f203e1..3d9e1ed8fc 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -116,6 +116,9 @@ def __init__(self, **kwargs: Any): ): raise ValueError("API key or AWS credentials are required to use the Anthropic API.") + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning) + if self._api_key is not None: self._client = Anthropic(api_key=self._api_key) else: @@ -177,10 +180,7 @@ def aws_session_token(self): def aws_region(self): return self._aws_region - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Anthropic API.") - + def create(self, params: Dict[str, Any]) -> ChatCompletion: if "tools" in params: converted_functions = self.convert_tools_to_functions(params["tools"]) params["functions"] = params.get("functions", []) + converted_functions diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py index 9c04b8c9a4..21253b92d1 100644 --- a/autogen/oai/bedrock.py +++ b/autogen/oai/bedrock.py @@ -93,6 +93,9 @@ def __init__(self, **kwargs: Any): profile_name=self._aws_profile_name, ) + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning) + self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config) def message_retrieval(self, response): @@ -179,12 +182,8 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str return base_params, additional_params - def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion: + def create(self, params) -> ChatCompletion: """Run Amazon Bedrock inference and return AutoGen response""" - - if response_format is not None: - raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.") - # Set custom client class settings self.parse_custom_params(params) diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index de3739767a..7c02afcdca 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -64,6 +64,9 @@ def __init__(self, api_key=None, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning) + def message_retrieval(self, response: ChatCompletion) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -112,10 +115,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cerebras_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Cerebras' API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Cerebras messages diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 60ead265fc..de83c7c1b0 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -10,7 +10,7 @@ import logging import sys import uuid -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, runtime_checkable from pydantic import BaseModel, schema_json_of @@ -33,6 +33,7 @@ # raises exception if openai>=1 is installed and something is wrong with imports from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI from openai import __version__ as OPENAIVERSION + from openai.lib._parsing._completions import type_to_response_format_param from openai.resources import Completions from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined] @@ -207,9 +208,7 @@ class Message(Protocol): choices: List[Choice] model: str - def create( - self, params: Dict[str, Any], response_format: Optional[BaseModel] = None - ) -> ModelClientResponseProtocol: ... # pragma: no cover + def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol @@ -238,38 +237,9 @@ def __init__(self, config): class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" - @staticmethod - def _convert_to_chat_completion(parsed: ParsedChatCompletion) -> ChatCompletion: - # Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage - def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage: - return ChatCompletionMessage( - role=parsed_message.role, - content=parsed_message.content, - function_call=parsed_message.function_call, - ) - - # Convert ParsedChatCompletion to ChatCompletion - return ChatCompletion( - id=parsed.id, - choices=[ - Choice( - finish_reason=choice.finish_reason, - index=choice.index, - logprobs=choice.logprobs, - message=convert_message(choice.message), # Parse the message - ) - for choice in parsed.choices - ], - created=parsed.created, - model=parsed.model, - object=parsed.object, - service_tier=parsed.service_tier, - system_fingerprint=parsed.system_fingerprint, - usage=parsed.usage, - ) - - def __init__(self, client: Union[OpenAI, AzureOpenAI]): + def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None): self._oai_client = client + self.response_format = response_format if ( not isinstance(client, openai.AzureOpenAI) and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX) @@ -287,22 +257,29 @@ def message_retrieval( if isinstance(response, Completion): return [choice.text for choice in choices] # type: ignore [union-attr] + def _format_content(content: str) -> str: + return ( + self.response_format.model_validate_json(content).format() + if isinstance(self.response_format, FormatterProtocol) + else content + ) + if TOOL_ENABLED: return [ # type: ignore [return-value] ( choice.message # type: ignore [union-attr] if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content + else _format_content(choice.message.content) ) # type: ignore [union-attr] for choice in choices ] else: return [ # type: ignore [return-value] - choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] + choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr] for choice in choices ] - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: + def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. Args: @@ -314,15 +291,13 @@ def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = """ iostream = IOStream.get_default() - if response_format is not None: + if self.response_format is not None: def _create_or_parse(*args, **kwargs): if "stream" in kwargs: kwargs.pop("stream") - kwargs["response_format"] = response_format - return OpenAIClient._convert_to_chat_completion( - self._oai_client.beta.chat.completions.parse(*args, **kwargs) - ) + kwargs["response_format"] = type_to_response_format_param(self.response_format) + return self._oai_client.chat.completions.create(*args, **kwargs) create_or_parse = _create_or_parse else: @@ -480,6 +455,11 @@ def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: } +@runtime_checkable +class FormatterProtocol(Protocol): + def format(self) -> str: ... + + class OpenAIWrapper: """A wrapper class for openai client.""" @@ -502,7 +482,12 @@ class OpenAIWrapper: total_usage_summary: Optional[Dict[str, Any]] = None actual_usage_summary: Optional[Dict[str, Any]] = None - def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any): + def __init__( + self, + *, + config_list: Optional[List[Dict[str, Any]]] = None, + **base_config: Any, + ): """ Args: config_list: a list of config dicts to override the base_config. @@ -605,6 +590,7 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} api_type = config.get("api_type") model_client_cls_name = config.get("model_client_cls") + response_format = config.get("response_format") if model_client_cls_name is not None: # a config for a custom client is set # adding placeholder until the register_model_client is called with the appropriate class @@ -617,58 +603,58 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s if api_type is not None and api_type.startswith("azure"): self._configure_azure_openai(config, openai_config) client = AzureOpenAI(**openai_config) - self._clients.append(OpenAIClient(client)) + self._clients.append(OpenAIClient(client, response_format=response_format)) elif api_type is not None and api_type.startswith("cerebras"): if cerebras_import_exception: raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") - client = CerebrasClient(**openai_config) + client = CerebrasClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("google"): if gemini_import_exception: raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.") - client = GeminiClient(**openai_config) + client = GeminiClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("anthropic"): if "api_key" not in config: self._configure_openai_config_for_bedrock(config, openai_config) if anthropic_import_exception: raise ImportError("Please install `anthropic` to use Anthropic API.") - client = AnthropicClient(**openai_config) + client = AnthropicClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("mistral"): if mistral_import_exception: raise ImportError("Please install `mistralai` to use the Mistral.AI API.") - client = MistralAIClient(**openai_config) + client = MistralAIClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("together"): if together_import_exception: raise ImportError("Please install `together` to use the Together.AI API.") - client = TogetherClient(**openai_config) + client = TogetherClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("groq"): if groq_import_exception: raise ImportError("Please install `groq` to use the Groq API.") - client = GroqClient(**openai_config) + client = GroqClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("cohere"): if cohere_import_exception: raise ImportError("Please install `cohere` to use the Cohere API.") - client = CohereClient(**openai_config) + client = CohereClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("ollama"): if ollama_import_exception: raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.") - client = OllamaClient(**openai_config) + client = OllamaClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("bedrock"): self._configure_openai_config_for_bedrock(config, openai_config) if bedrock_import_exception: raise ImportError("Please install `boto3` to use the Amazon Bedrock API.") - client = BedrockClient(**openai_config) + client = BedrockClient(response_format=response_format, **openai_config) self._clients.append(client) else: client = OpenAI(**openai_config) - self._clients.append(OpenAIClient(client)) + self._clients.append(OpenAIClient(client, response_format)) if logging_enabled(): log_new_client(client, self, openai_config) @@ -747,9 +733,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def create( - self, response_format: Optional[BaseModel] = None, **config: Any - ) -> ModelClient.ModelClientResponseProtocol: + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config. @@ -838,8 +822,8 @@ def yes_or_no_filter(context, response): with cache_client as cache: # Try to get the response from cache key = get_key( - {**params, **{"response_format": schema_json_of(response_format)}} - if response_format + {**params, **{"response_format": schema_json_of(params["response_format"])}} + if "response_format" in params else params ) request_ts = get_current_ts() @@ -882,7 +866,7 @@ def yes_or_no_filter(context, response): continue # filter is not passed; try the next config try: request_ts = get_current_ts() - response = client.create(params, response_format=response_format) + response = client.create(params) except APITimeoutError as err: logger.debug(f"config {i} timed out", exc_info=True) if i == last: @@ -944,6 +928,7 @@ def yes_or_no_filter(context, response): actual_usage = client.get_usage(response) total_usage = actual_usage.copy() if actual_usage is not None else total_usage self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + if cache_client is not None: # Cache the response with cache_client as cache: diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 6334288f40..b7d411454d 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -80,6 +80,9 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -148,10 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return cohere_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Cohere's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) client_name = params.get("client_name") or "autogen-cohere" # Parse parameters to the Cohere API's parameters diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 00f8662ec3..074d8587cf 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -136,6 +136,9 @@ def __init__(self, **kwargs): "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Gemini. It will be ignored.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -160,10 +163,7 @@ def get_usage(response) -> Dict: "model": response.model, } - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Gemini's API.") - + def create(self, params: Dict) -> ChatCompletion: if self.use_vertexai: self._initialize_vertexai(**params) else: diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index 39f1af1c68..ea560e7ea0 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -66,6 +66,9 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Groq API, it will be ignored.", UserWarning) + def message_retrieval(self, response) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -126,10 +129,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return groq_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Groq's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Groq messages diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py index efdf420ef1..022210c9aa 100644 --- a/autogen/oai/mistral.py +++ b/autogen/oai/mistral.py @@ -71,6 +71,9 @@ def __init__(self, **kwargs): self.api_key ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable." + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Mistral.AI, it will be ignored.", UserWarning) + self._client = Mistral(api_key=self.api_key) def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]: @@ -170,10 +173,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return mistral_params - def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Mistral's API.") - + def create(self, params: Dict[str, Any]) -> ChatCompletion: # 1. Parse parameters to Mistral.AI API's parameters mistral_params = self.parse_params(params) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 2fc065d701..1b7c3ced79 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -85,6 +85,8 @@ def __init__(self, **kwargs): Args: None """ + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Ollama, it will be ignored.", UserWarning) def message_retrieval(self, response) -> List: """ @@ -178,10 +180,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return ollama_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Ollama's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Are tools involved in this conversation? diff --git a/autogen/oai/together.py b/autogen/oai/together.py index 3f6d0289b3..a823155dd7 100644 --- a/autogen/oai/together.py +++ b/autogen/oai/together.py @@ -60,6 +60,9 @@ def __init__(self, **kwargs): if not self.api_key: self.api_key = os.getenv("TOGETHER_API_KEY") + if "response_format" in kwargs and kwargs["response_format"] is not None: + warnings.warn("response_format is not supported for Together.AI, it will be ignored.", UserWarning) + assert ( self.api_key ), "Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable." @@ -130,10 +133,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return together_params - def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion: - if response_format is not None: - raise NotImplementedError("Response format is not supported by Together AI's API.") - + def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) # Convert AutoGen messages to Together.AI messages diff --git a/notebook/agentchat_structured_outputs.ipynb b/notebook/agentchat_structured_outputs.ipynb index bc03af521a..180170879b 100644 --- a/notebook/agentchat_structured_outputs.ipynb +++ b/notebook/agentchat_structured_outputs.ipynb @@ -7,16 +7,16 @@ "source": [ "# Structured output\n", "\n", - "OpenAI offers a functionality for defining a structure of the messages generated by LLMs, AutoGen enables this functionality by propagating `response_format` passed to your agents to the underlying client.\n", + "OpenAI offers functionality for defining a structure of the messages generated by LLMs, AG2 enables this functionality by propagating `response_format`, in the LLM configuration for your agents, to the underlying client. This is currently only supported by OpenAI.\n", "\n", "For more info on structured output, please check [here](https://platform.openai.com/docs/guides/structured-outputs)\n", "\n", "\n", "````{=mdx}\n", ":::info Requirements\n", - "Install `pyautogen`:\n", + "Install `ag2`:\n", "```bash\n", - "pip install pyautogen\n", + "pip install ag2\n", "```\n", "\n", "For more information, please refer to the [installation guide](/docs/installation/).\n", @@ -31,14 +31,23 @@ "source": [ "## Set your API Endpoint\n", "\n", - "The [`config_list_from_json`](https://ag2ai.github.io/ag2/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file." + "The [`config_list_from_json`](https://ag2ai.github.io/ag2/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file. Structured Output is supported by OpenAI's models from gpt-4-0613 and gpt-3.5-turbo-0613." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import autogen\n", "\n", @@ -68,7 +77,7 @@ "source": [ "## Example: math reasoning\n", "\n", - "Using structured output, we can enforce chain-of-thought reasoning in the model to output an answer in a structured, step-by-step way" + "Using structured output, we can enforce chain-of-thought reasoning in the model to output an answer in a structured, step-by-step way." ] }, { @@ -77,7 +86,7 @@ "source": [ "### Define the reasoning model\n", "\n", - "First we will define the math reasoning model. This model will indirectly force the LLM to solve the posed math problems iteratively trough math reasoning steps." + "First we will define the math reasoning model. This model will indirectly force the LLM to solve the posed math problems iteratively through math reasoning steps." ] }, { @@ -93,11 +102,31 @@ " explanation: str\n", " output: str\n", "\n", + "\n", "class MathReasoning(BaseModel):\n", " steps: list[Step]\n", " final_answer: str" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Applying the Response Format\n", + "\n", + "The `response_format` is added to the LLM configuration and then this configuration is applied to the agent." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "for config in config_list:\n", + " config[\"response_format\"] = MathReasoning" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -112,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -126,8 +155,7 @@ "\n", "assistant = autogen.AssistantAgent(\n", " name=\"Math_solver\",\n", - " llm_config=llm_config,\n", - " response_format=MathReasoning,\n", + " llm_config=llm_config, # Response Format is in the configuration\n", ")" ] }, @@ -140,10 +168,100 @@ "Let's now start the chat and prompt the assistant to solve a simple equation. The assistant agent should return a response solving the equation using a step-by-step `MathReasoning` model." ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser_proxy\u001b[0m (to Math_solver):\n", + "\n", + "how can I solve 8x + 7 = -23\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mMath_solver\u001b[0m (to User_proxy):\n", + "\n", + "{\"steps\":[{\"explanation\":\"To isolate the term with x, we first subtract 7 from both sides of the equation.\",\"output\":\"8x + 7 - 7 = -23 - 7 -> 8x = -30.\"},{\"explanation\":\"Now that we have 8x = -30, we divide both sides by 8 to solve for x.\",\"output\":\"x = -30 / 8 -> x = -3.75.\"}],\"final_answer\":\"x = -3.75\"}\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': 'how can I solve 8x + 7 = -23', 'role': 'assistant', 'name': 'User_proxy'}, {'content': '{\"steps\":[{\"explanation\":\"To isolate the term with x, we first subtract 7 from both sides of the equation.\",\"output\":\"8x + 7 - 7 = -23 - 7 -> 8x = -30.\"},{\"explanation\":\"Now that we have 8x = -30, we divide both sides by 8 to solve for x.\",\"output\":\"x = -30 / 8 -> x = -3.75.\"}],\"final_answer\":\"x = -3.75\"}', 'role': 'user', 'name': 'Math_solver'}], summary='{\"steps\":[{\"explanation\":\"To isolate the term with x, we first subtract 7 from both sides of the equation.\",\"output\":\"8x + 7 - 7 = -23 - 7 -> 8x = -30.\"},{\"explanation\":\"Now that we have 8x = -30, we divide both sides by 8 to solve for x.\",\"output\":\"x = -30 / 8 -> x = -3.75.\"}],\"final_answer\":\"x = -3.75\"}', cost={'usage_including_cached_inference': {'total_cost': 0.00015089999999999998, 'gpt-4o-mini-2024-07-18': {'cost': 0.00015089999999999998, 'prompt_tokens': 582, 'completion_tokens': 106, 'total_tokens': 688}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "user_proxy.initiate_chat(assistant, message=\"how can I solve 8x + 7 = -23\", max_turns=1, summary_method=\"last_msg\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Formatting a response\n", + "\n", + "When defining a `response_format`, you have the flexibility to customize how the output is parsed and presented, making it more user-friendly. To demonstrate this, we’ll add a `format` method to our `MathReasoning` model. This method will define the logic for transforming the raw JSON response into a more human-readable and accessible format." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the reasoning model\n", + "\n", + "Let’s redefine the `MathReasoning` model to include a `format` method. This method will allow the underlying client to parse the return value from the LLM into a more human-readable format. If the `format` method is not defined, the client will default to returning the model’s JSON representation, as demonstrated in the previous example." + ] + }, { "cell_type": "code", "execution_count": 6, "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel\n", + "\n", + "\n", + "class Step(BaseModel):\n", + " explanation: str\n", + " output: str\n", + "\n", + "\n", + "class MathReasoning(BaseModel):\n", + " steps: list[Step]\n", + " final_answer: str\n", + "\n", + " def format(self) -> str:\n", + " steps_output = \"\\n\".join(\n", + " f\"Step {i + 1}: {step.explanation}\\n Output: {step.output}\" for i, step in enumerate(self.steps)\n", + " )\n", + " return f\"{steps_output}\\n\\nFinal Answer: {self.final_answer}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define chat actors and start the chat\n", + "\n", + "The rest of the process is the same as in the previous example: define the actors and start the chat.\n", + "\n", + "Observe how the Math_solver agent now communicates using the format we have defined in our `MathReasoning.format` method." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -154,10 +272,14 @@ "how can I solve 8x + 7 = -23\n", "\n", "--------------------------------------------------------------------------------\n", - "[autogen.oai.client: 11-29 23:42:48] {921} INFO - Failed to cache response: Can't pickle : attribute lookup ParsedChatCompletion[MathReasoning] on openai.types.chat.parsed_chat_completion failed\n", "\u001b[33mMath_solver\u001b[0m (to User_proxy):\n", "\n", - "{\"steps\":[{\"explanation\":\"We need to isolate the variable x on one side of the equation. We start by getting rid of the constant on the left side of the equation. In this case, the constant is 7, which is added to 8x. We can subtract 7 from both sides of the equation to eliminate this constant.\",\"output\":\"8x + 7 - 7 = -23 - 7\\n8x = -30\"},{\"explanation\":\"Now that we have 8x = -30, we want to solve for x. To do this, we need to divide both sides of the equation by 8, which is the coefficient of x.\",\"output\":\"8x / 8 = -30 / 8\\nx = -30 / 8\"},{\"explanation\":\"Next, we simplify the fraction -30 / 8. Both the numerator and the denominator can be divided by 2 to simplify the fraction.\",\"output\":\"x = -30 / 8 = -15 / 4\"},{\"explanation\":\"The simplified fraction -15 / 4 can also be written as a decimal if needed.\",\"output\":\"x = -15 / 4 = -3.75\"}],\"final_answer\":\"x = -15/4 or x = -3.75\"}\n", + "Step 1: To isolate the term with x, we first subtract 7 from both sides of the equation.\n", + " Output: 8x + 7 - 7 = -23 - 7 -> 8x = -30.\n", + "Step 2: Now that we have 8x = -30, we divide both sides by 8 to solve for x.\n", + " Output: x = -30 / 8 -> x = -3.75.\n", + "\n", + "Final Answer: x = -3.75\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -165,17 +287,85 @@ { "data": { "text/plain": [ - "ChatResult(chat_id=None, chat_history=[{'content': 'how can I solve 8x + 7 = -23', 'role': 'assistant', 'name': 'User_proxy'}, {'content': '{\"steps\":[{\"explanation\":\"We need to isolate the variable x on one side of the equation. We start by getting rid of the constant on the left side of the equation. In this case, the constant is 7, which is added to 8x. We can subtract 7 from both sides of the equation to eliminate this constant.\",\"output\":\"8x + 7 - 7 = -23 - 7\\\\n8x = -30\"},{\"explanation\":\"Now that we have 8x = -30, we want to solve for x. To do this, we need to divide both sides of the equation by 8, which is the coefficient of x.\",\"output\":\"8x / 8 = -30 / 8\\\\nx = -30 / 8\"},{\"explanation\":\"Next, we simplify the fraction -30 / 8. Both the numerator and the denominator can be divided by 2 to simplify the fraction.\",\"output\":\"x = -30 / 8 = -15 / 4\"},{\"explanation\":\"The simplified fraction -15 / 4 can also be written as a decimal if needed.\",\"output\":\"x = -15 / 4 = -3.75\"}],\"final_answer\":\"x = -15/4 or x = -3.75\"}', 'tool_calls': [], 'role': 'user', 'name': 'Math_solver'}], summary='{\"steps\":[{\"explanation\":\"We need to isolate the variable x on one side of the equation. We start by getting rid of the constant on the left side of the equation. In this case, the constant is 7, which is added to 8x. We can subtract 7 from both sides of the equation to eliminate this constant.\",\"output\":\"8x + 7 - 7 = -23 - 7\\\\n8x = -30\"},{\"explanation\":\"Now that we have 8x = -30, we want to solve for x. To do this, we need to divide both sides of the equation by 8, which is the coefficient of x.\",\"output\":\"8x / 8 = -30 / 8\\\\nx = -30 / 8\"},{\"explanation\":\"Next, we simplify the fraction -30 / 8. Both the numerator and the denominator can be divided by 2 to simplify the fraction.\",\"output\":\"x = -30 / 8 = -15 / 4\"},{\"explanation\":\"The simplified fraction -15 / 4 can also be written as a decimal if needed.\",\"output\":\"x = -15 / 4 = -3.75\"}],\"final_answer\":\"x = -15/4 or x = -3.75\"}', cost={'usage_including_cached_inference': {'total_cost': 0.004085, 'gpt-4o-2024-08-06': {'cost': 0.004085, 'prompt_tokens': 582, 'completion_tokens': 263, 'total_tokens': 845}}, 'usage_excluding_cached_inference': {'total_cost': 0.004085, 'gpt-4o-2024-08-06': {'cost': 0.004085, 'prompt_tokens': 582, 'completion_tokens': 263, 'total_tokens': 845}}}, human_input=[])" + "ChatResult(chat_id=None, chat_history=[{'content': 'how can I solve 8x + 7 = -23', 'role': 'assistant', 'name': 'User_proxy'}, {'content': 'Step 1: To isolate the term with x, we first subtract 7 from both sides of the equation.\\n Output: 8x + 7 - 7 = -23 - 7 -> 8x = -30.\\nStep 2: Now that we have 8x = -30, we divide both sides by 8 to solve for x.\\n Output: x = -30 / 8 -> x = -3.75.\\n\\nFinal Answer: x = -3.75', 'role': 'user', 'name': 'Math_solver'}], summary='Step 1: To isolate the term with x, we first subtract 7 from both sides of the equation.\\n Output: 8x + 7 - 7 = -23 - 7 -> 8x = -30.\\nStep 2: Now that we have 8x = -30, we divide both sides by 8 to solve for x.\\n Output: x = -30 / 8 -> x = -3.75.\\n\\nFinal Answer: x = -3.75', cost={'usage_including_cached_inference': {'total_cost': 0.00015089999999999998, 'gpt-4o-mini-2024-07-18': {'cost': 0.00015089999999999998, 'prompt_tokens': 582, 'completion_tokens': 106, 'total_tokens': 688}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "for config in config_list:\n", + " config[\"response_format\"] = MathReasoning\n", + "\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"User_proxy\",\n", + " system_message=\"A human admin.\",\n", + " human_input_mode=\"NEVER\",\n", + ")\n", + "\n", + "assistant = autogen.AssistantAgent(\n", + " name=\"Math_solver\",\n", + " llm_config=llm_config,\n", + ")\n", + "\n", "user_proxy.initiate_chat(assistant, message=\"how can I solve 8x + 7 = -23\", max_turns=1, summary_method=\"last_msg\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Normal function calling still works alongside structured output, so your agent can have a response format while still calling tools." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser_proxy\u001b[0m (to Math_solver):\n", + "\n", + "solve 3 + 4 by calling appropriate function\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mMath_solver\u001b[0m (to User_proxy):\n", + "\n", + "\u001b[32m***** Suggested tool call (call_oTp96rVzs2kAOwGhBM5rJDcW): add *****\u001b[0m\n", + "Arguments: \n", + "{\"x\":3,\"y\":4}\n", + "\u001b[32m********************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': 'solve 3 + 4 by calling appropriate function', 'role': 'assistant', 'name': 'User_proxy'}, {'tool_calls': [{'id': 'call_oTp96rVzs2kAOwGhBM5rJDcW', 'function': {'arguments': '{\"x\":3,\"y\":4}', 'name': 'add'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}], summary='', cost={'usage_including_cached_inference': {'total_cost': 0.0001029, 'gpt-4o-mini-2024-07-18': {'cost': 0.0001029, 'prompt_tokens': 618, 'completion_tokens': 17, 'total_tokens': 635}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@assistant.register_for_execution()\n", + "@assistant.register_for_llm(description=\"You can use this function call to solve addition\")\n", + "def add(x: int, y: int) -> int:\n", + " return x + y\n", + "\n", + "\n", + "user_proxy.initiate_chat(\n", + " assistant, message=\"solve 3 + 4 by calling appropriate function\", max_turns=1, summary_method=\"last_msg\"\n", + ")" + ] } ], "metadata": { @@ -186,7 +376,7 @@ ] }, "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -200,7 +390,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.20" + "version": "3.11.10" }, "orig_nbformat": 4 }, diff --git a/test/agentchat/test_structured_output.py b/test/agentchat/test_structured_output.py index 26a3fad62a..d99b2a63d6 100644 --- a/test/agentchat/test_structured_output.py +++ b/test/agentchat/test_structured_output.py @@ -7,15 +7,18 @@ import os import sys +from typing import List +from unittest.mock import MagicMock import pytest +from openai.types.chat.parsed_chat_completion import ChatCompletion, ChatCompletionMessage, Choice from pydantic import BaseModel, ValidationError from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST import autogen sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from conftest import reason, skip_openai # noqa: E402 +from conftest import MOCK_OPEN_AI_API_KEY, reason, skip_openai # noqa: E402 @pytest.mark.skipif(skip_openai, reason=reason) @@ -34,7 +37,10 @@ class ResponseModel(BaseModel): reasoning: str difficulty: float - llm_config = {"config_list": config_list, "cache_seed": 41} + for config in config_list: + config["response_format"] = ResponseModel + + llm_config = {"config_list": config_list, "cache_seed": 43} user_proxy = autogen.UserProxyAgent( name="User_proxy", @@ -45,7 +51,6 @@ class ResponseModel(BaseModel): assistant = autogen.AssistantAgent( name="Assistant", llm_config=llm_config, - response_format=ResponseModel, ) chat_result = user_proxy.initiate_chat( @@ -59,3 +64,72 @@ class ResponseModel(BaseModel): ResponseModel.model_validate_json(chat_result.chat_history[-1]["content"]) except ValidationError as e: raise AssertionError(f"Agent did not return a structured report. Exception: {e}") + + +# Helper classes for testing +class Step(BaseModel): + explanation: str + output: str + + +class MathReasoning(BaseModel): + steps: List[Step] + final_answer: str + + def format(self) -> str: + steps_output = "\n".join( + f"Step {i + 1}: {step.explanation}\n Output: {step.output}" for i, step in enumerate(self.steps) + ) + return f"{steps_output}\n\nFinal Answer: {self.final_answer}" + + +@pytest.fixture +def mock_assistant(): + """Set up a mocked AssistantAgent with a predefined response format.""" + config_list = [{"model": "gpt-4o", "api_key": MOCK_OPEN_AI_API_KEY, "response_format": MathReasoning}] + llm_config = {"config_list": config_list, "cache_seed": 43} + + assistant = autogen.AssistantAgent( + name="Assistant", + llm_config=llm_config, + ) + + oai_client_mock = MagicMock() + oai_client_mock.chat.completions.create.return_value = ChatCompletion( + id="some-id", + created=1733302346, + model="gpt-4o", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content='{"steps":[{"explanation":"some explanation","output":"some output"}],"final_answer":"final answer"}', + role="assistant", + ), + ) + ], + ) + assistant.client._clients[0]._oai_client = oai_client_mock + + return assistant + + +def test_structured_output_formatting(mock_assistant): + """Test that the AssistantAgent correctly formats structured output.""" + user_proxy = autogen.UserProxyAgent( + name="User_proxy", + system_message="A human admin.", + human_input_mode="NEVER", + ) + + chat_result = user_proxy.initiate_chat( + mock_assistant, + message="What is the square root of 4?", + max_turns=1, + summary_method="last_msg", + ) + + expected_output = "Step 1: some explanation\n Output: some output\n\nFinal Answer: final answer" + assert chat_result.chat_history[-1]["content"] == expected_output diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 9e05b8a606..5976b7a46f 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -41,7 +41,7 @@ def __init__(self, config: Dict, test_hook): self.test_hook["other_params"] = self.other_params self.test_hook["max_length"] = self.max_length - def create(self, params, response_format): + def create(self, params): from types import SimpleNamespace response = SimpleNamespace() @@ -177,7 +177,7 @@ def __init__(self, config: Dict, test_hook): self.test_hook = test_hook self.test_hook["called"] = True - def create(self, params, response_format): + def create(self, params): from types import SimpleNamespace response = SimpleNamespace()