Skip to content

Commit

Permalink
Fix issues with retrieving completion responses from the semantic cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ciprianjichici committed Jan 12, 2025
1 parent e493d5e commit a0c49e3
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 43 deletions.
4 changes: 3 additions & 1 deletion docs/release-notes/breaking-changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ az cosmosdb update --resource-group <resource-group-name> --name <account-name>
Create the `CompletionsCache` container in the Cosmos DB database with the following properties:

- **Container id**: `CompletionsCache`
- **Partition key**: `/partitionKey`
- **Partition key**: `/operationId`
- **Container Vector Policy**: a policy with the following properties:
- **Path**: `/userPromptEmbedding`
- **Data type**: `float32`
- **Distance function**: `Cosine`
- **Dimensions**: 2048
- **Index type**: `diskANN` (leave the default values)

After the container is created, set the `Time to Live` property on the container to 300 seconds.

## Starting with 0.9.1-rc105

### Configuration changes
Expand Down
13 changes: 13 additions & 0 deletions src/dotnet/Common/Interfaces/IAzureCosmosDBService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using FoundationaLLM.Common.Models.Configuration.Users;
using FoundationaLLM.Common.Models.Conversation;
using FoundationaLLM.Common.Models.Orchestration;
using FoundationaLLM.Common.Models.Orchestration.Response;
using FoundationaLLM.Common.Models.ResourceProviders;
using FoundationaLLM.Common.Models.ResourceProviders.Attachment;

Expand Down Expand Up @@ -274,4 +275,16 @@ Task CreateVectorSearchContainerAsync(
string vectorProperyPath,
int vectorDimensions,
CancellationToken cancellationToken = default);

/// <summary>
/// Gets the completion response for a given user prompt embedding using vector search and a minimum threshold for similarity.
/// </summary>
/// <param name="containerName">The name of the container holding the vector index.</param>
/// <param name="userPromptEmbedding">The reference embedding used for the vector search.</param>
/// <param name="minimumSimilarityScore">The threshold used for the similarity score.</param>
/// <returns>A <see cref="CompletionResponse"/> that matches the search criteria. If no item in the vector index matches the criteria, returns <see langref="null"/>.</returns>
Task<CompletionResponse?> GetCompletionResponseAsync(
string containerName,
ReadOnlyMemory<float> userPromptEmbedding,
decimal minimumSimilarityScore);
}
42 changes: 42 additions & 0 deletions src/dotnet/Common/Services/Azure/AzureCosmosDBService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
using FoundationaLLM.Common.Models.Configuration.Users;
using FoundationaLLM.Common.Models.Conversation;
using FoundationaLLM.Common.Models.Orchestration;
using FoundationaLLM.Common.Models.Orchestration.Response;
using FoundationaLLM.Common.Models.ResourceProviders;
using FoundationaLLM.Common.Models.ResourceProviders.Attachment;
using Microsoft.Azure.Cosmos;
using Microsoft.Azure.Cosmos.Serialization.HybridRow;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Polly;
using Polly.Retry;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Net;
using System.Text.Json;

