From 1c616d24e7073ce688233bdd32c54a1c734d38c9 Mon Sep 17 00:00:00 2001 From: gabrielgheorghescu Date: Tue, 6 Aug 2024 11:44:19 +0300 Subject: [PATCH 1/3] Implement async-completions controller methods for long running agents --- .../Interfaces/IDownstreamAPIService.cs | 24 ++++ .../Services/API/DownstreamAPIService.cs | 110 ++++++++++++++++++ .../Gatekeeper/Services/GatekeeperService.cs | 69 +++++++++-- 3 files changed, 196 insertions(+), 7 deletions(-) diff --git a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs index e08364ff82..69132b4014 100644 --- a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs +++ b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs @@ -19,5 +19,29 @@ public interface IDownstreamAPIService /// The completion request containing the user prompt and message history. /// The completion response. Task GetCompletion(string instanceId, CompletionRequest completionRequest); + + /// + /// Begins a completion operation. + /// + /// The FoundationaLLM instance id. + /// The completion request containing the user prompt and message history. + /// Returns an object containing the OperationId and Status. + Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest); + + /// + /// Gets the status of a completion operation. + /// + /// The FoundationaLLM instance id. + /// The OperationId for which to retrieve the status. + /// Returns an object containing the OperationId and Status. + Task GetCompletionOperationStatus(string instanceId, string operationId); + + /// + /// Gets a completion operation from the Orchestration service. + /// + /// The FoundationaLLM instance id. + /// The ID of the operation to retrieve. + /// Returns a completion response + Task GetCompletionOperationResult(string instanceId, string operationId); } } diff --git a/src/dotnet/Common/Services/API/DownstreamAPIService.cs b/src/dotnet/Common/Services/API/DownstreamAPIService.cs index f7c255f76a..4148a41d02 100644 --- a/src/dotnet/Common/Services/API/DownstreamAPIService.cs +++ b/src/dotnet/Common/Services/API/DownstreamAPIService.cs @@ -74,5 +74,115 @@ public async Task GetCompletion(string instanceId, Completio return fallback; } + /// + public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) + { + var fallback = new LongRunningOperation + { + OperationId = completionRequest.OperationId, + Status = OperationStatus.Failed + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var serializedRequest = JsonSerializer.Serialize(completionRequest, _jsonSerializerOptions); + var responseMessage = await client.PostAsync($"instances/{instanceId}/async-completions", + new StringContent( + serializedRequest, + Encoding.UTF8, "application/json")); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent); + + return longRunningOperationResponse ?? fallback; + } + + return fallback; + } + + /// + public async Task GetCompletionOperationStatus(string instanceId, string operationId) + { + var fallback = new LongRunningOperation + { + OperationId = operationId, + Status = OperationStatus.Failed + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/status"); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent); + + return longRunningOperationResponse ?? fallback; + } + + return fallback; + } + + /// + public async Task GetCompletionOperationResult(string instanceId, string operationId) + { + var fallback = new CompletionResponse + { + OperationId = operationId, + Completion = "A problem on my side prevented me from responding.", + UserPrompt = string.Empty, + PromptTokens = 0, + CompletionTokens = 0, + UserPromptEmbedding = [0f] + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/result"); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var completionResponse = JsonSerializer.Deserialize(responseContent); + + return completionResponse ?? fallback; + } + + return fallback; + } + } } diff --git a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs index 77f7e726bf..ba7aa72f63 100644 --- a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs +++ b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs @@ -3,6 +3,7 @@ using FoundationaLLM.Common.Models.Orchestration; using FoundationaLLM.Gatekeeper.Core.Interfaces; using FoundationaLLM.Gatekeeper.Core.Models.ConfigurationOptions; +using FoundationaLLM.Gatekeeper.Core.Models.LakeraGuard; using Microsoft.Extensions.Options; namespace FoundationaLLM.Gatekeeper.Core.Services @@ -92,17 +93,71 @@ public async Task GetCompletion(string instanceId, Completio } /// - public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) => - // TODO: Need to call State API to start the operation. - throw new NotImplementedException(); + public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) + { + if (completionRequest.GatekeeperOptions != null && completionRequest.GatekeeperOptions.Length > 0) + { + _gatekeeperServiceSettings.EnableAzureContentSafety = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafety); + _gatekeeperServiceSettings.EnableMicrosoftPresidio = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.MicrosoftPresidio); + _gatekeeperServiceSettings.EnableLakeraGuard = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.LakeraGuard); + _gatekeeperServiceSettings.EnableEnkryptGuardrails = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.EnkryptGuardrails); + _gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafetyPromptShield); + } + + if (_gatekeeperServiceSettings.EnableLakeraGuard) + { + var promptInjectionResult = await _lakeraGuardService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableEnkryptGuardrails) + { + var promptInjectionResult = await _enkryptGuardrailsService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield) + { + var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableAzureContentSafety) + { + var contentSafetyResult = await _contentSafetyService.AnalyzeText(completionRequest.UserPrompt!); + + if (!contentSafetyResult.Safe) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = contentSafetyResult.Reason, Status = OperationStatus.Failed }; + } + + var response = await _orchestrationAPIService.StartCompletionOperation(instanceId, completionRequest); + + return response; + } /// - public Task GetCompletionOperationStatus(string instanceId, string operationId) => throw new NotImplementedException(); + public async Task GetCompletionOperationStatus(string instanceId, string operationId) + { + var response = await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId); + + return response; + } /// - public async Task GetCompletionOperationResult(string instanceId, string operationId) => - // TODO: Need to call State API to get the operation. - throw new NotImplementedException(); + public async Task GetCompletionOperationResult(string instanceId, string operationId) + { + var completionResponse = await _orchestrationAPIService.GetCompletionOperationResult(instanceId, operationId); + + if (_gatekeeperServiceSettings.EnableMicrosoftPresidio) + completionResponse.Completion = await _gatekeeperIntegrationAPIService.AnonymizeText(completionResponse.Completion); + return completionResponse; + } } } From 8cc3c431a0d333419ab7789a0ddd07b4063b3e77 Mon Sep 17 00:00:00 2001 From: gabrielgheorghescu Date: Tue, 6 Aug 2024 11:45:45 +0300 Subject: [PATCH 2/3] Code clean --- src/dotnet/Gatekeeper/Services/GatekeeperService.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs index ba7aa72f63..44c2adecb1 100644 --- a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs +++ b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs @@ -3,7 +3,6 @@ using FoundationaLLM.Common.Models.Orchestration; using FoundationaLLM.Gatekeeper.Core.Interfaces; using FoundationaLLM.Gatekeeper.Core.Models.ConfigurationOptions; -using FoundationaLLM.Gatekeeper.Core.Models.LakeraGuard; using Microsoft.Extensions.Options; namespace FoundationaLLM.Gatekeeper.Core.Services From 5c92cc8747ba20bf834a388a55da1c37998cf9ec Mon Sep 17 00:00:00 2001 From: gabrielgheorghescu Date: Tue, 6 Aug 2024 15:58:34 +0300 Subject: [PATCH 3/3] Update to use expression body for returning method --- src/dotnet/Gatekeeper/Services/GatekeeperService.cs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs index 44c2adecb1..fab476cd80 100644 --- a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs +++ b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs @@ -141,12 +141,8 @@ public async Task StartCompletionOperation(string instance } /// - public async Task GetCompletionOperationStatus(string instanceId, string operationId) - { - var response = await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId); - - return response; - } + public async Task GetCompletionOperationStatus(string instanceId, string operationId) => + await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId); /// public async Task GetCompletionOperationResult(string instanceId, string operationId)