Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into convagentcontextvar
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 3, 2024
2 parents c072335 + 321e833 commit aca0346
Show file tree
Hide file tree
Showing 22 changed files with 688 additions and 180 deletions.
13 changes: 12 additions & 1 deletion autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
if tool_execution._next_agent is not None:
next_agent = tool_execution._next_agent
tool_execution._next_agent = None

# Check for string, access agent from group chat.

if isinstance(next_agent, str):
if next_agent in swarm_agent_names:
next_agent = groupchat.agent_by_name(name=next_agent)
else:
raise ValueError(
f"No agent found with the name '{next_agent}'. Ensure the agent exists in the swarm."
)

return next_agent

# get the last swarm agent
Expand Down Expand Up @@ -228,7 +239,7 @@ class SwarmResult(BaseModel):
"""

values: str = ""
agent: Optional["SwarmAgent"] = None
agent: Optional[Union["SwarmAgent", str]] = None
context_variables: Dict[str, Any] = {}

class Config: # Add this inner class
Expand Down
3 changes: 2 additions & 1 deletion autogen/agentchat/contrib/vectordb/chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

if chromadb.__version__ < "0.4.15":
raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
import chromadb.errors
import chromadb.utils.embedding_functions as ef
from chromadb.api.models.Collection import Collection
except ImportError:
Expand Down Expand Up @@ -90,7 +91,7 @@ def create_collection(
collection = self.active_collection
else:
collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function)
except ValueError:
except (ValueError, chromadb.errors.ChromaError):
collection = None
if collection is None:
return self.client.create_collection(
Expand Down
11 changes: 9 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import model_dump
from .._pydantic import BaseModel, model_dump
from ..cache.cache import AbstractCache
from ..code_utils import (
PYTHON_VARIANTS,
Expand Down Expand Up @@ -86,6 +86,7 @@ def __init__(
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
response_format: Optional[BaseModel] = None,
):
"""
Args:
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
context_variables (dict or None): Context variables that provide a persistent context for the agent.
The passed in context variables will be deep-copied, not referenced.
Only used in Swarms at this stage.
response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand All @@ -161,6 +163,7 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent
self._response_format = response_format
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
Expand Down Expand Up @@ -1490,7 +1493,11 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[

# TODO: #1143 handle token limit exceeded error
response = llm_client.create(
context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
context=messages[-1].pop("context", None),
messages=all_messages,
cache=cache,
agent=self,
response_format=self._response_format,
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

Expand Down
7 changes: 5 additions & 2 deletions autogen/logger/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import datetime
import inspect
from datetime import datetime, timezone
from pathlib import Path, PurePath
from typing import Any, Dict, List, Tuple, Union

__all__ = ("get_current_ts", "to_dict")


def get_current_ts() -> str:
return datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")


def to_dict(
Expand All @@ -22,6 +23,8 @@ def to_dict(
) -> Any:
if isinstance(obj, (int, float, str, bool)):
return obj
elif isinstance(obj, (Path, PurePath)):
return str(obj)
elif callable(obj):
return inspect.getsource(obj).strip()
elif isinstance(obj, dict):
Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@
import os
import time
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from anthropic import Anthropic, AnthropicBedrock
from anthropic import __version__ as anthropic_version
from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel
from typing_extensions import Annotated

from autogen.oai.client_utils import validate_parameter
Expand Down Expand Up @@ -174,7 +175,10 @@ def aws_session_token(self):
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any]) -> Completion:
def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Anthropic API.")

if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
params["functions"] = params.get("functions", []) + converted_functions
Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
import re
import time
import warnings
from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import boto3
import requests
from botocore.config import Config
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import validate_parameter

Expand Down Expand Up @@ -178,9 +179,12 @@ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str

return base_params, additional_params

def create(self, params):
def create(self, params, response_format: Optional[BaseModel] = None) -> ChatCompletion:
"""Run Amazon Bedrock inference and return AutoGen response"""

if response_format is not None:
raise NotImplementedError("Response format is not supported by Amazon Bedrock's API.")

# Set custom client class settings
self.parse_custom_params(params)

Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
import os
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from cerebras.cloud.sdk import Cerebras, Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand Down Expand Up @@ -111,7 +112,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cerebras_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cerebras' API.")

messages = params.get("messages", [])

Expand Down
44 changes: 34 additions & 10 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import logging
import sys
import uuid
from pickle import PickleError, PicklingError
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union

from pydantic import BaseModel
from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.io.base import IOStream
Expand Down Expand Up @@ -206,7 +207,9 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
def create(
self, params: Dict[str, Any], response_format: Optional[BaseModel] = None
) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand Down Expand Up @@ -269,7 +272,7 @@ def message_retrieval(
for choice in choices
]

def create(self, params: Dict[str, Any]) -> ChatCompletion:
def create(self, params: Dict[str, Any], response_format: Optional[BaseModel] = None) -> ChatCompletion:
"""Create a completion for a given config using openai's client.
Args:
Expand All @@ -281,7 +284,19 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""
iostream = IOStream.get_default()

completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
if response_format is not None:

def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return self._oai_client.beta.chat.completions.parse(*args, **kwargs)

create_or_parse = _create_or_parse
else:
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
Expand All @@ -296,7 +311,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None

# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
for chunk in create_or_parse(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
Expand Down Expand Up @@ -398,7 +413,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
response = completions.create(**params)
response = create_or_parse(**params)

return response

Expand Down Expand Up @@ -700,7 +715,9 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
def create(
self, response_format: Optional[BaseModel] = None, **config: Any
) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
Expand Down Expand Up @@ -788,7 +805,11 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
with cache_client as cache:
# Try to get the response from cache
key = get_key(params)
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
else params
)
request_ts = get_current_ts()

response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
Expand Down Expand Up @@ -829,7 +850,7 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
try:
request_ts = get_current_ts()
response = client.create(params)
response = client.create(params, response_format=response_format)
except APITimeoutError as err:
logger.debug(f"config {i} timed out", exc_info=True)
if i == last:
Expand Down Expand Up @@ -894,7 +915,10 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
# Cache the response
with cache_client as cache:
cache.set(key, response)
try:
cache.set(key, response)
except (PicklingError, AttributeError) as e:
logger.info(f"Failed to cache response: {e}")

if logging_enabled():
# TODO: log the config_id and pass_filter etc.
Expand Down
7 changes: 5 additions & 2 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import sys
import time
import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from cohere import Client as Cohere
from cohere.types import ToolParameterDefinitionsValue, ToolResult
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel

from autogen.oai.client_utils import logging_formatter, validate_parameter

Expand Down Expand Up @@ -147,7 +148,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:

return cohere_params

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Cohere's API.")

messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
Expand Down
8 changes: 6 additions & 2 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import time
import warnings
from io import BytesIO
from typing import Any, Dict, List, Mapping, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import google.generativeai as genai
import requests
Expand All @@ -56,6 +56,7 @@
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from pydantic import BaseModel
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
Expand Down Expand Up @@ -159,7 +160,10 @@ def get_usage(response) -> Dict:
"model": response.model,
}

def create(self, params: Dict) -> ChatCompletion:
def create(self, params: Dict, response_format: Optional[BaseModel] = None) -> ChatCompletion:
if response_format is not None:
raise NotImplementedError("Response format is not supported by Gemini's API.")

if self.use_vertexai:
self._initialize_vertexai(**params)
else:
Expand Down
Loading

0 comments on commit aca0346

Please sign in to comment.