diff --git a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs
index e08364ff82..69132b4014 100644
--- a/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs
+++ b/src/dotnet/Common/Interfaces/IDownstreamAPIService.cs
@@ -19,5 +19,29 @@ public interface IDownstreamAPIService
/// The completion request containing the user prompt and message history.
/// The completion response.
Task GetCompletion(string instanceId, CompletionRequest completionRequest);
+
+ ///
+ /// Begins a completion operation.
+ ///
+ /// The FoundationaLLM instance id.
+ /// The completion request containing the user prompt and message history.
+ /// Returns an object containing the OperationId and Status.
+ Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest);
+
+ ///
+ /// Gets the status of a completion operation.
+ ///
+ /// The FoundationaLLM instance id.
+ /// The OperationId for which to retrieve the status.
+ /// Returns an object containing the OperationId and Status.
+ Task GetCompletionOperationStatus(string instanceId, string operationId);
+
+ ///
+ /// Gets a completion operation from the Orchestration service.
+ ///
+ /// The FoundationaLLM instance id.
+ /// The ID of the operation to retrieve.
+ /// Returns a completion response
+ Task GetCompletionOperationResult(string instanceId, string operationId);
}
}
diff --git a/src/dotnet/Common/Services/API/DownstreamAPIService.cs b/src/dotnet/Common/Services/API/DownstreamAPIService.cs
index f7c255f76a..4148a41d02 100644
--- a/src/dotnet/Common/Services/API/DownstreamAPIService.cs
+++ b/src/dotnet/Common/Services/API/DownstreamAPIService.cs
@@ -74,5 +74,115 @@ public async Task GetCompletion(string instanceId, Completio
return fallback;
}
+ ///
+ public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest)
+ {
+ var fallback = new LongRunningOperation
+ {
+ OperationId = completionRequest.OperationId,
+ Status = OperationStatus.Failed
+ };
+
+ var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);
+
+ _logger.LogInformation(
+ "Created Http client {ClientName} with timeout {Timeout} seconds.",
+ _downstreamHttpClientName,
+ (int)client.Timeout.TotalSeconds);
+
+ var serializedRequest = JsonSerializer.Serialize(completionRequest, _jsonSerializerOptions);
+ var responseMessage = await client.PostAsync($"instances/{instanceId}/async-completions",
+ new StringContent(
+ serializedRequest,
+ Encoding.UTF8, "application/json"));
+
+ _logger.LogInformation(
+ "Http client {ClientName} returned a response with status code {HttpStatusCode}.",
+ _downstreamHttpClientName,
+ responseMessage.StatusCode);
+
+ if (responseMessage.IsSuccessStatusCode)
+ {
+ var responseContent = await responseMessage.Content.ReadAsStringAsync();
+ var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent);
+
+ return longRunningOperationResponse ?? fallback;
+ }
+
+ return fallback;
+ }
+
+ ///
+ public async Task GetCompletionOperationStatus(string instanceId, string operationId)
+ {
+ var fallback = new LongRunningOperation
+ {
+ OperationId = operationId,
+ Status = OperationStatus.Failed
+ };
+
+ var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);
+
+ _logger.LogInformation(
+ "Created Http client {ClientName} with timeout {Timeout} seconds.",
+ _downstreamHttpClientName,
+ (int)client.Timeout.TotalSeconds);
+
+ var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/status");
+
+ _logger.LogInformation(
+ "Http client {ClientName} returned a response with status code {HttpStatusCode}.",
+ _downstreamHttpClientName,
+ responseMessage.StatusCode);
+
+ if (responseMessage.IsSuccessStatusCode)
+ {
+ var responseContent = await responseMessage.Content.ReadAsStringAsync();
+ var longRunningOperationResponse = JsonSerializer.Deserialize(responseContent);
+
+ return longRunningOperationResponse ?? fallback;
+ }
+
+ return fallback;
+ }
+
+ ///
+ public async Task GetCompletionOperationResult(string instanceId, string operationId)
+ {
+ var fallback = new CompletionResponse
+ {
+ OperationId = operationId,
+ Completion = "A problem on my side prevented me from responding.",
+ UserPrompt = string.Empty,
+ PromptTokens = 0,
+ CompletionTokens = 0,
+ UserPromptEmbedding = [0f]
+ };
+
+ var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);
+
+ _logger.LogInformation(
+ "Created Http client {ClientName} with timeout {Timeout} seconds.",
+ _downstreamHttpClientName,
+ (int)client.Timeout.TotalSeconds);
+
+ var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/result");
+
+ _logger.LogInformation(
+ "Http client {ClientName} returned a response with status code {HttpStatusCode}.",
+ _downstreamHttpClientName,
+ responseMessage.StatusCode);
+
+ if (responseMessage.IsSuccessStatusCode)
+ {
+ var responseContent = await responseMessage.Content.ReadAsStringAsync();
+ var completionResponse = JsonSerializer.Deserialize(responseContent);
+
+ return completionResponse ?? fallback;
+ }
+
+ return fallback;
+ }
+
}
}
diff --git a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs
index 77f7e726bf..fab476cd80 100644
--- a/src/dotnet/Gatekeeper/Services/GatekeeperService.cs
+++ b/src/dotnet/Gatekeeper/Services/GatekeeperService.cs
@@ -92,17 +92,67 @@ public async Task GetCompletion(string instanceId, Completio
}
///
- public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest) =>
- // TODO: Need to call State API to start the operation.
- throw new NotImplementedException();
+ public async Task StartCompletionOperation(string instanceId, CompletionRequest completionRequest)
+ {
+ if (completionRequest.GatekeeperOptions != null && completionRequest.GatekeeperOptions.Length > 0)
+ {
+ _gatekeeperServiceSettings.EnableAzureContentSafety = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafety);
+ _gatekeeperServiceSettings.EnableMicrosoftPresidio = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.MicrosoftPresidio);
+ _gatekeeperServiceSettings.EnableLakeraGuard = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.LakeraGuard);
+ _gatekeeperServiceSettings.EnableEnkryptGuardrails = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.EnkryptGuardrails);
+ _gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafetyPromptShield);
+ }
+
+ if (_gatekeeperServiceSettings.EnableLakeraGuard)
+ {
+ var promptInjectionResult = await _lakeraGuardService.DetectPromptInjection(completionRequest.UserPrompt!);
+
+ if (!string.IsNullOrWhiteSpace(promptInjectionResult))
+ return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
+ }
+
+ if (_gatekeeperServiceSettings.EnableEnkryptGuardrails)
+ {
+ var promptInjectionResult = await _enkryptGuardrailsService.DetectPromptInjection(completionRequest.UserPrompt!);
+
+ if (!string.IsNullOrWhiteSpace(promptInjectionResult))
+ return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
+ }
+
+ if (_gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield)
+ {
+ var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!);
+
+ if (!string.IsNullOrWhiteSpace(promptInjectionResult))
+ return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
+ }
+
+ if (_gatekeeperServiceSettings.EnableAzureContentSafety)
+ {
+ var contentSafetyResult = await _contentSafetyService.AnalyzeText(completionRequest.UserPrompt!);
+
+ if (!contentSafetyResult.Safe)
+ return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = contentSafetyResult.Reason, Status = OperationStatus.Failed };
+ }
+
+ var response = await _orchestrationAPIService.StartCompletionOperation(instanceId, completionRequest);
+
+ return response;
+ }
///
- public Task GetCompletionOperationStatus(string instanceId, string operationId) => throw new NotImplementedException();
+ public async Task GetCompletionOperationStatus(string instanceId, string operationId) =>
+ await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId);
///
- public async Task GetCompletionOperationResult(string instanceId, string operationId) =>
- // TODO: Need to call State API to get the operation.
- throw new NotImplementedException();
+ public async Task GetCompletionOperationResult(string instanceId, string operationId)
+ {
+ var completionResponse = await _orchestrationAPIService.GetCompletionOperationResult(instanceId, operationId);
+ if (_gatekeeperServiceSettings.EnableMicrosoftPresidio)
+ completionResponse.Completion = await _gatekeeperIntegrationAPIService.AnonymizeText(completionResponse.Completion);
+
+ return completionResponse;
+ }
}
}