diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index 7d8a932b0254..f13bfdbb985c 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -9,6 +9,8 @@ on:
paths:
- "autogen/**"
- "test/agentchat/contrib/**"
+ - "test/test_browser_utils.py"
+ - "test/test_retrieve_utils.py"
- ".github/workflows/contrib-tests.yml"
- "setup.py"
@@ -598,3 +600,79 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
+
+ GroqTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Groq
+ run: |
+ pip install -e .[groq,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_groq.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ CohereTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Cohere
+ run: |
+ pip install -e .[cohere,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_cohere.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml
index f4074b061693..7e50025917de 100644
--- a/.github/workflows/dotnet-build.yml
+++ b/.github/workflows/dotnet-build.yml
@@ -56,11 +56,16 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
- global-json-file: dotnet/global.json
+ dotnet-version: '8.0.x'
- name: Restore dependencies
run: |
# dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config
dotnet restore -bl
+ - name: Format check
+ run: |
+ echo "Format check"
+ echo "If you see any error in this step, please run 'dotnet format' locally to format the code."
+ dotnet format --verify-no-changes -v diag --no-restore
- name: Build
run: |
echo "Build AutoGen"
diff --git a/.github/workflows/dotnet-release.yml b/.github/workflows/dotnet-release.yml
index 2877d058377b..aacfd115bb7e 100644
--- a/.github/workflows/dotnet-release.yml
+++ b/.github/workflows/dotnet-release.yml
@@ -32,7 +32,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
- global-json-file: dotnet/global.json
+ dotnet-version: '8.0.x'
- name: Restore dependencies
run: |
dotnet restore -bl
diff --git a/README.md b/README.md
index 5bff3300a50e..7c7ac4b85c59 100644
--- a/README.md
+++ b/README.md
@@ -66,7 +66,12 @@
## What is AutoGen
-AutoGen is a framework that enables the development of LLM applications using multiple agents that can converse with each other to solve tasks. AutoGen agents are customizable, conversable, and seamlessly allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools.
+AutoGen is an open-source programming framework for building AI agents and facilitating cooperation among multiple agents to solve tasks. AutoGen aims to streamline the development and research of agentic AI, much like PyTorch does for Deep Learning. It offers features such as agents capable of interacting with each other, facilitates the use of various large language models (LLMs) and tool use support, autonomous and human-in-the-loop workflows, and multi-agent conversation patterns.
+
+**Open Source Statement**: The project welcomes contributions from developers and organizations worldwide. Our goal is to foster a collaborative and inclusive community where diverse perspectives and expertise can drive innovation and enhance the project's capabilities. Whether you are an individual contributor or represent an organization, we invite you to join us in shaping the future of this project. Together, we can build something truly remarkable.
+
+The project is currently maintained by a [dynamic group of volunteers](https://butternut-swordtail-8a5.notion.site/410675be605442d3ada9a42eb4dfef30?v=fa5d0a79fd3d4c0f9c112951b2831cbb&pvs=4) from several different organizations. Contact project administrators Chi Wang and Qingyun Wu via auto-gen@outlook.com if you are interested in becoming a maintainer.
+
![AutoGen Overview](https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png)
@@ -288,6 +293,16 @@ In addition, you can find:
}
```
+[StateFlow](https://arxiv.org/abs/2403.11322)
+```
+@article{wu2024stateflow,
+ title={StateFlow: Enhancing LLM Task-Solving through State-Driven Workflows},
+ author={Wu, Yiran and Yue, Tianwei and Zhang, Shaokun and Wang, Chi and Wu, Qingyun},
+ journal={arXiv preprint arXiv:2403.11322},
+ year={2024}
+}
+```
+
↑ Back to Top ↑
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
index 6588a1ec6113..478f28fd74ec 100644
--- a/autogen/agentchat/contrib/agent_eval/README.md
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -1,7 +1,9 @@
-Agents for running the AgentEval pipeline.
+Agents for running the [AgentEval](https://microsoft.github.io/autogen/blog/2023/11/20/AgentEval/) pipeline.
AgentEval is a process for evaluating a LLM-based system's performance on a given task.
When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
+
+See our [blog post](https://microsoft.github.io/autogen/blog/2024/06/21/AgentEval) for usage examples and general explanations.
diff --git a/autogen/agentchat/contrib/llamaindex_conversable_agent.py b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
index f7a9c3e615dc..dbf6f274ae87 100644
--- a/autogen/agentchat/contrib/llamaindex_conversable_agent.py
+++ b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
@@ -8,15 +8,14 @@
try:
from llama_index.core.agent.runner.base import AgentRunner
+ from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.chat_engine.types import AgentChatResponse
- from llama_index_client import ChatMessage
except ImportError as e:
logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
raise e
class LLamaIndexConversableAgent(ConversableAgent):
-
def __init__(
self,
name: str,
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index 59a4abccb1d6..4842bd4e9f53 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,6 +1,7 @@
import hashlib
import os
import re
+import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from IPython import get_ipython
@@ -135,7 +136,7 @@ def __init__(
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
- **Deprecated**: use `vector_db` instead.
+ *[Deprecated]* use `vector_db` instead.
- `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
can also be the path to a single file, the url to a single file or a list
of directories, files and urls. Default is None, which works only if the
@@ -149,7 +150,7 @@ def __init__(
By default, "extra_docs" is set to false, starting document IDs from zero.
This poses a risk as new documents might overwrite existing ones, potentially
causing unintended loss or alteration of data in the collection.
- **Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
+ *[Deprecated]* use `new_docs` when use `vector_db` instead of `client`.
- `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
when False, updates existing documents and adds new ones. Default is True.
Document id is used to determine if a document is new or existing. By default, the
@@ -172,7 +173,7 @@ def __init__(
models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
The default model is a fast model. If you want to use a high performance model,
`all-mpnet-base-v2` is recommended.
- **Deprecated**: no need when use `vector_db` instead of `client`.
+ *[Deprecated]* no need when use `vector_db` instead of `client`.
- `embedding_function` (Optional, Callable) - the embedding function for creating the
vector db. Default is None, SentenceTransformer with the given `embedding_model`
will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
@@ -219,7 +220,7 @@ def __init__(
Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
not compatible with chromadb, you can easily plug in it with below code.
- **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
+ *[Deprecated]* use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@@ -365,7 +366,11 @@ def _init_db(self):
else:
all_docs_ids = set()
- chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ chunk_ids = (
+ [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ if not self._vector_db.type == "qdrant"
+ else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
+ )
chunk_ids_set = set(chunk_ids)
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
docs = [
diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py
index 29a080086193..20b6376d01d9 100644
--- a/autogen/agentchat/contrib/vectordb/base.py
+++ b/autogen/agentchat/contrib/vectordb/base.py
@@ -1,4 +1,16 @@
-from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
+from typing import (
+ Any,
+ Callable,
+ List,
+ Mapping,
+ Optional,
+ Protocol,
+ Sequence,
+ Tuple,
+ TypedDict,
+ Union,
+ runtime_checkable,
+)
Metadata = Union[Mapping[str, Any], None]
Vector = Union[Sequence[float], Sequence[int]]
@@ -49,6 +61,9 @@ class VectorDB(Protocol):
active_collection: Any = None
type: str = ""
+ embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
+ None # embeddings = embedding_function(sentences)
+ )
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
"""
@@ -185,7 +200,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
- PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
+ PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@@ -207,6 +222,10 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
from .pgvectordb import PGVectorDB
return PGVectorDB(**kwargs)
+ if db_type.lower() in ["qdrant", "qdrantdb"]:
+ from .qdrant import QdrantVectorDB
+
+ return QdrantVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py
new file mode 100644
index 000000000000..398734eb0334
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/qdrant.py
@@ -0,0 +1,328 @@
+import abc
+import logging
+import os
+from typing import Callable, List, Optional, Sequence, Tuple, Union
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ from qdrant_client import QdrantClient, models
+except ImportError:
+ raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
+
+logger = get_logger(__name__)
+
+Embeddings = Union[Sequence[float], Sequence[int]]
+
+
+class EmbeddingFunction(abc.ABC):
+ @abc.abstractmethod
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ raise NotImplementedError
+
+
+class FastEmbedEmbeddingFunction(EmbeddingFunction):
+ """Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
+
+ def __init__(
+ self,
+ model_name: str = "BAAI/bge-small-en-v1.5",
+ batch_size: int = 256,
+ cache_dir: Optional[str] = None,
+ threads: Optional[int] = None,
+ parallel: Optional[int] = None,
+ **kwargs,
+ ):
+ """Initialize fastembed.TextEmbedding.
+
+ Args:
+ model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
+ batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
+ Defaults to 256.
+ cache_dir (str, optional): The path to the model cache directory.\
+ Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
+ threads (int, optional): The number of threads single onnxruntime session can use.
+ parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
+ If `0`, use all available cores.\
+ If `None`, don't use data-parallel processing, use default onnxruntime threading.\
+ Defaults to None.
+ **kwargs: Additional options to pass to fastembed.TextEmbedding
+ Raises:
+ ValueError: If the model_name is not in the format / e.g. BAAI/bge-small-en-v1.5.
+ """
+ try:
+ from fastembed import TextEmbedding
+ except ImportError as e:
+ raise ValueError(
+ "The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
+ ) from e
+ self._batch_size = batch_size
+ self._parallel = parallel
+ self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
+
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
+
+ return [embedding.tolist() for embedding in embeddings]
+
+
+class QdrantVectorDB(VectorDB):
+ """
+ A vector database implementation that uses Qdrant as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ client=None,
+ embedding_function: EmbeddingFunction = None,
+ content_payload_key: str = "_content",
+ metadata_payload_key: str = "_metadata",
+ collection_options: dict = {},
+ **kwargs,
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Args:
+ client: qdrant_client.QdrantClient | An instance of QdrantClient.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Defaults to FastEmbedEmbeddingFunction.
+ collection_options: dict | The options for creating the collection.
+ kwargs: dict | Additional keyword arguments.
+ """
+ self.client: QdrantClient = client or QdrantClient(location=":memory:")
+ self.embedding_function = FastEmbedEmbeddingFunction() or embedding_function
+ self.collection_options = collection_options
+ self.content_payload_key = content_payload_key
+ self.metadata_payload_key = metadata_payload_key
+ self.type = "qdrant"
+
+ def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Any | The collection object.
+ """
+ embeddings_size = len(self.embedding_function(["test"])[0])
+
+ if self.client.collection_exists(collection_name) and overwrite:
+ self.client.delete_collection(collection_name)
+
+ if not self.client.collection_exists(collection_name):
+ self.client.create_collection(
+ collection_name,
+ vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
+ **self.collection_options,
+ )
+ elif not get_or_create:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None):
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any | The collection object.
+ """
+ if collection_name is None:
+ raise ValueError("The collection name is required.")
+
+ return self.client.get_collection(collection_name)
+
+ def delete_collection(self, collection_name: str) -> None:
+ """Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any
+ """
+ return self.client.delete_collection(collection_name)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+
+ if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
+ logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
+
+ self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ if not docs:
+ return
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
+ return self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ raise ValueError("Some IDs do not exist. Skipping update")
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ self.client.delete(collection_name, ids)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = 0,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is 0.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ embeddings = self.embedding_function(queries)
+ requests = [
+ models.SearchRequest(
+ vector=embedding,
+ limit=n_results,
+ score_threshold=distance_threshold,
+ with_payload=True,
+ with_vector=False,
+ )
+ for embedding in embeddings
+ ]
+
+ batch_results = self.client.search_batch(collection_name, requests)
+ return [self._scored_points_to_documents(results) for results in batch_results]
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is True.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ if ids is None:
+ results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
+ else:
+ results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
+ return [self._point_to_document(result) for result in results]
+
+ def _point_to_document(self, point) -> Document:
+ return {
+ "id": point.id,
+ "content": point.payload.get(self.content_payload_key, ""),
+ "metadata": point.payload.get(self.metadata_payload_key, {}),
+ "embedding": point.vector,
+ }
+
+ def _points_to_documents(self, points) -> List[Document]:
+ return [self._point_to_document(point) for point in points]
+
+ def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
+ return self._point_to_document(scored_point), scored_point.score
+
+ def _documents_to_points(self, documents: List[Document]):
+ contents = [document["content"] for document in documents]
+ embeddings = self.embedding_function(contents)
+ points = [
+ models.PointStruct(
+ id=documents[i]["id"],
+ vector=embeddings[i],
+ payload={
+ self.content_payload_key: documents[i].get("content"),
+ self.metadata_payload_key: documents[i].get("metadata"),
+ },
+ )
+ for i in range(len(documents))
+ ]
+ return points
+
+ def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
+ return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
+
+ def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validates all the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if missing_ids := set(ids) - set(retrieved_ids):
+ logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
+ return False
+
+ return True
+
+ def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validate none of the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if existing_ids := set(ids) & set(retrieved_ids):
+ logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
+ return False
+
+ return True
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index b434fc648eb1..81c666de022c 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -2526,14 +2526,16 @@ def _wrap_function(self, func: F) -> F:
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)
- log_function_use(self, func, kwargs, retval)
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
- log_function_use(self, func, kwargs, retval)
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
index af5583587f66..61a8a6335284 100644
--- a/autogen/logger/file_logger.py
+++ b/autogen/logger/file_logger.py
@@ -18,7 +18,9 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -204,7 +206,16 @@ def log_new_wrapper(
def log_new_client(
self,
- client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient,
+ client: (
+ AzureOpenAI
+ | OpenAI
+ | GeminiClient
+ | AnthropicClient
+ | MistralAIClient
+ | TogetherClient
+ | GroqClient
+ | CohereClient
+ ),
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 969a943017e3..2cf176ebb8f2 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -19,7 +19,9 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -391,7 +393,16 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st
def log_new_client(
self,
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
+ client: Union[
+ AzureOpenAI,
+ OpenAI,
+ GeminiClient,
+ AnthropicClient,
+ MistralAIClient,
+ TogetherClient,
+ GroqClient,
+ CohereClient,
+ ],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py
index 9faa4e2cb808..62078d42631d 100644
--- a/autogen/oai/anthropic.py
+++ b/autogen/oai/anthropic.py
@@ -16,6 +16,27 @@
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+Example usage for Anthropic Bedrock:
+
+Install the `anthropic` package by running `pip install --upgrade anthropic`.
+- https://docs.anthropic.com/en/docs/quickstart-guide
+
+import autogen
+
+config_list = [
+ {
+ "model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "aws_access_key":,
+ "aws_secret_key":,
+ "aws_session_token":,
+ "aws_region":"us-east-1",
+ "api_type": "anthropic",
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
"""
from __future__ import annotations
@@ -28,7 +49,7 @@
import warnings
from typing import Any, Dict, List, Tuple, Union
-from anthropic import Anthropic
+from anthropic import Anthropic, AnthropicBedrock
from anthropic import __version__ as anthropic_version
from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
@@ -49,10 +70,10 @@
"claude-3-5-sonnet-20240620": (0.003, 0.015),
"claude-3-sonnet-20240229": (0.003, 0.015),
"claude-3-opus-20240229": (0.015, 0.075),
- "claude-2.0": (0.008, 0.024),
+ "claude-3-haiku-20240307": (0.00025, 0.00125),
"claude-2.1": (0.008, 0.024),
- "claude-3.0-opus": (0.015, 0.075),
- "claude-3.0-haiku": (0.00025, 0.00125),
+ "claude-2.0": (0.008, 0.024),
+ "claude-instant-1.2": (0.008, 0.024),
}
@@ -64,14 +85,44 @@ def __init__(self, **kwargs: Any):
api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable.
"""
self._api_key = kwargs.get("api_key", None)
+ self._aws_access_key = kwargs.get("aws_access_key", None)
+ self._aws_secret_key = kwargs.get("aws_secret_key", None)
+ self._aws_session_token = kwargs.get("aws_session_token", None)
+ self._aws_region = kwargs.get("aws_region", None)
if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
- if self._api_key is None:
- raise ValueError("API key is required to use the Anthropic API.")
+ if not self._aws_access_key:
+ self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
+
+ if not self._aws_secret_key:
+ self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
+
+ if not self._aws_session_token:
+ self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
+
+ if not self._aws_region:
+ self._aws_region = os.getenv("AWS_REGION")
+
+ if self._api_key is None and (
+ self._aws_access_key is None
+ or self._aws_secret_key is None
+ or self._aws_session_token is None
+ or self._aws_region is None
+ ):
+ raise ValueError("API key or AWS credentials are required to use the Anthropic API.")
+
+ if self._api_key is not None:
+ self._client = Anthropic(api_key=self._api_key)
+ else:
+ self._client = AnthropicBedrock(
+ aws_access_key=self._aws_access_key,
+ aws_secret_key=self._aws_secret_key,
+ aws_session_token=self._aws_session_token,
+ aws_region=self._aws_region,
+ )
- self._client = Anthropic(api_key=self._api_key)
self._last_tooluse_status = {}
def load_config(self, params: Dict[str, Any]):
@@ -107,6 +158,22 @@ def cost(self, response) -> float:
def api_key(self):
return self._api_key
+ @property
+ def aws_access_key(self):
+ return self._aws_access_key
+
+ @property
+ def aws_secret_key(self):
+ return self._aws_secret_key
+
+ @property
+ def aws_session_token(self):
+ return self._aws_session_token
+
+ @property
+ def aws_region(self):
+ return self._aws_region
+
def create(self, params: Dict[str, Any]) -> Completion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
@@ -250,6 +317,7 @@ def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str,
tool_use_messages = 0
tool_result_messages = 0
last_tool_use_index = -1
+ last_tool_result_index = -1
for message in params["messages"]:
if message["role"] == "system":
params["system"] = message["content"]
@@ -290,25 +358,26 @@ def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str,
}
)
elif "tool_call_id" in message:
-
- if expected_role == "assistant":
- # Insert an extra assistant message as we will append a user message
- processed_messages.append(assistant_continue_message)
-
if has_tools:
# Map the tool usage call to tool_result for Anthropic
- processed_messages.append(
- {
- "role": "user",
- "content": [
- {
- "type": "tool_result",
- "tool_use_id": message["tool_call_id"],
- "content": message["content"],
- }
- ],
- }
- )
+ tool_result = {
+ "type": "tool_result",
+ "tool_use_id": message["tool_call_id"],
+ "content": message["content"],
+ }
+
+ # If the previous message also had a tool_result, add it to that
+ # Otherwise append a new message
+ if last_tool_result_index == len(processed_messages) - 1:
+ processed_messages[-1]["content"].append(tool_result)
+ else:
+ if expected_role == "assistant":
+ # Insert an extra assistant message as we will append a user message
+ processed_messages.append(assistant_continue_message)
+
+ processed_messages.append({"role": "user", "content": [tool_result]})
+ last_tool_result_index = len(processed_messages) - 1
+
tool_result_messages += 1
else:
# Not using tools, so put in a plain text message
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index 2c14ca0d4a0c..4e9d794a1f75 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -70,6 +70,20 @@
except ImportError as e:
together_import_exception = e
+try:
+ from autogen.oai.groq import GroqClient
+
+ groq_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ groq_import_exception = e
+
+try:
+ from autogen.oai.cohere import CohereClient
+
+ cohere_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ cohere_import_exception = e
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -483,7 +497,18 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
elif api_type is not None and api_type.startswith("together"):
if together_import_exception:
raise ImportError("Please install `together` to use the Together.AI API.")
- self._clients.append(TogetherClient(**config))
+ client = TogetherClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("groq"):
+ if groq_import_exception:
+ raise ImportError("Please install `groq` to use the Groq API.")
+ client = GroqClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("cohere"):
+ if cohere_import_exception:
+ raise ImportError("Please install `cohere` to use the Groq API.")
+ client = CohereClient(**openai_config)
+ self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
@@ -770,7 +795,7 @@ def _cost_with_customized_price(
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
if n_output_tokens is None:
n_output_tokens = 0
- return n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]
+ return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py
new file mode 100644
index 000000000000..e04d07327203
--- /dev/null
+++ b/autogen/oai/cohere.py
@@ -0,0 +1,459 @@
+"""Create an OpenAI-compatible client using Cohere's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "cohere",
+ "model": "command-r-plus",
+ "api_key": os.environ.get("COHERE_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Cohere's python library using: pip install --upgrade cohere
+
+Resources:
+- https://docs.cohere.com/reference/chat
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+import random
+import sys
+import time
+import warnings
+from typing import Any, Dict, List
+
+from cohere import Client as Cohere
+from cohere.types import ToolParameterDefinitionsValue, ToolResult
+from flaml.automl.logger import logger_formatter
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import validate_parameter
+
+logger = logging.getLogger(__name__)
+if not logger.handlers:
+ # Add the console handler.
+ _ch = logging.StreamHandler(stream=sys.stdout)
+ _ch.setFormatter(logger_formatter)
+ logger.addHandler(_ch)
+
+
+COHERE_PRICING_1K = {
+ "command-r-plus": (0.003, 0.015),
+ "command-r": (0.0005, 0.0015),
+ "command-nightly": (0.00025, 0.00125),
+ "command": (0.015, 0.075),
+ "command-light": (0.008, 0.024),
+ "command-light-nightly": (0.008, 0.024),
+}
+
+
+class CohereClient:
+ """Client for Cohere's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("COHERE_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ cohere_params = {}
+
+ # Check that we have what we need to use Cohere's API
+ # We won't enforce the available models as they are likely to change
+ cohere_params["model"] = params.get("model", None)
+ assert cohere_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Cohere model to use."
+
+ # Validate allowed Cohere parameters
+ # https://docs.cohere.com/reference/chat
+ cohere_params["temperature"] = validate_parameter(
+ params, "temperature", (int, float), False, 0.3, (0, None), None
+ )
+ cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
+ cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
+ cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ cohere_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, 0, (0, 1), None
+ )
+ cohere_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, 0, (0, 1), None
+ )
+
+ # Cohere parameters we are ignoring:
+ # preamble - we will put the system prompt in here.
+ # parallel_tool_calls (defaults to True), perfect as is.
+ # conversation_id - allows resuming a previous conversation, we don't support this.
+ logging.info("Conversation ID: %s", params.get("conversation_id", "None"))
+ # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future.
+ logging.info("Connectors: %s", params.get("connectors", "None"))
+ # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring.
+ # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG.
+ # citation_quality - used for RAG flows and dependent on other parameters we're ignoring.
+ # max_input_tokens - limits input tokens, not needed.
+ logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None"))
+ # stop_sequences - used to stop generation, not needed.
+ logging.info("Stop Sequences: %s", params.get("stop_sequences", "None"))
+
+ return cohere_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+
+ # Parse parameters to the Cohere API's parameters
+ cohere_params = self.parse_params(params)
+
+ # Convert AutoGen messages to Cohere messages
+ cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)
+
+ cohere_params["chat_history"] = cohere_messages
+ cohere_params["message"] = final_message
+ cohere_params["preamble"] = preamble
+
+ # We use chat model by default
+ client = Cohere(api_key=self.api_key)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Stream if in parameters
+ streaming = True if "stream" in params and params["stream"] else False
+ cohere_finish = ""
+
+ max_retries = 5
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ if streaming:
+ response = client.chat_stream(**cohere_params)
+ else:
+ response = client.chat(**cohere_params)
+ except CohereRateLimitError as e:
+ raise RuntimeError(f"Cohere exception occurred: {e}")
+ else:
+
+ if streaming:
+ # Streaming...
+ ans = ""
+ for event in response:
+ if event.event_type == "text-generation":
+ ans = ans + event.text
+ elif event.event_type == "tool-calls-generation":
+ # When streaming, tool calls are compiled at the end into a single event_type
+ ans = event.text
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in event.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = event.response.meta.tokens.input_tokens
+ completion_tokens = event.response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = event.response.response_id
+ else:
+ # Non-streaming finished
+ ans: str = response.text
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = response.meta.tokens.input_tokens
+ completion_tokens = response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = response.response_id
+ break
+
+ if response is not None:
+
+ response_content = ans
+
+ if streaming:
+ # Streaming response
+ if cohere_finish == "":
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.tool_calls is not None:
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.tool_calls:
+
+ # if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
+
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+ else:
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ raise RuntimeError(f"Failed to get response from Cohere after retrying {attempt + 1} times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=cohere_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
+ )
+
+ return response_oai
+
+
+def oai_messages_to_cohere_messages(
+ messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
+) -> tuple[list[dict[str, Any]], str, str]:
+ """Convert messages from OAI format to Cohere's format.
+ We correct for any specific role orders and types.
+
+ Parameters:
+ messages: list[Dict[str, Any]]: AutoGen messages
+ params: Dict[str, Any]: AutoGen parameters dictionary
+ cohere_params: Dict[str, Any]: Cohere parameters dictionary
+
+ Returns:
+ List[Dict[str, Any]]: Chat History messages
+ str: Preamble (system message)
+ str: Message (the final user message)
+ """
+
+ cohere_messages = []
+ preamble = ""
+
+ # Tools
+ if "tools" in params:
+ cohere_tools = []
+ for tool in params["tools"]:
+
+ # build list of properties
+ parameters = {}
+
+ for key, value in tool["function"]["parameters"]["properties"].items():
+ type_str = value["type"]
+ required = True # Defaults to False, we could consider leaving it as default.
+ description = value["description"]
+
+ # If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field)
+ if "enum" in value:
+ # Access the enum list
+ enum_values = value["enum"]
+ enum_strings = [str(value) for value in enum_values]
+ enum_string = ", ".join(enum_strings)
+ description = description + ". Possible values are " + enum_string + "."
+
+ parameters[key] = ToolParameterDefinitionsValue(
+ description=description, type=type_str, required=required
+ )
+
+ cohere_tool = {
+ "name": tool["function"]["name"],
+ "description": tool["function"]["description"],
+ "parameter_definitions": parameters,
+ }
+
+ cohere_tools.append(cohere_tool)
+
+ if len(cohere_tools) > 0:
+ cohere_params["tools"] = cohere_tools
+
+ tool_calls = []
+ tool_results = []
+
+ # Rules for cohere messages:
+ # no 'name' field
+ # 'system' messages go into the preamble parameter
+ # user role = 'USER'
+ # assistant role = 'CHATBOT'
+ # 'content' field renamed to 'message'
+ # tools go into tools parameter
+ # tool_results go into tool_results parameter
+ for message in messages:
+
+ if "role" in message and message["role"] == "system":
+ # System message
+ if preamble == "":
+ preamble = message["content"]
+ else:
+ preamble = preamble + "\n" + message["content"]
+ elif "tool_calls" in message:
+ # Suggested tool calls, build up the list before we put it into the tool_results
+ for tool_call in message["tool_calls"]:
+ tool_calls.append(tool_call)
+
+ # We also add the suggested tool call as a message
+ new_message = {
+ "role": "CHATBOT",
+ "message": message["content"],
+ # Not including tools in this message, may need to. Testing required.
+ }
+
+ cohere_messages.append(new_message)
+ elif "role" in message and message["role"] == "tool":
+ if "tool_call_id" in message:
+ # Convert the tool call to a result
+
+ tool_call_id = message["tool_call_id"]
+ content_output = message["content"]
+
+ # Find the original tool
+ for tool_call in tool_calls:
+ if tool_call["id"] == tool_call_id:
+
+ call = {
+ "name": tool_call["function"]["name"],
+ "parameters": json.loads(
+ tool_call["function"]["arguments"]
+ if not tool_call["function"]["arguments"] == ""
+ else "{}"
+ ),
+ }
+ output = [{"value": content_output}]
+
+ tool_results.append(ToolResult(call=call, outputs=output))
+
+ break
+ elif "content" in message and isinstance(message["content"], str):
+ # Standard text message
+ new_message = {
+ "role": "USER" if message["role"] == "user" else "CHATBOT",
+ "message": message["content"],
+ }
+
+ cohere_messages.append(new_message)
+
+ # Append any Tool Results
+ if len(tool_results) != 0:
+ cohere_params["tool_results"] = tool_results
+
+ # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use
+ cohere_params["force_single_step"] = False
+
+ # If we're adding tool_results, like we are, the last message can't be a USER message
+ # So, we add a CHATBOT 'continue' message, if so.
+ if cohere_messages[-1]["role"] == "USER":
+ cohere_messages.append({"role": "CHATBOT", "content": "Please continue."})
+
+ # We return a blank message when we have tool results
+ # TODO: Check what happens if tool_results aren't the latest message
+ return cohere_messages, preamble, ""
+
+ else:
+
+ # We need to get the last message to assign to the message field for Cohere,
+ # if the last message is a user message, use that, otherwise put in 'continue'.
+ if cohere_messages[-1]["role"] == "USER":
+ return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
+ else:
+ return cohere_messages, preamble, "Please continue."
+
+
+def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Cohere pricing."""
+ total = 0.0
+
+ if model in COHERE_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
+
+ return total
+
+
+class CohereError(Exception):
+ """Base class for other Cohere exceptions"""
+
+ pass
+
+
+class CohereRateLimitError(CohereError):
+ """Raised when rate limit is exceeded"""
+
+ pass
diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py
index 8babb8727e3c..73d41cddbf53 100644
--- a/autogen/oai/gemini.py
+++ b/autogen/oai/gemini.py
@@ -72,7 +72,7 @@ class GeminiClient:
"max_output_tokens": "max_output_tokens",
}
- def _initialize_vartexai(self, **params):
+ def _initialize_vertexai(self, **params):
if "google_application_credentials" in params:
# Path to JSON Keyfile
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
@@ -106,7 +106,7 @@ def __init__(self, **kwargs):
self.api_key = os.getenv("GOOGLE_API_KEY")
if self.api_key is None:
self.use_vertexai = True
- self._initialize_vartexai(**kwargs)
+ self._initialize_vertexai(**kwargs)
else:
self.use_vertexai = False
else:
@@ -142,7 +142,7 @@ def get_usage(response) -> Dict:
def create(self, params: Dict) -> ChatCompletion:
if self.use_vertexai:
- self._initialize_vartexai(**params)
+ self._initialize_vertexai(**params)
else:
assert ("project_id" not in params) and (
"location" not in params
diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py
new file mode 100644
index 000000000000..d2abe5116a25
--- /dev/null
+++ b/autogen/oai/groq.py
@@ -0,0 +1,282 @@
+"""Create an OpenAI-compatible client using Groq's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "groq",
+ "model": "mixtral-8x7b-32768",
+ "api_key": os.environ.get("GROQ_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Groq's python library using: pip install --upgrade groq
+
+Resources:
+- https://console.groq.com/docs/quickstart
+"""
+
+from __future__ import annotations
+
+import copy
+import os
+import time
+import warnings
+from typing import Any, Dict, List
+
+from groq import Groq, Stream
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
+GROQ_PRICING_1K = {
+ "llama3-70b-8192": (0.00059, 0.00079),
+ "mixtral-8x7b-32768": (0.00024, 0.00024),
+ "llama3-8b-8192": (0.00005, 0.00008),
+ "gemma-7b-it": (0.00007, 0.00007),
+}
+
+
+class GroqClient:
+ """Client for Groq's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("GROQ_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ groq_params = {}
+
+ # Check that we have what we need to use Groq's API
+ # We won't enforce the available models as they are likely to change
+ groq_params["model"] = params.get("model", None)
+ assert groq_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
+
+ # Validate allowed Groq parameters
+ # https://console.groq.com/docs/api-reference#chat
+ groq_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ groq_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
+ groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
+ groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+
+ # Groq parameters not supported by their models yet, ignoring
+ # logit_bias, logprobs, top_logprobs
+
+ # Groq parameters we are ignoring:
+ # n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
+ # parallel_tool_calls (defaults to True), stop
+ # function_call (deprecated), functions (deprecated)
+ # tool_choice (none if no tools, auto if there are tools)
+
+ return groq_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+
+ # Convert AutoGen messages to Groq messages
+ groq_messages = oai_messages_to_groq_messages(messages)
+
+ # Parse parameters to the Groq API's parameters
+ groq_params = self.parse_params(params)
+
+ # Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(groq_messages, params["tools"], hide_tools):
+ groq_params["tools"] = params["tools"]
+
+ groq_params["messages"] = groq_messages
+
+ # We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
+ client = Groq(api_key=self.api_key, max_retries=5)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Streaming tool call recommendations
+ streaming_tool_calls = []
+
+ ans = None
+ try:
+ response = client.chat.completions.create(**groq_params)
+ except Exception as e:
+ raise RuntimeError(f"Groq exception occurred: {e}")
+ else:
+
+ if groq_params["stream"]:
+ # Read in the chunks as they stream, taking in tool_calls which may be across
+ # multiple chunks if more than one suggested
+ ans = ""
+ for chunk in response:
+ ans = ans + (chunk.choices[0].delta.content or "")
+
+ if chunk.choices[0].delta.tool_calls:
+ # We have a tool call recommendation
+ for tool_call in chunk.choices[0].delta.tool_calls:
+ streaming_tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ type="function",
+ )
+ )
+
+ if chunk.choices[0].finish_reason:
+ prompt_tokens = chunk.x_groq.usage.prompt_tokens
+ completion_tokens = chunk.x_groq.usage.completion_tokens
+ total_tokens = chunk.x_groq.usage.total_tokens
+ else:
+ # Non-streaming finished
+ ans: str = response.choices[0].message.content
+
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ total_tokens = response.usage.total_tokens
+
+ if response is not None:
+
+ if isinstance(response, Stream):
+ # Streaming response
+ if chunk.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = streaming_tool_calls
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = ans
+ response_id = chunk.id
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = response.choices[0].message.content
+ response_id = response.id
+ else:
+ raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=groq_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
+ )
+
+ return response_oai
+
+
+def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Groq's format.
+ We correct for any specific role orders and types.
+ """
+
+ groq_messages = copy.deepcopy(messages)
+
+ # Remove the name field
+ for message in groq_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ return groq_messages
+
+
+def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Groq pricing."""
+ total = 0.0
+
+ if model in GROQ_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
+
+ return total
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index 0c8a0a413375..749727d952c0 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -96,7 +96,7 @@ def is_valid_api_key(api_key: str) -> bool:
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
- api_key_re = re.compile(r"^sk-(proj-)?[A-Za-z0-9]{32,}$")
+ api_key_re = re.compile(r"^sk-([A-Za-z0-9]+(-+[A-Za-z0-9]+)*-)?[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index adb55ba63b4f..1ffc8b622f0a 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -14,7 +14,9 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
from autogen.oai.together import TogetherClient
@@ -110,7 +112,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
def log_new_client(
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
+ client: Union[
+ AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient
+ ],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 2842a7494536..365285e09551 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -95,7 +95,7 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
@@ -166,7 +166,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
@@ -193,7 +193,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
- print(f"Warning: not supported field {field}")
+ logger.warning(f"Not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2
diff --git a/autogen/version.py b/autogen/version.py
index 110d3e10d2f0..93824aa1f87c 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.2.30"
+__version__ = "0.2.32"
diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln
index 5ecfe1938873..1218cf129821 100644
--- a/dotnet/AutoGen.sln
+++ b/dotnet/AutoGen.sln
@@ -1,4 +1,3 @@
-
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.8.34322.80
@@ -33,6 +32,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral", "src\Auto
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "test\AutoGen.Mistral.Tests\AutoGen.Mistral.Tests.csproj", "{15441693-3659-4868-B6C1-B106F52FF3BA}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI", "src\AutoGen.WebAPI\AutoGen.WebAPI.csproj", "{257FFD71-08E5-40C7-AB04-6A81A78EB410}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Tests", "test\AutoGen.WebAPI.Tests\AutoGen.WebAPI.Tests.csproj", "{E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}"
+EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
@@ -61,6 +64,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Sample", "sa
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sample\AutoGen.OpenAI.Sample\AutoGen.OpenAI.Sample.csproj", "{0E635268-351C-4A6B-A28D-593D868C2CA4}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -115,6 +122,14 @@ Global
{15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.Build.0 = Release|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -171,6 +186,14 @@ Global
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.Build.0 = Release|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -188,6 +211,8 @@ Global
{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
@@ -202,6 +227,8 @@ Global
{8EA16BAB-465A-4C07-ABC4-1070D40067E9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props
index 4b3e9441f1ee..29e40fff384c 100644
--- a/dotnet/Directory.Build.props
+++ b/dotnet/Directory.Build.props
@@ -31,6 +31,7 @@
+
diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props
index 041ee0ec6c97..f43a47c8ce27 100644
--- a/dotnet/eng/MetaInfo.props
+++ b/dotnet/eng/MetaInfo.props
@@ -1,7 +1,7 @@
- 0.0.15
+ 0.0.16
AutoGen
https://microsoft.github.io/autogen-for-net/
https://github.com/microsoft/autogen
diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props
index 0b8dcaa565cb..20be183219e5 100644
--- a/dotnet/eng/Version.props
+++ b/dotnet/eng/Version.props
@@ -2,8 +2,8 @@
1.0.0-beta.17
- 1.10.0
- 1.10.0-alpha
+ 1.15.1
+ 1.15.1-alpha
5.0.0
4.3.0
6.0.0
@@ -12,6 +12,7 @@
17.7.0
1.0.0-beta.24229.4
8.0.0
+ 8.0.4
3.0.0
4.3.0.2
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
index 33a5aa7f16b6..2948c9bf283c 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
@@ -13,6 +13,7 @@
+
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
similarity index 93%
rename from dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs
rename to dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
index 94b5f37511e6..6f32c3cb4a21 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// AnthropicSamples.cs
+// Create_Anthropic_Agent.cs
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
@@ -7,7 +7,7 @@
namespace AutoGen.Anthropic.Samples;
-public static class AnthropicSamples
+public static class Create_Anthropic_Agent
{
public static async Task RunAsync()
{
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
new file mode 100644
index 000000000000..0324a39ffa59
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
@@ -0,0 +1,100 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Anthropic_Agent_With_Tool.cs
+
+using AutoGen.Anthropic.DTO;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+using FluentAssertions;
+
+namespace AutoGen.Anthropic.Samples;
+
+#region WeatherFunction
+
+public partial class WeatherFunction
+{
+ ///
+ /// Gets the weather based on the location and the unit
+ ///
+ ///
+ ///
+ ///
+ [Function]
+ public async Task GetWeather(string location, string unit)
+ {
+ // dummy implementation
+ return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
+ }
+}
+#endregion
+public class Create_Anthropic_Agent_With_Tool
+{
+ public static async Task RunAsync()
+ {
+ #region define_tool
+ var tool = new Tool
+ {
+ Name = "GetWeather",
+ Description = "Get the current weather in a given location",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
+ { "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
+ },
+ Required = new List { "location" }
+ }
+ };
+
+ var weatherFunction = new WeatherFunction();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [
+ weatherFunction.GetWeatherFunctionContract,
+ ],
+ functionMap: new Dictionary>>
+ {
+ { weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
+ });
+
+ #endregion
+
+ #region create_anthropic_agent
+
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
+ throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");
+
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
+ tools: [tool]); // Define tools for AnthropicClientAgent
+ #endregion
+
+ #region register_middleware
+
+ var agentWithConnector = agent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage()
+ .RegisterStreamingMiddleware(functionMiddleware);
+ #endregion register_middleware
+
+ #region single_turn
+ var question = new TextMessage(Role.Assistant,
+ "What is the weather like in San Francisco?",
+ from: "user");
+ var functionCallReply = await agentWithConnector.SendAsync(question);
+ #endregion
+
+ #region Single_turn_verify_reply
+ functionCallReply.Should().BeOfType();
+ #endregion Single_turn_verify_reply
+
+ #region Multi_turn
+ var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
+ #endregion Multi_turn
+
+ #region Multi_turn_verify_reply
+ finalReply.Should().BeOfType();
+ #endregion Multi_turn_verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
index f3c615088610..6d1e4e594b99 100644
--- a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
@@ -7,6 +7,6 @@ internal static class Program
{
public static async Task Main(string[] args)
{
- await AnthropicSamples.RunAsync();
+ await Create_Anthropic_Agent_With_Tool.RunAsync();
}
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
index 4833c6195c9d..a103f4ec2d4d 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
@@ -129,7 +129,7 @@ public async Task CodeSnippet5()
},
functionMap: new Dictionary>>
{
- { this.UpperCaseFunction.Name, this.UpperCaseWrapper }, // The wrapper function for the UpperCase function
+ { this.UpperCaseFunctionContract.Name, this.UpperCaseWrapper }, // The wrapper function for the UpperCase function
});
var response = await assistantAgent.SendAsync("hello");
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
index 320afd0de679..1b5a9a903207 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
@@ -13,38 +13,46 @@ public class MiddlewareAgentCodeSnippet
public async Task CreateMiddlewareAgentAsync()
{
#region create_middleware_agent_with_original_agent
- // Create an agent that always replies "Hello World"
- IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hello World");
+ // Create an agent that always replies "Hi!"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hi!");
// Create a middleware agent on top of default reply agent
var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
return await agent.GenerateReplyAsync(messages, options, ct);
});
var reply = await middlewareAgent.SendAsync("Hello World");
reply.GetContent().Should().Be("[middleware 0] Hello World");
+ reply = await middlewareAgent.SendAsync("Hello AI!");
+ reply.GetContent().Should().Be("Hi!");
#endregion create_middleware_agent_with_original_agent
#region register_middleware_agent
middlewareAgent = agent.RegisterMiddleware(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
return await agent.GenerateReplyAsync(messages, options, ct);
});
#endregion register_middleware_agent
#region short_circuit_middleware_agent
- // This middleware will short circuit the agent and return the last message directly.
+ // This middleware will short circuit the agent and return a message directly.
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
- var lastMessage = messages.Last() as TextMessage;
- lastMessage.Content = $"[middleware shortcut]";
- return lastMessage;
+ return new TextMessage(Role.Assistant, $"[middleware shortcut]");
});
#endregion short_circuit_middleware_agent
}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
index 47dd8ce66c90..216059928408 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
@@ -17,14 +17,18 @@ public static async Task RunAsync()
// setup dotnet interactive
var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
if (!Directory.Exists(workDir))
+ {
Directory.CreateDirectory(workDir);
+ }
using var service = new InteractiveService(workDir);
var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
var result = Path.Combine(workDir, "result.txt");
if (File.Exists(result))
+ {
File.Delete(result);
+ }
await service.StartAsync(workDir, default);
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
index 6584baa5fae5..004e0f055449 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
@@ -8,6 +8,7 @@
using AutoGen.Core;
using AutoGen.DotnetInteractive;
using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
using FluentAssertions;
public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
@@ -138,7 +139,7 @@ public static async Task CreateReviewerAgentAsync()
name: "code_reviewer",
systemMessage: @"You review code block from coder",
config: gpt3Config,
- functions: [functions.ReviewCodeBlockFunction],
+ functions: [functions.ReviewCodeBlockFunctionContract.ToOpenAIFunctionDefinition()],
functionMap: new Dictionary>>()
{
{ nameof(ReviewCodeBlock), functions.ReviewCodeBlockWrapper },
@@ -224,7 +225,9 @@ public static async Task RunWorkflowAsync()
long the39thFibonacciNumber = 63245986;
var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
if (!Directory.Exists(workDir))
+ {
Directory.CreateDirectory(workDir);
+ }
using var service = new InteractiveService(workDir);
var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
@@ -328,7 +331,9 @@ public static async Task RunAsync()
long the39thFibonacciNumber = 63245986;
var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
if (!Directory.Exists(workDir))
+ {
Directory.CreateDirectory(workDir);
+ }
using var service = new InteractiveService(workDir);
var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
index 9a62144df2bd..c9dda27d2e23 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
@@ -5,6 +5,7 @@
using System.Text.Json.Serialization;
using AutoGen.Core;
using AutoGen.LMStudio;
+using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
namespace AutoGen.BasicSample;
@@ -69,8 +70,8 @@ public static async Task RunAsync()
// And ask agent to response in function call object format using few-shot example
object[] functionList =
[
- SerializeFunctionDefinition(instance.GetWeatherFunction),
- SerializeFunctionDefinition(instance.GoogleSearchFunction)
+ SerializeFunctionDefinition(instance.GetWeatherFunctionContract.ToOpenAIFunctionDefinition()),
+ SerializeFunctionDefinition(instance.GetWeatherFunctionContract.ToOpenAIFunctionDefinition())
];
var functionListString = JsonSerializer.Serialize(functionList, new JsonSerializerOptions { WriteIndented = true });
var lmAgent = new LMStudioAgent(
@@ -98,12 +99,12 @@ You are a helpful AI assistant
{
var arguments = JsonSerializer.Serialize(functionCall.Arguments);
// invoke function wrapper
- if (functionCall.Name == instance.GetWeatherFunction.Name)
+ if (functionCall.Name == instance.GetWeatherFunctionContract.Name)
{
var result = await instance.GetWeatherWrapper(arguments);
return new TextMessage(Role.Assistant, result);
}
- else if (functionCall.Name == instance.GoogleSearchFunction.Name)
+ else if (functionCall.Name == instance.GetWeatherFunctionContract.Name)
{
var result = await instance.GoogleSearchWrapper(arguments);
return new TextMessage(Role.Assistant, result);
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
index dadad7f00b99..596ab08d02a1 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
@@ -1,68 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example13_OpenAIAgent_JsonMode.cs
-using System.Text.Json;
-using System.Text.Json.Serialization;
-using AutoGen.Core;
-using AutoGen.OpenAI;
-using AutoGen.OpenAI.Extension;
-using Azure.AI.OpenAI;
-using FluentAssertions;
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
-namespace AutoGen.BasicSample;
-
-public class Example13_OpenAIAgent_JsonMode
-{
- public static async Task RunAsync()
- {
- #region create_agent
- var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(deployName: "gpt-35-turbo"); // json mode only works with 0125 and later model.
- var apiKey = config.ApiKey;
- var endPoint = new Uri(config.Endpoint);
-
- var openAIClient = new OpenAIClient(endPoint, new Azure.AzureKeyCredential(apiKey));
- var openAIClientAgent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "assistant",
- modelName: config.DeploymentName,
- systemMessage: "You are a helpful assistant designed to output JSON.",
- seed: 0, // explicitly set a seed to enable deterministic output
- responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
- .RegisterMessageConnector()
- .RegisterPrintMessage();
- #endregion create_agent
-
- #region chat_with_agent
- var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
-
- var person = JsonSerializer.Deserialize(reply.GetContent());
- Console.WriteLine($"Name: {person.Name}");
- Console.WriteLine($"Age: {person.Age}");
-
- if (!string.IsNullOrEmpty(person.Address))
- {
- Console.WriteLine($"Address: {person.Address}");
- }
-
- Console.WriteLine("Done.");
- #endregion chat_with_agent
-
- person.Name.Should().Be("John");
- person.Age.Should().Be(25);
- person.Address.Should().BeNullOrEmpty();
- }
-}
-
-#region person_class
-public class Person
-{
- [JsonPropertyName("name")]
- public string Name { get; set; }
-
- [JsonPropertyName("age")]
- public int Age { get; set; }
-
- [JsonPropertyName("address")]
- public string Address { get; set; }
-}
-#endregion person_class
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
index 788122d3f383..dee9915511d6 100644
--- a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
@@ -50,7 +50,9 @@ private static void AddMessagesFromResource(string imageResourcePath, List SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
- {
- request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
-
- return base.SendAsync(request, cancellationToken);
- }
-}
-#endregion CustomHttpClientHandler
-
-public class Example16_OpenAIChatAgent_ConnectToThirdPartyBackend
-{
- public static async Task RunAsync()
- {
- #region create_agent
- using var client = new HttpClient(new CustomHttpClientHandler("http://localhost:11434"));
- var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
- {
- Transport = new HttpClientTransport(client),
- };
-
- // api-key is not required for local server
- // so you can use any string here
- var openAIClient = new OpenAIClient("api-key", option);
- var model = "llama3";
-
- var agent = new OpenAIChatAgent(
- openAIClient: openAIClient,
- name: "assistant",
- modelName: model,
- systemMessage: "You are a helpful assistant designed to output JSON.",
- seed: 0)
- .RegisterMessageConnector()
- .RegisterPrintMessage();
- #endregion create_agent
-
- #region send_message
- await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
- #endregion send_message
- }
-}
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
index 9d21bbde7d30..7acaae4b1f82 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Dynamic_GroupChat.cs
+// Dynamic_Group_Chat.cs
using AutoGen.Core;
using AutoGen.OpenAI;
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
index 3352f90d9211..5b94a238bbe8 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Image_Chat_With_Agent.cs
+#region Using
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
+#endregion Using
using FluentAssertions;
namespace AutoGen.BasicSample;
@@ -33,16 +35,17 @@ public static async Task RunAsync()
var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(imageBytes, "image/png"));
#endregion Prepare_Image_Input
- #region Chat_With_Agent
- var reply = await agent.SendAsync("what's in the picture", chatHistory: [imageMessage]);
- #endregion Chat_With_Agent
-
#region Prepare_Multimodal_Input
var textMessage = new TextMessage(Role.User, "what's in the picture");
var multimodalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
- reply = await agent.SendAsync(multimodalMessage);
#endregion Prepare_Multimodal_Input
+ #region Chat_With_Agent
+ var reply = await agent.SendAsync("what's in the picture", chatHistory: [imageMessage]);
+ // or use multimodal message to generate reply
+ reply = await agent.SendAsync(multimodalMessage);
+ #endregion Chat_With_Agent
+
#region verify_reply
reply.Should().BeOfType();
#endregion verify_reply
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
new file mode 100644
index 000000000000..48ebd127b562
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
@@ -0,0 +1,56 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Streaming_Tool_Call.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.BasicSample.GettingStart;
+
+internal class Streaming_Tool_Call
+{
+ public static async Task RunAsync()
+ {
+ #region Create_tools
+ var tools = new Tools();
+ #endregion Create_tools
+
+ #region Create_auto_invoke_middleware
+ var autoInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { tools.GetWeatherFunctionContract.Name, tools.GetWeatherWrapper },
+ });
+ #endregion Create_auto_invoke_middleware
+
+ #region Create_Agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o";
+ var openaiClient = new OpenAIClient(apiKey);
+ var agent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "agent",
+ modelName: model,
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(autoInvokeMiddleware)
+ .RegisterPrintMessage();
+ #endregion Create_Agent
+
+ IMessage finalReply = null;
+ var question = new TextMessage(Role.User, "What's the weather in Seattle");
+
+ // In streaming function call
+ // function can only be invoked untill all the chunks are collected
+ // therefore, only one ToolCallAggregateMessage chunk will be return here.
+ await foreach (var message in agent.GenerateStreamingReplyAsync([question]))
+ {
+ finalReply = message;
+ }
+
+ finalReply?.GetContent().Should().Be("The weather in Seattle is sunny.");
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
index f1a230c123b1..b441fe389da2 100644
--- a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
@@ -11,6 +11,7 @@
namespace AutoGen.BasicSample;
+#region Tools
public partial class Tools
{
///
@@ -23,6 +24,8 @@ public async Task GetWeather(string city)
return $"The weather in {city} is sunny.";
}
}
+#endregion Tools
+
public class Use_Tools_With_Agent
{
public static async Task RunAsync()
@@ -31,37 +34,53 @@ public static async Task RunAsync()
var tools = new Tools();
#endregion Create_tools
- #region Create_Agent
- var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
- var model = "gpt-3.5-turbo";
- var openaiClient = new OpenAIClient(apiKey);
- var functionCallMiddleware = new FunctionCallMiddleware(
+ #region Create_auto_invoke_middleware
+ var autoInvokeMiddleware = new FunctionCallMiddleware(
functions: [tools.GetWeatherFunctionContract],
functionMap: new Dictionary>>()
{
{ tools.GetWeatherFunctionContract.Name!, tools.GetWeatherWrapper },
});
+ #endregion Create_auto_invoke_middleware
+
+ #region Create_no_invoke_middleware
+ var noInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract]);
+ #endregion Create_no_invoke_middleware
+
+ #region Create_Agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-3.5-turbo";
+ var openaiClient = new OpenAIClient(apiKey);
var agent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "agent",
modelName: model,
systemMessage: "You are a helpful AI assistant")
- .RegisterMessageConnector() // convert OpenAI message to AutoGen message
- .RegisterMiddleware(functionCallMiddleware) // pass function definition to agent.
- .RegisterPrintMessage(); // print the message content
+ .RegisterMessageConnector(); // convert OpenAI message to AutoGen message
#endregion Create_Agent
- #region Single_Turn_Tool_Call
+ #region Single_Turn_Auto_Invoke
+ var autoInvokeAgent = agent
+ .RegisterMiddleware(autoInvokeMiddleware) // pass function definition to agent.
+ .RegisterPrintMessage(); // print the message content
var question = new TextMessage(Role.User, "What is the weather in Seattle?");
- var toolCallReply = await agent.SendAsync(question);
- #endregion Single_Turn_Tool_Call
+ var reply = await autoInvokeAgent.SendAsync(question);
+ reply.Should().BeOfType();
+ #endregion Single_Turn_Auto_Invoke
+
+ #region Single_Turn_No_Invoke
+ var noInvokeAgent = agent
+ .RegisterMiddleware(noInvokeMiddleware) // pass function definition to agent.
+ .RegisterPrintMessage(); // print the message content
- #region verify_too_call_reply
- toolCallReply.Should().BeOfType();
- #endregion verify_too_call_reply
+ question = new TextMessage(Role.User, "What is the weather in Seattle?");
+ reply = await noInvokeAgent.SendAsync(question);
+ reply.Should().BeOfType();
+ #endregion Single_Turn_No_Invoke
#region Multi_Turn_Tool_Call
- var finalReply = await agent.SendAsync(chatHistory: [question, toolCallReply]);
+ var finalReply = await agent.SendAsync(chatHistory: [question, reply]);
#endregion Multi_Turn_Tool_Call
#region verify_reply
@@ -70,16 +89,19 @@ public static async Task RunAsync()
#region parallel_tool_call
question = new TextMessage(Role.User, "What is the weather in Seattle, New York and Vancouver");
- toolCallReply = await agent.SendAsync(question);
+ reply = await agent.SendAsync(question);
#endregion parallel_tool_call
#region verify_parallel_tool_call_reply
- toolCallReply.Should().BeOfType();
- (toolCallReply as ToolCallAggregateMessage)!.Message1.ToolCalls.Count().Should().Be(3);
+ reply.Should().BeOfType();
+ (reply as ToolCallAggregateMessage)!.Message1.ToolCalls.Count().Should().Be(3);
#endregion verify_parallel_tool_call_reply
#region Multi_Turn_Parallel_Tool_Call
- finalReply = await agent.SendAsync(chatHistory: [question, toolCallReply]);
+ finalReply = await agent.SendAsync(chatHistory: [question, reply]);
+ finalReply.Should().BeOfType();
+ (finalReply as ToolCallAggregateMessage)!.Message1.ToolCalls.Count().Should().Be(3);
#endregion Multi_Turn_Parallel_Tool_Call
}
+
}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
new file mode 100644
index 000000000000..ffe18f8a616a
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
@@ -0,0 +1,21 @@
+
+
+
+ Exe
+ net8.0
+ enable
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ true
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
new file mode 100644
index 000000000000..3823de2a5284
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
@@ -0,0 +1,62 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Connect_To_Ollama.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using Azure.Core.Pipeline;
+#endregion using_statement
+
+namespace AutoGen.OpenAI.Sample;
+
+#region CustomHttpClientHandler
+public sealed class CustomHttpClientHandler : HttpClientHandler
+{
+ private string _modelServiceUrl;
+
+ public CustomHttpClientHandler(string modelServiceUrl)
+ {
+ _modelServiceUrl = modelServiceUrl;
+ }
+
+ protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
+
+ return base.SendAsync(request, cancellationToken);
+ }
+}
+#endregion CustomHttpClientHandler
+
+public class Connect_To_Ollama
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ using var client = new HttpClient(new CustomHttpClientHandler("http://localhost:11434"));
+ var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
+ {
+ Transport = new HttpClientTransport(client),
+ };
+
+ // api-key is not required for local server
+ // so you can use any string here
+ var openAIClient = new OpenAIClient("api-key", option);
+ var model = "llama3";
+
+ var agent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: model,
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
new file mode 100644
index 000000000000..5a38a3ff03b9
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.OpenAI.Sample;
+
+Tool_Call_With_Ollama_And_LiteLLM.RunAsync().Wait();
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
new file mode 100644
index 000000000000..b0b0adc0e6f5
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool_Call_With_Ollama_And_LiteLLM.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using Azure.Core.Pipeline;
+
+namespace AutoGen.OpenAI.Sample;
+
+#region Function
+public partial class Function
+{
+ [Function]
+ public async Task GetWeatherAsync(string city)
+ {
+ return await Task.FromResult("The weather in " + city + " is 72 degrees and sunny.");
+ }
+}
+#endregion Function
+
+public class Tool_Call_With_Ollama_And_LiteLLM
+{
+ public static async Task RunAsync()
+ {
+ // Before running this code, make sure you have
+ // - Ollama:
+ // - Install dolphincoder:latest in Ollama
+ // - Ollama running on http://localhost:11434
+ // - LiteLLM
+ // - Install LiteLLM
+ // - Start LiteLLM with the following command:
+ // - litellm --model ollama_chat/dolphincoder --port 4000
+
+ # region Create_tools
+ var functions = new Function();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [functions.GetWeatherAsyncFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { functions.GetWeatherAsyncFunctionContract.Name!, functions.GetWeatherAsyncWrapper },
+ });
+ #endregion Create_tools
+ #region Create_Agent
+ var liteLLMUrl = "http://localhost:4000";
+ using var httpClient = new HttpClient(new CustomHttpClientHandler(liteLLMUrl));
+ var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
+ {
+ Transport = new HttpClientTransport(httpClient),
+ };
+
+ // api-key is not required for local server
+ // so you can use any string here
+ var openAIClient = new OpenAIClient("api-key", option);
+
+ var agent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: "dolphincoder:latest",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionMiddleware)
+ .RegisterPrintMessage();
+
+ var reply = await agent.SendAsync("what's the weather in new york");
+ #endregion Create_Agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
new file mode 100644
index 000000000000..d92983c5050f
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
@@ -0,0 +1,67 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Json_Mode.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Use_Json_Mode
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-3.5-turbo";
+
+ var openAIClient = new OpenAIClient(apiKey);
+ var openAIClientAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: model,
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0, // explicitly set a seed to enable deterministic output
+ responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ }
+}
+
+#region person_class
+public class Person
+{
+ [JsonPropertyName("name")]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ public int Age { get; set; }
+
+ [JsonPropertyName("address")]
+ public string Address { get; set; }
+}
+#endregion person_class
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
new file mode 100644
index 000000000000..41f3b7d1d381
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
@@ -0,0 +1,13 @@
+
+
+
+ net8.0
+ enable
+ enable
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
new file mode 100644
index 000000000000..dbeb8494363d
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using System.Runtime.CompilerServices;
+using AutoGen.Core;
+using AutoGen.WebAPI;
+
+var alice = new DummyAgent("alice");
+var bob = new DummyAgent("bob");
+
+var builder = WebApplication.CreateBuilder(args);
+// Add services to the container.
+
+// run endpoint at port 5000
+builder.WebHost.UseUrls("http://localhost:5000");
+var app = builder.Build();
+
+app.UseAgentAsOpenAIChatCompletionEndpoint(alice);
+app.UseAgentAsOpenAIChatCompletionEndpoint(bob);
+
+app.Run();
+
+public class DummyAgent : IStreamingAgent
+{
+ public DummyAgent(string name = "dummy")
+ {
+ Name = name;
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ return new TextMessage(Role.Assistant, $"I am dummy {this.Name}", this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var reply = $"I am dummy {this.Name}";
+ foreach (var c in reply)
+ {
+ yield return new TextMessageUpdate(Role.Assistant, c.ToString(), this.Name);
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
index e395bb4a225f..73510baeb71c 100644
--- a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
+++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
@@ -1,5 +1,9 @@
-using System;
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClientAgent.cs
+
+using System;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -16,6 +20,8 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
+ private readonly Tool[]? _tools;
+ private readonly ToolChoice? _toolChoice;
public AnthropicClientAgent(
AnthropicClient anthropicClient,
@@ -23,7 +29,9 @@ public AnthropicClientAgent(
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
- int maxTokens = 1024)
+ int maxTokens = 1024,
+ Tool[]? tools = null,
+ ToolChoice? toolChoice = null)
{
Name = name;
_anthropicClient = anthropicClient;
@@ -31,6 +39,8 @@ public AnthropicClientAgent(
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
+ _tools = tools;
+ _toolChoice = toolChoice;
}
public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null,
@@ -40,7 +50,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return new MessageEnvelope(response, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages,
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages,
GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync(
@@ -59,6 +69,9 @@ private ChatCompletionRequest CreateParameters(IEnumerable messages, G
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
+ Tools = _tools?.ToList(),
+ ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
+ StopSequences = options?.StopSequence?.ToArray(),
};
chatCompletionRequest.Messages = BuildMessages(messages);
@@ -86,6 +99,22 @@ private List BuildMessages(IEnumerable messages)
}
}
- return chatMessages;
+ // merge messages with the same role
+ // fixing #2884
+ var mergedMessages = chatMessages.Aggregate(new List(), (acc, message) =>
+ {
+ if (acc.Count > 0 && acc.Last().Role == message.Role)
+ {
+ acc.Last().Content.AddRange(message.Content);
+ }
+ else
+ {
+ acc.Add(message);
+ }
+
+ return acc;
+ });
+
+ return mergedMessages;
}
}
diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
index 90bd33683f20..c58b2c1952ed 100644
--- a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
+++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClient.cs
using System;
@@ -24,12 +24,12 @@ public sealed class AnthropicClient : IDisposable
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
- Converters = { new ContentBaseConverter() }
+ Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter() }
};
private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
{
- Converters = { new ContentBaseConverter() }
+ Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter() }
};
public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey)
@@ -48,7 +48,9 @@ public async Task CreateChatCompletionsAsync(ChatComplet
var responseStream = await httpResponseMessage.Content.ReadAsStreamAsync();
if (httpResponseMessage.IsSuccessStatusCode)
+ {
return await DeserializeResponseAsync(responseStream, cancellationToken);
+ }
ErrorResponse res = await DeserializeResponseAsync(responseStream, cancellationToken);
throw new Exception(res.Error?.Message);
@@ -61,24 +63,58 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());
var currentEvent = new SseEvent();
+
while (await reader.ReadLineAsync() is { } line)
{
if (!string.IsNullOrEmpty(line))
{
- currentEvent.Data = line.Substring("data:".Length).Trim();
+ if (line.StartsWith("event:"))
+ {
+ currentEvent.EventType = line.Substring("event:".Length).Trim();
+ }
+ else if (line.StartsWith("data:"))
+ {
+ currentEvent.Data = line.Substring("data:".Length).Trim();
+ }
}
- else
+ else // an empty line indicates the end of an event
{
- if (currentEvent.Data == "[DONE]")
- continue;
+ if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
+ {
+ var dataBlock = JsonSerializer.Deserialize(currentEvent.Data!);
+ if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
+ {
+ currentEvent.ContentBlock = dataBlock.ContentBlock;
+ }
+ }
- if (currentEvent.Data != null)
+ if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
{
- yield return await JsonSerializer.DeserializeAsync(
+ var res = await JsonSerializer.DeserializeAsync(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
+ if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
+ currentEvent.ContentBlock != null)
+ {
+ currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
+ }
+ else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
+ {
+ if (res.Content == null)
+ {
+ res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
+ }
+ else
+ {
+ res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
+ }
+
+ currentEvent = new SseEvent();
+ }
+
+ yield return res;
}
- else if (currentEvent.Data != null)
+ else if (currentEvent.EventType == "error" && currentEvent.Data != null)
{
var res = await JsonSerializer.DeserializeAsync(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);
@@ -86,8 +122,10 @@ public async IAsyncEnumerable StreamingChatCompletionsAs
throw new Exception(res?.Error?.Message);
}
- // Reset the current event for the next one
- currentEvent = new SseEvent();
+ if (currentEvent.ContentBlock == null)
+ {
+ currentEvent = new SseEvent();
+ }
}
}
}
@@ -113,11 +151,50 @@ public void Dispose()
private struct SseEvent
{
+ public string EventType { get; set; }
public string? Data { get; set; }
+ public ContentBlock? ContentBlock { get; set; }
- public SseEvent(string? data = null)
+ public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
{
+ EventType = eventType;
Data = data;
+ ContentBlock = contentBlock;
}
}
+
+ private class ContentBlock
+ {
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public object? Input { get; set; }
+
+ public string? parameters { get; set; }
+
+ public void AppendDeltaParameters(string deltaParams)
+ {
+ StringBuilder sb = new StringBuilder(parameters);
+ sb.Append(deltaParams);
+ parameters = sb.ToString();
+ }
+
+ public ToolUseContent CreateToolUseContent()
+ {
+ return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
+ }
+ }
+
+ private class DataBlock
+ {
+ [JsonPropertyName("content_block")]
+ public ContentBlock? ContentBlock { get; set; }
+ }
}
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
index 4cb8fdbb34e0..3e620f934c28 100644
--- a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
+++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
@@ -1,12 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// ContentConverter.cs
-
-using AutoGen.Anthropic.DTO;
-
+// ContentBaseConverter.cs
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
+using AutoGen.Anthropic.DTO;
namespace AutoGen.Anthropic.Converters;
public sealed class ContentBaseConverter : JsonConverter
@@ -24,6 +22,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_use":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_result":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
}
}
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
new file mode 100644
index 000000000000..cd95d837cffd
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// JsonPropertyNameEnumCoverter.cs
+
+using System;
+using System.Reflection;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.Converters;
+
+internal class JsonPropertyNameEnumConverter : JsonConverter where T : struct, Enum
+{
+ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ string value = reader.GetString() ?? throw new JsonException("Value was null.");
+
+ foreach (var field in typeToConvert.GetFields())
+ {
+ var attribute = field.GetCustomAttribute();
+ if (attribute?.Name == value)
+ {
+ return (T)Enum.Parse(typeToConvert, field.Name);
+ }
+ }
+
+ throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
+ }
+
+ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
+ {
+ var field = value.GetType().GetField(value.ToString());
+ var attribute = field.GetCustomAttribute();
+
+ if (attribute != null)
+ {
+ writer.WriteStringValue(attribute.Name);
+ }
+ else
+ {
+ writer.WriteStringValue(value.ToString());
+ }
+ }
+}
+
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
index 0c1749eaa989..463ee7fc2595 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionRequest.cs
-using System.Text.Json.Serialization;
using System.Collections.Generic;
+using System.Text.Json.Serialization;
namespace AutoGen.Anthropic.DTO;
@@ -37,6 +37,12 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }
+ [JsonPropertyName("tools")]
+ public List? Tools { get; set; }
+
+ [JsonPropertyName("tool_choice")]
+ public ToolChoice? ToolChoice { get; set; }
+
public ChatCompletionRequest()
{
Messages = new List();
@@ -62,4 +68,6 @@ public ChatMessage(string role, List content)
Role = role;
Content = content;
}
+
+ public void AddContent(ContentBase content) => Content.Add(content);
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
index c6861f9c3150..fc33aa0e26b1 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
@@ -1,10 +1,11 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionResponse.cs
-namespace AutoGen.Anthropic.DTO;
using System.Collections.Generic;
using System.Text.Json.Serialization;
+namespace AutoGen.Anthropic.DTO;
public class ChatCompletionResponse
{
[JsonPropertyName("content")]
@@ -49,9 +50,6 @@ public class StreamingMessage
[JsonPropertyName("role")]
public string? Role { get; set; }
- [JsonPropertyName("content")]
- public List? Content { get; set; }
-
[JsonPropertyName("model")]
public string? Model { get; set; }
@@ -85,6 +83,9 @@ public class Delta
[JsonPropertyName("text")]
public string? Text { get; set; }
+ [JsonPropertyName("partial_json")]
+ public string? PartialJson { get; set; }
+
[JsonPropertyName("usage")]
public Usage? Usage { get; set; }
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
index dd2481bd58f3..353cf6ae824b 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
@@ -1,6 +1,7 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// Content.cs
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
namespace AutoGen.Anthropic.DTO;
@@ -40,3 +41,30 @@ public class ImageSource
[JsonPropertyName("data")]
public string? Data { get; set; }
}
+
+public class ToolUseContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_use";
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public JsonNode? Input { get; set; }
+}
+
+public class ToolResultContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_result";
+
+ [JsonPropertyName("tool_use_id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
index d02a8f6d1cfc..1a94334c88ff 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// ErrorResponse.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
new file mode 100644
index 000000000000..2a46bc42a35b
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public class Tool
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+
+ [JsonPropertyName("input_schema")]
+ public InputSchema? InputSchema { get; set; }
+}
+
+public class InputSchema
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("properties")]
+ public Dictionary? Properties { get; set; }
+
+ [JsonPropertyName("required")]
+ public List? Required { get; set; }
+}
+
+public class SchemaProperty
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
new file mode 100644
index 000000000000..0a5c3790e1de
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ToolChoice.cs
+
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.Converters;
+
+namespace AutoGen.Anthropic.DTO;
+
+[JsonConverter(typeof(JsonPropertyNameEnumConverter))]
+public enum ToolChoiceType
+{
+ [JsonPropertyName("auto")]
+ Auto, // Default behavior
+
+ [JsonPropertyName("any")]
+ Any, // Use any provided tool
+
+ [JsonPropertyName("tool")]
+ Tool // Force a specific tool
+}
+
+public class ToolChoice
+{
+ [JsonPropertyName("type")]
+ public ToolChoiceType Type { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ private ToolChoice(ToolChoiceType type, string? name = null)
+ {
+ Type = type;
+ Name = name;
+ }
+
+ public static ToolChoice Auto => new(ToolChoiceType.Auto);
+ public static ToolChoice Any => new(ToolChoiceType.Any);
+ public static ToolChoice ToolUse(string name) => new(ToolChoiceType.Tool, name);
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
index bb2f5820f74c..af06a0547849 100644
--- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
+++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
@@ -6,6 +6,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
+using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Anthropic.DTO;
@@ -28,7 +29,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
: response;
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
@@ -36,7 +37,7 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
- if (reply is IStreamingMessage chatMessage)
+ if (reply is IMessage chatMessage)
{
var response = ProcessChatCompletionResponse(chatMessage, agent);
if (response is not null)
@@ -51,9 +52,20 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c
}
}
- private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage chatMessage,
+ private IMessage? ProcessChatCompletionResponse(IMessage chatMessage,
IStreamingAgent agent)
{
+ if (chatMessage.Content.Content is { Count: 1 } &&
+ chatMessage.Content.Content[0] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(
+ toolUseContent.Name ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Name)} to be specified"),
+ toolUseContent.Input?.ToString() ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Input)} to be specified"),
+ from: agent.Name);
+ }
+
var delta = chatMessage.Content.Delta;
return delta != null && !string.IsNullOrEmpty(delta.Text)
? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
@@ -71,16 +83,20 @@ private async Task> ProcessMessageAsync(IEnumerable ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage =>
- new MessageEnvelope(new ChatMessage("user",
+ (MessageEnvelope[])[new MessageEnvelope(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
- from: agent.Name),
+ from: agent.Name)],
MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
- _ => message,
+
+ ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
+ _ => [message],
};
- processedMessages.Add(processedMessage);
+ processedMessages.AddRange(processedMessage);
}
return processedMessages;
@@ -93,15 +109,42 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
throw new ArgumentNullException(nameof(response.Content));
}
- if (response.Content.Count != 1)
+ // When expecting a tool call, sometimes the response will contain two messages, one chat and one tool.
+ // The first message is typically a TextContent, of the LLM explaining what it is trying to do.
+ // The second message contains the tool call.
+ if (response.Content.Count > 1)
{
- throw new NotSupportedException($"{nameof(response.Content)} != 1");
+ if (response.Content.Count == 2 && response.Content[0] is TextContent &&
+ response.Content[1] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+ }
+
+ throw new NotSupportedException($"Expected {nameof(response.Content)} to have one output");
}
- return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
+ var content = response.Content[0];
+ switch (content)
+ {
+ case TextContent textContent:
+ return new TextMessage(Role.Assistant, textContent.Text ?? string.Empty, from: from.Name);
+
+ case ToolUseContent toolUseContent:
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+
+ case ImageContent:
+ throw new InvalidOperationException(
+ "Claude is an image understanding model only. It can interpret and analyze images, but it cannot generate, produce, edit, manipulate or create images");
+ default:
+ throw new ArgumentOutOfRangeException(nameof(content));
+ }
}
- private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent)
+ private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
ChatMessage messages;
@@ -139,10 +182,10 @@ private IMessage ProcessTextMessage(TextMessage textMessage, IAgent
"user", textMessage.Content);
}
- return new MessageEnvelope(messages, from: textMessage.From);
+ return [new MessageEnvelope(messages, from: textMessage.From)];
}
- private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
+ private async Task> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List();
foreach (var message in multiModalMessage.Content)
@@ -158,8 +201,7 @@ private async Task ProcessMultiModalMessageAsync(MultiModalMessage mul
}
}
- var chatMessage = new ChatMessage("user", content);
- return MessageEnvelope.Create(chatMessage, agent.Name);
+ return [MessageEnvelope.Create(new ChatMessage("user", content), agent.Name)];
}
private async Task ProcessImageSourceAsync(ImageMessage imageMessage)
@@ -192,4 +234,52 @@ private async Task ProcessImageSourceAsync(ImageMessage imageMessag
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
+
+ private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
+ {
+ var chatMessage = new ChatMessage("assistant", new List());
+ foreach (var toolCall in toolCallMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolUseContent
+ {
+ Id = toolCall.ToolCallId,
+ Name = toolCall.FunctionName,
+ Input = JsonNode.Parse(toolCall.FunctionArguments)
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
+ {
+ var chatMessage = new ChatMessage("user", new List());
+ foreach (var toolCall in toolCallResultMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolResultContent
+ {
+ Id = toolCall.ToolCallId ?? string.Empty,
+ Content = toolCall.Result,
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallResultMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallAggregateMessage(AggregateMessage aggregateMessage, IAgent agent)
+ {
+ if (aggregateMessage.From is { } from && from != agent.Name)
+ {
+ var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result);
+ var messages = contents.Select(c =>
+ new ChatMessage("assistant", c ?? throw new ArgumentNullException(nameof(c))));
+
+ return messages.Select(m => new MessageEnvelope(m, from: from));
+ }
+
+ var toolCallMessage = ProcessToolCallMessage(aggregateMessage.Message1, agent);
+ var toolCallResult = ProcessToolCallResultMessage(aggregateMessage.Message2);
+
+ return toolCallMessage.Concat(toolCallResult);
+ }
}
diff --git a/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
index e70572cbddf2..6fd70cb4ee3e 100644
--- a/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
+++ b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// Constants.cs
+// AnthropicConstants.cs
namespace AutoGen.Anthropic.Utils;
diff --git a/dotnet/src/AutoGen.Core/Agent/IAgent.cs b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
index b9149008480d..34a31055d1bf 100644
--- a/dotnet/src/AutoGen.Core/Agent/IAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
@@ -7,10 +7,14 @@
using System.Threading.Tasks;
namespace AutoGen.Core;
-public interface IAgent
+
+public interface IAgentMetaInformation
{
public string Name { get; }
+}
+public interface IAgent : IAgentMetaInformation
+{
///
/// Generate reply
///
diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
index 665f18bac12a..6b7794c921ad 100644
--- a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
@@ -11,7 +11,7 @@ namespace AutoGen.Core;
///
public interface IStreamingAgent : IAgent
{
- public IAsyncEnumerable GenerateStreamingReplyAsync(
+ public IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
index 52967d6ff1ce..c7643b1e4735 100644
--- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
+++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
@@ -47,7 +47,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
@@ -83,7 +83,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
diff --git a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
index e3e44622c817..45728023b96b 100644
--- a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
+++ b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
@@ -100,8 +100,7 @@ internal static IEnumerable ProcessConversationsForRolePlay(
var msg = @$"From {x.From}:
{x.GetContent()}
-round #
- {i}";
+round # {i}";
return new TextMessage(Role.User, content: msg);
});
diff --git a/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
index 2c828c26d890..556c16436c63 100644
--- a/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
+++ b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
@@ -35,7 +35,7 @@ public class FunctionContract
///
/// The name of the function.
///
- public string? Name { get; set; }
+ public string Name { get; set; } = null!;
///
/// The description of the function.
diff --git a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
index 02f4da50bae0..acff955a292c 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
@@ -12,9 +13,12 @@ public class Graph
{
private readonly List transitions = new List();
- public Graph(IEnumerable transitions)
+ public Graph(IEnumerable? transitions = null)
{
- this.transitions.AddRange(transitions);
+ if (transitions != null)
+ {
+ this.transitions.AddRange(transitions);
+ }
}
public void AddTransition(Transition transition)
@@ -33,13 +37,13 @@ public void AddTransition(Transition transition)
/// the from agent
/// messages
/// A list of agents that the messages can be transit to
- public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages)
+ public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages, CancellationToken ct = default)
{
var nextAgents = new List();
var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty();
foreach (var transition in availableTransitions)
{
- if (await transition.CanTransitionAsync(messages))
+ if (await transition.CanTransitionAsync(messages, ct))
{
nextAgents.Add(transition.To);
}
@@ -56,7 +60,7 @@ public class Transition
{
private readonly IAgent _from;
private readonly IAgent _to;
- private readonly Func, Task>? _canTransition;
+ private readonly Func, CancellationToken, Task>? _canTransition;
///
/// Create a new instance of .
@@ -66,22 +70,44 @@ public class Transition
/// from agent
/// to agent
/// detect if the transition is allowed, default to be always true
- internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null)
+ internal Transition(IAgent from, IAgent to, Func, CancellationToken, Task>? canTransitionAsync = null)
{
_from = from;
_to = to;
_canTransition = canTransitionAsync;
}
+ ///
+ /// Create a new instance of without transition condition check.
+ ///
+ /// "
+ public static Transition Create(TFromAgent from, TToAgent to)
+ where TFromAgent : IAgent
+ where TToAgent : IAgent
+ {
+ return new Transition(from, to, (fromAgent, toAgent, messages, _) => Task.FromResult(true));
+ }
+
///
/// Create a new instance of .
///
/// "
- public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null)
+ public static Transition Create(TFromAgent from, TToAgent to, Func, Task> canTransitionAsync)
+ where TFromAgent : IAgent
+ where TToAgent : IAgent
+ {
+ return new Transition(from, to, (fromAgent, toAgent, messages, _) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages));
+ }
+
+ ///
+ /// Create a new instance of with cancellation token.
+ ///
+ /// "
+ public static Transition Create(TFromAgent from, TToAgent to, Func, CancellationToken, Task> canTransitionAsync)
where TFromAgent : IAgent
where TToAgent : IAgent
{
- return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true));
+ return new Transition(from, to, (fromAgent, toAgent, messages, ct) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages, ct));
}
public IAgent From => _from;
@@ -92,13 +118,13 @@ public static Transition Create(TFromAgent from, TToAgent
/// Check if the transition is allowed.
///
/// messages
- public Task CanTransitionAsync(IEnumerable messages)
+ public Task CanTransitionAsync(IEnumerable messages, CancellationToken ct = default)
{
if (_canTransition == null)
{
return Task.FromResult(true);
}
- return _canTransition(this.From, this.To, messages);
+ return _canTransition(this.From, this.To, messages, ct);
}
}
diff --git a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
index 5e82931ab658..57e15c18ca62 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
@@ -15,6 +15,7 @@ public class GroupChat : IGroupChat
private List agents = new List();
private IEnumerable initializeMessages = new List();
private Graph? workflow = null;
+ private readonly IOrchestrator orchestrator;
public IEnumerable? Messages { get; private set; }
@@ -36,6 +37,37 @@ public GroupChat(
this.initializeMessages = initializeMessages ?? new List();
this.workflow = workflow;
+ if (admin is not null)
+ {
+ this.orchestrator = new RolePlayOrchestrator(admin, workflow);
+ }
+ else if (workflow is not null)
+ {
+ this.orchestrator = new WorkflowOrchestrator(workflow);
+ }
+ else
+ {
+ this.orchestrator = new RoundRobinOrchestrator();
+ }
+
+ this.Validation();
+ }
+
+ ///
+ /// Create a group chat which uses the to decide the next speaker(s).
+ ///
+ ///
+ ///
+ ///
+ public GroupChat(
+ IEnumerable members,
+ IOrchestrator orchestrator,
+ IEnumerable? initializeMessages = null)
+ {
+ this.agents = members.ToList();
+ this.initializeMessages = initializeMessages ?? new List();
+ this.orchestrator = orchestrator;
+
this.Validation();
}
@@ -64,12 +96,6 @@ private void Validation()
throw new Exception("All agents in the workflow must be in the group chat.");
}
}
-
- // must provide one of admin or workflow
- if (this.admin == null && this.workflow == null)
- {
- throw new Exception("Must provide one of admin or workflow.");
- }
}
///
@@ -81,6 +107,7 @@ private void Validation()
/// current speaker
/// conversation history
/// next speaker.
+ [Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")]
public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory)
{
var agentNames = this.agents.Select(x => x.Name).ToList();
@@ -140,37 +167,40 @@ public void AddInitializeMessage(IMessage message)
}
public async Task> CallAsync(
- IEnumerable? conversationWithName = null,
+ IEnumerable? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List();
- if (conversationWithName != null)
+ conversationHistory.AddRange(this.initializeMessages);
+ if (chatHistory != null)
{
- conversationHistory.AddRange(conversationWithName);
+ conversationHistory.AddRange(chatHistory);
}
+ var roundLeft = maxRound;
- var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
+ while (roundLeft > 0)
{
- null => this.agents.First(),
- _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
- };
- var round = 0;
- while (round < maxRound)
- {
- var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
- var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
- var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
+ var orchestratorContext = new OrchestrationContext
+ {
+ Candidates = this.agents,
+ ChatHistory = conversationHistory,
+ };
+ var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct);
+ if (nextSpeaker == null)
+ {
+ break;
+ }
+
+ var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct);
conversationHistory.Add(result);
- // if message is terminate message, then terminate the conversation
- if (result?.IsGroupChatTerminateMessage() ?? false)
+ if (result.IsGroupChatTerminateMessage())
{
- break;
+ return conversationHistory;
}
- lastSpeaker = currentSpeaker;
- round++;
+ roundLeft--;
}
return conversationHistory;
diff --git a/dotnet/src/AutoGen.Core/IGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs
similarity index 100%
rename from dotnet/src/AutoGen.Core/IGroupChat.cs
rename to dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs
diff --git a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
index b8de89b834fe..b95cd1958fc5 100644
--- a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
+++ b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
@@ -3,9 +3,6 @@
using System;
using System.Collections.Generic;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
namespace AutoGen.Core;
@@ -25,76 +22,12 @@ public SequentialGroupChat(IEnumerable agents, List? initializ
///
/// A group chat that allows agents to talk in a round-robin manner.
///
-public class RoundRobinGroupChat : IGroupChat
+public class RoundRobinGroupChat : GroupChat
{
- private readonly List agents = new List();
- private readonly List initializeMessages = new List();
-
public RoundRobinGroupChat(
IEnumerable agents,
List? initializeMessages = null)
+ : base(agents, initializeMessages: initializeMessages)
{
- this.agents.AddRange(agents);
- this.initializeMessages = initializeMessages ?? new List();
- }
-
- ///
- public void AddInitializeMessage(IMessage message)
- {
- this.SendIntroduction(message);
- }
-
- public async Task> CallAsync(
- IEnumerable? conversationWithName = null,
- int maxRound = 10,
- CancellationToken ct = default)
- {
- var conversationHistory = new List();
- if (conversationWithName != null)
- {
- conversationHistory.AddRange(conversationWithName);
- }
-
- var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
- {
- null => this.agents.First(),
- _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
- };
- var round = 0;
- while (round < maxRound)
- {
- var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
- var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
- var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
- conversationHistory.Add(result);
-
- // if message is terminate message, then terminate the conversation
- if (result?.IsGroupChatTerminateMessage() ?? false)
- {
- break;
- }
-
- lastSpeaker = currentSpeaker;
- round++;
- }
-
- return conversationHistory;
- }
-
- public void SendIntroduction(IMessage message)
- {
- this.initializeMessages.Add(message);
- }
-
- private IAgent SelectNextSpeaker(IAgent currentSpeaker)
- {
- var index = this.agents.IndexOf(currentSpeaker);
- if (index == -1)
- {
- throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
- }
-
- var nextIndex = (index + 1) % this.agents.Count;
- return this.agents[nextIndex];
}
}
diff --git a/dotnet/src/AutoGen.Core/Message/IMessage.cs b/dotnet/src/AutoGen.Core/Message/IMessage.cs
index ad215d510e3b..9952cbf06792 100644
--- a/dotnet/src/AutoGen.Core/Message/IMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/IMessage.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs
+using System;
using System.Collections.Generic;
namespace AutoGen.Core;
@@ -35,19 +36,21 @@ namespace AutoGen.Core;
///
///
///
-public interface IMessage : IStreamingMessage
+public interface IMessage
{
+ string? From { get; set; }
}
-public interface IMessage : IMessage, IStreamingMessage
+public interface IMessage : IMessage
{
+ T Content { get; }
}
///
/// The interface for messages that can get text content.
/// This interface will be used by to get the content from the message.
///
-public interface ICanGetTextContent : IMessage, IStreamingMessage
+public interface ICanGetTextContent : IMessage
{
public string? GetContent();
}
@@ -55,17 +58,18 @@ public interface ICanGetTextContent : IMessage, IStreamingMessage
///
/// The interface for messages that can get a list of
///
-public interface ICanGetToolCalls : IMessage, IStreamingMessage
+public interface ICanGetToolCalls : IMessage
{
public IEnumerable GetToolCalls();
}
-
+[Obsolete("Use IMessage instead")]
public interface IStreamingMessage
{
string? From { get; set; }
}
+[Obsolete("Use IMessage instead")]
public interface IStreamingMessage : IStreamingMessage
{
T Content { get; }
diff --git a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
index f83bea279260..dc9709bbde5b 100644
--- a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
+++ b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs
@@ -5,7 +5,7 @@
namespace AutoGen.Core;
-public abstract class MessageEnvelope : IMessage, IStreamingMessage
+public abstract class MessageEnvelope : IMessage
{
public MessageEnvelope(string? from = null, IDictionary? metadata = null)
{
@@ -23,7 +23,7 @@ public static MessageEnvelope Create(TContent content, strin
public IDictionary Metadata { get; set; }
}
-public class MessageEnvelope : MessageEnvelope, IMessage, IStreamingMessage
+public class MessageEnvelope : MessageEnvelope, IMessage
{
public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null)
: base(from, metadata)
diff --git a/dotnet/src/AutoGen.Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
index addd8728a926..9419c2b3ba86 100644
--- a/dotnet/src/AutoGen.Core/Message/TextMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs
@@ -3,7 +3,7 @@
namespace AutoGen.Core;
-public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
+public class TextMessage : IMessage, ICanGetTextContent
{
public TextMessage(Role role, string content, string? from = null)
{
@@ -51,7 +51,7 @@ public override string ToString()
}
}
-public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
+public class TextMessageUpdate : IMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
index 7781b785ef8c..7d46d56135aa 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// FunctionCallAggregateMessage.cs
+// ToolCallAggregateMessage.cs
using System.Collections.Generic;
diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
index 396dba3d3a17..8660b323044f 100644
--- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
+++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
@@ -36,7 +36,7 @@ public override string ToString()
}
}
-public class ToolCallMessage : IMessage, ICanGetToolCalls
+public class ToolCallMessage : IMessage, ICanGetToolCalls, ICanGetTextContent
{
public ToolCallMessage(IEnumerable toolCalls, string? from = null)
{
@@ -80,6 +80,12 @@ public void Update(ToolCallMessageUpdate update)
public string? From { get; set; }
+ ///
+ /// Some LLMs might also include text content in a tool call response, like GPT.
+ /// This field is used to store the text content in that case.
+ ///
+ public string? Content { get; set; }
+
public override string ToString()
{
var sb = new StringBuilder();
@@ -96,9 +102,14 @@ public IEnumerable GetToolCalls()
{
return this.ToolCalls;
}
+
+ public string? GetContent()
+ {
+ return this.Content;
+ }
}
-public class ToolCallMessageUpdate : IStreamingMessage
+public class ToolCallMessageUpdate : IMessage
{
public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null)
{
diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
index d0788077b590..7d30f6d0928a 100644
--- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
@@ -70,7 +70,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return reply;
}
- public async IAsyncEnumerable InvokeAsync(
+ public async IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -86,16 +86,16 @@ public async IAsyncEnumerable InvokeAsync(
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
- IStreamingMessage? initMessage = default;
+ IMessage? mergedFunctionCallMessage = default;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
- if (initMessage is null)
+ if (mergedFunctionCallMessage is null)
{
- initMessage = new ToolCallMessage(toolCallMessageUpdate);
+ mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
}
- else if (initMessage is ToolCallMessage toolCall)
+ else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
@@ -104,13 +104,17 @@ public async IAsyncEnumerable InvokeAsync(
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
+ else if (message is ToolCallMessage toolCallMessage1)
+ {
+ mergedFunctionCallMessage = toolCallMessage1;
+ }
else
{
yield return message;
}
}
- if (initMessage is ToolCallMessage toolCallMsg)
+ if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
diff --git a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
index bc7aec57f52b..d550bdb519ce 100644
--- a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs
@@ -14,7 +14,7 @@ public interface IStreamingMiddleware : IMiddleware
///
/// The streaming version of .
///
- public IAsyncEnumerable InvokeAsync(
+ public IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);
diff --git a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
index 099f78e5f176..a4e84de85a44 100644
--- a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
+++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs
@@ -48,7 +48,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
}
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IMessage? recentUpdate = null;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs
new file mode 100644
index 000000000000..777834871f65
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// IOrchestrator.cs
+
+using System;
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class OrchestrationContext
+{
+ public IEnumerable Candidates { get; set; } = Array.Empty();
+
+ public IEnumerable ChatHistory { get; set; } = Array.Empty();
+}
+
+public interface IOrchestrator
+{
+ ///
+ /// Return the next agent as the next speaker. return null if no agent is selected.
+ ///
+ /// orchestration context, such as candidate agents and chat history.
+ /// cancellation token
+ public Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default);
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs
new file mode 100644
index 000000000000..6798f23f2df8
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs
@@ -0,0 +1,116 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RolePlayOrchestrator.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class RolePlayOrchestrator : IOrchestrator
+{
+ private readonly IAgent admin;
+ private readonly Graph? workflow = null;
+ public RolePlayOrchestrator(IAgent admin, Graph? workflow = null)
+ {
+ this.admin = admin;
+ this.workflow = workflow;
+ }
+
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var candidates = context.Candidates.ToList();
+
+ if (candidates.Count == 0)
+ {
+ return null;
+ }
+
+ if (candidates.Count == 1)
+ {
+ return candidates.First();
+ }
+
+ // if there's a workflow
+ // and the next available agent from the workflow is in the group chat
+ // then return the next agent from the workflow
+ if (this.workflow != null)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+ if (lastMessage == null)
+ {
+ return null;
+ }
+ var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From);
+ var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
+ nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
+ candidates = nextAgents.ToList();
+ if (!candidates.Any())
+ {
+ return null;
+ }
+
+ if (candidates is { Count: 1 })
+ {
+ return candidates.First();
+ }
+ }
+
+ // In this case, since there are more than one available agents from the workflow for the next speaker
+ // the admin will be invoked to decide the next speaker
+ var agentNames = candidates.Select(candidate => candidate.Name);
+ var rolePlayMessage = new TextMessage(Role.User,
+ content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation.
+The available roles are:
+{string.Join(",", agentNames)}
+
+Each message will start with 'From name:', e.g:
+From {agentNames.First()}:
+//your message//.");
+
+ var chatHistoryWithName = this.ProcessConversationsForRolePlay(context.ChatHistory);
+ var messages = new IMessage[] { rolePlayMessage }.Concat(chatHistoryWithName);
+
+ var response = await this.admin.GenerateReplyAsync(
+ messages: messages,
+ options: new GenerateReplyOptions
+ {
+ Temperature = 0,
+ MaxToken = 128,
+ StopSequence = [":"],
+ Functions = null,
+ },
+ cancellationToken: cancellationToken);
+
+ var name = response.GetContent() ?? throw new Exception("No name is returned.");
+
+ // remove From
+ name = name!.Substring(5);
+ var candidate = candidates.FirstOrDefault(x => x.Name!.ToLower() == name.ToLower());
+
+ if (candidate != null)
+ {
+ return candidate;
+ }
+
+ var errorMessage = $"The response from admin is {name}, which is either not in the candidates list or not in the correct format.";
+ throw new Exception(errorMessage);
+ }
+
+ private IEnumerable ProcessConversationsForRolePlay(IEnumerable messages)
+ {
+ return messages.Select((x, i) =>
+ {
+ var msg = @$"From {x.From}:
+{x.GetContent()}
+
+round # {i}";
+
+ return new TextMessage(Role.User, content: msg);
+ });
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs
new file mode 100644
index 000000000000..0f8b8e483c63
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RoundRobinOrchestrator.cs
+
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+///
+/// Return the next agent in a round-robin fashion.
+///
+/// If the last message is from one of the candidates, the next agent will be the next candidate in the list.
+///
+///
+/// Otherwise, no agent will be selected. In this case, the orchestrator will return an empty list.
+///
+///
+/// This orchestrator always return a single agent.
+///
+///
+public class RoundRobinOrchestrator : IOrchestrator
+{
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+
+ if (lastMessage == null)
+ {
+ return null;
+ }
+
+ var candidates = context.Candidates.ToList();
+ var lastAgentIndex = candidates.FindIndex(a => a.Name == lastMessage.From);
+ if (lastAgentIndex == -1)
+ {
+ return null;
+ }
+
+ var nextAgentIndex = (lastAgentIndex + 1) % candidates.Count;
+ return candidates[nextAgentIndex];
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs
new file mode 100644
index 000000000000..b84850a07c75
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// WorkflowOrchestrator.cs
+
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class WorkflowOrchestrator : IOrchestrator
+{
+ private readonly Graph workflow;
+
+ public WorkflowOrchestrator(Graph workflow)
+ {
+ this.workflow = workflow;
+ }
+
+ public async Task GetNextSpeakerAsync(
+ OrchestrationContext context,
+ CancellationToken cancellationToken = default)
+ {
+ var lastMessage = context.ChatHistory.LastOrDefault();
+ if (lastMessage == null)
+ {
+ return null;
+ }
+
+ var candidates = context.Candidates.ToList();
+ var currentSpeaker = candidates.FirstOrDefault(candidates => candidates.Name == lastMessage.From);
+
+ if (currentSpeaker == null)
+ {
+ return null;
+ }
+ var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
+ nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
+ candidates = nextAgents.ToList();
+ if (!candidates.Any())
+ {
+ return null;
+ }
+
+ if (candidates is { Count: 1 })
+ {
+ return candidates.First();
+ }
+ else
+ {
+ throw new System.Exception("There are more than one available agents from the workflow for the next speaker.");
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
index 7490b64e1267..1ca19fcbcfff 100644
--- a/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
+++ b/dotnet/src/AutoGen.DotnetInteractive/InteractiveService.cs
@@ -19,7 +19,7 @@ public class InteractiveService : IDisposable
private bool disposedValue;
private const string DotnetInteractiveToolNotInstallMessage = "Cannot find a tool in the manifest file that has a command named 'dotnet-interactive'.";
//private readonly ProcessJobTracker jobTracker = new ProcessJobTracker();
- private string installingDirectory;
+ private string? installingDirectory;
public event EventHandler? DisplayEvent;
@@ -30,7 +30,11 @@ public class InteractiveService : IDisposable
public event EventHandler? HoverTextProduced;
///
- /// Create an instance of InteractiveService
+ /// Install dotnet interactive tool to
+ /// and create an instance of .
+ ///
+ /// When using this constructor, you need to call to install dotnet interactive tool
+ /// and start the kernel.
///
/// dotnet interactive installing directory
public InteractiveService(string installingDirectory)
@@ -38,8 +42,23 @@ public InteractiveService(string installingDirectory)
this.installingDirectory = installingDirectory;
}
+ ///
+ /// Create an instance of with a running kernel.
+ /// When using this constructor, you don't need to call to start the kernel.
+ ///
+ ///
+ public InteractiveService(Kernel kernel)
+ {
+ this.kernel = kernel;
+ }
+
public async Task StartAsync(string workingDirectory, CancellationToken ct = default)
{
+ if (this.kernel != null)
+ {
+ return true;
+ }
+
this.kernel = await this.CreateKernelAsync(workingDirectory, true, ct);
return true;
}
diff --git a/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
index b081faae8321..e759ba26d1e9 100644
--- a/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
+++ b/dotnet/src/AutoGen.Gemini/GeminiChatAgent.cs
@@ -143,7 +143,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return MessageEnvelope.Create(response, this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = this.client.GenerateContentStreamAsync(request);
diff --git a/dotnet/src/AutoGen.Gemini/IGeminiClient.cs b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
index 2e209e02b030..d391a4508398 100644
--- a/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
+++ b/dotnet/src/AutoGen.Gemini/IGeminiClient.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// IVertexGeminiClient.cs
+// IGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
diff --git a/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
index cb18ba084d78..422fb4cd3458 100644
--- a/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
+++ b/dotnet/src/AutoGen.Gemini/Middleware/GeminiMessageConnector.cs
@@ -39,7 +39,7 @@ public GeminiMessageConnector(bool strictMode = false)
public string Name => nameof(GeminiMessageConnector);
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
diff --git a/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
index c54f2280dfd3..12a11993cd69 100644
--- a/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
+++ b/dotnet/src/AutoGen.Gemini/VertexGeminiClient.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// IGeminiClient.cs
+// VertexGeminiClient.cs
using System.Collections.Generic;
using System.Threading;
diff --git a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
index cc2c74145504..db14d68a1217 100644
--- a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
+++ b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs
@@ -78,7 +78,7 @@ public async Task GenerateReplyAsync(
return new MessageEnvelope(response, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -97,6 +97,7 @@ private ChatCompletionRequest BuildChatRequest(IEnumerable messages, G
var chatHistory = BuildChatHistory(messages);
var chatRequest = new ChatCompletionRequest(model: _model, messages: chatHistory.ToList(), temperature: options?.Temperature, randomSeed: _randomSeed)
{
+ Stop = options?.StopSequence,
MaxTokens = options?.MaxToken,
ResponseFormat = _jsonOutput ? new ResponseFormat() { ResponseFormatType = "json_object" } : null,
};
diff --git a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
index 71a084673f13..affe2bb6dcc3 100644
--- a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
+++ b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs
@@ -105,6 +105,9 @@ public class ChatCompletionRequest
[JsonPropertyName("random_seed")]
public int? RandomSeed { get; set; }
+ [JsonPropertyName("stop")]
+ public string[]? Stop { get; set; }
+
[JsonPropertyName("tools")]
public List? Tools { get; set; }
diff --git a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
index 95592e97fcc5..78de12a5c01e 100644
--- a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
+++ b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs
@@ -15,14 +15,14 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List();
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
- if (reply is IStreamingMessage chatMessage)
+ if (reply is IMessage chatMessage)
{
chunks.Add(chatMessage.Content);
var response = ProcessChatCompletionResponse(chatMessage, agent);
@@ -167,7 +167,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
}
}
- private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage message, IAgent agent)
+ private IMessage? ProcessChatCompletionResponse(IMessage message, IAgent agent)
{
var response = message.Content;
if (response.VarObject != "chat.completion.chunk")
diff --git a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
index 9ef68388d605..87b176d8bcc5 100644
--- a/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
+++ b/dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
@@ -53,7 +53,7 @@ public async Task GenerateReplyAsync(
}
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
index 2e0d891cc61e..75f622ff7f04 100644
--- a/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
+++ b/dotnet/src/AutoGen.Ollama/DTOs/Message.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// ChatResponseUpdate.cs
+// Message.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
index 5ce0dc8cc40a..cce6dbb83076 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/ITextEmbeddingService.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// ITextEmbeddingService.cs
using System.Threading;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
index 2e431e7bcb81..ea4993eb813f 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/OllamaTextEmbeddingService.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingService.cs
using System;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
index 7f2531c522ad..d776b183db0b 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsRequest.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsRequest.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
index 580059c033b5..f3ce64b7032f 100644
--- a/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
+++ b/dotnet/src/AutoGen.Ollama/Embeddings/TextEmbeddingsResponse.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsResponse.cs
using System.Text.Json.Serialization;
diff --git a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
index a21ec3a1c991..3919b238d659 100644
--- a/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
+++ b/dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
@@ -30,14 +30,14 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
};
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ProcessMessage(context.Messages, agent);
var chunks = new List();
await foreach (var update in agent.GenerateStreamingReplyAsync(messages, context.Options, cancellationToken))
{
- if (update is IStreamingMessage chatResponseUpdate)
+ if (update is IMessage chatResponseUpdate)
{
var response = chatResponseUpdate.Content switch
{
diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
index cdc6cc464d17..5de481245b72 100644
--- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
+++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs
@@ -104,7 +104,7 @@ public async Task GenerateReplyAsync(
return await _innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
- public IAsyncEnumerable GenerateStreamingReplyAsync(
+ public IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
index 37a4882f69e1..b192cde1024b 100644
--- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
+++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs
@@ -87,7 +87,7 @@ public async Task GenerateReplyAsync(
return new MessageEnvelope(reply, from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
index 246e50cc6c59..e1dd0757fcf3 100644
--- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
+++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs
@@ -47,7 +47,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}
- public async IAsyncEnumerable InvokeAsync(
+ public async IAsyncEnumerable InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -57,7 +57,7 @@ public async IAsyncEnumerable InvokeAsync(
string? currentToolName = null;
await foreach (var reply in streamingReply)
{
- if (reply is IStreamingMessage update)
+ if (reply is IMessage update)
{
if (update.Content.FunctionName is string functionName)
{
@@ -98,7 +98,7 @@ public IMessage PostProcessMessage(IMessage message)
};
}
- public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage update, string? currentToolName)
+ public IMessage? PostProcessStreamingMessage(IMessage update, string? currentToolName)
{
if (update.Content.ContentUpdate is string contentUpdate)
{
@@ -136,14 +136,13 @@ private IMessage PostProcessChatCompletions(IMessage message)
private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
{
- if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content))
- {
- return new TextMessage(Role.Assistant, content, from);
- }
-
+ var textContent = chatResponseMessage.Content;
if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
{
- return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from);
+ return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from)
+ {
+ Content = textContent,
+ };
}
if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
@@ -154,7 +153,15 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
- return new ToolCallMessage(toolCalls, from);
+ return new ToolCallMessage(toolCalls, from)
+ {
+ Content = textContent,
+ };
+ }
+
+ if (textContent is string content && !string.IsNullOrEmpty(content))
+ {
+ return new TextMessage(Role.Assistant, content, from);
}
throw new InvalidOperationException("Invalid ChatResponseMessage");
@@ -327,7 +334,8 @@ private IEnumerable ProcessToolCallMessage(IAgent agent, Too
}
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
- var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
+ var textContent = message.GetContent() ?? string.Empty;
+ var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
index 6ce242eb1abe..a055c0afcb6a 100644
--- a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs
@@ -47,7 +47,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent,
return PostProcessMessage(reply);
}
- public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope(m));
@@ -67,11 +67,11 @@ private IMessage PostProcessMessage(IMessage input)
};
}
- private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input)
+ private IMessage PostProcessStreamingMessage(IMessage input)
{
return input switch
{
- IStreamingMessage streamingMessage => PostProcessMessage(streamingMessage),
+ IMessage streamingMessage => PostProcessMessage(streamingMessage),
IMessage msg => PostProcessMessage(msg),
_ => input,
};
@@ -98,7 +98,7 @@ private IMessage PostProcessMessage(IMessage messageEnvelope
}
}
- private IStreamingMessage PostProcessMessage(IStreamingMessage streamingMessage)
+ private IMessage PostProcessMessage(IMessage streamingMessage)
{
var chatMessageContent = streamingMessage.Content;
if (chatMessageContent.ChoiceIndex > 0)
diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
index 21f652f56c4f..d12c54c1b3b2 100644
--- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
+++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
@@ -65,7 +65,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G
return new MessageEnvelope(reply.First(), from: this.Name);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
IEnumerable messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
diff --git a/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
index 24e42affa3bd..aa4980379f4f 100644
--- a/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/SourceGeneratorFunctionContract.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// FunctionContract.cs
+// SourceGeneratorFunctionContract.cs
namespace AutoGen.SourceGenerator
{
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
index 40adbdcde47c..8eeb117141d8 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs
@@ -107,7 +107,7 @@ public virtual string TransformText()
}
if (functionContract.Description != null) {
this.Write(" Description = @\"");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Description));
+ this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Description.Replace("\"", "\"\"")));
this.Write("\",\r\n");
}
if (functionContract.ReturnType != null) {
@@ -132,7 +132,7 @@ public virtual string TransformText()
}
if (parameter.Description != null) {
this.Write(" Description = @\"");
- this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Description));
+ this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Description.Replace("\"", "\"\"")));
this.Write("\",\r\n");
}
if (parameter.Type != null) {
@@ -152,12 +152,7 @@ public virtual string TransformText()
}
this.Write(" },\r\n");
}
- this.Write(" };\r\n }\r\n\r\n public global::Azure.AI.OpenAI.FunctionDefin" +
- "ition ");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionDefinitionName()));
- this.Write("\r\n {\r\n get => this.");
- this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionContractName()));
- this.Write(".ToOpenAIFunctionDefinition();\r\n }\r\n");
+ this.Write(" };\r\n }\r\n");
}
this.Write(" }\r\n");
if (!String.IsNullOrEmpty(NameSpace)) {
diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
index 0d1b221c35c8..dc41f0af9d70 100644
--- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
+++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt
@@ -63,7 +63,7 @@ namespace <#=NameSpace#>
Name = @"<#=functionContract.Name#>",
<#}#>
<#if (functionContract.Description != null) {#>
- Description = @"<#=functionContract.Description#>",
+ Description = @"<#=functionContract.Description.Replace("\"", "\"\"")#>",
<#}#>
<#if (functionContract.ReturnType != null) {#>
ReturnType = typeof(<#=functionContract.ReturnType#>),
@@ -81,7 +81,7 @@ namespace <#=NameSpace#>
Name = @"<#=parameter.Name#>",
<#}#>
<#if (parameter.Description != null) {#>
- Description = @"<#=parameter.Description#>",
+ Description = @"<#= parameter.Description.Replace("\"", "\"\"") #>",
<#}#>
<#if (parameter.Type != null) {#>
ParameterType = typeof(<#=parameter.Type#>),
@@ -96,11 +96,6 @@ namespace <#=NameSpace#>
<#}#>
};
}
-
- public global::Azure.AI.OpenAI.FunctionDefinition <#=functionContract.GetFunctionDefinitionName()#>
- {
- get => this.<#=functionContract.GetFunctionContractName()#>.ToOpenAIFunctionDefinition();
- }
<#}#>
}
<#if (!String.IsNullOrEmpty(NameSpace)) {#>
diff --git a/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj b/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj
new file mode 100644
index 000000000000..c5b720764761
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/AutoGen.WebAPI.csproj
@@ -0,0 +1,27 @@
+
+
+
+ net6.0;net8.0
+ true
+ $(NoWarn);CS1591;CS1573
+
+
+
+
+
+
+
+ AutoGen.WebAPI
+
+ Turn an `AutoGen.Core.IAgent` into a RESTful API.
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.WebAPI/Extension.cs b/dotnet/src/AutoGen.WebAPI/Extension.cs
new file mode 100644
index 000000000000..c8534e43e540
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/Extension.cs
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Extension.cs
+
+using AutoGen.Core;
+using Microsoft.AspNetCore.Builder;
+
+namespace AutoGen.WebAPI;
+
+public static class Extension
+{
+ ///
+ /// Serve the agent as an OpenAI chat completion endpoint using .
+ /// If the request path is /v1/chat/completions and model name is the same as the agent name,
+ /// the request will be handled by the agent.
+ /// otherwise, the request will be passed to the next middleware.
+ ///
+ /// application builder
+ ///
+ public static IApplicationBuilder UseAgentAsOpenAIChatCompletionEndpoint(this IApplicationBuilder app, IAgent agent)
+ {
+ var middleware = new OpenAIChatCompletionMiddleware(agent);
+ return app.Use(middleware.InvokeAsync);
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs
new file mode 100644
index 000000000000..888a0f8dd8c8
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/Converter/OpenAIMessageConverter.cs
@@ -0,0 +1,56 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIMessageConverter.cs
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIMessageConverter : JsonConverter
+{
+ public override OpenAIMessage Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ using JsonDocument document = JsonDocument.ParseValue(ref reader);
+ var root = document.RootElement;
+ var role = root.GetProperty("role").GetString();
+ var contentDocument = root.GetProperty("content");
+ var isContentDocumentString = contentDocument.ValueKind == JsonValueKind.String;
+ switch (role)
+ {
+ case "system":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "user" when isContentDocumentString:
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "user" when !isContentDocumentString:
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "assistant":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ case "tool":
+ return JsonSerializer.Deserialize(root.GetRawText()) ?? throw new JsonException();
+ default:
+ throw new JsonException();
+ }
+ }
+
+ public override void Write(Utf8JsonWriter writer, OpenAIMessage value, JsonSerializerOptions options)
+ {
+ switch (value)
+ {
+ case OpenAISystemMessage systemMessage:
+ JsonSerializer.Serialize(writer, systemMessage, options);
+ break;
+ case OpenAIUserMessage userMessage:
+ JsonSerializer.Serialize(writer, userMessage, options);
+ break;
+ case OpenAIAssistantMessage assistantMessage:
+ JsonSerializer.Serialize(writer, assistantMessage, options);
+ break;
+ case OpenAIToolMessage toolMessage:
+ JsonSerializer.Serialize(writer, toolMessage, options);
+ break;
+ default:
+ throw new JsonException();
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs
new file mode 100644
index 000000000000..bfd090358453
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIAssistantMessage.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIAssistantMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIAssistantMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "assistant";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("tool_calls")]
+ public OpenAIToolCallObject[]? ToolCalls { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs
new file mode 100644
index 000000000000..041f4cfc848c
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletion.cs
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletion.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletion
+{
+ [JsonPropertyName("id")]
+ public string? ID { get; set; }
+
+ [JsonPropertyName("created")]
+ public long Created { get; set; }
+
+ [JsonPropertyName("choices")]
+ public OpenAIChatCompletionChoice[]? Choices { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("system_fingerprint")]
+ public string? SystemFingerprint { get; set; }
+
+ [JsonPropertyName("object")]
+ public string Object { get; set; } = "chat.completion";
+
+ [JsonPropertyName("usage")]
+ public OpenAIChatCompletionUsage? Usage { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs
new file mode 100644
index 000000000000..35b6fce59a8e
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionChoice.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionChoice.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionChoice
+{
+ [JsonPropertyName("finish_reason")]
+ public string? FinishReason { get; set; }
+
+ [JsonPropertyName("index")]
+ public int Index { get; set; }
+
+ [JsonPropertyName("message")]
+ public OpenAIChatCompletionMessage? Message { get; set; }
+
+ [JsonPropertyName("delta")]
+ public OpenAIChatCompletionMessage? Delta { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs
new file mode 100644
index 000000000000..de6be0dbf7a5
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionMessage.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionMessage
+{
+ [JsonPropertyName("role")]
+ public string Role { get; } = "assistant";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs
new file mode 100644
index 000000000000..0b9137d43a39
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionOption.cs
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionOption.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionOption
+{
+ [JsonPropertyName("messages")]
+ public OpenAIMessage[]? Messages { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("max_tokens")]
+ public int? MaxTokens { get; set; }
+
+ [JsonPropertyName("temperature")]
+ public float Temperature { get; set; } = 1;
+
+ ///
+ /// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message
+ ///
+ [JsonPropertyName("stream")]
+ public bool? Stream { get; set; } = false;
+
+ [JsonPropertyName("stream_options")]
+ public OpenAIStreamOptions? StreamOptions { get; set; }
+
+ [JsonPropertyName("stop")]
+ public string[]? Stop { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs
new file mode 100644
index 000000000000..f196ccb842ea
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIChatCompletionUsage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionUsage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIChatCompletionUsage
+{
+ [JsonPropertyName("completion_tokens")]
+ public int CompletionTokens { get; set; }
+
+ [JsonPropertyName("prompt_tokens")]
+ public int PromptTokens { get; set; }
+
+ [JsonPropertyName("total_tokens")]
+ public int TotalTokens { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs
new file mode 100644
index 000000000000..a50012c9fed1
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIImageUrlObject.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIImageUrlObject.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIImageUrlObject
+{
+ [JsonPropertyName("url")]
+ public string? Url { get; set; }
+
+ [JsonPropertyName("detail")]
+ public string? Detail { get; set; } = "auto";
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs
new file mode 100644
index 000000000000..deb729b72003
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIMessage.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+[JsonConverter(typeof(OpenAIMessageConverter))]
+internal abstract class OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public abstract string? Role { get; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs
new file mode 100644
index 000000000000..e95991388b7f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIStreamOptions.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIStreamOptions.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIStreamOptions
+{
+ [JsonPropertyName("include_usage")]
+ public bool? IncludeUsage { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs
new file mode 100644
index 000000000000..f29b10826c4f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAISystemMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAISystemMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAISystemMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "system";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs
new file mode 100644
index 000000000000..f3fc37f9c44f
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolCallObject.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIToolCallObject.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIToolCallObject
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("arguments")]
+ public string? Arguments { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs
new file mode 100644
index 000000000000..0c84c164cd96
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIToolMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIToolMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIToolMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "tool";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("tool_call_id")]
+ public string? ToolCallId { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs
new file mode 100644
index 000000000000..28b83ffb3058
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserImageContent.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserImageContent.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserImageContent : OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public override string MessageType { get; } = "image";
+
+ [JsonPropertyName("image_url")]
+ public string? Url { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs
new file mode 100644
index 000000000000..b5f1e7c50c12
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "user";
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs
new file mode 100644
index 000000000000..94e7d91534a5
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMessageItem.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMessageItem.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal abstract class OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public abstract string MessageType { get; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs
new file mode 100644
index 000000000000..789df5afaaae
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserMultiModalMessage.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserMultiModalMessage.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserMultiModalMessage : OpenAIMessage
+{
+ [JsonPropertyName("role")]
+ public override string? Role { get; } = "user";
+
+ [JsonPropertyName("content")]
+ public OpenAIUserMessageItem[]? Content { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs
new file mode 100644
index 000000000000..d22d5aa4c7f3
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/DTO/OpenAIUserTextContent.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIUserTextContent.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.WebAPI.OpenAI.DTO;
+
+internal class OpenAIUserTextContent : OpenAIUserMessageItem
+{
+ [JsonPropertyName("type")]
+ public override string MessageType { get; } = "text";
+
+ [JsonPropertyName("text")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs b/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs
new file mode 100644
index 000000000000..27481da006a2
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAI/Service/OpenAIChatCompletionService.cs
@@ -0,0 +1,157 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionService.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using AutoGen.WebAPI.OpenAI.DTO;
+
+namespace AutoGen.Server;
+
+internal class OpenAIChatCompletionService
+{
+ private readonly IAgent agent;
+
+ public OpenAIChatCompletionService(IAgent agent)
+ {
+ this.agent = agent;
+ }
+
+ public async Task GetChatCompletionAsync(OpenAIChatCompletionOption request)
+ {
+ var messages = this.ProcessMessages(request.Messages ?? Array.Empty());
+
+ var generateOption = this.ProcessReplyOptions(request);
+
+ var reply = await this.agent.GenerateReplyAsync(messages, generateOption);
+
+ var openAIChatCompletion = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ if (reply.GetContent() is string content)
+ {
+ var message = new OpenAIChatCompletionMessage()
+ {
+ Content = content,
+ };
+
+ var choice = new OpenAIChatCompletionChoice()
+ {
+ Message = message,
+ Index = 0,
+ FinishReason = "completed",
+ };
+
+ openAIChatCompletion.Choices = [choice];
+
+ return openAIChatCompletion;
+ }
+
+ throw new NotImplementedException("Unsupported reply content type");
+ }
+
+ public async IAsyncEnumerable GetStreamingChatCompletionAsync(OpenAIChatCompletionOption request)
+ {
+ if (this.agent is IStreamingAgent streamingAgent)
+ {
+ var messages = this.ProcessMessages(request.Messages ?? Array.Empty());
+
+ var generateOption = this.ProcessReplyOptions(request);
+
+ await foreach (var reply in streamingAgent.GenerateStreamingReplyAsync(messages, generateOption))
+ {
+ var openAIChatCompletion = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ if (reply.GetContent() is string content)
+ {
+ var message = new OpenAIChatCompletionMessage()
+ {
+ Content = content,
+ };
+
+ var choice = new OpenAIChatCompletionChoice()
+ {
+ Delta = message,
+ Index = 0,
+ };
+
+ openAIChatCompletion.Choices = [choice];
+
+ yield return openAIChatCompletion;
+ }
+ else
+ {
+ throw new NotImplementedException("Unsupported reply content type");
+ }
+ }
+
+ var doneMessage = new OpenAIChatCompletion()
+ {
+ Created = DateTimeOffset.UtcNow.Ticks / TimeSpan.TicksPerMillisecond / 1000,
+ Model = this.agent.Name,
+ };
+
+ var doneChoice = new OpenAIChatCompletionChoice()
+ {
+ FinishReason = "stop",
+ Index = 0,
+ };
+
+ doneMessage.Choices = [doneChoice];
+
+ yield return doneMessage;
+ }
+ else
+ {
+ yield return await this.GetChatCompletionAsync(request);
+ }
+ }
+
+ private IEnumerable ProcessMessages(IEnumerable messages)
+ {
+ return messages.Select(m => m switch
+ {
+ OpenAISystemMessage systemMessage when systemMessage.Content is string content => new TextMessage(Role.System, content, this.agent.Name),
+ OpenAIUserMessage userMessage when userMessage.Content is string content => new TextMessage(Role.User, content, this.agent.Name),
+ OpenAIAssistantMessage assistantMessage when assistantMessage.Content is string content => new TextMessage(Role.Assistant, content, this.agent.Name),
+ OpenAIUserMultiModalMessage userMultiModalMessage when userMultiModalMessage.Content is { Length: > 0 } => this.CreateMultiModaMessageFromOpenAIUserMultiModalMessage(userMultiModalMessage),
+ _ => throw new ArgumentException($"Unsupported message type {m.GetType()}")
+ });
+ }
+
+ private GenerateReplyOptions ProcessReplyOptions(OpenAIChatCompletionOption request)
+ {
+ return new GenerateReplyOptions()
+ {
+ Temperature = request.Temperature,
+ MaxToken = request.MaxTokens,
+ StopSequence = request.Stop,
+ };
+ }
+
+ private MultiModalMessage CreateMultiModaMessageFromOpenAIUserMultiModalMessage(OpenAIUserMultiModalMessage message)
+ {
+ if (message.Content is null)
+ {
+ throw new ArgumentNullException(nameof(message.Content));
+ }
+
+ IEnumerable items = message.Content.Select(item => item switch
+ {
+ OpenAIUserImageContent imageContent when imageContent.Url is string url => new ImageMessage(Role.User, url, this.agent.Name),
+ OpenAIUserTextContent textContent when textContent.Content is string content => new TextMessage(Role.User, content, this.agent.Name),
+ _ => throw new ArgumentException($"Unsupported content type {item.GetType()}")
+ });
+
+ return new MultiModalMessage(Role.User, items, this.agent.Name);
+ }
+}
diff --git a/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs b/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs
new file mode 100644
index 000000000000..53b3699fd62e
--- /dev/null
+++ b/dotnet/src/AutoGen.WebAPI/OpenAIChatCompletionMiddleware.cs
@@ -0,0 +1,92 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAIChatCompletionMiddleware.cs
+
+using System.Text.Json;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using AutoGen.Server;
+using AutoGen.WebAPI.OpenAI.DTO;
+using Microsoft.AspNetCore.Http;
+
+namespace AutoGen.WebAPI;
+
+public class OpenAIChatCompletionMiddleware : Microsoft.AspNetCore.Http.IMiddleware
+{
+ private readonly IAgent _agent;
+ private readonly OpenAIChatCompletionService chatCompletionService;
+
+ public OpenAIChatCompletionMiddleware(IAgent agent)
+ {
+ _agent = agent;
+ chatCompletionService = new OpenAIChatCompletionService(_agent);
+ }
+
+ public async Task InvokeAsync(HttpContext context, RequestDelegate next)
+ {
+ // if HttpPost and path is /v1/chat/completions
+ // get the request body
+ // call chatCompletionService.GetChatCompletionAsync(request)
+ // return the response
+
+ // else
+ // call next middleware
+ if (context.Request.Method == HttpMethods.Post && context.Request.Path == "/v1/chat/completions")
+ {
+ context.Request.EnableBuffering();
+ var body = await context.Request.ReadFromJsonAsync();
+ context.Request.Body.Position = 0;
+ if (body is null)
+ {
+ // return 400 Bad Request
+ context.Response.StatusCode = 400;
+ return;
+ }
+
+ if (body.Model != _agent.Name)
+ {
+ await next(context);
+ return;
+ }
+
+ if (body.Stream is true)
+ {
+ // Send as server side events
+ context.Response.Headers.Append("Content-Type", "text/event-stream");
+ context.Response.Headers.Append("Cache-Control", "no-cache");
+ context.Response.Headers.Append("Connection", "keep-alive");
+ await foreach (var chatCompletion in chatCompletionService.GetStreamingChatCompletionAsync(body))
+ {
+ if (chatCompletion?.Choices?[0].FinishReason is "stop")
+ {
+ // the stream is done
+ // send Data: [DONE]\n\n
+ await context.Response.WriteAsync("data: [DONE]\n\n");
+ break;
+ }
+ else
+ {
+ // remove null
+ var option = new JsonSerializerOptions
+ {
+ DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull,
+ };
+ var data = JsonSerializer.Serialize(chatCompletion, option);
+ await context.Response.WriteAsync($"data: {data}\n\n");
+ }
+ }
+
+ return;
+ }
+ else
+ {
+ var chatCompletion = await chatCompletionService.GetChatCompletionAsync(body);
+ await context.Response.WriteAsJsonAsync(chatCompletion);
+ return;
+ }
+ }
+ else
+ {
+ await next(context);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
index d29025b44aff..085917d419e9 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
@@ -32,6 +32,30 @@ public async Task AnthropicAgentChatCompletionTestAsync()
reply.From.Should().Be(agent.Name);
}
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentMergeMessageWithSameRoleTests()
+ {
+ // this test is added to fix issue #2884
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant that convert user message to upper case")
+ .RegisterMessageConnector();
+
+ var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
+ var anotherUserMessage = new TextMessage(Role.User, "hijklmn");
+ var assistantMessage = new TextMessage(Role.Assistant, "opqrst");
+ var anotherAssistantMessage = new TextMessage(Role.Assistant, "uvwxyz");
+ var yetAnotherUserMessage = new TextMessage(Role.User, "123456");
+
+ // just make sure it doesn't throw exception
+ var reply = await agent.SendAsync(chatHistory: [uppCaseMessage, anotherUserMessage, assistantMessage, anotherAssistantMessage, yetAnotherUserMessage]);
+ reply.GetContent().Should().NotBeNull();
+ }
+
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{
@@ -105,4 +129,101 @@ public async Task AnthropicAgentTestImageMessageAsync()
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentTestToolAsync()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: new[] { function.WeatherReportFunctionContract },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name ?? string.Empty, function.WeatherReportWrapper },
+ });
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are an LLM that is specialized in finding the weather !",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var reply = await agent.SendAsync("What is the weather in Philadelphia?");
+ reply.GetContent().Should().Be("Weather report for Philadelphia on today is sunny");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector();
+
+ var weatherFunctionArgumets = """
+ {
+ "city": "Philadelphia",
+ "date": "6/14/2024"
+ }
+ """;
+
+ var function = new AnthropicTestFunctionCalls();
+ var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
+ var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
+ {
+ ToolCallId = "get_weather",
+ Result = functionCallResult,
+ };
+
+ IMessage[] chatHistory = [
+ new TextMessage(Role.User, "what's the weather in Philadelphia?"),
+ new ToolCallMessage([toolCall], from: "assistant"),
+ new ToolCallResultMessage([toolCall], from: "user"),
+ ];
+
+ var reply = await agent.SendAsync(chatHistory: chatHistory);
+
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("The weather report for Philadelphia on 6/14/2024 is sunny.");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMiddlewareMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var function = new AnthropicTestFunctionCalls();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [function.WeatherReportFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name!, function.GetWeatherReportWrapper }
+ });
+
+ var functionCallAgent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = new TextMessage(Role.User, "what's the weather in Philadelphia?");
+ var reply = await functionCallAgent.SendAsync(question);
+
+ var finalReply = await functionCallAgent.SendAsync(chatHistory: [question, reply]);
+ finalReply.Should().BeOfType();
+ finalReply.GetContent()!.ToLower().Should().Contain("sunny");
+ }
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
index a0b1f60cfb95..102e48b9b8ac 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
@@ -1,5 +1,9 @@
-using System.Text;
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClientTest.cs
+
+using System.Text;
using System.Text.Json;
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Utils;
@@ -58,7 +62,9 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync()
foreach (ChatCompletionResponse result in results)
{
if (result.Delta is not null && !string.IsNullOrEmpty(result.Delta.Text))
+ {
sb.Append(result.Delta.Text);
+ }
}
string resultContent = sb.ToString();
@@ -108,6 +114,57 @@ public async Task AnthropicClientImageChatCompletionTestAsync()
response.Usage.OutputTokens.Should().BeGreaterThan(0);
}
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientTestToolsAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude3Haiku;
+ request.Stream = false;
+ request.MaxTokens = 100;
+ request.Messages = new List() { new("user", "Use the stock price tool to look for MSFT. Your response should only be the tool.") };
+ request.Tools = new List() { AnthropicTestUtils.StockTool };
+
+ ChatCompletionResponse response =
+ await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+
+ Assert.NotNull(response.Content);
+ Assert.True(response.Content.First() is ToolUseContent);
+ ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First());
+ Assert.Equal("get_stock_price", toolUseContent.Name);
+ Assert.NotNull(toolUseContent.Input);
+ Assert.True(toolUseContent.Input is JsonNode);
+ JsonNode jsonNode = toolUseContent.Input;
+ Assert.Equal("{\"ticker\":\"MSFT\"}", jsonNode.ToJsonString());
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientTestToolChoiceAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude3Haiku;
+ request.Stream = false;
+ request.MaxTokens = 100;
+ request.Messages = new List() { new("user", "What is the weather today? Your response should only be the tool.") };
+ request.Tools = new List() { AnthropicTestUtils.StockTool, AnthropicTestUtils.WeatherTool };
+
+ // Force to use get_stock_price even though the prompt is about weather
+ request.ToolChoice = ToolChoice.ToolUse("get_stock_price");
+
+ ChatCompletionResponse response =
+ await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+
+ Assert.NotNull(response.Content);
+ Assert.True(response.Content.First() is ToolUseContent);
+ ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First());
+ Assert.Equal("get_stock_price", toolUseContent.Name);
+ Assert.NotNull(toolUseContent.Input);
+ Assert.True(toolUseContent.Input is JsonNode);
+ }
+
private sealed class Person
{
[JsonPropertyName("name")]
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs
new file mode 100644
index 000000000000..8b5466e3a519
--- /dev/null
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestFunctionCalls.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicTestFunctionCalls.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Tests;
+
+public partial class AnthropicTestFunctionCalls
+{
+ private class GetWeatherSchema
+ {
+ [JsonPropertyName("city")]
+ public string? City { get; set; }
+
+ [JsonPropertyName("date")]
+ public string? Date { get; set; }
+ }
+
+ ///
+ /// Get weather report
+ ///
+ /// city
+ /// date
+ [Function]
+ public async Task WeatherReport(string city, string date)
+ {
+ return $"Weather report for {city} on {date} is sunny";
+ }
+
+ public Task GetWeatherReportWrapper(string arguments)
+ {
+ var schema = JsonSerializer.Deserialize(
+ arguments,
+ new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
+
+ return WeatherReport(schema?.City ?? string.Empty, schema?.Date ?? string.Empty);
+ }
+}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
index de630da6d87c..a1faffec5344 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicTestUtils.cs
+using AutoGen.Anthropic.DTO;
+
namespace AutoGen.Anthropic.Tests;
public static class AnthropicTestUtils
@@ -13,4 +15,52 @@ public static async Task Base64FromImageAsync(string imageName)
return Convert.ToBase64String(
await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName)));
}
+
+ public static Tool WeatherTool
+ {
+ get
+ {
+ return new Tool
+ {
+ Name = "WeatherReport",
+ Description = "Get the current weather",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "city", new SchemaProperty {Type = "string", Description = "The name of the city"} },
+ { "date", new SchemaProperty {Type = "string", Description = "date of the day"} }
+ }
+ }
+ };
+ }
+ }
+
+ public static Tool StockTool
+ {
+ get
+ {
+ return new Tool
+ {
+ Name = "get_stock_price",
+ Description = "Get the current stock price for a given ticker symbol.",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ {
+ "ticker", new SchemaProperty
+ {
+ Type = "string",
+ Description = "The stock ticker symbol, e.g. AAPL for Apple Inc."
+ }
+ }
+ },
+ Required = new List { "ticker" }
+ }
+ };
+ }
+ }
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
index 0f22d9fe6764..ac479ed2e722 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj
@@ -12,6 +12,7 @@
+
diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
index 872cce5e645b..c076aee18376 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiAgentTests.cs
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GeminiAgentTests.cs
-using AutoGen.Tests;
-using Google.Cloud.AIPlatform.V1;
using AutoGen.Core;
-using FluentAssertions;
using AutoGen.Gemini.Extension;
-using static Google.Cloud.AIPlatform.V1.Part;
+using AutoGen.Tests;
+using FluentAssertions;
+using Google.Cloud.AIPlatform.V1;
using Xunit.Abstractions;
+using static Google.Cloud.AIPlatform.V1.Part;
namespace AutoGen.Gemini.Tests;
public class GeminiAgentTests
@@ -86,8 +86,8 @@ public async Task VertexGeminiAgentGenerateStreamingReplyForTextContentAsync()
var message = MessageEnvelope.Create(textContent, from: agent.Name);
var completion = agent.GenerateStreamingReplyAsync([message]);
- var chunks = new List();
- IStreamingMessage finalReply = null!;
+ var chunks = new List();
+ IMessage finalReply = null!;
await foreach (var item in completion)
{
@@ -212,8 +212,8 @@ public async Task VertexGeminiAgentGenerateStreamingReplyWithToolsAsync()
var message = MessageEnvelope.Create(textContent, from: agent.Name);
- var chunks = new List();
- IStreamingMessage finalReply = null!;
+ var chunks = new List();
+ IMessage finalReply = null!;
var completion = agent.GenerateStreamingReplyAsync([message]);
diff --git a/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
index 7ffb532ea9c1..12ba94734032 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/GeminiMessageTests.cs
@@ -225,10 +225,10 @@ public async Task ItProcessStreamingTextMessageAsync()
})
.Select(m => MessageEnvelope.Create(m));
- IStreamingMessage? finalReply = null;
+ IMessage? finalReply = null;
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
{
- reply.Should().BeAssignableTo();
+ reply.Should().BeAssignableTo();
finalReply = reply;
}
diff --git a/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
index 2f06305ed59f..fba97aa522d5 100644
--- a/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
+++ b/dotnet/test/AutoGen.Gemini.Tests/VertexGeminiClientTests.cs
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
-// GeminiVertexClientTests.cs
+// VertexGeminiClientTests.cs
using AutoGen.Tests;
using FluentAssertions;
@@ -53,7 +53,7 @@ public async Task ItGenerateContentWithImageAsync()
var model = "gemini-1.5-flash-001";
var text = "what's in the image";
- var imagePath = Path.Combine("testData", "images", "image.png");
+ var imagePath = Path.Combine("testData", "images", "square.png");
var image = File.ReadAllBytes(imagePath);
var request = new GenerateContentRequest
{
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
index c1fb466f0b09..8a416116ea92 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaAgentTests.cs
@@ -65,8 +65,8 @@ public async Task GenerateStreamingReplyAsync_ReturnsValidMessages_WhenCalled()
var msg = new Message("user", "hey how are you");
var messages = new IMessage[] { MessageEnvelope.Create(msg, from: modelName) };
- IStreamingMessage? finalReply = default;
- await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ IMessage? finalReply = default;
+ await foreach (IMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
{
message.Should().NotBeNull();
message.From.Should().Be(ollamaAgent.Name);
@@ -171,8 +171,8 @@ public async Task ItReturnValidStreamingMessageUsingLLavaAsync()
var messages = new IMessage[] { MessageEnvelope.Create(imageMessage, from: modelName) };
- IStreamingMessage? finalReply = default;
- await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
+ IMessage? finalReply = default;
+ await foreach (IMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
{
message.Should().NotBeNull();
message.From.Should().Be(ollamaAgent.Name);
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
index b19291e97671..82cc462061da 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaMessageTests.cs
@@ -57,10 +57,10 @@ public async Task ItProcessStreamingTextMessageAsync()
})
.Select(m => MessageEnvelope.Create(m));
- IStreamingMessage? finalReply = null;
+ IMessage? finalReply = null;
await foreach (var reply in agent.GenerateStreamingReplyAsync(messageChunks))
{
- reply.Should().BeAssignableTo();
+ reply.Should().BeAssignableTo();
finalReply = reply;
}
diff --git a/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
index 06522bdd8238..b7186a3c6ebc 100644
--- a/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
+++ b/dotnet/test/AutoGen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs
@@ -1,4 +1,4 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingServiceTests.cs
using AutoGen.Tests;
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
index ba499232beb9..04800a631ee6 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
+++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj
@@ -8,6 +8,7 @@
+
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
index aae314ff773e..01af3d4646c4 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/MathClassTest.cs
@@ -110,7 +110,7 @@ public async Task OpenAIAgentMathChatTestAsync()
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary>>
{
- { this.UpdateProgressFunction.Name!, this.UpdateProgressWrapper },
+ { this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
index 81581d068ee7..a9b852e0d8c1 100644
--- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
+++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
@@ -278,9 +278,9 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content;
- chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(1);
+ chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
@@ -291,7 +291,10 @@ public async Task ItProcessToolCallMessageAsync()
.RegisterMiddleware(middleware);
// user message
- IMessage message = new ToolCallMessage("test", "test", "assistant");
+ IMessage message = new ToolCallMessage("test", "test", "assistant")
+ {
+ Content = "textContent",
+ };
await agent.GenerateReplyAsync([message]);
}
@@ -526,13 +529,14 @@ public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
.RegisterMiddleware(middleware);
// tool call message
- var toolCallMessage = CreateInstance(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary());
+ var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary());
var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
var message = await agent.GenerateReplyAsync([chatRequestMessage]);
message.Should().BeOfType();
message.GetToolCalls()!.Count().Should().Be(1);
message.GetToolCalls()!.First().FunctionName.Should().Be("test");
message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
+ message.GetContent().Should().Be("textContent");
}
[Fact]
diff --git a/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromMethod.approved.txt b/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromMethod.approved.txt
index 677831d412b7..eb346da3b313 100644
--- a/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromMethod.approved.txt
+++ b/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromMethod.approved.txt
@@ -14,8 +14,7 @@
"Name": "message",
"Description": "",
"ParameterType": "System.String, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e",
- "IsRequired": true,
- "DefaultValue": ""
+ "IsRequired": true
}
],
"ReturnType": "System.String, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e",
diff --git a/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromTestPlugin.approved.txt b/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromTestPlugin.approved.txt
index ee835b1ba081..9ed3c675e4a0 100644
--- a/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromTestPlugin.approved.txt
+++ b/dotnet/test/AutoGen.SemanticKernel.Tests/ApprovalTests/KernelFunctionExtensionTests.ItCreateFunctionContractsFromTestPlugin.approved.txt
@@ -16,8 +16,7 @@
"Name": "newState",
"Description": "new state",
"ParameterType": "System.Boolean, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e",
- "IsRequired": true,
- "DefaultValue": ""
+ "IsRequired": true
}
],
"ReturnType": "System.String, System.Private.CoreLib, Version=8.0.0.0, Culture=neutral, PublicKeyToken=7cec85d7bea7798e",
diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt b/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt
index 0439febc52c7..f223d3124ddd 100644
--- a/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt
+++ b/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt
@@ -61,11 +61,6 @@ namespace AutoGen.SourceGenerator.Tests
},
};
}
-
- public global::Azure.AI.OpenAI.FunctionDefinition AddAsyncFunction
- {
- get => this.AddAsyncFunctionContract.ToOpenAIFunctionDefinition();
- }
}
}
diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionCallTemplateEncodingTests.cs b/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionCallTemplateEncodingTests.cs
new file mode 100644
index 000000000000..0b2e211c6386
--- /dev/null
+++ b/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionCallTemplateEncodingTests.cs
@@ -0,0 +1,94 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionCallTemplateEncodingTests.cs
+
+using System.Text.Json; // Needed for JsonSerializer
+using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
+using Xunit; // Needed for Fact and Assert
+
+namespace AutoGen.SourceGenerator.Tests
+{
+ public class FunctionCallTemplateEncodingTests
+ {
+ private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
+ {
+ WriteIndented = true,
+ };
+
+ [Fact]
+ public void FunctionDescription_Should_Encode_DoubleQuotes()
+ {
+ // Arrange
+ var functionContracts = new List
+ {
+ new SourceGeneratorFunctionContract
+ {
+ Name = "TestFunction",
+ Description = "This is a \"test\" function",
+ Parameters = new SourceGeneratorParameterContract[]
+ {
+ new SourceGeneratorParameterContract
+ {
+ Name = "param1",
+ Description = "This is a \"parameter\" description",
+ Type = "string",
+ IsOptional = false
+ }
+ },
+ ReturnType = "void"
+ }
+ };
+
+ var template = new FunctionCallTemplate
+ {
+ NameSpace = "TestNamespace",
+ ClassName = "TestClass",
+ FunctionContracts = functionContracts
+ };
+
+ // Act
+ var result = template.TransformText();
+
+ // Assert
+ Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
+ Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
+ }
+
+ [Fact]
+ public void ParameterDescription_Should_Encode_DoubleQuotes()
+ {
+ // Arrange
+ var functionContracts = new List
+ {
+ new SourceGeneratorFunctionContract
+ {
+ Name = "TestFunction",
+ Description = "This is a test function",
+ Parameters = new SourceGeneratorParameterContract[]
+ {
+ new SourceGeneratorParameterContract
+ {
+ Name = "param1",
+ Description = "This is a \"parameter\" description",
+ Type = "string",
+ IsOptional = false
+ }
+ },
+ ReturnType = "void"
+ }
+ };
+
+ var template = new FunctionCallTemplate
+ {
+ NameSpace = "TestNamespace",
+ ClassName = "TestClass",
+ FunctionContracts = functionContracts
+ };
+
+ // Act
+ var result = template.TransformText();
+
+ // Assert
+ Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionExample.test.cs b/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionExample.test.cs
index f7b90e0b96ff..0096f2c157ce 100644
--- a/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionExample.test.cs
+++ b/dotnet/test/AutoGen.SourceGenerator.Tests/FunctionExample.test.cs
@@ -5,6 +5,7 @@
using ApprovalTests;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
+using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
using FluentAssertions;
using Xunit;
@@ -29,7 +30,7 @@ public void Add_Test()
};
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
- this.VerifyFunctionDefinition(functionExamples.AddFunction);
+ this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToOpenAIFunctionDefinition());
}
[Fact]
@@ -41,7 +42,7 @@ public void Sum_Test()
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
- this.VerifyFunctionDefinition(functionExamples.SumFunction);
+ this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToOpenAIFunctionDefinition());
}
[Fact]
@@ -57,7 +58,7 @@ public async Task DictionaryToString_Test()
};
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
- this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunction);
+ this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToOpenAIFunctionDefinition());
}
[Fact]
@@ -96,7 +97,7 @@ public void Query_Test()
};
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
- this.VerifyFunctionDefinition(functionExamples.QueryFunction);
+ this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToOpenAIFunctionDefinition());
}
[UseReporter(typeof(DiffReporter))]
diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj
index 4def281ed7b4..3dc669b5edd8 100644
--- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj
+++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj
@@ -9,6 +9,7 @@
+
diff --git a/dotnet/test/AutoGen.Tests/BasicSampleTest.cs b/dotnet/test/AutoGen.Tests/BasicSampleTest.cs
index 8f2b9b2de51b..89925b7d3b39 100644
--- a/dotnet/test/AutoGen.Tests/BasicSampleTest.cs
+++ b/dotnet/test/AutoGen.Tests/BasicSampleTest.cs
@@ -37,11 +37,6 @@ public async Task AgentFunctionCallTestAsync()
await Example03_Agent_FunctionCall.RunAsync();
}
- [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
- public async Task OpenAIAgent_JsonMode()
- {
- await Example13_OpenAIAgent_JsonMode.RunAsync();
- }
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
@@ -49,12 +44,6 @@ public async Task MistralClientAgent_TokenCount()
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
- [ApiKeyFact("OPENAI_API_KEY")]
- public async Task DynamicGroupChatGetMLNetPRTestAsync()
- {
- await Example04_Dynamic_GroupChat_Coding_Task.RunAsync();
- }
-
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs
index 9cead5ad2516..af5490218e8d 100644
--- a/dotnet/test/AutoGen.Tests/EchoAgent.cs
+++ b/dotnet/test/AutoGen.Tests/EchoAgent.cs
@@ -29,7 +29,7 @@ public Task GenerateReplyAsync(
return Task.FromResult(lastMessage);
}
- public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var message in messages)
{
diff --git a/dotnet/test/AutoGen.Tests/GroupChat/GraphTests.cs b/dotnet/test/AutoGen.Tests/GroupChat/GraphTests.cs
new file mode 100644
index 000000000000..7eeea6743f04
--- /dev/null
+++ b/dotnet/test/AutoGen.Tests/GroupChat/GraphTests.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GraphTests.cs
+
+using Xunit;
+
+namespace AutoGen.Tests
+{
+ public class GraphTests
+ {
+ [Fact]
+ public void GraphTest()
+ {
+ var graph1 = new Graph();
+ Assert.NotNull(graph1);
+
+ var graph2 = new Graph(null);
+ Assert.NotNull(graph2);
+ }
+ }
+}
diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs
new file mode 100644
index 000000000000..5a2cebb66cff
--- /dev/null
+++ b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs
@@ -0,0 +1,362 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RolePlayOrchestratorTests.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Anthropic;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Gemini;
+using AutoGen.Mistral;
+using AutoGen.Mistral.Extension;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using Moq;
+using Xunit;
+
+namespace AutoGen.Tests;
+
+public class RolePlayOrchestratorTests
+{
+ [Fact]
+ public async Task ItReturnNextSpeakerTestAsync()
+ {
+ var admin = Mock.Of();
+ Mock.Get(admin).Setup(x => x.Name).Returns("Admin");
+ Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .Callback, GenerateReplyOptions, CancellationToken>((messages, option, _) =>
+ {
+ // verify prompt
+ var rolePlayPrompt = messages.First().GetContent();
+ rolePlayPrompt.Should().Contain("You are in a role play game. Carefully read the conversation history and carry on the conversation");
+ rolePlayPrompt.Should().Contain("The available roles are:");
+ rolePlayPrompt.Should().Contain("Alice,Bob");
+ rolePlayPrompt.Should().Contain("From Alice:");
+ option.StopSequence.Should().BeEquivalentTo([":"]);
+ option.Temperature.Should().Be(0);
+ option.MaxToken.Should().Be(128);
+ option.Functions.Should().BeNull();
+ })
+ .ReturnsAsync(new TextMessage(Role.Assistant, "From Alice"));
+
+ var alice = new EchoAgent("Alice");
+ var bob = new EchoAgent("Bob");
+
+ var orchestrator = new RolePlayOrchestrator(admin);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice, bob],
+ ChatHistory = [],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().Be(alice);
+ }
+
+ [Fact]
+ public async Task ItReturnNullWhenNoCandidateIsAvailableAsync()
+ {
+ var admin = Mock.Of();
+ var orchestrator = new RolePlayOrchestrator(admin);
+ var context = new OrchestrationContext
+ {
+ Candidates = [],
+ ChatHistory = [],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().BeNull();
+ }
+
+ [Fact]
+ public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync()
+ {
+ var admin = Mock.Of();
+ var alice = new EchoAgent("Alice");
+ var orchestrator = new RolePlayOrchestrator(admin);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice],
+ ChatHistory = [],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().Be(alice);
+ }
+
+ [Fact]
+ public async Task ItThrowExceptionWhenAdminFailsToFollowPromptAsync()
+ {
+ var admin = Mock.Of();
+ Mock.Get(admin).Setup(x => x.Name).Returns("Admin");
+ Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .ReturnsAsync(new TextMessage(Role.Assistant, "I don't know")); // admin fails to follow the prompt and returns an invalid message
+
+ var alice = new EchoAgent("Alice");
+ var bob = new EchoAgent("Bob");
+
+ var orchestrator = new RolePlayOrchestrator(admin);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice, bob],
+ ChatHistory = [],
+ };
+
+ var action = async () => await orchestrator.GetNextSpeakerAsync(context);
+
+ await action.Should().ThrowAsync()
+ .WithMessage("The response from admin is 't know, which is either not in the candidates list or not in the correct format.");
+ }
+
+ [Fact]
+ public async Task ItSelectNextSpeakerFromWorkflowIfProvided()
+ {
+ var workflow = new Graph();
+ var alice = new EchoAgent("Alice");
+ var bob = new EchoAgent("Bob");
+ var charlie = new EchoAgent("Charlie");
+ workflow.AddTransition(Transition.Create(alice, bob));
+ workflow.AddTransition(Transition.Create(bob, charlie));
+ workflow.AddTransition(Transition.Create(charlie, alice));
+
+ var admin = Mock.Of();
+ var orchestrator = new RolePlayOrchestrator(admin, workflow);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice, bob, charlie],
+ ChatHistory =
+ [
+ new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
+ ],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().Be(bob);
+ }
+
+ [Fact]
+ public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync()
+ {
+ var workflow = new Graph();
+ var alice = new EchoAgent("Alice");
+ var bob = new EchoAgent("Bob");
+ workflow.AddTransition(Transition.Create(alice, bob));
+
+ var admin = Mock.Of();
+ var orchestrator = new RolePlayOrchestrator(admin, workflow);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice, bob],
+ ChatHistory =
+ [
+ new TextMessage(Role.User, "Hello, Alice", from: "Bob"),
+ ],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().BeNull();
+ }
+
+ [Fact]
+ public async Task ItUseCandidatesFromWorflowAsync()
+ {
+ var workflow = new Graph();
+ var alice = new EchoAgent("Alice");
+ var bob = new EchoAgent("Bob");
+ var charlie = new EchoAgent("Charlie");
+ workflow.AddTransition(Transition.Create(alice, bob));
+ workflow.AddTransition(Transition.Create(alice, charlie));
+
+ var admin = Mock.Of();
+ Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .Callback, GenerateReplyOptions, CancellationToken>((messages, option, _) =>
+ {
+ messages.First().IsSystemMessage().Should().BeTrue();
+
+ // verify prompt
+ var rolePlayPrompt = messages.First().GetContent();
+ rolePlayPrompt.Should().Contain("Bob,Charlie");
+ rolePlayPrompt.Should().Contain("From Bob:");
+ option.StopSequence.Should().BeEquivalentTo([":"]);
+ option.Temperature.Should().Be(0);
+ option.MaxToken.Should().Be(128);
+ option.Functions.Should().BeEmpty();
+ })
+ .ReturnsAsync(new TextMessage(Role.Assistant, "From Bob"));
+ var orchestrator = new RolePlayOrchestrator(admin, workflow);
+ var context = new OrchestrationContext
+ {
+ Candidates = [alice, bob],
+ ChatHistory =
+ [
+ new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
+ ],
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker.Should().Be(bob);
+ }
+
+ [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
+ public async Task GPT_3_5_CoderReviewerRunnerTestAsync()
+ {
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: deployName)
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(openAIChatAgent);
+ }
+
+ [ApiKeyFact("OPENAI_API_KEY")]
+ public async Task GPT_4o_CoderReviewerRunnerTestAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ var model = "gpt-4o";
+ var openaiClient = new OpenAIClient(apiKey);
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: model)
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(openAIChatAgent);
+ }
+
+ [ApiKeyFact("OPENAI_API_KEY")]
+ public async Task GPT_4o_mini_CoderReviewerRunnerTestAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ var model = "gpt-4o-mini";
+ var openaiClient = new OpenAIClient(apiKey);
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: model)
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(openAIChatAgent);
+ }
+
+
+ [ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
+ public async Task GoogleGemini_1_5_flash_001_CoderReviewerRunnerTestAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
+ var geminiAgent = new GeminiChatAgent(
+ name: "gemini",
+ model: "gemini-1.5-flash-001",
+ apiKey: apiKey)
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(geminiAgent);
+ }
+
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task Claude3_Haiku_CoderReviewerRunnerTestAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Please set ANTHROPIC_API_KEY environment variable.");
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant that convert user message to upper case")
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(agent);
+ }
+
+ [ApiKeyFact("MISTRAL_API_KEY")]
+ public async Task Mistra_7b_CoderReviewerRunnerTestAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new InvalidOperationException("MISTRAL_API_KEY is not set.");
+ var client = new MistralClient(apiKey: apiKey);
+
+ var agent = new MistralClientAgent(
+ client: client,
+ name: "MistralClientAgent",
+ model: "open-mistral-7b")
+ .RegisterMessageConnector();
+
+ await CoderReviewerRunnerTestAsync(agent);
+ }
+
+ ///
+ /// This test is to mimic the conversation among coder, reviewer and runner.
+ /// The coder will write the code, the reviewer will review the code, and the runner will run the code.
+ ///
+ ///
+ ///
+ public async Task CoderReviewerRunnerTestAsync(IAgent admin)
+ {
+ var coder = new EchoAgent("Coder");
+ var reviewer = new EchoAgent("Reviewer");
+ var runner = new EchoAgent("Runner");
+ var user = new EchoAgent("User");
+ var initializeMessage = new List
+ {
+ new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"),
+ new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"),
+ new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"),
+ new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"),
+ new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name),
+ };
+
+ var chatHistory = new List()
+ {
+ new TextMessage(Role.User, """
+ ```csharp
+ Console.WriteLine("Hello World");
+ ```
+ """, from: coder.Name),
+ new TextMessage(Role.User, "The code looks good", from: reviewer.Name),
+ new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name),
+ };
+
+ var orchestrator = new RolePlayOrchestrator(admin);
+ foreach (var message in chatHistory)
+ {
+ var context = new OrchestrationContext
+ {
+ Candidates = [coder, reviewer, runner, user],
+ ChatHistory = initializeMessage,
+ };
+
+ var speaker = await orchestrator.GetNextSpeakerAsync(context);
+ speaker!.Name.Should().Be(message.From);
+ initializeMessage.Add(message);
+ }
+
+ // the last next speaker should be the user
+ var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext
+ {
+ Candidates = [coder, reviewer, runner, user],
+ ChatHistory = initializeMessage,
+ });
+
+ lastSpeaker!.Name.Should().Be(user.Name);
+ }
+}
diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs
new file mode 100644
index 000000000000..e14bf85cf215
--- /dev/null
+++ b/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs
@@ -0,0 +1,103 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RoundRobinOrchestratorTests.cs
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using FluentAssertions;
+using Xunit;
+
+namespace AutoGen.Tests;
+
+public class RoundRobinOrchestratorTests
+{
+ [Fact]
+ public async Task ItReturnNextAgentAsync()
+ {
+ var orchestrator = new RoundRobinOrchestrator();
+ var context = new OrchestrationContext
+ {
+ Candidates = new List