Skip to content

Commit

Permalink
Fix caching issues with format, fix function registration issues when…
Browse files Browse the repository at this point in the history
… using response_format

Signed-off-by: Sternakt <[email protected]>
  • Loading branch information
sternakt committed Dec 4, 2024
1 parent b22d6cf commit 52973fa
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 117 deletions.
11 changes: 7 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ def _validate_llm_config(self, llm_config):
raise ValueError(
"When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
)
self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)
self.client = (
None
if self.llm_config is False
else OpenAIWrapper(response_format=self._response_format, **self.llm_config)
)

@staticmethod
def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool:
Expand Down Expand Up @@ -1452,7 +1456,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
messages=all_messages,
cache=cache,
agent=self,
response_format=self._response_format,
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

Expand Down Expand Up @@ -2549,7 +2552,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
if len(self.llm_config["functions"]) == 0:
del self.llm_config["functions"]

self.client = OpenAIWrapper(**self.llm_config)
self.client = OpenAIWrapper(**self.llm_config, response_format=self._response_format)

def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
"""update a tool_signature in the LLM configuration for tool_call.
Expand Down Expand Up @@ -2593,7 +2596,7 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
if len(self.llm_config["tools"]) == 0:
del self.llm_config["tools"]

self.client = OpenAIWrapper(**self.llm_config)
self.client = OpenAIWrapper(**self.llm_config, response_format=self._response_format)

def can_execute_function(self, name: Union[List[str], str]) -> bool:
"""Whether the agent can execute the function."""
Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(self, **kwargs: Any):
):
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning)

if self._api_key is not None:
self._client = Anthropic(api_key=self._api_key)
else:
Expand Down Expand Up @@ -175,10 +178,7 @@ def aws_session_token(self):
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Anthropic API.")

def create(self, params: Dict[str, Any]) -> ChatCompletion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
Expand Down
9 changes: 4 additions & 5 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(self, **kwargs: Any):
profile_name=self._aws_profile_name,
)

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Bedrock, it will be ignored.", UserWarning)

self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)

def message_retrieval(self, response):
Expand Down Expand Up @@ -179,12 +182,8 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str

return base_params, additional_params

def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AutoGen response"""

if response_format is not None:
raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.")

# Set custom client class settings
self.parse_custom_params(params)

Expand Down
14 changes: 9 additions & 5 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,17 @@
class CerebrasClient:
"""Client for Cerebras's API."""

def __init__(self, api_key=None, **kwargs):
def __init__(self, api_key=None, response_format: Optional[BaseModel] = None, **kwargs):
"""Requires api_key or environment variable to be set
Args:
api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation

if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

self.api_key = api_key
if not self.api_key:
self.api_key = os.getenv("CEREBRAS_API_KEY")
Expand All @@ -64,6 +68,9 @@ def __init__(self, api_key=None, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Crebras, it will be ignored.", UserWarning)

def message_retrieval(self, response: ChatCompletion) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -112,10 +119,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cerebras_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])

# Convert AutoGen messages to Cerebras messages
Expand Down
106 changes: 42 additions & 64 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
from openai import __version__ as OPENAIVERSION
from openai.lib._parsing._completions import type_to_response_format_param
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -207,9 +208,7 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(
self, params: Dict[str, Any], response_format: Optional[BaseModel] = None
) -> ModelClientResponseProtocol: ... # pragma: no cover
def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand Down Expand Up @@ -238,42 +237,9 @@ def __init__(self, config):
class OpenAIClient:
"""Follows the Client protocol and wraps the OpenAI client."""

@staticmethod
def _convert_to_chat_completion(parsed: ParsedChatCompletion, response_format: BaseModel) -> ChatCompletion:
# Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage
def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage:
return ChatCompletionMessage(
role=parsed_message.role,
content=(
response_format.model_validate_json(parsed_message.content).format()
if isinstance(response_format, FormatterProtocol)
else parsed_message.content
),
function_call=parsed_message.function_call,
)

# Convert ParsedChatCompletion to ChatCompletion
return ChatCompletion(
id=parsed.id,
choices=[
Choice(
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=choice.logprobs,
message=convert_message(choice.message), # Parse the message
)
for choice in parsed.choices
],
created=parsed.created,
model=parsed.model,
object=parsed.object,
service_tier=parsed.service_tier,
system_fingerprint=parsed.system_fingerprint,
usage=parsed.usage,
)

def __init__(self, client: Union[OpenAI, AzureOpenAI]):
def __init__(self, client: Union[OpenAI, AzureOpenAI], response_format: Optional[BaseModel] = None):
self._oai_client = client
self.response_format = response_format
if (
not isinstance(client, openai.AzureOpenAI)
and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
Expand All @@ -291,22 +257,29 @@ def message_retrieval(
if isinstance(response, Completion):
return [choice.text for choice in choices] # type: ignore [union-attr]

def _format_content(content: str) -> str:
return (
self.response_format.model_validate_json(content).format()
if isinstance(self.response_format, FormatterProtocol)
else content
)

if TOOL_ENABLED:
return [ # type: ignore [return-value]
(
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content
else _format_content(choice.message.content)
) # type: ignore [union-attr]
for choice in choices
]
else:
return [ # type: ignore [return-value]
choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr]
for choice in choices
]

def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
Expand All @@ -318,15 +291,13 @@ def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] =
"""
iostream = IOStream.get_default()