namespace FoundationaLLM.Common.Services.Azure
{
Expand Down Expand Up @@ -639,5 +642,44 @@ public async Task CreateVectorSearchContainerAsync(
if (containerResponse.Container != null)
_containers[containerName] = containerResponse.Container;
}

/// <inheritdoc/>
public async Task<CompletionResponse?> GetCompletionResponseAsync(
string containerName,
ReadOnlyMemory<float> userPromptEmbedding,
decimal minimumSimilarityScore)
{
var query = new QueryDefinition("""
SELECT TOP 1
x.serializedItem, x.similarityScore
FROM
(
SELECT c.serializedItem, VectorDistance(c.userPromptEmbedding, @userPromptEmbedding) AS similarityScore FROM c
) x
WHERE
x.similarityScore >= @minimumSimilarityScore
ORDER BY
x.similarityScore DESC
""");
query.WithParameter("@userPromptEmbedding", userPromptEmbedding.ToArray());
query.WithParameter("@minimumSimilarityScore", (float)minimumSimilarityScore);

using var feedIterator = _completionsCache.GetItemQueryIterator<Object>(query);
if (feedIterator.HasMoreResults)
{
var response = await feedIterator.ReadNextAsync();
var result = response.Resource.FirstOrDefault();

if (result == null)
return null;

var serializedCompletionResponse = (result as Newtonsoft.Json.Linq.JObject)!["serializedItem"]!.ToString();
var completionResponse = JsonSerializer.Deserialize<CompletionResponse>(serializedCompletionResponse);

return completionResponse;
}

return null;
}
}
}
10 changes: 5 additions & 5 deletions src/dotnet/Orchestration/Interfaces/ISemanticCacheService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Task InitializeCacheForAgent(
Task ResetCacheForAgent(string instanceId, string agentName);

/// <summary>
/// Tries to get a cache item from the semantic cache for the specified agent in the specified FoundationaLLM instance.
/// Tries to get a <see cref="CompletionResponse"/> from the semantic cache for the specified agent in the specified FoundationaLLM instance.
/// </summary>
/// <param name="instanceId">The unique identifier of the FoundationaLLM instance.</param>
/// <param name="agentName">The name of the agent.</param>
Expand All @@ -52,15 +52,15 @@ Task InitializeCacheForAgent(
CompletionRequest completionRequest);

/// <summary>
/// Sets a cache item in the semantic cache for the specified agent in the specified FoundationaLLM instance.
/// Sets a <see cref="CompletionResponse"/> in the semantic cache for the specified agent in the specified FoundationaLLM instance.
/// </summary>
/// <param name="instanceId">The unique identifier of the FoundationaLLM instance.</param>
/// <param name="agentName">The name of the agent.</param>
/// <param name="cacheItem">The <see cref="SemanticCacheItem"/> to be set in the agent's cache.</param>
/// <param name="completionResponse">The <see cref="CompletionResponse"/> to be set in the agent's cache.</param>
/// <returns></returns>
Task SetCacheItem(
Task SetCompletionResponseInCache(
string instanceId,
string agentName,
SemanticCacheItem cacheItem);
CompletionResponse completionResponse);
}
}
6 changes: 3 additions & 3 deletions src/dotnet/Orchestration/Models/SemanticCacheItem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ public class SemanticCacheItem
{
public string Id { get; set; }

public string PartitionKey { get; set; }
public string OperationId { get; set; }

public string UserPrompt { get; set; }

public int UserPromptTokens { get; set; }

public ReadOnlyMemory<float> UserPromptEmbedding { get; set; }
public float[] UserPromptEmbedding { get; set; }

public CompletionResponse CompletionResponse { get; set; }
public string SerializedItem { get; set; }
}
}
117 changes: 91 additions & 26 deletions src/dotnet/Orchestration/Orchestration/AgentOrchestration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using FoundationaLLM.Common.Models.ResourceProviders.Attachment;
using FoundationaLLM.Common.Models.ResourceProviders.AzureOpenAI;
using FoundationaLLM.Orchestration.Core.Interfaces;
using FoundationaLLM.Orchestration.Models;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using System.Text.RegularExpressions;
Expand Down Expand Up @@ -91,15 +92,19 @@ public override async Task<LongRunningOperation> StartCompletionOperation(Comple
Result = validationResponse
};

await HandlePromptRewrite(completionRequest);
var cachedResponse = await GetCompletionResponseFromCache(completionRequest);
if (cachedResponse != null)
return new LongRunningOperation
{
OperationId = completionRequest.OperationId!,
Status = OperationStatus.Completed,
Result = cachedResponse
};
if (_agent!.CacheSettings != null
&& _agent!.CacheSettings.SemanticCacheEnabled)
{
await HandlePromptRewrite(completionRequest);
var cachedResponse = await GetCompletionResponseFromCache(completionRequest);
if (cachedResponse != null)
return new LongRunningOperation
{
OperationId = completionRequest.OperationId!,
Status = OperationStatus.Completed,
Result = cachedResponse
};
}

