Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured outputs for Anthropic / Gemini / Ollama #336

Merged
merged 12 commits into from
Jan 15, 2025
154 changes: 120 additions & 34 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
Expand Down Expand Up @@ -73,18 +73,20 @@
import inspect
import json
import os
import re
import time
import warnings
from typing import Any
from typing import Any, Optional, Type

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
from anthropic import __version__ as anthropic_version
from anthropic.types import TextBlock, ToolUseBlock
from anthropic.types import Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import validate_parameter
from autogen.oai.client_utils import FormatterProtocol, validate_parameter

TOOL_ENABLED = anthropic_version >= "0.23.1"
if TOOL_ENABLED:
Expand Down Expand Up @@ -145,9 +147,6 @@ def __init__(self, **kwargs: Any):
else:
raise ValueError("API key or AWS credentials or GCP credentials are required to use the Anthropic API.")

if "response_format" in kwargs and kwargs["response_format"] is not None:
warnings.warn("response_format is not supported for Anthropic, it will be ignored.", UserWarning)

if self._api_key is not None:
self._client = Anthropic(api_key=self._api_key)
elif self._gcp_region is not None:
Expand All @@ -166,6 +165,9 @@ def __init__(self, **kwargs: Any):

self._last_tooluse_status = {}

# Store the response format, if provided (for structured outputs)
self._response_format: Optional[Type[BaseModel]] = None

def load_config(self, params: dict[str, Any]):
"""Load the configuration for the Anthropic API client."""
anthropic_params = {}
Expand Down Expand Up @@ -228,14 +230,22 @@ def gcp_auth_token(self):
return self._gcp_auth_token

def create(self, params: dict[str, Any]) -> ChatCompletion:
"""Creates a completion using the Anthropic API."""
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions

# Convert AutoGen messages to Anthropic messages
# Convert AG2 messages to Anthropic messages
anthropic_messages = oai_messages_to_anthropic_messages(params)
anthropic_params = self.load_config(params)

# If response_format exists, we want structured outputs
# Anthropic doesn't support response_format, so using Anthropic's "JSON Mode":
# https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
if params.get("response_format"):
self._response_format = params["response_format"]
self._add_response_format_to_system(params)

# TODO: support stream
params = params.copy()
if "functions" in params:
Expand All @@ -260,36 +270,46 @@ def create(self, params: dict[str, Any]) -> ChatCompletion:

response = self._client.messages.create(**anthropic_params)

# Calculate and save the cost onto the response
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens

tool_calls = []
message_text = ""
if response is not None:
# If we have tool use as the response, populate completed tool calls for our return OAI response
if response.stop_reason == "tool_use":
anthropic_finish = "tool_calls"
tool_calls = []
for content in response.content:
if type(content) == ToolUseBlock:
tool_calls.append(
ChatCompletionMessageToolCall(
id=content.id,
function={"name": content.name, "arguments": json.dumps(content.input)},
type="function",

if self._response_format:
try:
parsed_response = self._extract_json_response(response)
message_text = _format_json_response(parsed_response)
except ValueError as e:
message_text = str(e)

anthropic_finish = "stop"
else:
if response is not None:
# If we have tool use as the response, populate completed tool calls for our return OAI response
if response.stop_reason == "tool_use":
anthropic_finish = "tool_calls"
for content in response.content:
if type(content) == ToolUseBlock:
tool_calls.append(
ChatCompletionMessageToolCall(
id=content.id,
function={"name": content.name, "arguments": json.dumps(content.input)},
type="function",
)
)
)
else:
anthropic_finish = "stop"
tool_calls = None
else:
anthropic_finish = "stop"
tool_calls = None

# Retrieve any text content from the response
for content in response.content:
if type(content) == TextBlock:
message_text = content.text
break

# Retrieve any text content from the response
for content in response.content:
if type(content) == TextBlock:
message_text = content.text
break
# Calculate and save the cost onto the response
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens

# Convert output back to AutoGen response format
# Convert output back to AG2 response format
message = ChatCompletionMessage(
role="assistant",
content=message_text,
Expand Down Expand Up @@ -348,6 +368,72 @@ def convert_tools_to_functions(tools: list) -> list:

return functions

def _add_response_format_to_system(self, params: dict[str, Any]):
"""Add prompt that will generate properly formatted JSON for structured outputs to system parameter.

Based on Anthropic's JSON Mode cookbook, we ask the LLM to put the JSON within <json_response> tags.

Args:
params (dict): The client parameters
"""
if not params.get("system"):
return

# Get the schema of the Pydantic model
schema = self._response_format.model_json_schema()

# Add instructions for JSON formatting
format_content = f"""Please provide your response as a JSON object that matches the following schema:
{json.dumps(schema, indent=2)}

Format your response as valid JSON within <json_response> tags.
Do not include any text before or after the tags.
Ensure the JSON is properly formatted and matches the schema exactly."""

# Add formatting to last user message
params["system"] += "\n\n" + format_content

def _extract_json_response(self, response: Message) -> Any:
"""Extract and validate JSON response from the output for structured outputs.

Args:
response (Message): The response from the API.

Returns:
Any: The parsed JSON response.
"""
if not self._response_format:
return response

# Extract content from response
content = response.content[0].text if response.content else ""

# Try to extract JSON from tags first
json_match = re.search(r"<json_response>(.*?)</json_response>", content, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# Fallback to finding first JSON object
json_start = content.find("{")
json_end = content.rfind("}")
if json_start == -1 or json_end == -1:
raise ValueError("No valid JSON found in response for Structured Output.")
json_str = content[json_start : json_end + 1]

try:
# Parse JSON and validate against the Pydantic model
json_data = json.loads(json_str)
return self._response_format.model_validate(json_data)
except Exception as e:
raise ValueError(
f"Failed to parse response as valid JSON matching the schema for Structured Output: {str(e)}"
)


def _format_json_response(response: Any) -> str:
"""Formats the JSON response for structured outputs using the format method if it exists."""
return response.format() if isinstance(response, FormatterProtocol) else response


def oai_messages_to_anthropic_messages(params: dict[str, Any]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Anthropic format.
Expand Down
11 changes: 3 additions & 8 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -10,14 +10,14 @@
import logging
import sys
import uuid
from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable
from typing import Any, Callable, Optional, Protocol, Union

from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.io.base import IOStream
from autogen.logger.logger_utils import get_current_ts
from autogen.oai.client_utils import logging_formatter
from autogen.oai.client_utils import FormatterProtocol, logging_formatter
from autogen.oai.openai_utils import OAI_PRICE1K, get_key, is_valid_api_key
from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
from autogen.token_count_utils import count_token
Expand Down Expand Up @@ -446,11 +446,6 @@ def get_usage(response: Union[ChatCompletion, Completion]) -> dict:
}


@runtime_checkable
class FormatterProtocol(Protocol):
def format(self) -> str: ...


class OpenAIWrapper:
"""A wrapper class for openai client."""

Expand Down
11 changes: 9 additions & 2 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -8,7 +8,14 @@

import logging
import warnings
from typing import Any
from typing import Any, Protocol, runtime_checkable


@runtime_checkable
class FormatterProtocol(Protocol):
"""Structured Output classes with a format method"""

def format(self) -> str: ...


def validate_parameter(
Expand Down
Loading
Loading