Skip to content

Commit

Permalink
Merge pull request #112 from ag2ai/23-feature-request-structured-output
Browse files Browse the repository at this point in the history
Add structured output
  • Loading branch information
qingyun-wu authored Dec 3, 2024
2 parents 95431af + 6534772 commit 321e833
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 32 deletions.
11 changes: 9 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import model_dump
from .._pydantic import BaseModel, model_dump
from ..cache.cache import AbstractCache
from ..code_utils import (
PYTHON_VARIANTS,
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
description: Optional[str] = None,
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
response_format: Optional[BaseModel] = None,
):
"""
Args:
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
resume previous had conversations. Defaults to an empty chat history.
silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of
silent in each function.
response_format(BaseModel): Used to specify structured response format for the agent. Not available for all LLMs.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand All @@ -157,6 +159,7 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._response_format = response_format
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
Expand Down Expand Up @@ -1445,7 +1448,11 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[

# TODO: #1143 handle token limit exceeded error
response = llm_client.create(
context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
context=messages[-1].pop("context", None),
messages=all_messages,
cache=cache,
agent=self,
response_format=self._response_format,
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@
import os
import time
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from anthropic import Anthropic, AnthropicBedrock
from anthropic import __version__ as anthropic_version
from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel
from typing_extensions import Annotated

from autogen.oai.client_utils import validate_parameter
Expand Down Expand Up @@ -174,7 +175,10 @@ def aws_session_token(self):
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any]) -> Completion:
def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Anthropic API.")

if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
import re
import time
import warnings
from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import boto3
import requests
from botocore.config import Config
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import validate_parameter

Expand Down Expand Up @@ -178,9 +179,12 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str

return base_params, additional_params

def create(self, params):
def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AutoGen response"""

if response_format is not None:
raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.")

# Set custom client class settings
self.parse_custom_params(params)

Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
import os
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from cerebras.cloud.sdk import Cerebras, Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand Down Expand Up @@ -111,7 +112,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cerebras_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

messages = params.get("messages", [])

Expand Down
44 changes: 34 additions & 10 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import logging
import sys
import uuid
from pickle import PickleError, PicklingError
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union

from pydantic import BaseModel
from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.io.base import IOStream
Expand Down Expand Up @@ -206,7 +207,9 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
def create(
self, params: Dict[str, Any], response_format: Optional[BaseModel] = None
) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand Down Expand Up @@ -269,7 +272,7 @@ def message_retrieval(
for choice in choices
]

def create(self, params: Dict[str, Any]) -> ChatCompletion:
def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
Expand All @@ -281,7 +284,19 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""
iostream = IOStream.get_default()

completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
if response_format is not None:

def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return self._oai_client.beta.chat.completions.parse(*args, **kwargs)

create_or_parse = _create_or_parse
else:
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
Expand All @@ -296,7 +311,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None

# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
for chunk in create_or_parse(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
Expand Down Expand Up @@ -398,7 +413,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
response = completions.create(**params)
response = create_or_parse(**params)

return response

Expand Down Expand Up @@ -700,7 +715,9 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
def create(
self, response_format: Optional[BaseModel] = None, **config: Any
) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
Expand Down Expand Up @@ -788,7 +805,11 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
with cache_client as cache:
# Try to get the response from cache
key = get_key(params)
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
else params
)
request_ts = get_current_ts()

response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
Expand Down Expand Up @@ -829,7 +850,7 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
try:
request_ts = get_current_ts()
response = client.create(params)
response = client.create(params, response_format=response_format)
except APITimeoutError as err:
logger.debug(f"config {i} timed out", exc_info=True)
if i == last:
Expand Down Expand Up @@ -894,7 +915,10 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
# Cache the response
with cache_client as cache:
cache.set(key, response)
try:
cache.set(key, response)
except (PicklingError, AttributeError) as e:
logger.info(f"Failed to cache response: {e}")

if logging_enabled():
# TODO: log the config_id and pass_filter etc.
Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import sys
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from cohere import Client as Cohere
from cohere.types import ToolParameterDefinitionsValue, ToolResult
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import logging_formatter, validate_parameter

Expand Down Expand Up @@ -147,7 +148,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cohere_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cohere's API.")

messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import time
import warnings
from io import BytesIO
from typing import Any, Dict, List, Mapping, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import google.generativeai as genai
import requests
Expand All @@ -56,6 +56,7 @@
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from pydantic import BaseModel
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
Expand Down Expand Up @@ -159,7 +160,10 @@ def get_usage(response) -> Dict:
"model": response.model,
}

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Gemini's API.")

if self.use_vertexai:
self._initialize_vertexai(**params)
else:
Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
import os
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from groq import Groq, Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand Down Expand Up @@ -125,7 +126,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return groq_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Groq's API.")

messages = params.get("messages", [])

Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import os
import time
import warnings
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

# Mistral libraries
# pip install mistralai
Expand All @@ -47,6 +47,7 @@
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand Down Expand Up @@ -169,7 +170,10 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return mistral_params

def create(self, params: Dict[str, Any]) -> ChatCompletion:
def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Mistral's API.")

# 1. Parse parameters to Mistral.AI API's parameters
mistral_params = self.parse_params(params)

Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
import re
import time
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import ollama
from fix_busted_json import repair_json
from ollama import Client
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand Down Expand Up @@ -177,7 +178,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return ollama_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Ollama's API.")

messages = params.get("messages", [])

Expand Down
Loading

0 comments on commit 321e833

Please sign in to comment.