diff --git a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs index e08364ff82..69132b4014 100644 --- a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs +++ b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs @@ -19,5 +19,29 @@ public interface IDownstreamAPIService /// The completion request containing the user prompt and message history. /// The completion response. Task GetCompletion(string instanceId, CompletionRequest completionRequest); + + /// + /// Begins a completion operation. + /// + /// The FoundationaLLM instance id. + /// The completion request containing the user prompt and message history. + /// Returns an object containing the OperationId and Status. + Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest); + + /// + /// Gets the status of a completion operation. + /// + /// The FoundationaLLM instance id. + /// The OperationId for which to retrieve the status. + /// Returns an object containing the OperationId and Status. + Task GetCompletionOperationStatus(string instanceId, string operationId); + + /// + /// Gets a completion operation from the Orchestration service. + /// + /// The FoundationaLLM instance id. + /// The ID of the operation to retrieve. + /// Returns a completion response + Task GetCompletionOperationResult(string instanceId, string operationId); } } diff --git a/src/dotnet/Common/Services/API/DownstreamAPIService.cs b/src/dotnet/Common/Services/API/DownstreamAPIService.cs index f7c255f76a..4148a41d02 100644 --- a/src/dotnet/Common/Services/API/DownstreamAPIService.cs +++ b/src/dotnet/Common/Services/API/DownstreamAPIService.cs @@ -74,5 +74,115 @@ public async Task GetCompletion(string instanceId, Completio return fallback; } + /// + public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) + { + var fallback = new LongRunningOperation + { + OperationId = completionRequest.OperationId, + Status = OperationStatus.Failed + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var serializedRequest = JsonSerializer.Serialize(completionRequest, _jsonSerializerOptions); + var responseMessage = await client.PostAsync($"instances/{instanceId}/async-completions", + new StringContent( + serializedRequest, + Encoding.UTF8, "application/json")); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent); + + return longRunningOperationResponse ?? fallback; + } + + return fallback; + } + + /// + public async Task GetCompletionOperationStatus(string instanceId, string operationId) + { + var fallback = new LongRunningOperation + { + OperationId = operationId, + Status = OperationStatus.Failed + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/status"); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent); + + return longRunningOperationResponse ?? fallback; + } + + return fallback; + } + + /// + public async Task GetCompletionOperationResult(string instanceId, string operationId) + { + var fallback = new CompletionResponse + { + OperationId = operationId, + Completion = "A problem on my side prevented me from responding.", + UserPrompt = string.Empty, + PromptTokens = 0, + CompletionTokens = 0, + UserPromptEmbedding = [0f] + }; + + var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity); + + _logger.LogInformation( + "Created Http client {ClientName} with timeout {Timeout} seconds.", + _downstreamHttpClientName, + (int)client.Timeout.TotalSeconds); + + var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/result"); + + _logger.LogInformation( + "Http client {ClientName} returned a response with status code {HttpStatusCode}.", + _downstreamHttpClientName, + responseMessage.StatusCode); + + if (responseMessage.IsSuccessStatusCode) + { + var responseContent = await responseMessage.Content.ReadAsStringAsync(); + var completionResponse = JsonSerializer.Deserialize(responseContent); + + return completionResponse ?? fallback; + } + + return fallback; + } + } } diff --git a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs index 77f7e726bf..fab476cd80 100644 --- a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs +++ b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs @@ -92,17 +92,67 @@ public async Task GetCompletion(string instanceId, Completio } /// - public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) => - // TODO: Need to call State API to start the operation. - throw new NotImplementedException(); + public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) + { + if (completionRequest.GatekeeperOptions != null && completionRequest.GatekeeperOptions.Length > 0) + { + _gatekeeperServiceSettings.EnableAzureContentSafety = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafety); + _gatekeeperServiceSettings.EnableMicrosoftPresidio = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.MicrosoftPresidio); + _gatekeeperServiceSettings.EnableLakeraGuard = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.LakeraGuard); + _gatekeeperServiceSettings.EnableEnkryptGuardrails = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.EnkryptGuardrails); + _gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafetyPromptShield); + } + + if (_gatekeeperServiceSettings.EnableLakeraGuard) + { + var promptInjectionResult = await _lakeraGuardService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableEnkryptGuardrails) + { + var promptInjectionResult = await _enkryptGuardrailsService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield) + { + var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!); + + if (!string.IsNullOrWhiteSpace(promptInjectionResult)) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed }; + } + + if (_gatekeeperServiceSettings.EnableAzureContentSafety) + { + var contentSafetyResult = await _contentSafetyService.AnalyzeText(completionRequest.UserPrompt!); + + if (!contentSafetyResult.Safe) + return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = contentSafetyResult.Reason, Status = OperationStatus.Failed }; + } + + var response = await _orchestrationAPIService.StartCompletionOperation(instanceId, completionRequest); + + return response; + } /// - public Task GetCompletionOperationStatus(string instanceId, string operationId) => throw new NotImplementedException(); + public async Task GetCompletionOperationStatus(string instanceId, string operationId) => + await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId); /// - public async Task GetCompletionOperationResult(string instanceId, string operationId) => - // TODO: Need to call State API to get the operation. - throw new NotImplementedException(); + public async Task GetCompletionOperationResult(string instanceId, string operationId) + { + var completionResponse = await _orchestrationAPIService.GetCompletionOperationResult(instanceId, operationId); + if (_gatekeeperServiceSettings.EnableMicrosoftPresidio) + completionResponse.Completion = await _gatekeeperIntegrationAPIService.AnonymizeText(completionResponse.Completion); + + return completionResponse; + } } }