Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement custom formatting in response_format #145

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
response_format: Optional[BaseModel] = None,
):
"""
Args:
Expand Down Expand Up @@ -141,7 +140,6 @@ def __init__(
Note: Will maintain a reference to the passed in context variables (enabling a shared context)
Only used in Swarms at this stage:
https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent
response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand All @@ -164,7 +162,6 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._response_format = response_format
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
Expand Down Expand Up @@ -1498,7 +1495,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
messages=all_messages,
cache=cache,
agent=self,
response_format=self._response_format,
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self, **kwargs: Any):
):
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning)

if self._api_key is not None:
self._client = Anthropic(api_key=self._api_key)
else:
Expand Down Expand Up @@ -177,10 +180,7 @@ def aws_session_token(self):
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Anthropic API.")

def create(self, params: Dict[str, Any]) -> ChatCompletion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
Expand Down
9 changes: 4 additions & 5 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(self, **kwargs: Any):
profile_name=self._aws_profile_name,
)

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)

self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)

def message_retrieval(self, response):
Expand Down Expand Up @@ -179,12 +182,8 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str

return base_params, additional_params

def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AutoGen response"""

if response_format is not None:
raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.")

# Set custom client class settings
self.parse_custom_params(params)

Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(self, api_key=None, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning)

def message_retrieval(self, response: ChatCompletion) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -112,10 +115,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cerebras_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])

# Convert AutoGen messages to Cerebras messages
Expand Down
107 changes: 46 additions & 61 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import sys
import uuid
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, runtime_checkable

from pydantic import BaseModel, schema_json_of

Expand All @@ -33,6 +33,7 @@
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
from openai import __version__ as OPENAIVERSION
from openai.lib._parsing._completions import type_to_response_format_param
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -207,9 +208,7 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(
self, params: Dict[str, Any], response_format: Optional[BaseModel] = None
) -> ModelClientResponseProtocol: ... # pragma: no cover
def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand Down Expand Up @@ -238,38 +237,9 @@ def __init__(self, config):
class OpenAIClient:
"""Follows the Client protocol and wraps the OpenAI client."""

@staticmethod
def _convert_to_chat_completion(parsed: ParsedChatCompletion) -> ChatCompletion:
# Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage
def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage:
return ChatCompletionMessage(
role=parsed_message.role,
content=parsed_message.content,
function_call=parsed_message.function_call,
)

# Convert ParsedChatCompletion to ChatCompletion
return ChatCompletion(
id=parsed.id,
choices=[
Choice(
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=choice.logprobs,
message=convert_message(choice.message), # Parse the message
)
for choice in parsed.choices
],
created=parsed.created,
model=parsed.model,
object=parsed.object,
service_tier=parsed.service_tier,
system_fingerprint=parsed.system_fingerprint,
usage=parsed.usage,
)

def __init__(self, client: Union[OpenAI, AzureOpenAI]):
def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None):
self._oai_client = client
self.response_format = response_format
if (
not isinstance(client, openai.AzureOpenAI)
and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
Expand All @@ -287,22 +257,29 @@ def message_retrieval(
if isinstance(response, Completion):
return [choice.text for choice in choices] # type: ignore [union-attr]

def _format_content(content: str) -> str:
return (
self.response_format.model_validate_json(content).format()
if isinstance(self.response_format, FormatterProtocol)
else content
)

if TOOL_ENABLED:
return [ # type: ignore [return-value]
(
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content
else _format_content(choice.message.content)
) # type: ignore [union-attr]
for choice in choices
]
else:
return [ # type: ignore [return-value]
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr]
for choice in choices
]

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
Expand All @@ -314,15 +291,13 @@ def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] =
"""
iostream = IOStream.get_default()

if response_format is not None:
if self.response_format is not None:

def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return OpenAIClient._convert_to_chat_completion(
self._oai_client.beta.chat.completions.parse(*args, **kwargs)
)
kwargs["response_format"] = type_to_response_format_param(self.response_format)
return self._oai_client.chat.completions.create(*args, **kwargs)

create_or_parse = _create_or_parse
else:
Expand Down Expand Up @@ -480,6 +455,11 @@ def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
}


@runtime_checkable
class FormatterProtocol(Protocol):
def format(self) -> str: ...


class OpenAIWrapper:
"""A wrapper class for openai client."""

Expand All @@ -502,7 +482,12 @@ class OpenAIWrapper:
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
def __init__(
self,
*,
config_list: Optional[List[Dict[str, Any]]] = None,
**base_config: Any,
):
"""
Args:
config_list: a list of config dicts to override the base_config.
Expand Down Expand Up @@ -605,6 +590,7 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
api_type = config.get("api_type")
model_client_cls_name = config.get("model_client_cls")
response_format = config.get("response_format")
if model_client_cls_name is not None:
# a config for a custom client is set
# adding placeholder until the register_model_client is called with the appropriate class
Expand All @@ -617,58 +603,58 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
if api_type is not None and api_type.startswith("azure"):
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, response_format=response_format))
elif api_type is not None and api_type.startswith("cerebras"):
if cerebras_import_exception:
raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
client = CerebrasClient(**openai_config)
client = CerebrasClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
client = GeminiClient(**openai_config)
client = GeminiClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
if "api_key" not in config:
self._configure_openai_config_for_bedrock(config, openai_config)
if anthropic_import_exception:
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
client = AnthropicClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("mistral"):
if mistral_import_exception:
raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
client = MistralAIClient(**openai_config)
client = MistralAIClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
client = TogetherClient(**openai_config)
client = TogetherClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("groq"):
if groq_import_exception:
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
client = GroqClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("cohere"):
if cohere_import_exception:
raise ImportError("Please install `cohere` to use the Cohere API.")
client = CohereClient(**openai_config)
client = CohereClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("ollama"):
if ollama_import_exception:
raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
client = OllamaClient(**openai_config)
client = OllamaClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("bedrock"):
self._configure_openai_config_for_bedrock(config, openai_config)
if bedrock_import_exception:
raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
client = BedrockClient(**openai_config)
client = BedrockClient(response_format=response_format, **openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, response_format))

if logging_enabled():
log_new_client(client, self, openai_config)
Expand Down Expand Up @@ -747,9 +733,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def create(
self, response_format: Optional[BaseModel] = None, **config: Any
) -> ModelClient.ModelClientResponseProtocol:
def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
Expand Down Expand Up @@ -838,8 +822,8 @@ def yes_or_no_filter(context, response):
with cache_client as cache:
# Try to get the response from cache
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
{**params, **{"response_format": schema_json_of(params["response_format"])}}
if "response_format" in params
else params
)
request_ts = get_current_ts()
Expand Down Expand Up @@ -882,7 +866,7 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
try:
request_ts = get_current_ts()
response = client.create(params, response_format=response_format)
response = client.create(params)
except APITimeoutError as err:
logger.debug(f"config {i} timed out", exc_info=True)
if i == last:
Expand Down Expand Up @@ -944,6 +928,7 @@ def yes_or_no_filter(context, response):
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)

if cache_client is not None:
# Cache the response
with cache_client as cache:
Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning)

def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -148,10 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cohere_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cohere's API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
Expand Down
Loading
Loading