diff --git a/docs/release-notes/breaking-changes.md b/docs/release-notes/breaking-changes.md index 42ede74678..004870a757 100644 --- a/docs/release-notes/breaking-changes.md +++ b/docs/release-notes/breaking-changes.md @@ -3,6 +3,50 @@ > [!NOTE] > This section is for changes that are not yet released but will affect future releases. +## Starting with 0.9.1-rc117 + +### Agent configuration changes + +```json +"text_rewrite_settings": { + "user_prompt_rewrite_enabled" : true, + "user_prompt_rewrite_settings": { + "user_prompt_rewrite_ai_model_object_id": "/instances/73fad442-f614-4510-811f-414cb3a3d34b/providers/FoundationaLLM.AIModel/aiModels/GPT4oCompletionAIModel", + "user_prompt_rewrite_prompt_object_id": "/instances/73fad442-f614-4510-811f-414cb3a3d34b/providers/FoundationaLLM.Prompt/prompts/FoundationaLLM-v2-Rewrite", + "user_prompts_window_size": 1 + } +}, +"cache_settings": { + "semantic_cache_enabled": true, + "semantic_cache_settings": { + "embedding_ai_model_object_id": "/instances/73fad442-f614-4510-811f-414cb3a3d34b/providers/FoundationaLLM.AIModel/aiModels/DefaultEmbeddingAIModel", + "embedding_dimensions": 2048, + "minimum_similarity_threshold": 0.975 + } +}, +``` + +### Semantic cache + +Enable vector search in the Cosmos DB database using the following CLI command: + +```cli +az cosmosdb update --resource-group --name --capabilities EnableNoSQLVectorSearch +``` + +Create the `CompletionsCache` container in the Cosmos DB database with the following properties: + +- **Container id**: `CompletionsCache` +- **Partition key**: `/operationId` +- **Container Vector Policy**: a policy with the following properties: + - **Path**: `/userPromptEmbedding` + - **Data type**: `float32` + - **Distance function**: `Cosine` + - **Dimensions**: 2048 + - **Index type**: `diskANN` (leave the default values) + +After the container is created, set the `Time to Live` property on the container to 300 seconds. + ## Starting with 0.9.1-rc105 ### Configuration changes diff --git a/src/dotnet/Common/Common.csproj b/src/dotnet/Common/Common.csproj index 811102afb3..d122a457d1 100644 --- a/src/dotnet/Common/Common.csproj +++ b/src/dotnet/Common/Common.csproj @@ -58,7 +58,7 @@ - + diff --git a/src/dotnet/Common/Constants/AzureCosmosDBContainers.cs b/src/dotnet/Common/Constants/AzureCosmosDBContainers.cs index 1374dd0244..22c4abff88 100644 --- a/src/dotnet/Common/Constants/AzureCosmosDBContainers.cs +++ b/src/dotnet/Common/Constants/AzureCosmosDBContainers.cs @@ -44,5 +44,10 @@ public static class AzureCosmosDBContainers /// Stores information about external resources (e.g., Azure OpenAI assistants threads and files). /// public const string ExternalResources = "ExternalResources"; + + /// + /// The vector store for cached completions (used by the semantic cache service). + /// + public const string CompletionsCache = "CompletionsCache"; } } diff --git a/src/dotnet/Common/Constants/Chat/MessageContentTypes.cs b/src/dotnet/Common/Constants/Chat/MessageContentTypes.cs deleted file mode 100644 index c11046d579..0000000000 --- a/src/dotnet/Common/Constants/Chat/MessageContentTypes.cs +++ /dev/null @@ -1,28 +0,0 @@ -namespace FoundationaLLM.Common.Constants.Chat -{ - /// - /// - /// - public static class MessageContentTypes - { - /// - /// Plaintext and formatted text, such as markdown. - /// - public const string Text = "text"; - - /// - /// Image file link. - /// - public const string Image = "image"; - - /// - /// General file link. - /// - public const string File = "file"; - - /// - /// HTML file link. - /// - public const string Html = "html"; - } -} diff --git a/src/dotnet/Common/Exceptions/SemanticCacheException.cs b/src/dotnet/Common/Exceptions/SemanticCacheException.cs new file mode 100644 index 0000000000..7dd9443a1e --- /dev/null +++ b/src/dotnet/Common/Exceptions/SemanticCacheException.cs @@ -0,0 +1,32 @@ +namespace FoundationaLLM.Common.Exceptions +{ + /// + /// Represents an error generated by the semantic cache. + /// + public class SemanticCacheException : Exception + { + /// + /// Initializes a new instance of the class with a default message. + /// + public SemanticCacheException() + { + } + + /// + /// Initializes a new instance of the class with its message set to . + /// + /// A string that describes the error. + public SemanticCacheException(string? message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with its message set to . + /// + /// A string that describes the error. + /// The exception that is the cause of the current exception. + public SemanticCacheException(string? message, Exception? innerException) : base(message, innerException) + { + } + } +} diff --git a/src/dotnet/Common/Exceptions/UserPromptRewriteException.cs b/src/dotnet/Common/Exceptions/UserPromptRewriteException.cs new file mode 100644 index 0000000000..f20ee19669 --- /dev/null +++ b/src/dotnet/Common/Exceptions/UserPromptRewriteException.cs @@ -0,0 +1,32 @@ +namespace FoundationaLLM.Common.Exceptions +{ + /// + /// Represents an error generated by user prompt rewrite service. + /// + public class UserPromptRewriteException : Exception + { + /// + /// Initializes a new instance of the class with a default message. + /// + public UserPromptRewriteException() + { + } + + /// + /// Initializes a new instance of the class with its message set to . + /// + /// A string that describes the error. + public UserPromptRewriteException(string? message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with its message set to . + /// + /// A string that describes the error. + /// The exception that is the cause of the current exception. + public UserPromptRewriteException(string? message, Exception? innerException) : base(message, innerException) + { + } + } +} diff --git a/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs b/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs index 651b3bdaae..738354f174 100644 --- a/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs +++ b/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs @@ -2,6 +2,7 @@ using FoundationaLLM.Common.Models.Configuration.Users; using FoundationaLLM.Common.Models.Conversation; using FoundationaLLM.Common.Models.Orchestration; +using FoundationaLLM.Common.Models.Orchestration.Response; using FoundationaLLM.Common.Models.ResourceProviders; using FoundationaLLM.Common.Models.ResourceProviders.Attachment; @@ -258,4 +259,32 @@ Task> PatchMultipleSessionsItemsInTransactionAsync( /// Cancellation token for async calls. /// Task DeleteAttachment(AttachmentReference attachment, CancellationToken cancellationToken = default); + + /// + /// Creates a new container for vector search. + /// + /// The name of the container to create. + /// The property path that contains the partition key. + /// The property path that contains the vectors. + /// The length of each vector (the number of dimensions used for embedding). + /// The cancellation token to signal the need to cancel the operation. + /// + Task CreateVectorSearchContainerAsync( + string containerName, + string partitionKeyPath, + string vectorProperyPath, + int vectorDimensions, + CancellationToken cancellationToken = default); + + /// + /// Gets the completion response for a given user prompt embedding using vector search and a minimum threshold for similarity. + /// + /// The name of the container holding the vector index. + /// The reference embedding used for the vector search. + /// The threshold used for the similarity score. + /// A that matches the search criteria. If no item in the vector index matches the criteria, returns . + Task GetCompletionResponseAsync( + string containerName, + ReadOnlyMemory userPromptEmbedding, + decimal minimumSimilarityScore); } diff --git a/src/dotnet/Common/Models/Conversation/Message.cs b/src/dotnet/Common/Models/Conversation/Message.cs index 4cd39d51ce..e4a7adba17 100644 --- a/src/dotnet/Common/Models/Conversation/Message.cs +++ b/src/dotnet/Common/Models/Conversation/Message.cs @@ -58,6 +58,11 @@ public record Message /// public string Text { get; set; } + /// + /// The optional rewrite of the text content of the message. + /// + public string? TextRewrite { get; set; } + /// /// The rating associated with the message, if any. /// diff --git a/src/dotnet/Common/Models/Orchestration/Request/CompletionRequestBase.cs b/src/dotnet/Common/Models/Orchestration/Request/CompletionRequestBase.cs index 297be97199..5dcc4a9332 100644 --- a/src/dotnet/Common/Models/Orchestration/Request/CompletionRequestBase.cs +++ b/src/dotnet/Common/Models/Orchestration/Request/CompletionRequestBase.cs @@ -9,31 +9,37 @@ namespace FoundationaLLM.Common.Models.Orchestration.Request public class CompletionRequestBase { /// - /// The Operation ID identifying the completion request. + /// Gets or sets the operation identifier of the completion request. /// [JsonPropertyName("operation_id")] public string? OperationId { get; set; } /// - /// Indicates whether this is a long-running operation. + /// Gets or sets a flag that indicates whether this is a long-running operation. /// [JsonPropertyName("long_running_operation")] public bool LongRunningOperation { get; set; } /// - /// The session ID. + /// Gets or sets the conversation identifier. /// [JsonPropertyName("session_id")] public string? SessionId { get; set; } /// - /// Represent the input or user prompt. + /// Gets or sets the user prompt. /// [JsonPropertyName("user_prompt")] public required string UserPrompt { get; set; } /// - /// The message history associated with the completion request. + /// Gets or sets the rewrite of the user prompt. + /// + [JsonPropertyName("user_prompt_rewrite")] + public string? UserPromptRewrite { get; set; } + + /// + /// Gets or sets the message history associated with the completion request. /// [JsonPropertyName("message_history")] public List? MessageHistory { get; set; } = []; diff --git a/src/dotnet/Common/Models/Orchestration/Response/CompletionResponseBase.cs b/src/dotnet/Common/Models/Orchestration/Response/CompletionResponseBase.cs index 9c21d43a84..f85ef424e9 100644 --- a/src/dotnet/Common/Models/Orchestration/Response/CompletionResponseBase.cs +++ b/src/dotnet/Common/Models/Orchestration/Response/CompletionResponseBase.cs @@ -43,6 +43,12 @@ public class CompletionResponseBase [JsonPropertyName("user_prompt")] public string UserPrompt { get; set; } + /// + /// The user prompt rewrite. + /// + [JsonPropertyName("user_prompt_rewrite")] + public string? UserPromptRewrite { get; set; } + /// /// The full prompt composed by the LLM. /// diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs index 7878bbb039..c6dc5b9a61 100644 --- a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs @@ -22,17 +22,30 @@ public class AgentBase : ResourceBase /// [JsonPropertyName("sessions_enabled")] public bool SessionsEnabled { get; set; } + + /// + /// Gets or sets the agent's text rewrite settings. + /// + [JsonPropertyName("text_rewrite_settings")] + public AgentTextRewriteSettings? TextRewriteSettings { get; set; } + + /// + /// Gets or sets the agent's caching settings. + /// + [JsonPropertyName("cache_settings")] + public AgentCacheSettings? CacheSettings { get; set; } + /// /// The agent's conversation history configuration. /// [JsonPropertyName("conversation_history_settings")] - public ConversationHistorySettings? ConversationHistorySettings { get; set; } + public AgentConversationHistorySettings? ConversationHistorySettings { get; set; } + /// /// The agent's Gatekeeper configuration. /// [JsonPropertyName("gatekeeper_settings")] - public GatekeeperSettings? GatekeeperSettings { get; set; } - + public AgentGatekeeperSettings? GatekeeperSettings { get; set; } /// /// Settings for the orchestration service. @@ -103,41 +116,4 @@ public class AgentBase : ResourceBase public bool HasCapability(string capabilityName) => Capabilities?.Contains(capabilityName) ?? false; } - - /// - /// Agent conversation history settings. - /// - public class ConversationHistorySettings - { - /// - /// Indicates whether the conversation history is enabled. - /// - [JsonPropertyName("enabled")] - public bool Enabled { get; set; } - - /// - /// The maximum number of turns to store in the conversation history. - /// - [JsonPropertyName("max_history")] - public int MaxHistory { get; set; } - } - - /// - /// Agent Gatekeeper settings. - /// - public class GatekeeperSettings - { - /// - /// Indicates whether to abide by or override the system settings for the Gatekeeper. - /// - [JsonPropertyName("use_system_setting")] - public bool UseSystemSetting { get; set; } - - /// - /// If is false, provides Gatekeeper feature selection. - /// - [JsonPropertyName("options")] - public string[]? Options { get; set; } - } - } diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentCacheSettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentCacheSettings.cs new file mode 100644 index 0000000000..0e6b906843 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentCacheSettings.cs @@ -0,0 +1,28 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related caching settings. + /// + public class AgentCacheSettings + { + /// + /// Gets or sets a value indicating whether the agent's semantic cache is enabled. + /// + /// + /// When enabled, the agent's semantic cache settings are provided in . + /// + [JsonPropertyName("semantic_cache_enabled")] + public bool SemanticCacheEnabled { get; set; } = false; + + /// + /// Gets or sets the agent's semantic cache settings. + /// + /// + /// The values in this property are only valid when is . + /// + [JsonPropertyName("semantic_cache_settings")] + public AgentSemanticCacheSettings? SemanticCacheSettings { get; set; } + } +} diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentConversationHistorySettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentConversationHistorySettings.cs new file mode 100644 index 0000000000..1917aad4c9 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentConversationHistorySettings.cs @@ -0,0 +1,22 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related conversation history settings. + /// + public class AgentConversationHistorySettings + { + /// + /// Indicates whether the conversation history is enabled. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; set; } + + /// + /// The maximum number of turns to store in the conversation history. + /// + [JsonPropertyName("max_history")] + public int MaxHistory { get; set; } + } +} diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentGatekeeperSettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentGatekeeperSettings.cs new file mode 100644 index 0000000000..396b003ef2 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentGatekeeperSettings.cs @@ -0,0 +1,22 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related gatekeeper settings. + /// + public class AgentGatekeeperSettings + { + /// + /// Indicates whether to abide by or override the system settings for the Gatekeeper. + /// + [JsonPropertyName("use_system_setting")] + public bool UseSystemSetting { get; set; } + + /// + /// If is false, provides Gatekeeper feature selection. + /// + [JsonPropertyName("options")] + public string[]? Options { get; set; } + } +} diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentSemanticCacheSettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentSemanticCacheSettings.cs new file mode 100644 index 0000000000..3dcf65cad9 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentSemanticCacheSettings.cs @@ -0,0 +1,32 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related cache settings for the semantic cache. + /// + public class AgentSemanticCacheSettings + { + /// + /// Gets or sets the object identifier of the AI model to use for the embedding. + /// + [JsonPropertyName("embedding_ai_model_object_id")] + public required string EmbeddingAIModelObjectId { get; set; } + + /// + /// Gets or sets the number of dimensions to use for the embedding. + /// + [JsonPropertyName("embedding_dimensions")] + public int EmbeddingDimensions { get; set; } + + /// + /// Gets or sets the minimum similarity threshold for the semantic cache. + /// + /// + /// This value determines the minimum similarity between the current conversation context + /// and the context of the item in the cache for the item to be considered a match. + /// + [JsonPropertyName("minimum_similarity_threshold")] + public decimal MinimumSimilarityThreshold { get; set; } = 0.975m; + } +} diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentTextRewriteSettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentTextRewriteSettings.cs new file mode 100644 index 0000000000..59cdb964f2 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentTextRewriteSettings.cs @@ -0,0 +1,28 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related text rewrite settings. + /// + public class AgentTextRewriteSettings + { + /// + /// Gets or sets a value indicating whether user prompt rewrite is enabled for the agent. + /// + /// + /// When enabled, the agent's semantic cache settings are provided in . + /// + [JsonPropertyName("user_prompt_rewrite_enabled")] + public bool UserPromptRewriteEnabled { get; set; } = false; + + /// + /// Gets or sets the agent's semantic cache settings. + /// + /// + /// The values in this property are only valid when is . + /// + [JsonPropertyName("user_prompt_rewrite_settings")] + public AgentUserPromptRewriteSettings? UserPromptRewriteSettings { get; set; } + } +} diff --git a/src/dotnet/Common/Models/ResourceProviders/Agent/AgentUserPromptRewriteSettings.cs b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentUserPromptRewriteSettings.cs new file mode 100644 index 0000000000..714572bd69 --- /dev/null +++ b/src/dotnet/Common/Models/ResourceProviders/Agent/AgentUserPromptRewriteSettings.cs @@ -0,0 +1,28 @@ +using System.Text.Json.Serialization; + +namespace FoundationaLLM.Common.Models.ResourceProviders.Agent +{ + /// + /// Provides agent-related user prompt rewrite settings. + /// + public class AgentUserPromptRewriteSettings + { + /// + /// Gets or sets the object identifier of the AI model to use for the user prompt rewriting. + /// + [JsonPropertyName("user_prompt_rewrite_ai_model_object_id")] + public required string UserPromptRewriteAIModelObjectId { get; set; } + + /// + /// Gets or sets the object identifier of the prompt to use for the user prompt rewriting. + /// + [JsonPropertyName("user_prompt_rewrite_prompt_object_id")] + public required string UserPromptRewritePromptObjectId { get; set; } + + /// + /// Gets or sets the window size for the user prompt rewriting. + /// + [JsonPropertyName("user_prompts_window_size")] + public required int UserPromptsWindowSize { get; set; } + } +} diff --git a/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs b/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs index 409086bb0b..5117e2e199 100644 --- a/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs +++ b/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs @@ -5,17 +5,21 @@ using FoundationaLLM.Common.Models.Configuration.Users; using FoundationaLLM.Common.Models.Conversation; using FoundationaLLM.Common.Models.Orchestration; +using FoundationaLLM.Common.Models.Orchestration.Response; using FoundationaLLM.Common.Models.ResourceProviders; using FoundationaLLM.Common.Models.ResourceProviders.Attachment; using Microsoft.Azure.Cosmos; +using Microsoft.Azure.Cosmos.Serialization.HybridRow; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Polly; using Polly.Retry; +using System.Collections.ObjectModel; using System.Diagnostics; using System.Net; +using System.Text.Json; -namespace FoundationaLLM.Common.Services +namespace FoundationaLLM.Common.Services.Azure { /// /// Service to access Azure Cosmos DB for NoSQL. @@ -27,6 +31,7 @@ public class AzureCosmosDBService : IAzureCosmosDBService private Container _operations; private Container _attachments; private Container _externalResources; + private Container _completionsCache; private readonly Lazy> _userProfiles; private Task _userProfilesTask => _userProfiles.Value; private readonly Database _database; @@ -109,11 +114,17 @@ public AzureCosmosDBService( throw new ArgumentException( $"Unable to connect to existing Azure Cosmos DB container ({AzureCosmosDBContainers.ExternalResources})."); + + _completionsCache = database?.GetContainer(AzureCosmosDBContainers.CompletionsCache) ?? + throw new ArgumentException( + $"Unable to connect to existing Azure Cosmos DB container ({AzureCosmosDBContainers.CompletionsCache})."); + _containers[AzureCosmosDBContainers.Sessions] = _sessions; _containers[AzureCosmosDBContainers.UserSessions] = _userSessions; _containers[AzureCosmosDBContainers.Operations] = _operations; _containers[AzureCosmosDBContainers.Attachments] = _attachments; _containers[AzureCosmosDBContainers.ExternalResources] = _externalResources; + _containers[AzureCosmosDBContainers.CompletionsCache] = _completionsCache; _logger.LogInformation("Cosmos DB service initialized."); } @@ -299,7 +310,7 @@ public async Task PatchItemPropertiesAsync(string containerName, string pa public async Task> GetSessionMessagesAsync(string sessionId, string upn, CancellationToken cancellationToken = default) { var query = - new QueryDefinition($"SELECT * FROM c WHERE c.sessionId = @sessionId AND c.type = @type AND c.upn = @upn AND {SoftDeleteQueryRestriction}") + new QueryDefinition($"SELECT * FROM c WHERE c.sessionId = @sessionId AND c.type = @type AND c.upn = @upn AND {SoftDeleteQueryRestriction} ORDER BY c.timeStamp") .WithParameter("@sessionId", sessionId) .WithParameter("@type", nameof(Message)) .WithParameter("@upn", upn); @@ -588,5 +599,87 @@ await _attachments.UpsertItemAsync( partitionKey: partitionKey, cancellationToken: cancellationToken); } + + /// + public async Task CreateVectorSearchContainerAsync( + string containerName, + string partitionKeyPath, + string vectorPropertyPath, + int vectorDimensions, + CancellationToken cancellationToken = default) + { + var containerProperties = new ContainerProperties(containerName, partitionKeyPath) + { + VectorEmbeddingPolicy = new(new Collection( + [ + new Embedding() + { + Path = vectorPropertyPath, + DataType = VectorDataType.Float32, + DistanceFunction = DistanceFunction.Cosine, + Dimensions = vectorDimensions + } + ])), + + IndexingPolicy = new IndexingPolicy + { + VectorIndexes = + [ + new VectorIndexPath() + { + Path = vectorPropertyPath, + Type = VectorIndexType.DiskANN + } + ] + } + }; + containerProperties.IndexingPolicy.IncludedPaths.Add(new IncludedPath { Path = "/*" }); + containerProperties.IndexingPolicy.ExcludedPaths.Add(new ExcludedPath { Path = $"{vectorPropertyPath}/*" }); + + var containerResponse = await _database.CreateContainerIfNotExistsAsync( + containerProperties, + cancellationToken: cancellationToken); + if (containerResponse.Container != null) + _containers[containerName] = containerResponse.Container; + } + + /// + public async Task GetCompletionResponseAsync( + string containerName, + ReadOnlyMemory userPromptEmbedding, + decimal minimumSimilarityScore) + { + var query = new QueryDefinition(""" + SELECT TOP 1 + x.serializedItem, x.similarityScore + FROM + ( + SELECT c.serializedItem, VectorDistance(c.userPromptEmbedding, @userPromptEmbedding) AS similarityScore FROM c + ) x + WHERE + x.similarityScore >= @minimumSimilarityScore + ORDER BY + x.similarityScore DESC + """); + query.WithParameter("@userPromptEmbedding", userPromptEmbedding.ToArray()); + query.WithParameter("@minimumSimilarityScore", (float)minimumSimilarityScore); + + using var feedIterator = _completionsCache.GetItemQueryIterator(query); + if (feedIterator.HasMoreResults) + { + var response = await feedIterator.ReadNextAsync(); + var result = response.Resource.FirstOrDefault(); + + if (result == null) + return null; + + var serializedCompletionResponse = (result as Newtonsoft.Json.Linq.JObject)!["serializedItem"]!.ToString(); + var completionResponse = JsonSerializer.Deserialize(serializedCompletionResponse); + + return completionResponse; + } + + return null; + } } } diff --git a/src/dotnet/Core/Core.csproj b/src/dotnet/Core/Core.csproj index 798a023462..bfbe23f824 100644 --- a/src/dotnet/Core/Core.csproj +++ b/src/dotnet/Core/Core.csproj @@ -13,7 +13,6 @@ - diff --git a/src/dotnet/CoreWorker/Program.cs b/src/dotnet/CoreWorker/Program.cs index 9c01fb2c03..e9ae89e5dd 100644 --- a/src/dotnet/CoreWorker/Program.cs +++ b/src/dotnet/CoreWorker/Program.cs @@ -3,7 +3,7 @@ using FoundationaLLM.Common.Constants.Configuration; using FoundationaLLM.Common.Interfaces; using FoundationaLLM.Common.Models.Configuration.CosmosDB; -using FoundationaLLM.Common.Services; +using FoundationaLLM.Common.Services.Azure; using FoundationaLLM.Core.Interfaces; using FoundationaLLM.Core.Services; using FoundationaLLM.Core.Worker; diff --git a/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs b/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs new file mode 100644 index 0000000000..4fa8fc305c --- /dev/null +++ b/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs @@ -0,0 +1,66 @@ +using FoundationaLLM.Common.Models.Authentication; +using FoundationaLLM.Common.Models.Orchestration.Request; +using FoundationaLLM.Common.Models.Orchestration.Response; +using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using FoundationaLLM.Orchestration.Models; + +namespace FoundationaLLM.Orchestration.Core.Interfaces +{ + /// + /// Defines the interface for the semantic cache service. + /// + public interface ISemanticCacheService + { + /// + /// Determines whether the semantic cache for the specified agent in the specified FoundationaLLM instance exists. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// if the semantic cache for the specified agent exists, otherwise. + bool HasCacheForAgent(string instanceId, string agentName); + + /// + /// Initializes the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// The providing the agent's semantic cache settings. + /// + Task InitializeCacheForAgent( + string instanceId, + string agentName, + AgentSemanticCacheSettings agentSettings); + + /// + /// Resets the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// + Task ResetCacheForAgent(string instanceId, string agentName); + + /// + /// Tries to get a from the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// The for which to get the cache item. + /// A if a match exists. + Task GetCompletionResponseFromCache( + string instanceId, + string agentName, + CompletionRequest completionRequest); + + /// + /// Sets a in the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// The to be set in the agent's cache. + /// + Task SetCompletionResponseInCache( + string instanceId, + string agentName, + CompletionResponse completionResponse); + } +} diff --git a/src/dotnet/Orchestration/Interfaces/IUserPromptRewriteService.cs b/src/dotnet/Orchestration/Interfaces/IUserPromptRewriteService.cs new file mode 100644 index 0000000000..d802a48e7f --- /dev/null +++ b/src/dotnet/Orchestration/Interfaces/IUserPromptRewriteService.cs @@ -0,0 +1,43 @@ +using FoundationaLLM.Common.Models.Orchestration.Request; +using FoundationaLLM.Common.Models.ResourceProviders.Agent; + +namespace FoundationaLLM.Orchestration.Core.Interfaces +{ + /// + /// Defines the interface for the user prompt rewrite service. + /// + public interface IUserPromptRewriteService + { + /// + /// Determines whether user prompt rewrite is configured for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// if the user prompt rewrite for the specified agent is configured, otherwise. + bool HasUserPromptRewriterForAgent(string instanceId, string agentName); + + /// + /// Initializes user prompt rewrite for the specified agent in the specified FoundationaLLM instance. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// The providing the agent's user prompt rewrite settings. + /// + Task InitializeUserPromptRewriterForAgent( + string instanceId, + string agentName, + AgentUserPromptRewriteSettings agentSettings); + + /// + /// Rewrites the user prompt to a form that can be used by the AI model. + /// + /// The unique identifier of the FoundationaLLM instance. + /// The name of the agent. + /// The for which to rewrite the user prompt. + /// The rewritten user prompt. + Task RewriteUserPrompt( + string instanceId, + string agentName, + CompletionRequest completionRequest); + } +} diff --git a/src/dotnet/Orchestration/Models/AgentSemanticCache.cs b/src/dotnet/Orchestration/Models/AgentSemanticCache.cs new file mode 100644 index 0000000000..c3c6ea3861 --- /dev/null +++ b/src/dotnet/Orchestration/Models/AgentSemanticCache.cs @@ -0,0 +1,22 @@ +using Azure.AI.OpenAI; +using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using OpenAI.Embeddings; + +namespace FoundationaLLM.Orchestration.Core.Models +{ + /// + /// Provides all dependencies for an agent-related semantic cache. + /// + public class AgentSemanticCache + { + /// + /// Gets or sets the agent's semantic cache settings. + /// + public required AgentSemanticCacheSettings Settings { get; set; } + + /// + /// Gets or sets the Azure OpenAI client. + /// + public required EmbeddingClient EmbeddingClient { get; set; } + } +} diff --git a/src/dotnet/Orchestration/Models/AgentUserPromptRewriter.cs b/src/dotnet/Orchestration/Models/AgentUserPromptRewriter.cs new file mode 100644 index 0000000000..bed48aa94f --- /dev/null +++ b/src/dotnet/Orchestration/Models/AgentUserPromptRewriter.cs @@ -0,0 +1,26 @@ +using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using OpenAI.Chat; + +namespace FoundationaLLM.Orchestration.Core.Models +{ + /// + /// Provide the capability to rewrite user prompts for agents. + /// + public class AgentUserPromptRewriter + { + /// + /// Gets or sets the agent's user prompt rewrite settings. + /// + public required AgentUserPromptRewriteSettings Settings { get; set; } + + /// + /// Gets or sets the system prompt to be used for rewriting user prompts. + /// + public required string RewriterSystemPrompt { get; set; } + + /// + /// Gets or sets the Azure OpenAI chat client used for rewriting. + /// + public required ChatClient ChatClient { get; set; } + } +} diff --git a/src/dotnet/Orchestration/Models/SemanticCacheItem.cs b/src/dotnet/Orchestration/Models/SemanticCacheItem.cs new file mode 100644 index 0000000000..155ea209e5 --- /dev/null +++ b/src/dotnet/Orchestration/Models/SemanticCacheItem.cs @@ -0,0 +1,22 @@ +using FoundationaLLM.Common.Models.Orchestration.Response; + +namespace FoundationaLLM.Orchestration.Models +{ + /// + /// Represents an item in the semantic cache. + /// + public class SemanticCacheItem + { + public string Id { get; set; } + + public string OperationId { get; set; } + + public string UserPrompt { get; set; } + + public int UserPromptTokens { get; set; } + + public float[] UserPromptEmbedding { get; set; } + + public string SerializedItem { get; set; } + } +} diff --git a/src/dotnet/Orchestration/Orchestration/KnowledgeManagementOrchestration.cs b/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs similarity index 79% rename from src/dotnet/Orchestration/Orchestration/KnowledgeManagementOrchestration.cs rename to src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs index 844c2f3487..5154d525c7 100644 --- a/src/dotnet/Orchestration/Orchestration/KnowledgeManagementOrchestration.cs +++ b/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs @@ -14,6 +14,7 @@ using FoundationaLLM.Common.Models.ResourceProviders.Attachment; using FoundationaLLM.Common.Models.ResourceProviders.AzureOpenAI; using FoundationaLLM.Orchestration.Core.Interfaces; +using FoundationaLLM.Orchestration.Models; using Microsoft.Extensions.Logging; using System.Text.Json; using System.Text.RegularExpressions; @@ -33,13 +34,15 @@ namespace FoundationaLLM.Orchestration.Core.Orchestration /// A dictionary of objects retrieved from various object ids related to the agent. For more details see . /// The call context of the request being handled. /// + /// The used to rewrite user prompts. + /// The used to cache and retrieve completion responses. /// The logger used for logging. /// The used to create HttpClient instances. /// The dictionary of /// Inidicates that access was denied to all underlying data sources. /// The OpenAI Assistants vector store id. /// An optional observer for completion requests. - public class KnowledgeManagementOrchestration( + public class AgentOrchestration( string instanceId, string agentObjectId, KnowledgeManagementAgent? agent, @@ -47,6 +50,8 @@ public class KnowledgeManagementOrchestration( Dictionary? explodedObjects, ICallContext callContext, ILLMOrchestrationService orchestrationService, + IUserPromptRewriteService userPromptRewriteService, + ISemanticCacheService semanticCacheService, ILogger logger, IHttpClientFactoryService httpClientFactoryService, Dictionary resourceProviderServices, @@ -72,6 +77,9 @@ public class KnowledgeManagementOrchestration( private readonly string? _openAIVectorStoreId = openAIVectorStoreId; private GatewayServiceClient _gatewayClient; + private readonly IUserPromptRewriteService _userPromptRewriteService = userPromptRewriteService; + private readonly ISemanticCacheService _semanticCacheService = semanticCacheService; + /// public override async Task StartCompletionOperation(CompletionRequest completionRequest) { @@ -84,6 +92,20 @@ public override async Task StartCompletionOperation(Comple Result = validationResponse }; + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled) + { + await HandlePromptRewrite(completionRequest); + var cachedResponse = await GetCompletionResponseFromCache(completionRequest); + if (cachedResponse != null) + return new LongRunningOperation + { + OperationId = completionRequest.OperationId!, + Status = OperationStatus.Completed, + Result = cachedResponse + }; + } + var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest); if (_completionRequestObserver != null) await _completionRequestObserver(llmCompletionRequest); @@ -106,9 +128,24 @@ public override async Task GetCompletionOperationStatus(st // parse the LLM Completion response from JsonElement if (operationStatus.Result is JsonElement jsonElement) { - var completionResponse = JsonSerializer.Deserialize(jsonElement.ToString()); - if (completionResponse != null) - operationStatus.Result = await GetCompletionResponse(operationId, completionResponse); + var llmCompletionResponse = JsonSerializer.Deserialize(jsonElement.ToString()); + if (llmCompletionResponse != null) + { + var completionResponse = await GetCompletionResponse(operationId, llmCompletionResponse); + + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled + && ( + completionResponse.Errors == null + || completionResponse.Errors.Length == 0 + )) + { + // This is a valid response that can be cached. + await SetCompletionResponseInCache(completionResponse); + } + + operationStatus.Result = completionResponse; + } } } @@ -122,15 +159,37 @@ public override async Task GetCompletion(CompletionRequest c if (validationResponse != null) return validationResponse; + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled) + { + await HandlePromptRewrite(completionRequest); + var cachedResponse = await GetCompletionResponseFromCache(completionRequest); + if (cachedResponse != null) + return cachedResponse; + } + var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest); if (_completionRequestObserver != null) await _completionRequestObserver(llmCompletionRequest); - var result = await _orchestrationService.GetCompletion( + var llmCompletionResponse = await _orchestrationService.GetCompletion( _instanceId, llmCompletionRequest); - return await GetCompletionResponse(completionRequest.OperationId!, result); + var completionResponse = await GetCompletionResponse(completionRequest.OperationId!, llmCompletionResponse); + + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled + && ( + completionResponse.Errors == null + || completionResponse.Errors.Length == 0 + )) + { + // This is a valid response that can be cached. + await SetCompletionResponseInCache(completionResponse); + } + + return completionResponse; } private async Task ValidateCompletionRequest(CompletionRequest completionRequest) @@ -162,11 +221,74 @@ await _httpClientFactoryService return null; } + private async Task HandlePromptRewrite(CompletionRequest completionRequest) + { + if (_agent!.TextRewriteSettings != null + && _agent!.TextRewriteSettings.UserPromptRewriteEnabled) + { + if (!_userPromptRewriteService.HasUserPromptRewriterForAgent(_instanceId, _agent.Name)) + await _userPromptRewriteService.InitializeUserPromptRewriterForAgent( + _instanceId, + _agent.Name, + _agent.TextRewriteSettings.UserPromptRewriteSettings!); + + await _userPromptRewriteService.RewriteUserPrompt(instanceId, _agent.Name, completionRequest); + } + } + + private async Task GetCompletionResponseFromCache(CompletionRequest completionRequest) + { + if (!_semanticCacheService.HasCacheForAgent(_instanceId, _agent!.Name)) + await _semanticCacheService.InitializeCacheForAgent( + _instanceId, + _agent.Name, + _agent.CacheSettings!.SemanticCacheSettings!); + + var cachedResponse = await _semanticCacheService.GetCompletionResponseFromCache( + _instanceId, + _agent.Name, + completionRequest); + + if (cachedResponse == null) + return null; + + cachedResponse.OperationId = completionRequest.OperationId!; + var contentArtifactsList = new List(cachedResponse.ContentArtifacts ??= []) + { + new() { + Id = "CachedResponse", + Title = "Cached Response", + Source = "SemanticCache" + } + }; + cachedResponse.ContentArtifacts = [.. contentArtifactsList]; + + return cachedResponse; + } + + private async Task SetCompletionResponseInCache(CompletionResponse completionResponse) + { + try + { + await _semanticCacheService.SetCompletionResponseInCache(_instanceId, _agent!.Name, completionResponse); + } + catch (Exception ex) + { + _logger.LogError( + ex, + "An error occurred while setting the completion response in the semantic cache for operation {OperationId} and agent {AgentName} in instance {InstanceId}.", + completionResponse.OperationId, + _agent!.Name, + _instanceId); + } + } + private async Task GetLLMCompletionRequest(CompletionRequest completionRequest) => new LLMCompletionRequest { OperationId = completionRequest.OperationId, UserPrompt = completionRequest.UserPrompt!, + UserPromptRewrite = completionRequest.UserPromptRewrite, MessageHistory = completionRequest.MessageHistory, Attachments = await GetAttachmentPaths(completionRequest.Attachments), Agent = _agent!, @@ -273,6 +395,7 @@ private async Task GetCompletionResponse(string operationId, Completion = llmCompletionResponse.Completion, Content = llmCompletionResponse.Content != null ? await TransformContentItems(llmCompletionResponse.Content) : null, UserPrompt = llmCompletionResponse.UserPrompt!, + UserPromptRewrite = llmCompletionResponse.UserPromptRewrite, ContentArtifacts = llmCompletionResponse.ContentArtifacts, FullPrompt = llmCompletionResponse.FullPrompt, PromptTemplate = llmCompletionResponse.PromptTemplate, diff --git a/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs b/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs index 8dc5b12852..3fa1098d26 100644 --- a/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs +++ b/src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs @@ -42,6 +42,8 @@ public class OrchestrationBuilder /// The used to interact with the Cosmos DB database. /// The used to render templates. /// The used to execute code. + /// The used to rewrite user prompts. + /// The used to cache and retrieve completion responses. /// The provding dependency injection services for the current scope. /// The logger factory used to create new loggers. /// An optional observer for completion requests. @@ -58,6 +60,8 @@ public class OrchestrationBuilder IAzureCosmosDBService cosmosDBService, ITemplatingService templatingService, ICodeExecutionService codeExecutionService, + IUserPromptRewriteService userPromptRewriteService, + ISemanticCacheService semanticCacheService, IServiceProvider serviceProvider, ILoggerFactory loggerFactory, Func? completionRequestObserver = null) @@ -108,7 +112,7 @@ await cosmosDBService.PatchOperationsItemPropertiesAsync(), serviceProvider.GetRequiredService(), resourceProviderServices, @@ -158,7 +164,7 @@ await cosmosDBService.PatchOperationsItemPropertiesAsync(), serviceProvider.GetRequiredService(), resourceProviderServices, diff --git a/src/dotnet/Orchestration/Services/DependencyInjection.cs b/src/dotnet/Orchestration/Services/DependencyInjection.cs index f50f28b1ed..6337b445ba 100644 --- a/src/dotnet/Orchestration/Services/DependencyInjection.cs +++ b/src/dotnet/Orchestration/Services/DependencyInjection.cs @@ -39,5 +39,19 @@ public static void AddLLMOrchestrationServices(this IHostApplicationBuilder buil builder.Services.AddScoped(); builder.Services.AddScoped(); } + + /// + /// Adds the semantic cache service to the dependency injection container. + /// + /// + public static void AddSemanticCacheService(this IHostApplicationBuilder builder) => + builder.Services.AddSingleton(); + + /// + /// Adds the user prompt rewrite service to the dependency injection container. + /// + /// + public static void AddUserPromptRewriteService(this IHostApplicationBuilder builder) => + builder.Services.AddSingleton(); } } diff --git a/src/dotnet/Orchestration/Services/LangChainService.cs b/src/dotnet/Orchestration/Services/LangChainService.cs index b74efabc38..3da226c710 100644 --- a/src/dotnet/Orchestration/Services/LangChainService.cs +++ b/src/dotnet/Orchestration/Services/LangChainService.cs @@ -75,6 +75,7 @@ public async Task GetCompletion(string instanceId, LLMCom Completion = completionResponse!.Completion, ContentArtifacts = completionResponse.ContentArtifacts, UserPrompt = completionResponse.UserPrompt, + UserPromptRewrite = completionResponse.UserPromptRewrite, FullPrompt = completionResponse.FullPrompt, PromptTemplate = string.Empty, AgentName = request.Agent.Name, @@ -99,7 +100,8 @@ public async Task GetCompletion(string instanceId, LLMCom PromptTemplate = string.Empty, AgentName = request.Agent.Name, PromptTokens = 0, - CompletionTokens = 0 + CompletionTokens = 0, + Errors = [ "A problem on my side prevented me from responding." ] }; } } diff --git a/src/dotnet/Orchestration/Services/OrchestrationService.cs b/src/dotnet/Orchestration/Services/OrchestrationService.cs index 41c0b18e45..9a0fbf8e85 100644 --- a/src/dotnet/Orchestration/Services/OrchestrationService.cs +++ b/src/dotnet/Orchestration/Services/OrchestrationService.cs @@ -10,7 +10,6 @@ using FoundationaLLM.Common.Models.Orchestration.Response; using FoundationaLLM.Common.Models.ResourceProviders.Agent; using FoundationaLLM.Common.Services.Storage; -using FoundationaLLM.Common.Services.Users; using FoundationaLLM.Orchestration.Core.Interfaces; using FoundationaLLM.Orchestration.Core.Models; using FoundationaLLM.Orchestration.Core.Models.ConfigurationOptions; @@ -34,6 +33,8 @@ public class OrchestrationService : IOrchestrationService private readonly ITemplatingService _templatingService; private readonly ICodeExecutionService _codeExecutionService; private readonly IUserProfileService _userProfileService; + private readonly IUserPromptRewriteService _userPromptRewriteService; + private readonly ISemanticCacheService _semanticCacheService; private readonly ICallContext _callContext; private readonly IConfiguration _configuration; private readonly ILogger _logger; @@ -54,6 +55,8 @@ public class OrchestrationService : IOrchestrationService /// The used to render templates. /// The used to execute code. /// The used to interact with user profiles. + /// The used to rewrite user prompts. + /// The used to cache and retrieve completion responses. /// The call context of the request being handled. /// The used to retrieve app settings from configuration. /// The provding dependency injection services for the current scope. @@ -66,6 +69,8 @@ public OrchestrationService( ITemplatingService templatingService, ICodeExecutionService codeExecutionService, IUserProfileService userProfileService, + IUserPromptRewriteService userPromptRewriteService, + ISemanticCacheService semanticCacheService, ICallContext callContext, IConfiguration configuration, IServiceProvider serviceProvider, @@ -80,6 +85,9 @@ public OrchestrationService( _codeExecutionService = codeExecutionService; _userProfileService = userProfileService; + _userPromptRewriteService = userPromptRewriteService; + _semanticCacheService = semanticCacheService; + _callContext = callContext; _configuration = configuration; _serviceProvider = serviceProvider; @@ -129,6 +137,8 @@ public async Task GetCompletion(string instanceId, Completio _cosmosDBService, _templatingService, _codeExecutionService, + _userPromptRewriteService, + _semanticCacheService, _serviceProvider, _loggerFactory, ObserveCompletionRequest) @@ -171,6 +181,8 @@ public async Task StartCompletionOperation(string instance _cosmosDBService, _templatingService, _codeExecutionService, + _userPromptRewriteService, + _semanticCacheService, _serviceProvider, _loggerFactory, ObserveCompletionRequest) @@ -287,6 +299,8 @@ private async Task GetCompletionForAgentConversation( _cosmosDBService, _templatingService, _codeExecutionService, + _userPromptRewriteService, + _semanticCacheService, _serviceProvider, _loggerFactory); diff --git a/src/dotnet/Orchestration/Services/SemanticCacheService.cs b/src/dotnet/Orchestration/Services/SemanticCacheService.cs new file mode 100644 index 0000000000..55f98d3a7c --- /dev/null +++ b/src/dotnet/Orchestration/Services/SemanticCacheService.cs @@ -0,0 +1,199 @@ +using Azure; +using Azure.AI.OpenAI; +using FoundationaLLM.Common.Authentication; +using FoundationaLLM.Common.Constants.Authentication; +using FoundationaLLM.Common.Constants.ResourceProviders; +using FoundationaLLM.Common.Exceptions; +using FoundationaLLM.Common.Interfaces; +using FoundationaLLM.Common.Models.Orchestration.Request; +using FoundationaLLM.Common.Models.Orchestration.Response; +using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using FoundationaLLM.Common.Models.ResourceProviders.AIModel; +using FoundationaLLM.Common.Models.ResourceProviders.Configuration; +using FoundationaLLM.Orchestration.Core.Interfaces; +using FoundationaLLM.Orchestration.Core.Models; +using FoundationaLLM.Orchestration.Models; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using OpenAI.Embeddings; +using System.Text.Json; + +namespace FoundationaLLM.Orchestration.Core.Services +{ + /// + /// Provides a service for managing the semantic cache. + /// + public class SemanticCacheService : ISemanticCacheService + { + private const string SEMANTIC_CACHE_CONTAINER_NAME = "CompletionsCache"; + + private readonly Dictionary _agentCaches = []; + private readonly SemaphoreSlim _syncLock = new SemaphoreSlim(1, 1); + + private readonly IAzureCosmosDBService _cosmosDBService; + private readonly IResourceProviderService _aiModelResourceProviderService; + private readonly IResourceProviderService _configurationResourceProviderService; + private readonly IConfiguration _configuration; + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// The service providing access to the Cosmos DB vector store. + /// A list of resource providers hashed by resource provider name. + /// The used to retrieve app settings from configuration. + /// The logger used for logging.. + public SemanticCacheService( + IAzureCosmosDBService cosmosDBService, + IEnumerable resourceProviderServices, + IConfiguration configuration, + ILogger logger) + { + _cosmosDBService = cosmosDBService; + _aiModelResourceProviderService = resourceProviderServices + .Single(x => x.Name == ResourceProviderNames.FoundationaLLM_AIModel); + _configurationResourceProviderService = resourceProviderServices + .Single(x => x.Name == ResourceProviderNames.FoundationaLLM_Configuration); + _configuration = configuration; + _logger = logger; + } + + /// + public bool HasCacheForAgent(string instanceId, string agentName) => + _agentCaches.ContainsKey($"{instanceId}|{agentName}"); + + /// + public async Task InitializeCacheForAgent( + string instanceId, + string agentName, + AgentSemanticCacheSettings agentSettings) + { + try + { + await _syncLock.WaitAsync(); + + if (_agentCaches.Count == 0) + { + // This is the first time an agent attempts to initialize the cache. + // Ensure the proper container exists in Cosmos DB. + + // For now we are skipping the dynamic creation of the container as it looks like + // there is a bug with the Cosmos DB client when working with RBAC. + + //await _cosmosDBService.CreateVectorSearchContainerAsync( + // SEMANTIC_CACHE_CONTAINER_NAME, + // "/partitionKey", + // "/userPromptEmbedding", + // agentSettings.EmbeddingDimensions); + } + + if (HasCacheForAgent(instanceId, agentName)) + { + _logger.LogWarning("Semantic cache for agent {AgentName} in instance {InstanceId} already exists.", agentName, instanceId); + return; + } + + var embeddingAIModel = await _aiModelResourceProviderService.GetResourceAsync( + agentSettings.EmbeddingAIModelObjectId, + DefaultAuthentication.ServiceIdentity!); + var embeddingAPIEndpointConfiguration = await _configurationResourceProviderService.GetResourceAsync( + embeddingAIModel.EndpointObjectId!, + DefaultAuthentication.ServiceIdentity!); + + _agentCaches[$"{instanceId}|{agentName}"] = new AgentSemanticCache + { + Settings = agentSettings, + EmbeddingClient = GetEmbeddingClient(embeddingAIModel.DeploymentName!, embeddingAPIEndpointConfiguration) + }; + } + finally + { + _syncLock.Release(); + } + } + + /// + public async Task ResetCacheForAgent(string instanceId, string agentName) => + await Task.CompletedTask; + + /// + public async Task SetCompletionResponseInCache(string instanceId, string agentName, CompletionResponse completionResponse) + { + if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache) + || agentCache == null) + throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}."); + + var cacheItem = new SemanticCacheItem + { + Id = Guid.NewGuid().ToString().ToLower(), + OperationId = completionResponse.OperationId!, + UserPrompt = completionResponse.UserPromptRewrite!, + SerializedItem = JsonSerializer.Serialize(completionResponse), + }; + + var embeddingResponse = await agentCache.EmbeddingClient.GenerateEmbeddingAsync( + cacheItem.UserPrompt, + new EmbeddingGenerationOptions + { + Dimensions = agentCache.Settings.EmbeddingDimensions + }); + cacheItem.UserPromptEmbedding = embeddingResponse.Value.ToFloats().ToArray(); + + await _cosmosDBService.UpsertItemAsync( + SEMANTIC_CACHE_CONTAINER_NAME, + cacheItem.OperationId, + cacheItem); + } + + /// + public async Task GetCompletionResponseFromCache( + string instanceId, + string agentName, + CompletionRequest completionRequest) + { + if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache) + || agentCache == null) + throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}."); + + var embeddingResult = await agentCache.EmbeddingClient.GenerateEmbeddingAsync( + completionRequest.UserPromptRewrite, + new EmbeddingGenerationOptions + { + Dimensions = agentCache.Settings.EmbeddingDimensions + }); + var userPromptEmbedding = embeddingResult.Value.ToFloats(); + + var cachedCompletionResponse = await _cosmosDBService.GetCompletionResponseAsync( + SEMANTIC_CACHE_CONTAINER_NAME, + userPromptEmbedding, + agentCache.Settings.MinimumSimilarityThreshold); + + return cachedCompletionResponse; + } + + private EmbeddingClient GetEmbeddingClient(string deploymentName, APIEndpointConfiguration apiEndpointConfiguration) => + apiEndpointConfiguration.AuthenticationType switch + { + AuthenticationTypes.AzureIdentity => (new AzureOpenAIClient( + new Uri(apiEndpointConfiguration.Url), + DefaultAuthentication.AzureCredential)) + .GetEmbeddingClient(deploymentName), + AuthenticationTypes.APIKey => (new AzureOpenAIClient( + new Uri(apiEndpointConfiguration.Url), + new AzureKeyCredential(GetAPIKey(apiEndpointConfiguration)))) + .GetEmbeddingClient(deploymentName), + _ => throw new NotImplementedException($"API endpoint authentication type {apiEndpointConfiguration.AuthenticationType} is not supported.") + }; + + private string GetAPIKey(APIEndpointConfiguration apiEndpointConfiguration) + { + if (!apiEndpointConfiguration.AuthenticationParameters.TryGetValue( + AuthenticationParametersKeys.APIKeyConfigurationName, out var apiKeyConfigurationNameObj)) + throw new SemanticCacheException($"The {AuthenticationParametersKeys.APIKeyConfigurationName} key is missing from the endpoint's authentication parameters dictionary."); + + var apiKey = _configuration.GetValue(apiKeyConfigurationNameObj?.ToString()!)!; + + return apiKey; + } + } +} diff --git a/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs b/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs new file mode 100644 index 0000000000..55e3dd24c7 --- /dev/null +++ b/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs @@ -0,0 +1,174 @@ +using Azure; +using Azure.AI.OpenAI; +using FoundationaLLM.Common.Authentication; +using FoundationaLLM.Common.Constants; +using FoundationaLLM.Common.Constants.Authentication; +using FoundationaLLM.Common.Constants.ResourceProviders; +using FoundationaLLM.Common.Exceptions; +using FoundationaLLM.Common.Interfaces; +using FoundationaLLM.Common.Models.Orchestration.Request; +using FoundationaLLM.Common.Models.ResourceProviders.Agent; +using FoundationaLLM.Common.Models.ResourceProviders.AIModel; +using FoundationaLLM.Common.Models.ResourceProviders.Configuration; +using FoundationaLLM.Common.Models.ResourceProviders.Prompt; +using FoundationaLLM.Orchestration.Core.Interfaces; +using FoundationaLLM.Orchestration.Core.Models; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace FoundationaLLM.Orchestration.Core.Services +{ + /// + /// Provides a service for managing the semantic cache. + /// + public class UserPromptRewriteService : IUserPromptRewriteService + { + private readonly Dictionary _agentRewriters = []; + private readonly SemaphoreSlim _syncLock = new SemaphoreSlim(1, 1); + + private readonly IResourceProviderService _aiModelResourceProviderService; + private readonly IResourceProviderService _configurationResourceProviderService; + private readonly IResourceProviderService _promptResourceProviderService; + private readonly IConfiguration _configuration; + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// A list of resource providers hashed by resource provider name. + /// The used to retrieve app settings from configuration. + /// The logger used for logging.. + public UserPromptRewriteService( + IEnumerable resourceProviderServices, + IConfiguration configuration, + ILogger logger) + { + _aiModelResourceProviderService = resourceProviderServices + .Single(x => x.Name == ResourceProviderNames.FoundationaLLM_AIModel); + _configurationResourceProviderService = resourceProviderServices + .Single(x => x.Name == ResourceProviderNames.FoundationaLLM_Configuration); + _promptResourceProviderService = resourceProviderServices + .Single(x => x.Name == ResourceProviderNames.FoundationaLLM_Prompt); + _configuration = configuration; + _logger = logger; + } + + /// + public bool HasUserPromptRewriterForAgent(string instanceId, string agentName) => + _agentRewriters.ContainsKey($"{instanceId}|{agentName}"); + + /// + public async Task InitializeUserPromptRewriterForAgent( + string instanceId, + string agentName, + AgentUserPromptRewriteSettings agentSettings) + { + try + { + await _syncLock.WaitAsync(); + + if (HasUserPromptRewriterForAgent(instanceId, agentName)) + { + _logger.LogWarning("A user prompt rewriter for agent {AgentName} in instance {InstanceId} already exists.", agentName, instanceId); + return; + } + + var userPromptRewriteAIModel = await _aiModelResourceProviderService.GetResourceAsync( + agentSettings.UserPromptRewriteAIModelObjectId, + DefaultAuthentication.ServiceIdentity!); + var userPromptRewriteAPIEndpointConfiguration = await _configurationResourceProviderService.GetResourceAsync( + userPromptRewriteAIModel.EndpointObjectId!, + DefaultAuthentication.ServiceIdentity!); + var userPromptRewritePrompt = await _promptResourceProviderService.GetResourceAsync( + agentSettings.UserPromptRewritePromptObjectId, + DefaultAuthentication.ServiceIdentity!); + + _agentRewriters[$"{instanceId}|{agentName}"] = new AgentUserPromptRewriter + { + Settings = agentSettings, + RewriterSystemPrompt = (userPromptRewritePrompt as MultipartPrompt)!.Prefix!, + ChatClient = GetChatClient( + userPromptRewriteAIModel.DeploymentName!, + userPromptRewriteAPIEndpointConfiguration) + }; + } + finally + { + _syncLock.Release(); + } + } + + /// + public async Task RewriteUserPrompt( + string instanceId, + string agentName, + CompletionRequest completionRequest) + { + if (!_agentRewriters.TryGetValue($"{instanceId}|{agentName}", out AgentUserPromptRewriter? agentRewriter) + || agentRewriter == null) + throw new UserPromptRewriteException($"The user prompt rewriter is not initialized for agent {agentName} in instance {instanceId}."); + + try + { + var userPromptsHistory = completionRequest.MessageHistory? + .Where(x => StringComparer.Ordinal.Equals(x.Sender, nameof(Participants.User))) + .Select(x => x.Text) + .TakeLast(agentRewriter.Settings.UserPromptsWindowSize) + .ToList() + ?? []; + userPromptsHistory.Add(completionRequest.UserPrompt); + var completionResult = await agentRewriter.ChatClient.CompleteChatAsync( + [ + new SystemChatMessage(agentRewriter.RewriterSystemPrompt), + new UserChatMessage($"QUESTIONS:{Environment.NewLine}{string.Join(Environment.NewLine, [.. userPromptsHistory])}") + ], + new ChatCompletionOptions + { + Temperature = 0 + }); + + completionRequest.UserPromptRewrite = completionResult.Value.Content[0].Text; + + } + catch (Exception ex) + { + _logger.LogError( + ex, + "An error occurred while rewriting the user prompt {UserPrompt} for agent {AgentName} in instance {InstanceId}.", + completionRequest.UserPrompt, + agentName, + instanceId); + + completionRequest.UserPromptRewrite = completionRequest.UserPrompt; + } + } + + private ChatClient GetChatClient(string deploymentName, APIEndpointConfiguration apiEndpointConfiguration) => + apiEndpointConfiguration.AuthenticationType switch + { + AuthenticationTypes.AzureIdentity => (new AzureOpenAIClient( + new Uri(apiEndpointConfiguration.Url), + DefaultAuthentication.AzureCredential)) + .GetChatClient(deploymentName), + AuthenticationTypes.APIKey => (new AzureOpenAIClient( + new Uri(apiEndpointConfiguration.Url), + new AzureKeyCredential(GetAPIKey(apiEndpointConfiguration)))) + .GetChatClient(deploymentName), + _ => throw new NotImplementedException($"API endpoint authentication type {apiEndpointConfiguration.AuthenticationType} is not supported.") + }; + + private string GetAPIKey(APIEndpointConfiguration apiEndpointConfiguration) + { + if (!apiEndpointConfiguration.AuthenticationParameters.TryGetValue( + AuthenticationParametersKeys.APIKeyConfigurationName, out var apiKeyConfigurationNameObj)) + throw new SemanticCacheException($"The {AuthenticationParametersKeys.APIKeyConfigurationName} key is missing from the endpoint's authentication parameters dictionary."); + + var apiKey = _configuration.GetValue(apiKeyConfigurationNameObj?.ToString()!)!; + + return apiKey; + } + } +} diff --git a/src/dotnet/OrchestrationAPI/Program.cs b/src/dotnet/OrchestrationAPI/Program.cs index 3c34b760a1..cf0be1a43c 100644 --- a/src/dotnet/OrchestrationAPI/Program.cs +++ b/src/dotnet/OrchestrationAPI/Program.cs @@ -110,6 +110,9 @@ public static void Main(string[] args) builder.Services.AddScoped(); builder.AddUserProfileService(); + builder.AddUserPromptRewriteService(); + builder.AddSemanticCacheService(); + builder.Services.AddSingleton(); builder.Services.AddTransient, ConfigureSwaggerOptions>(); diff --git a/src/dotnet/State/State.csproj b/src/dotnet/State/State.csproj index 3874b6fd78..a3ed3eef92 100644 --- a/src/dotnet/State/State.csproj +++ b/src/dotnet/State/State.csproj @@ -9,10 +9,6 @@ True - - - - diff --git a/src/python/PythonSDK/foundationallm/event_handlers/openai_assistant_async_event_handler.py b/src/python/PythonSDK/foundationallm/event_handlers/openai_assistant_async_event_handler.py index c8ea240cba..3d883a0afb 100644 --- a/src/python/PythonSDK/foundationallm/event_handlers/openai_assistant_async_event_handler.py +++ b/src/python/PythonSDK/foundationallm/event_handlers/openai_assistant_async_event_handler.py @@ -28,6 +28,7 @@ def __init__(self, client: AsyncAzureOpenAI, operations_manager: OperationsManag id = request.document_id, operation_id = request.operation_id, user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, content = [], analysis_results = [] ) @@ -48,7 +49,7 @@ async def on_event(self, event: AssistantStreamEvent) -> None: if details and details.type == "tool_calls": for tool in details.tool_calls or []: if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input and tool.code_interpreter.input.endswith(tuple(self.stop_tokens)): - self.run_steps[event.data.id] = event.data # Overwrite the run step with the final version. + self.run_steps[event.data.id] = event.data # Overwrite the run step with the final version. await self.update_state_api_analysis_results_async() if tool.type == "function": self.run_steps[event.data.id] = event.data @@ -62,7 +63,7 @@ async def on_event(self, event: AssistantStreamEvent) -> None: print(f'{event.event} ({event.data.id}): {event.data.last_error}') if event.event == "thread.run.failed": raise Exception(event.data.last_error.message) - + @override async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: @@ -93,15 +94,15 @@ async def on_requires_action(self, run_id: str): # Get data from the tool if call.type == "function": if call.function.name == "generate_image": - try: - tool_response = await self.image_service.generate_image_async(**json.loads(call.function.arguments)) + try: + tool_response = await self.image_service.generate_image_async(**json.loads(call.function.arguments)) tool_responses.append( { "tool_call_id": call.id, "output": json.dumps(tool_response) } ) - except Exception as ex: + except Exception as ex: print(f'Error getting tool response: {ex}') break try: @@ -130,7 +131,7 @@ async def update_state_api_analysis_results_async(self): if analysis_result: self.interim_result.analysis_results.append(analysis_result) await self.operations_manager.set_operation_result_async(self.request.operation_id, self.request.instance_id, self.interim_result) - + async def update_state_api_content_async(self): self.interim_result.content = [] # Clear the content list before adding new messages. for k, v in self.messages.items(): diff --git a/src/python/PythonSDK/foundationallm/langchain/agents/langchain_knowledge_management_agent.py b/src/python/PythonSDK/foundationallm/langchain/agents/langchain_knowledge_management_agent.py index e71b912303..0a249615a6 100644 --- a/src/python/PythonSDK/foundationallm/langchain/agents/langchain_knowledge_management_agent.py +++ b/src/python/PythonSDK/foundationallm/langchain/agents/langchain_knowledge_management_agent.py @@ -429,6 +429,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C operation_id = request.operation_id, full_prompt = self.prompt.prefix, user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, errors = [ "Assistants API response was None." ] ) @@ -443,6 +444,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C prompt_tokens = assistant_response.prompt_tokens + image_analysis_token_usage.prompt_tokens, total_tokens = assistant_response.total_tokens + image_analysis_token_usage.total_tokens, user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, errors = assistant_response.errors ) # End Assistants API implementation @@ -499,6 +501,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C content = [response_content], content_artifacts = content_artifacts, user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, full_prompt = self.prompt.prefix, completion_tokens = final_message.usage_metadata["output_tokens"] or 0, prompt_tokens = final_message.usage_metadata["input_tokens"] or 0, @@ -545,6 +548,8 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C user_prompt=parsed_user_prompt, message_history=messages ) + # Ensure the user prompt rewrite is returned in the response + response.user_prompt_rewrite = request.user_prompt_rewrite return response # End External Agent workflow implementation @@ -604,6 +609,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C operation_id = request.operation_id, content = [response_content], user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, full_prompt = self.full_prompt.text, completion_tokens = completion.usage_metadata["output_tokens"] + image_analysis_token_usage.completion_tokens, prompt_tokens = completion.usage_metadata["input_tokens"] + image_analysis_token_usage.prompt_tokens, @@ -629,6 +635,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C operation_id = request.operation_id, content = [response_content], user_prompt = request.user_prompt, + user_prompt_rewrite = request.user_prompt_rewrite, full_prompt = self.full_prompt.text, completion_tokens = cb.completion_tokens + image_analysis_token_usage.completion_tokens, prompt_tokens = cb.prompt_tokens + image_analysis_token_usage.prompt_tokens, diff --git a/src/python/PythonSDK/foundationallm/models/orchestration/completion_request_base.py b/src/python/PythonSDK/foundationallm/models/orchestration/completion_request_base.py index 4390d0e9c6..9aa840a1d5 100644 --- a/src/python/PythonSDK/foundationallm/models/orchestration/completion_request_base.py +++ b/src/python/PythonSDK/foundationallm/models/orchestration/completion_request_base.py @@ -10,5 +10,6 @@ class CompletionRequestBase(BaseModel): operation_id: str = Field(description="The operation ID for the completion request.") session_id: Optional[str] = Field(None, description="The session ID for the completion request.") user_prompt: str = Field(description="The user prompt for the completion request.") + user_prompt_rewrite: Optional[str] = Field(None, description="The user prompt rewrite for the completion request.") message_history: Optional[List[MessageHistoryItem]] = Field(list, description="The message history for the completion.") attachments: Optional[List[AttachmentProperties]] = Field(list, description="The attachments collection for the completion request.") diff --git a/src/python/PythonSDK/foundationallm/models/orchestration/completion_response.py b/src/python/PythonSDK/foundationallm/models/orchestration/completion_response.py index c2ebfe0e68..96cb20a7e4 100644 --- a/src/python/PythonSDK/foundationallm/models/orchestration/completion_response.py +++ b/src/python/PythonSDK/foundationallm/models/orchestration/completion_response.py @@ -15,11 +15,12 @@ class CompletionResponse(BaseModel): id: Optional[str] = None operation_id: str user_prompt: str + user_prompt_rewrite: Optional[str] = None full_prompt: Optional[str] = None content: Optional[ List[ Union[ - OpenAIImageFileMessageContentItem, + OpenAIImageFileMessageContentItem, OpenAITextMessageContentItem ] ] diff --git a/src/python/PythonSDK/foundationallm/models/services/openai_assistants_request.py b/src/python/PythonSDK/foundationallm/models/services/openai_assistants_request.py index 8b41eb7563..c393236a69 100644 --- a/src/python/PythonSDK/foundationallm/models/services/openai_assistants_request.py +++ b/src/python/PythonSDK/foundationallm/models/services/openai_assistants_request.py @@ -19,3 +19,4 @@ class OpenAIAssistantsAPIRequest(BaseModel): thread_id: str attachments: Optional[List[str]] = [] user_prompt: str + user_prompt_rewrite: Optional[str] = None diff --git a/tests/dotnet/Common.Tests/Models/Agents/AgentBaseTests.cs b/tests/dotnet/Common.Tests/Models/Agents/AgentBaseTests.cs index 3b28c9a225..eb2dc4d7f8 100644 --- a/tests/dotnet/Common.Tests/Models/Agents/AgentBaseTests.cs +++ b/tests/dotnet/Common.Tests/Models/Agents/AgentBaseTests.cs @@ -28,7 +28,7 @@ public void AgentType_UnsupportedType_ThrowsException() public void ConversationHistory_SetAndGet_ReturnsCorrectValue() { // Arrange - var conversationHistory = new ConversationHistorySettings { Enabled = true, MaxHistory = 100 }; + var conversationHistory = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 100 }; _agentBase.ConversationHistorySettings = conversationHistory; // Assert @@ -39,7 +39,7 @@ public void ConversationHistory_SetAndGet_ReturnsCorrectValue() public void Gatekeeper_SetAndGet_ReturnsCorrectValue() { // Arrange - var gatekeeper = new GatekeeperSettings { UseSystemSetting = false, Options = new string[] { "Option1", "Option2" } }; + var gatekeeper = new AgentGatekeeperSettings { UseSystemSetting = false, Options = new string[] { "Option1", "Option2" } }; _agentBase.GatekeeperSettings = gatekeeper; // Assert diff --git a/tests/dotnet/Core.Examples/Catalogs/AgentCatalog.cs b/tests/dotnet/Core.Examples/Catalogs/AgentCatalog.cs index 9819dc5555..a64444413b 100644 --- a/tests/dotnet/Core.Examples/Catalogs/AgentCatalog.cs +++ b/tests/dotnet/Core.Examples/Catalogs/AgentCatalog.cs @@ -29,12 +29,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -56,12 +56,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null, }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -84,12 +84,12 @@ public static class AgentCatalog TextEmbeddingProfileObjectId = null, DataSourceObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -112,12 +112,12 @@ public static class AgentCatalog TextEmbeddingProfileObjectId = null, DataSourceObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -139,12 +139,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -166,12 +166,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -193,12 +193,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -220,12 +220,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -247,12 +247,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -274,12 +274,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, @@ -301,12 +301,12 @@ public static class AgentCatalog IndexingProfileObjectIds = null, TextEmbeddingProfileObjectId = null }, - ConversationHistorySettings = new ConversationHistorySettings + ConversationHistorySettings = new AgentConversationHistorySettings { Enabled = true, MaxHistory = 10 }, - GatekeeperSettings = new GatekeeperSettings + GatekeeperSettings = new AgentGatekeeperSettings { UseSystemSetting = false }, diff --git a/tests/dotnet/Orchestration.Tests/Orchestration/KnowledgeManagementOrchestrationTests.cs b/tests/dotnet/Orchestration.Tests/Orchestration/KnowledgeManagementOrchestrationTests.cs index 69ff6fd4d8..55cf8421c0 100644 --- a/tests/dotnet/Orchestration.Tests/Orchestration/KnowledgeManagementOrchestrationTests.cs +++ b/tests/dotnet/Orchestration.Tests/Orchestration/KnowledgeManagementOrchestrationTests.cs @@ -13,7 +13,7 @@ namespace FoundationaLLM.Orchestration.Tests.Orchestration public class KnowledgeManagementOrchestrationTests { private readonly string _instanceId = "00000000-0000-0000-0000-000000000000"; - private KnowledgeManagementOrchestration _knowledgeManagementOrchestration; + private AgentOrchestration _knowledgeManagementOrchestration; private KnowledgeManagementAgent _agent = new KnowledgeManagementAgent() { Name = "Test_agent", ObjectId="Test_objctid", Type = AgentTypes.KnowledgeManagement }; private ICallContext _callContext = Substitute.For(); private ILLMOrchestrationService _orchestrationService = Substitute.For(); @@ -21,7 +21,7 @@ public class KnowledgeManagementOrchestrationTests public KnowledgeManagementOrchestrationTests() { - _knowledgeManagementOrchestration = new KnowledgeManagementOrchestration( + _knowledgeManagementOrchestration = new AgentOrchestration( _instanceId, _agent.ObjectId, _agent, @@ -29,6 +29,8 @@ public KnowledgeManagementOrchestrationTests() null, _callContext, _orchestrationService, + null, + null, _logger, null, null, diff --git a/tests/dotnet/Orchestration.Tests/Services/OrchestrationServiceTests.cs b/tests/dotnet/Orchestration.Tests/Services/OrchestrationServiceTests.cs index b8c9704664..7c941c84c9 100644 --- a/tests/dotnet/Orchestration.Tests/Services/OrchestrationServiceTests.cs +++ b/tests/dotnet/Orchestration.Tests/Services/OrchestrationServiceTests.cs @@ -39,6 +39,8 @@ public OrchestrationServiceTests() null, null, null, + null, + null, _callContext, _configuration, null,