var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest);
if (_completionRequestObserver != null)
Expand All @@ -123,9 +128,24 @@ public override async Task<LongRunningOperation> GetCompletionOperationStatus(st
// parse the LLM Completion response from JsonElement
if (operationStatus.Result is JsonElement jsonElement)
{
var completionResponse = JsonSerializer.Deserialize<LLMCompletionResponse>(jsonElement.ToString());
if (completionResponse != null)
operationStatus.Result = await GetCompletionResponse(operationId, completionResponse);
var llmCompletionResponse = JsonSerializer.Deserialize<LLMCompletionResponse>(jsonElement.ToString());
if (llmCompletionResponse != null)
{
var completionResponse = await GetCompletionResponse(operationId, llmCompletionResponse);

if (_agent!.CacheSettings != null
&& _agent!.CacheSettings.SemanticCacheEnabled
&& (
completionResponse.Errors == null
|| completionResponse.Errors.Length == 0
))
{
// This is a valid response that can be cached.
await SetCompletionResponseInCache(completionResponse);
}

operationStatus.Result = completionResponse;
}
}
}

Expand All @@ -139,10 +159,14 @@ public override async Task<CompletionResponse> GetCompletion(CompletionRequest c
if (validationResponse != null)
return validationResponse;

await HandlePromptRewrite(completionRequest);
var cachedResponse = await GetCompletionResponseFromCache(completionRequest);
if (cachedResponse != null)
return cachedResponse;
if (_agent!.CacheSettings != null
&& _agent!.CacheSettings.SemanticCacheEnabled)
{
await HandlePromptRewrite(completionRequest);
var cachedResponse = await GetCompletionResponseFromCache(completionRequest);
if (cachedResponse != null)
return cachedResponse;
}

var llmCompletionRequest = await GetLLMCompletionRequest(completionRequest);
if (_completionRequestObserver != null)
Expand All @@ -154,6 +178,17 @@ public override async Task<CompletionResponse> GetCompletion(CompletionRequest c

var completionResponse = await GetCompletionResponse(completionRequest.OperationId!, llmCompletionResponse);

if (_agent!.CacheSettings != null
&& _agent!.CacheSettings.SemanticCacheEnabled
&& (
completionResponse.Errors == null
|| completionResponse.Errors.Length == 0
))
{
// This is a valid response that can be cached.
await SetCompletionResponseInCache(completionResponse);
}

return completionResponse;
}

Expand Down Expand Up @@ -203,19 +238,49 @@ await _userPromptRewriteService.InitializeUserPromptRewriterForAgent(

private async Task<CompletionResponse?> GetCompletionResponseFromCache(CompletionRequest completionRequest)
{
if (_agent!.CacheSettings != null
&& _agent!.CacheSettings.SemanticCacheEnabled)
{
if (!_semanticCacheService.HasCacheForAgent(_instanceId, _agent.Name))
await _semanticCacheService.InitializeCacheForAgent(
_instanceId,
_agent.Name,
_agent.CacheSettings.SemanticCacheSettings!);
if (!_semanticCacheService.HasCacheForAgent(_instanceId, _agent!.Name))
await _semanticCacheService.InitializeCacheForAgent(
_instanceId,
_agent.Name,
_agent.CacheSettings!.SemanticCacheSettings!);

var cachedResponse = await _semanticCacheService.GetCompletionResponseFromCache(
_instanceId,
_agent.Name,
completionRequest);

if (cachedResponse == null)
return null;
}

return null;
cachedResponse.OperationId = completionRequest.OperationId!;
var contentArtifactsList = new List<ContentArtifact>(cachedResponse.ContentArtifacts ??= [])
{
new() {
Id = "CachedResponse",
Title = "Cached Response",
Source = "SemanticCache"
}
};
cachedResponse.ContentArtifacts = [.. contentArtifactsList];

return cachedResponse;
}

private async Task SetCompletionResponseInCache(CompletionResponse completionResponse)
{
try
{
await _semanticCacheService.SetCompletionResponseInCache(_instanceId, _agent!.Name, completionResponse);
}
catch (Exception ex)
{
_logger.LogError(
ex,
"An error occurred while setting the completion response in the semantic cache for operation {OperationId} and agent {AgentName} in instance {InstanceId}.",
completionResponse.OperationId,
_agent!.Name,
_instanceId);
}
}

private async Task<LLMCompletionRequest> GetLLMCompletionRequest(CompletionRequest completionRequest) =>
Expand Down
46 changes: 39 additions & 7 deletions src/dotnet/Orchestration/Services/SemanticCacheService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using OpenAI.Embeddings;
using System.Text.Json;

namespace FoundationaLLM.Orchestration.Core.Services
{
Expand Down Expand Up @@ -112,31 +113,62 @@ public async Task InitializeCacheForAgent(
}

/// <inheritdoc/>
public Task ResetCacheForAgent(string instanceId, string agentName) =>
Task.CompletedTask;
public async Task ResetCacheForAgent(string instanceId, string agentName) =>
await Task.CompletedTask;

/// <inheritdoc/>
public Task SetCacheItem(string instanceId, string agentName, SemanticCacheItem cacheItem) =>
Task.CompletedTask;
public async Task SetCompletionResponseInCache(string instanceId, string agentName, CompletionResponse completionResponse)
{
if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache)
|| agentCache == null)
throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}.");

var cacheItem = new SemanticCacheItem
{
Id = Guid.NewGuid().ToString().ToLower(),
OperationId = completionResponse.OperationId!,
UserPrompt = completionResponse.UserPromptRewrite!,
SerializedItem = JsonSerializer.Serialize(completionResponse),
};

var embeddingResponse = await agentCache.EmbeddingClient.GenerateEmbeddingAsync(
cacheItem.UserPrompt,
new EmbeddingGenerationOptions
{
Dimensions = agentCache.Settings.EmbeddingDimensions
});
cacheItem.UserPromptEmbedding = embeddingResponse.Value.ToFloats().ToArray();

await _cosmosDBService.UpsertItemAsync<SemanticCacheItem>(
SEMANTIC_CACHE_CONTAINER_NAME,
cacheItem.OperationId,
cacheItem);
}

