Skip to content

Commit

Permalink
.Net Agents - ChatCompletionAgent Pre-Graduation (microsoft#7900)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Update public signatures / interface of `ChatCompletionAgent` in
preparation for removing experimental tag.

Fixes: microsoft#6037 
Fixes: microsoft#6523


### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

- Include `Kernel` and `KernelArguments` in agent invocation signatures
for `ChatCompletionAgent`
- Change `ChatCompletionAgent.ExeuctionSettings` to
`ChatCompletionAgent.KernelArguments`
- Mirror pattern for `OpenAIAssistantAgent`
- Integrate `Kernel` and `KernelArguments.ExecutionSettings` as
overrides for agent invocation
- Added new sample that explores the entire selection matrix:
`ChatCompletion_ServiceSelection`

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Aug 8, 2024
1 parent f71d326 commit bdf15a8
Show file tree
Hide file tree
Showing 15 changed files with 218 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
Expand Down Expand Up @@ -76,7 +76,7 @@ public async Task UseAutoFunctionInvocationFilterWithAgentChatAsync()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
Expand Down
128 changes: 128 additions & 0 deletions dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;

namespace Agents;

/// <summary>
/// Demonstrate service selection for <see cref="ChatCompletionAgent"/> through setting service-id
/// on <see cref="ChatCompletionAgent.Arguments"/> and also providing override <see cref="KernelArguments"/>
/// when calling <see cref="ChatCompletionAgent.InvokeAsync"/>
/// </summary>
public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseTest(output)
{
private const string ServiceKeyGood = "chat-good";
private const string ServiceKeyBad = "chat-bad";

[Fact]
public async Task UseServiceSelectionWithChatCompletionAgentAsync()
{
// Create kernel with two instances of IChatCompletionService
// One service is configured with a valid API key and the other with an
// invalid key that will result in a 401 Unauthorized error.
Kernel kernel = CreateKernelWithTwoServices();

// Define the agent targeting ServiceId = ServiceKeyGood
ChatCompletionAgent agentGood =
new()
{
Kernel = kernel,
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ServiceId = ServiceKeyGood }),
};

// Define the agent targeting ServiceId = ServiceKeyBad
ChatCompletionAgent agentBad =
new()
{
Kernel = kernel,
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ServiceId = ServiceKeyBad }),
};

// Define the agent with no explicit ServiceId defined
ChatCompletionAgent agentDefault = new() { Kernel = kernel };

// Invoke agent as initialized with ServiceId = ServiceKeyGood: Expect agent response
Console.WriteLine("\n[Agent With Good ServiceId]");
await InvokeAgentAsync(agentGood);

// Invoke agent as initialized with ServiceId = ServiceKeyBad: Expect failure due to invalid service key
Console.WriteLine("\n[Agent With Bad ServiceId]");
await InvokeAgentAsync(agentBad);

// Invoke agent as initialized with no explicit ServiceId: Expect agent response
Console.WriteLine("\n[Agent With No ServiceId]");
await InvokeAgentAsync(agentDefault);

// Invoke agent with override arguments where ServiceId = ServiceKeyGood: Expect agent response
Console.WriteLine("\n[Bad Agent: Good ServiceId Override]");
await InvokeAgentAsync(agentBad, new(new OpenAIPromptExecutionSettings() { ServiceId = ServiceKeyGood }));

// Invoke agent with override arguments where ServiceId = ServiceKeyBad: Expect failure due to invalid service key
Console.WriteLine("\n[Good Agent: Bad ServiceId Override]");
await InvokeAgentAsync(agentGood, new(new OpenAIPromptExecutionSettings() { ServiceId = ServiceKeyBad }));
Console.WriteLine("\n[Default Agent: Bad ServiceId Override]");
await InvokeAgentAsync(agentDefault, new(new OpenAIPromptExecutionSettings() { ServiceId = ServiceKeyBad }));

