From a0c49e311ff867340ba338220aae4219ac010ac3 Mon Sep 17 00:00:00 2001 From: Ciprian Jichici Date: Sun, 12 Jan 2025 23:22:42 +0200 Subject: [PATCH] Fix issues with retrieving completion responses from the semantic cache --- docs/release-notes/breaking-changes.md | 4 +- .../Interfaces/IAzureCosmosDBService.cs | 13 ++ .../Services/Azure/AzureCosmosDBService.cs | 42 +++++++ .../Interfaces/ISemanticCacheService.cs | 10 +- .../Orchestration/Models/SemanticCacheItem.cs | 6 +- .../Orchestration/AgentOrchestration.cs | 117 ++++++++++++++---- .../Services/SemanticCacheService.cs | 46 +++++-- .../Services/UserPromptRewriteService.cs | 6 +- 8 files changed, 201 insertions(+), 43 deletions(-) diff --git a/docs/release-notes/breaking-changes.md b/docs/release-notes/breaking-changes.md index 996b3a2ea..004870a75 100644 --- a/docs/release-notes/breaking-changes.md +++ b/docs/release-notes/breaking-changes.md @@ -37,7 +37,7 @@ az cosmosdb update --resource-group --name Create the `CompletionsCache` container in the Cosmos DB database with the following properties: - **Container id**: `CompletionsCache` -- **Partition key**: `/partitionKey` +- **Partition key**: `/operationId` - **Container Vector Policy**: a policy with the following properties: - **Path**: `/userPromptEmbedding` - **Data type**: `float32` @@ -45,6 +45,8 @@ Create the `CompletionsCache` container in the Cosmos DB database with the follo - **Dimensions**: 2048 - **Index type**: `diskANN` (leave the default values) +After the container is created, set the `Time to Live` property on the container to 300 seconds. + ## Starting with 0.9.1-rc105 ### Configuration changes diff --git a/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs b/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs index 7102609d9..738354f17 100644 --- a/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs +++ b/src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs @@ -2,6 +2,7 @@ using FoundationaLLM.Common.Models.Configuration.Users; using FoundationaLLM.Common.Models.Conversation; using FoundationaLLM.Common.Models.Orchestration; +using FoundationaLLM.Common.Models.Orchestration.Response; using FoundationaLLM.Common.Models.ResourceProviders; using FoundationaLLM.Common.Models.ResourceProviders.Attachment; @@ -274,4 +275,16 @@ Task CreateVectorSearchContainerAsync( string vectorProperyPath, int vectorDimensions, CancellationToken cancellationToken = default); + + /// + /// Gets the completion response for a given user prompt embedding using vector search and a minimum threshold for similarity. + /// + /// The name of the container holding the vector index. + /// The reference embedding used for the vector search. + /// The threshold used for the similarity score. + /// A that matches the search criteria. If no item in the vector index matches the criteria, returns . + Task GetCompletionResponseAsync( + string containerName, + ReadOnlyMemory userPromptEmbedding, + decimal minimumSimilarityScore); } diff --git a/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs b/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs index 1fead245f..5117e2e19 100644 --- a/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs +++ b/src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs @@ -5,9 +5,11 @@ using FoundationaLLM.Common.Models.Configuration.Users; using FoundationaLLM.Common.Models.Conversation; using FoundationaLLM.Common.Models.Orchestration; +using FoundationaLLM.Common.Models.Orchestration.Response; using FoundationaLLM.Common.Models.ResourceProviders; using FoundationaLLM.Common.Models.ResourceProviders.Attachment; using Microsoft.Azure.Cosmos; +using Microsoft.Azure.Cosmos.Serialization.HybridRow; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Polly; @@ -15,6 +17,7 @@ using System.Collections.ObjectModel; using System.Diagnostics; using System.Net; +using System.Text.Json; namespace FoundationaLLM.Common.Services.Azure { @@ -639,5 +642,44 @@ public async Task CreateVectorSearchContainerAsync( if (containerResponse.Container != null) _containers[containerName] = containerResponse.Container; } + + /// + public async Task GetCompletionResponseAsync( + string containerName, + ReadOnlyMemory userPromptEmbedding, + decimal minimumSimilarityScore) + { + var query = new QueryDefinition(""" + SELECT TOP 1 + x.serializedItem, x.similarityScore + FROM + ( + SELECT c.serializedItem, VectorDistance(c.userPromptEmbedding, @userPromptEmbedding) AS similarityScore FROM c + ) x + WHERE + x.similarityScore >= @minimumSimilarityScore + ORDER BY + x.similarityScore DESC + """); + query.WithParameter("@userPromptEmbedding", userPromptEmbedding.ToArray()); + query.WithParameter("@minimumSimilarityScore", (float)minimumSimilarityScore); + + using var feedIterator = _completionsCache.GetItemQueryIterator(query); + if (feedIterator.HasMoreResults) + { + var response = await feedIterator.ReadNextAsync(); + var result = response.Resource.FirstOrDefault(); + + if (result == null) + return null; + + var serializedCompletionResponse = (result as Newtonsoft.Json.Linq.JObject)!["serializedItem"]!.ToString(); + var completionResponse = JsonSerializer.Deserialize(serializedCompletionResponse); + + return completionResponse; + } + + return null; + } } } diff --git a/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs b/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs index 8338a3b2e..4fa8fc305 100644 --- a/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs +++ b/src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs @@ -40,7 +40,7 @@ Task InitializeCacheForAgent( Task ResetCacheForAgent(string instanceId, string agentName); /// - /// Tries to get a cache item from the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// Tries to get a from the semantic cache for the specified agent in the specified FoundationaLLM instance. /// /// The unique identifier of the FoundationaLLM instance. /// The name of the agent. @@ -52,15 +52,15 @@ Task InitializeCacheForAgent( CompletionRequest completionRequest); /// - /// Sets a cache item in the semantic cache for the specified agent in the specified FoundationaLLM instance. + /// Sets a in the semantic cache for the specified agent in the specified FoundationaLLM instance. /// /// The unique identifier of the FoundationaLLM instance. /// The name of the agent. - /// The to be set in the agent's cache. + /// The to be set in the agent's cache. /// - Task SetCacheItem( + Task SetCompletionResponseInCache( string instanceId, string agentName, - SemanticCacheItem cacheItem); + CompletionResponse completionResponse); } } diff --git a/src/dotnet/Orchestration/Models/SemanticCacheItem.cs b/src/dotnet/Orchestration/Models/SemanticCacheItem.cs index 326420ccf..155ea209e 100644 --- a/src/dotnet/Orchestration/Models/SemanticCacheItem.cs +++ b/src/dotnet/Orchestration/Models/SemanticCacheItem.cs @@ -9,14 +9,14 @@ public class SemanticCacheItem { public string Id { get; set; } - public string PartitionKey { get; set; } + public string OperationId { get; set; } public string UserPrompt { get; set; } public int UserPromptTokens { get; set; } - public ReadOnlyMemory UserPromptEmbedding { get; set; } + public float[] UserPromptEmbedding { get; set; } - public CompletionResponse CompletionResponse { get; set; } + public string SerializedItem { get; set; } } } diff --git a/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs b/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs index cc8bda4a7..5154d525c 100644 --- a/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs +++ b/src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs @@ -14,6 +14,7 @@ using FoundationaLLM.Common.Models.ResourceProviders.Attachment; using FoundationaLLM.Common.Models.ResourceProviders.AzureOpenAI; using FoundationaLLM.Orchestration.Core.Interfaces; +using FoundationaLLM.Orchestration.Models; using Microsoft.Extensions.Logging; using System.Text.Json; using System.Text.RegularExpressions; @@ -91,15 +92,19 @@ public override async Task StartCompletionOperation(Comple Result = validationResponse }; - await HandlePromptRewrite(completionRequest); - var cachedResponse = await GetCompletionResponseFromCache(completionRequest); - if (cachedResponse != null) - return new LongRunningOperation - { - OperationId = completionRequest.OperationId!, - Status = OperationStatus.Completed, - Result = cachedResponse - }; + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled) + { + await HandlePromptRewrite(completionRequest); + var cachedResponse = await GetCompletionResponseFromCache(completionRequest); + if (cachedResponse != null) + return new LongRunningOperation + { + OperationId = completionRequest.OperationId!, + Status = OperationStatus.Completed, + Result = cachedResponse + }; + } var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest); if (_completionRequestObserver != null) @@ -123,9 +128,24 @@ public override async Task GetCompletionOperationStatus(st // parse the LLM Completion response from JsonElement if (operationStatus.Result is JsonElement jsonElement) { - var completionResponse = JsonSerializer.Deserialize(jsonElement.ToString()); - if (completionResponse != null) - operationStatus.Result = await GetCompletionResponse(operationId, completionResponse); + var llmCompletionResponse = JsonSerializer.Deserialize(jsonElement.ToString()); + if (llmCompletionResponse != null) + { + var completionResponse = await GetCompletionResponse(operationId, llmCompletionResponse); + + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled + && ( + completionResponse.Errors == null + || completionResponse.Errors.Length == 0 + )) + { + // This is a valid response that can be cached. + await SetCompletionResponseInCache(completionResponse); + } + + operationStatus.Result = completionResponse; + } } } @@ -139,10 +159,14 @@ public override async Task GetCompletion(CompletionRequest c if (validationResponse != null) return validationResponse; - await HandlePromptRewrite(completionRequest); - var cachedResponse = await GetCompletionResponseFromCache(completionRequest); - if (cachedResponse != null) - return cachedResponse; + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled) + { + await HandlePromptRewrite(completionRequest); + var cachedResponse = await GetCompletionResponseFromCache(completionRequest); + if (cachedResponse != null) + return cachedResponse; + } var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest); if (_completionRequestObserver != null) @@ -154,6 +178,17 @@ public override async Task GetCompletion(CompletionRequest c var completionResponse = await GetCompletionResponse(completionRequest.OperationId!, llmCompletionResponse); + if (_agent!.CacheSettings != null + && _agent!.CacheSettings.SemanticCacheEnabled + && ( + completionResponse.Errors == null + || completionResponse.Errors.Length == 0 + )) + { + // This is a valid response that can be cached. + await SetCompletionResponseInCache(completionResponse); + } + return completionResponse; } @@ -203,19 +238,49 @@ await _userPromptRewriteService.InitializeUserPromptRewriterForAgent( private async Task GetCompletionResponseFromCache(CompletionRequest completionRequest) { - if (_agent!.CacheSettings != null - && _agent!.CacheSettings.SemanticCacheEnabled) - { - if (!_semanticCacheService.HasCacheForAgent(_instanceId, _agent.Name)) - await _semanticCacheService.InitializeCacheForAgent( - _instanceId, - _agent.Name, - _agent.CacheSettings.SemanticCacheSettings!); + if (!_semanticCacheService.HasCacheForAgent(_instanceId, _agent!.Name)) + await _semanticCacheService.InitializeCacheForAgent( + _instanceId, + _agent.Name, + _agent.CacheSettings!.SemanticCacheSettings!); + var cachedResponse = await _semanticCacheService.GetCompletionResponseFromCache( + _instanceId, + _agent.Name, + completionRequest); + + if (cachedResponse == null) return null; - } - return null; + cachedResponse.OperationId = completionRequest.OperationId!; + var contentArtifactsList = new List(cachedResponse.ContentArtifacts ??= []) + { + new() { + Id = "CachedResponse", + Title = "Cached Response", + Source = "SemanticCache" + } + }; + cachedResponse.ContentArtifacts = [.. contentArtifactsList]; + + return cachedResponse; + } + + private async Task SetCompletionResponseInCache(CompletionResponse completionResponse) + { + try + { + await _semanticCacheService.SetCompletionResponseInCache(_instanceId, _agent!.Name, completionResponse); + } + catch (Exception ex) + { + _logger.LogError( + ex, + "An error occurred while setting the completion response in the semantic cache for operation {OperationId} and agent {AgentName} in instance {InstanceId}.", + completionResponse.OperationId, + _agent!.Name, + _instanceId); + } } private async Task GetLLMCompletionRequest(CompletionRequest completionRequest) => diff --git a/src/dotnet/Orchestration/Services/SemanticCacheService.cs b/src/dotnet/Orchestration/Services/SemanticCacheService.cs index 2d31c22e5..55f98d3a7 100644 --- a/src/dotnet/Orchestration/Services/SemanticCacheService.cs +++ b/src/dotnet/Orchestration/Services/SemanticCacheService.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using OpenAI.Embeddings; +using System.Text.Json; namespace FoundationaLLM.Orchestration.Core.Services { @@ -112,12 +113,37 @@ public async Task InitializeCacheForAgent( } /// - public Task ResetCacheForAgent(string instanceId, string agentName) => - Task.CompletedTask; + public async Task ResetCacheForAgent(string instanceId, string agentName) => + await Task.CompletedTask; /// - public Task SetCacheItem(string instanceId, string agentName, SemanticCacheItem cacheItem) => - Task.CompletedTask; + public async Task SetCompletionResponseInCache(string instanceId, string agentName, CompletionResponse completionResponse) + { + if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache) + || agentCache == null) + throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}."); + + var cacheItem = new SemanticCacheItem + { + Id = Guid.NewGuid().ToString().ToLower(), + OperationId = completionResponse.OperationId!, + UserPrompt = completionResponse.UserPromptRewrite!, + SerializedItem = JsonSerializer.Serialize(completionResponse), + }; + + var embeddingResponse = await agentCache.EmbeddingClient.GenerateEmbeddingAsync( + cacheItem.UserPrompt, + new EmbeddingGenerationOptions + { + Dimensions = agentCache.Settings.EmbeddingDimensions + }); + cacheItem.UserPromptEmbedding = embeddingResponse.Value.ToFloats().ToArray(); + + await _cosmosDBService.UpsertItemAsync( + SEMANTIC_CACHE_CONTAINER_NAME, + cacheItem.OperationId, + cacheItem); + } /// public async Task GetCompletionResponseFromCache( @@ -125,18 +151,24 @@ public Task SetCacheItem(string instanceId, string agentName, SemanticCacheItem string agentName, CompletionRequest completionRequest) { - if (!_agentCaches.TryGetValue($"{instanceId}-{agentName}", out AgentSemanticCache? agentCache) + if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache) || agentCache == null) throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}."); - var userPromptEmbedding = await agentCache.EmbeddingClient.GenerateEmbeddingAsync( + var embeddingResult = await agentCache.EmbeddingClient.GenerateEmbeddingAsync( completionRequest.UserPromptRewrite, new EmbeddingGenerationOptions { Dimensions = agentCache.Settings.EmbeddingDimensions }); + var userPromptEmbedding = embeddingResult.Value.ToFloats(); + + var cachedCompletionResponse = await _cosmosDBService.GetCompletionResponseAsync( + SEMANTIC_CACHE_CONTAINER_NAME, + userPromptEmbedding, + agentCache.Settings.MinimumSimilarityThreshold); - return null; + return cachedCompletionResponse; } private EmbeddingClient GetEmbeddingClient(string deploymentName, APIEndpointConfiguration apiEndpointConfiguration) => diff --git a/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs b/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs index 7003f5c44..55e3dd24c 100644 --- a/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs +++ b/src/dotnet/Orchestration/Services/UserPromptRewriteService.cs @@ -124,7 +124,11 @@ public async Task RewriteUserPrompt( [ new SystemChatMessage(agentRewriter.RewriterSystemPrompt), new UserChatMessage($"QUESTIONS:{Environment.NewLine}{string.Join(Environment.NewLine, [.. userPromptsHistory])}") - ]); + ], + new ChatCompletionOptions + { + Temperature = 0 + }); completionRequest.UserPromptRewrite = completionResult.Value.Content[0].Text;