if response_format is not None:
if self.response_format is not None:

def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return OpenAIClient._convert_to_chat_completion(
self._oai_client.beta.chat.completions.parse(*args, **kwargs), response_format
)
kwargs["response_format"] = type_to_response_format_param(self.response_format)
return self._oai_client.chat.completions.create(*args, **kwargs)

create_or_parse = _create_or_parse
else:
Expand Down Expand Up @@ -451,6 +422,8 @@ def _create_or_parse(*args, **kwargs):
params["stream"] = False
response = create_or_parse(**params)

print("!" * 100)
print(response)
return response

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
Expand Down Expand Up @@ -511,7 +484,13 @@ class OpenAIWrapper:
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
def __init__(
self,
*,
config_list: Optional[List[Dict[str, Any]]] = None,
response_format: Optional[BaseModel] = None,
**base_config: Any,
):
"""
Args:
config_list: a list of config dicts to override the base_config.
Expand Down Expand Up @@ -552,6 +531,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base

self._clients: List[ModelClient] = []
self._config_list: List[Dict[str, Any]] = []
self.response_format = response_format

if config_list:
config_list = [config.copy() for config in config_list] # make a copy before modifying
Expand Down Expand Up @@ -626,58 +606,58 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
if api_type is not None and api_type.startswith("azure"):
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, response_format=self.response_format))
elif api_type is not None and api_type.startswith("cerebras"):
if cerebras_import_exception:
raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
client = CerebrasClient(**openai_config)
client = CerebrasClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
client = GeminiClient(**openai_config)
client = GeminiClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
if "api_key" not in config:
self._configure_openai_config_for_bedrock(config, openai_config)
if anthropic_import_exception:
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
client = AnthropicClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("mistral"):
if mistral_import_exception:
raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
client = MistralAIClient(**openai_config)
client = MistralAIClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
client = TogetherClient(**openai_config)
client = TogetherClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("groq"):
if groq_import_exception:
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
client = GroqClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("cohere"):
if cohere_import_exception:
raise ImportError("Please install `cohere` to use the Cohere API.")
client = CohereClient(**openai_config)
client = CohereClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("ollama"):
if ollama_import_exception:
raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
client = OllamaClient(**openai_config)
client = OllamaClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("bedrock"):
self._configure_openai_config_for_bedrock(config, openai_config)
if bedrock_import_exception:
raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
client = BedrockClient(**openai_config)
client = BedrockClient(response_format=self.response_format, **openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
self._clients.append(OpenAIClient(client, self.response_format))

if logging_enabled():
log_new_client(client, self, openai_config)
Expand Down Expand Up @@ -756,9 +736,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def create(
self, response_format: Optional[BaseModel] = None, **config: Any
) -> ModelClient.ModelClientResponseProtocol:
def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
Expand Down Expand Up @@ -847,8 +825,8 @@ def yes_or_no_filter(context, response):
with cache_client as cache:
# Try to get the response from cache
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
{**params, **{"response_format": schema_json_of(self.response_format)}}
if self.response_format
else params
)
request_ts = get_current_ts()
Expand Down Expand Up @@ -891,7 +869,7 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
try:
request_ts = get_current_ts()
response = client.create(params, response_format=response_format)
response = client.create(params)
except APITimeoutError as err:
logger.debug(f"config {i} timed out", exc_info=True)
if i == last:
Expand Down
8 changes: 4 additions & 4 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, **kwargs):
self.api_key
), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Cohere, it will be ignored.", UserWarning)

def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
Expand Down Expand Up @@ -148,10 +151,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cohere_params

def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cohere's API.")

def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
Expand Down
Loading

0 comments on commit 52973fa

Please sign in to comment.