// Invoke agent with override arguments with no explicit ServiceId: Expect agent response
Console.WriteLine("\n[Good Agent: No ServiceId Override]");
await InvokeAgentAsync(agentGood, new(new OpenAIPromptExecutionSettings()));
Console.WriteLine("\n[Bad Agent: No ServiceId Override]");
await InvokeAgentAsync(agentBad, new(new OpenAIPromptExecutionSettings()));
Console.WriteLine("\n[Default Agent: No ServiceId Override]");
await InvokeAgentAsync(agentDefault, new(new OpenAIPromptExecutionSettings()));

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(ChatCompletionAgent agent, KernelArguments? arguments = null)
{
ChatHistory chat = [new(AuthorRole.User, "Hello")];

try
{
await foreach (ChatMessageContent response in agent.InvokeAsync(chat, arguments))
{
Console.WriteLine(response.Content);
}
}
catch (HttpOperationException exception)
{
Console.WriteLine($"Status: {exception.StatusCode}");
}
}
}

private Kernel CreateKernelWithTwoServices()
{
IKernelBuilder builder = Kernel.CreateBuilder();

if (this.UseOpenAIConfig)
{
builder.AddOpenAIChatCompletion(
TestConfiguration.OpenAI.ChatModelId,
"bad-key",
serviceId: ServiceKeyBad);

builder.AddOpenAIChatCompletion(
TestConfiguration.OpenAI.ChatModelId,
TestConfiguration.OpenAI.ApiKey,
serviceId: ServiceKeyGood);
}
else
{
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
"bad-key",
serviceId: ServiceKeyBad);

builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey,
serviceId: ServiceKeyGood);
}

return builder.Build();
}
}
2 changes: 1 addition & 1 deletion dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public async Task UseStreamingChatCompletionAgentWithPluginAsync()
Name = "Host",
Instructions = MenuInstructions,
Kernel = this.CreateKernelWithChatCompletion(),
ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }),
};

// Initialize plugin and add to the agent's Kernel (same as direct Kernel usage).
Expand Down
2 changes: 1 addition & 1 deletion dotnet/samples/GettingStartedWithAgents/Step2_Plugins.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public async Task UseChatCompletionWithPluginAgentAsync()
Instructions = HostInstructions,
Name = HostName,
Kernel = this.CreateKernelWithChatCompletion(),
ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }),
};

// Initialize plugin and add to the agent's Kernel (same as direct Kernel usage).
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Abstractions/KernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ public abstract class KernelAgent : Agent
/// <remarks>
/// Defaults to empty Kernel, but may be overridden.
/// </remarks>
public Kernel Kernel { get; init; } = new Kernel();
public Kernel Kernel { get; init; } = new();
}
44 changes: 33 additions & 11 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Services;

namespace Microsoft.SemanticKernel.Agents;

