Skip to content

Commit

Permalink
Merge pull request #145 from ag2ai/143-feature-request-add-formatting…
Browse files Browse the repository at this point in the history
…-option-to-response_format

Implement custom formatting in response_format
  • Loading branch information
marklysze authored Dec 5, 2024
2 parents a366da3 + 11fd1c2 commit 7736885
Show file tree
Hide file tree
Showing 14 changed files with 366 additions and 123 deletions.
4 changes: 0 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
response_format: Optional[BaseModel] = None,
):
"""
Args:
Expand Down Expand Up @@ -141,7 +140,6 @@ def __init__(
Note: Will maintain a reference to the passed in context variables (enabling a shared context)
Only used in Swarms at this stage:
https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent
response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand All @@ -164,7 +162,6 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._response_format = response_format
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
Expand Down Expand Up @@ -1498,7 +1495,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
messages=all_messages,
cache=cache,
agent=self,
response_format=self._response_format,
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self, **kwargs: Any):
):
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning)

if self._api_key is not None:
self._client = Anthropic(api_key=self._api_key)
else:
Expand Down Expand Up @@ -177,10 +180,7 @@ def aws_session_token(self):
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Anthropic API.")

def create(self, params: Dict[str, Any]) -> ChatCompletion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
Expand Down
9 changes: 4 additions & 5 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(self, **kwargs: Any):
profile_name=self._aws_profile_name,
)

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)

self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)

def message_retrieval(self, response):
Expand Down Expand Up @@ -179,12 +182,8 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str

return base_params, additional_params

def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AutoGen response"""

if response_format is not None:
raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.")

# Set custom client class settings
self.parse_custom_params(params)

Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(self, api_key=None, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning)

def message_retrieval(self, response: ChatCompletion) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -112,10 +115,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cerebras_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])

# Convert AutoGen messages to Cerebras messages
Expand Down
107 changes: 46 additions & 61 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import sys
import uuid
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, runtime_checkable

from pydantic import BaseModel, schema_json_of

Expand All @@ -33,6 +33,7 @@
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
from openai import __version__ as OPENAIVERSION
from openai.lib._parsing._completions import type_to_response_format_param
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -207,9 +208,7 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(
self, params: Dict[str, Any], response_format: Optional[BaseModel] = None
) -> ModelClientResponseProtocol: ... # pragma: no cover
def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand Down Expand Up @@ -238,38 +237,9 @@ def __init__(self, config):
class OpenAIClient:
"""Follows the Client protocol and wraps the OpenAI client."""

@staticmethod
def _convert_to_chat_completion(parsed: ParsedChatCompletion) -> ChatCompletion:
# Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage
def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage:
return ChatCompletionMessage(
role=parsed_message.role,
content=parsed_message.content,
function_call=parsed_message.function_call,
)

# Convert ParsedChatCompletion to ChatCompletion
return ChatCompletion(
id=parsed.id,
choices=[
Choice(
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=choice.logprobs,
message=convert_message(choice.message), # Parse the message
)
for choice in parsed.choices
],
created=parsed.created,
model=parsed.model,
object=parsed.object,
service_tier=parsed.service_tier,
system_fingerprint=parsed.system_fingerprint,
usage=parsed.usage,
)

def __init__(self, client: Union[OpenAI, AzureOpenAI]):
def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None):
self._oai_client = client
self.response_format = response_format
if (
not isinstance(client, openai.AzureOpenAI)
and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
Expand All @@ -287,22 +257,29 @@ def message_retrieval(
if isinstance(response, Completion):
return [choice.text for choice in choices] # type: ignore [union-attr]

def _format_content(content: str) -> str:
return (
self.response_format.model_validate_json(content).format()
if isinstance(self.response_format, FormatterProtocol)
else content
)

if TOOL_ENABLED:
return [ # type: ignore [return-value]
(
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content
else _format_content(choice.message.content)
) # type: ignore [union-attr]
for choice in choices
]
else:
return [ # type: ignore [return-value]
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr]
for choice in choices
]

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
Expand All @@ -314,15 +291,13 @@ def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] =
"""
iostream = IOStream.get_default()

if response_format is not None:
if self.response_format is not None:

def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return OpenAIClient._convert_to_chat_completion(
self._oai_client.beta.chat.completions.parse(*args, **kwargs)
)
kwargs["response_format"] = type_to_response_format_param(self.response_format)
return self._oai_client.chat.completions.create(*args, **kwargs)

create_or_parse = _create_or_parse
else:
Expand Down Expand Up @@ -480,6 +455,11 @@ def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
}


@runtime_checkable
class FormatterProtocol(Protocol):
def format(self) -> str: ...


class OpenAIWrapper:
"""A wrapper class for openai client."""

Expand All @@ -502,7 +482,12 @@ class OpenAIWrapper:
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
def __init__(
self,
*,
config_list: Optional[List[Dict[str, Any]]] = None,
**base_config: Any,
):
"""
Args:
config_list: a list of config dicts to override the base_config.
Expand Down Expand Up @@ -605,6 +590,7 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
api_type = config.get("api_type")
model_client_cls_name = config.get("model_client_cls")
response_format = config.get("response_format")
if model_client_cls_name is not None:
# a config for a custom client is set
# adding placeholder until the register_model_client is called with the appropriate class
Expand All @@ -617,58 +603,58 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
if api_type is not None and api_type.startswith("azure"):
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, response_format=response_format))
elif api_type is not None and api_type.startswith("cerebras"):
if cerebras_import_exception:
raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
client = CerebrasClient(**openai_config)
client = CerebrasClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
client = GeminiClient(**openai_config)
client = GeminiClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
if "api_key" not in config:
self._configure_openai_config_for_bedrock(config, openai_config)
if anthropic_import_exception:
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
client = AnthropicClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("mistral"):
if mistral_import_exception:
raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
client = MistralAIClient(**openai_config)
client = MistralAIClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
client = TogetherClient(**openai_config)
client = TogetherClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("groq"):
if groq_import_exception:
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
client = GroqClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("cohere"):
if cohere_import_exception:
raise ImportError("Please install `cohere` to use the Cohere API.")
client = CohereClient(**openai_config)
client = CohereClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("ollama"):
if ollama_import_exception:
raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
client = OllamaClient(**openai_config)
client = OllamaClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("bedrock"):
self._configure_openai_config_for_bedrock(config, openai_config)
if bedrock_import_exception:
raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
client = BedrockClient(**openai_config)
client = BedrockClient(response_format=response_format, **openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, response_format))

if logging_enabled():
log_new_client(client, self, openai_config)
Expand Down Expand Up @@ -747,9 +733,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def create(
self, response_format: Optional[BaseModel] = None, **config: Any
) -> ModelClient.ModelClientResponseProtocol:
def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
Expand Down Expand Up @@ -838,8 +822,8 @@ def yes_or_no_filter(context, response):
with cache_client as cache:
# Try to get the response from cache
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
{**params, **{"response_format": schema_json_of(params["response_format"])}}
if "response_format" in params
else params
)
request_ts = get_current_ts()
Expand Down Expand Up @@ -882,7 +866,7 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
try:
request_ts = get_current_ts()
response = client.create(params, response_format=response_format)
response = client.create(params)
except APITimeoutError as err:
logger.debug(f"config {i} timed out", exc_info=True)
if i == last:
Expand Down Expand Up @@ -944,6 +928,7 @@ def yes_or_no_filter(context, response):
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)

if cache_client is not None:
# Cache the response
with cache_client as cache:
Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning)

def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -148,10 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cohere_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cohere's API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
Expand Down
Loading

0 comments on commit 7736885

Please sign in to comment.