/// <inheritdoc/>
public async Task<CompletionResponse?> GetCompletionResponseFromCache(
string instanceId,
string agentName,
CompletionRequest completionRequest)
{
if (!_agentCaches.TryGetValue($"{instanceId}-{agentName}", out AgentSemanticCache? agentCache)
if (!_agentCaches.TryGetValue($"{instanceId}|{agentName}", out AgentSemanticCache? agentCache)
|| agentCache == null)
throw new SemanticCacheException($"The semantic cache is not initialized for agent {agentName} in instance {instanceId}.");

var userPromptEmbedding = await agentCache.EmbeddingClient.GenerateEmbeddingAsync(
var embeddingResult = await agentCache.EmbeddingClient.GenerateEmbeddingAsync(
completionRequest.UserPromptRewrite,
new EmbeddingGenerationOptions
{
Dimensions = agentCache.Settings.EmbeddingDimensions
});
var userPromptEmbedding = embeddingResult.Value.ToFloats();

var cachedCompletionResponse = await _cosmosDBService.GetCompletionResponseAsync(
SEMANTIC_CACHE_CONTAINER_NAME,
userPromptEmbedding,
agentCache.Settings.MinimumSimilarityThreshold);

return null;
return cachedCompletionResponse;
}

private EmbeddingClient GetEmbeddingClient(string deploymentName, APIEndpointConfiguration apiEndpointConfiguration) =>
Expand Down
Loading

0 comments on commit a0c49e3

Please sign in to comment.