Expand All @@ -15,24 +16,29 @@ namespace Microsoft.SemanticKernel.Agents;
/// </summary>
/// <remarks>
/// NOTE: Enable OpenAIPromptExecutionSettings.ToolCallBehavior for agent plugins.
/// (<see cref="ChatCompletionAgent.ExecutionSettings"/>)
/// (<see cref="ChatCompletionAgent.Arguments"/>)
/// </remarks>
public sealed class ChatCompletionAgent : KernelAgent, IChatHistoryHandler
{
/// <summary>
/// Optional execution settings for the agent.
/// Optional arguments for the agent.
/// </summary>
public PromptExecutionSettings? ExecutionSettings { get; set; }
public KernelArguments? Arguments { get; init; }

/// <inheritdoc/>
public IChatHistoryReducer? HistoryReducer { get; init; }

/// <inheritdoc/>
public async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IChatCompletionService chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();
kernel ??= this.Kernel;
arguments ??= this.Arguments;

(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = this.GetChatCompletionService(kernel, arguments);

ChatHistory chat = this.SetupAgentChatHistory(history);

Expand All @@ -43,8 +49,8 @@ public async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> messages =
await chatCompletionService.GetChatMessageContentsAsync(
chat,
this.ExecutionSettings,
this.Kernel,
executionSettings,
kernel,
cancellationToken).ConfigureAwait(false);

this.Logger.LogAgentChatServiceInvokedAgent(nameof(InvokeAsync), this.Id, chatCompletionService.GetType(), messages.Count);
Expand All @@ -61,7 +67,6 @@ await chatCompletionService.GetChatMessageContentsAsync(

foreach (ChatMessageContent message in messages ?? [])
{
// TODO: MESSAGE SOURCE - ISSUE #5731
message.AuthorName = this.Name;

yield return message;
Expand All @@ -71,9 +76,14 @@ await chatCompletionService.GetChatMessageContentsAsync(
/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IChatCompletionService chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();
kernel ??= this.Kernel;
arguments ??= this.Arguments;

(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = this.GetChatCompletionService(kernel, arguments);

ChatHistory chat = this.SetupAgentChatHistory(history);

Expand All @@ -84,15 +94,14 @@ public async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
IAsyncEnumerable<StreamingChatMessageContent> messages =
chatCompletionService.GetStreamingChatMessageContentsAsync(
chat,
this.ExecutionSettings,
this.Kernel,
executionSettings,
kernel,
cancellationToken);

this.Logger.LogAgentChatServiceInvokedStreamingAgent(nameof(InvokeAsync), this.Id, chatCompletionService.GetType());

await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
{
// TODO: MESSAGE SOURCE - ISSUE #5731
message.AuthorName = this.Name;

yield return message;
Expand Down Expand Up @@ -136,6 +145,19 @@ protected override Task<AgentChannel> CreateChannelAsync(CancellationToken cance
return Task.FromResult<AgentChannel>(channel);
}

private (IChatCompletionService service, PromptExecutionSettings? executionSettings) GetChatCompletionService(Kernel kernel, KernelArguments? arguments)
{
// Need to provide a KernelFunction to the service selector as a container for the execution-settings.
KernelFunction nullPrompt = KernelFunctionFactory.CreateFromPrompt("placeholder", arguments?.ExecutionSettings?.Values);
(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) =
kernel.ServiceSelector.SelectAIService<IChatCompletionService>(
kernel,
nullPrompt,
arguments ?? []);

return (chatCompletionService, executionSettings);
}

private ChatHistory SetupAgentChatHistory(IReadOnlyList<ChatMessageContent> history)
{
ChatHistory chat = [];
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Core/ChatHistoryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public sealed class ChatHistoryChannel : AgentChannel
Queue<ChatMessageContent> messageQueue = [];

ChatMessageContent? yieldMessage = null;
await foreach (ChatMessageContent responseMessage in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
await foreach (ChatMessageContent responseMessage in historyHandler.InvokeAsync(this._history, null, null, cancellationToken).ConfigureAwait(false))
{
// Capture all messages that have been included in the mutated the history.
for (int messageIndex = messageCount; messageIndex < this._history.Count; messageIndex++)
Expand Down
8 changes: 8 additions & 0 deletions dotnet/src/Agents/Core/IChatHistoryHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@ public interface IChatHistoryHandler
/// Entry point for calling into an agent from a <see cref="ChatHistoryChannel"/>.
/// </summary>
/// <param name="history">The chat history at the point the channel is created.</param>
/// <param name="arguments">Optional arguments to pass to the agents's invocation, including any <see cref="PromptExecutionSettings"/>.</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use by the agent.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of messages.</returns>
IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Entry point for calling into an agent from a <see cref="ChatHistoryChannel"/> for streaming content.
/// </summary>
/// <param name="history">The chat history at the point the channel is created.</param>
/// <param name="arguments">Optional arguments to pass to the agents's invocation, including any <see cref="PromptExecutionSettings"/>.</param>
/// <param name="kernel">Optional <see cref="Kernel"/> override containing services, plugins, and other state for use by the agent.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of streaming content.</returns>
IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default);
}
9 changes: 8 additions & 1 deletion dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,29 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
/// <param name="threadId">The thread identifier</param>
/// <param name="pollingConfiguration">Config to utilize when polling for run state.</param>
/// <param name="logger">The logger to utilize (might be agent or channel scoped)</param>
/// <param name="kernel">The <see cref="Kernel"/> plugins and other state.</param>
/// <param name="arguments">Optional arguments to pass to the agents's invocation, including any <see cref="PromptExecutionSettings"/>.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of messages.</returns>
/// <remarks>
/// The `arguments` parameter is not currently used by the agent, but is provided for future extensibility.
/// </remarks>
public static async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
OpenAIAssistantAgent agent,
AssistantsClient client,
string threadId,
OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration,
ILogger logger,
Kernel kernel,
KernelArguments? arguments,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
if (agent.IsDeleted)
{
throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {agent.Id}.");
}

ToolDefinition[]? tools = [.. agent.Tools, .. agent.Kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))];
ToolDefinition[]? tools = [.. agent.Tools, .. kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))];

logger.LogOpenAIAssistantCreatingRun(nameof(InvokeAsync), threadId);

Expand Down
Loading

0 comments on commit bdf15a8

Please sign in to comment.