Skip to content

Commit

Permalink
Merge pull request #1358 from solliancenet/gg-gatekeeper-completions-080
Browse files Browse the repository at this point in the history
Implement async-completions controller methods for long running agents
  • Loading branch information
ciprianjichici authored Aug 7, 2024
2 parents 368d074 + 5c92cc8 commit 62c2e2f
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 7 deletions.
24 changes: 24 additions & 0 deletions src/dotnet/Common/Interfaces/IDownstreamAPIService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,29 @@ public interface IDownstreamAPIService
/// <param name="completionRequest">The completion request containing the user prompt and message history.</param>
/// <returns>The completion response.</returns>
Task<CompletionResponse> GetCompletion(string instanceId, CompletionRequest completionRequest);

/// <summary>
/// Begins a completion operation.
/// </summary>
/// <param name="instanceId">The FoundationaLLM instance id.</param>
/// <param name="completionRequest">The completion request containing the user prompt and message history.</param>
/// <returns>Returns an <see cref="LongRunningOperation"/> object containing the OperationId and Status.</returns>
Task<LongRunningOperation> StartCompletionOperation(string instanceId, CompletionRequest completionRequest);

/// <summary>
/// Gets the status of a completion operation.
/// </summary>
/// <param name="instanceId">The FoundationaLLM instance id.</param>
/// <param name="operationId">The OperationId for which to retrieve the status.</param>
/// <returns>Returns an <see cref="LongRunningOperation"/> object containing the OperationId and Status.</returns>
Task<LongRunningOperation> GetCompletionOperationStatus(string instanceId, string operationId);

/// <summary>
/// Gets a completion operation from the Orchestration service.
/// </summary>
/// <param name="instanceId">The FoundationaLLM instance id.</param>
/// <param name="operationId">The ID of the operation to retrieve.</param>
/// <returns>Returns a completion response</returns>
Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId);
}
}
110 changes: 110 additions & 0 deletions src/dotnet/Common/Services/API/DownstreamAPIService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,115 @@ public async Task<CompletionResponse> GetCompletion(string instanceId, Completio
return fallback;
}

/// <inheritdoc/>
public async Task<LongRunningOperation> StartCompletionOperation(string instanceId, CompletionRequest completionRequest)
{
var fallback = new LongRunningOperation
{
OperationId = completionRequest.OperationId,
Status = OperationStatus.Failed
};

var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);

_logger.LogInformation(
"Created Http client {ClientName} with timeout {Timeout} seconds.",
_downstreamHttpClientName,
(int)client.Timeout.TotalSeconds);

var serializedRequest = JsonSerializer.Serialize(completionRequest, _jsonSerializerOptions);
var responseMessage = await client.PostAsync($"instances/{instanceId}/async-completions",
new StringContent(
serializedRequest,
Encoding.UTF8, "application/json"));

_logger.LogInformation(
"Http client {ClientName} returned a response with status code {HttpStatusCode}.",
_downstreamHttpClientName,
responseMessage.StatusCode);

if (responseMessage.IsSuccessStatusCode)
{
var responseContent = await responseMessage.Content.ReadAsStringAsync();
var longRunningOperationResponse = JsonSerializer.Deserialize<LongRunningOperation>(responseContent);

return longRunningOperationResponse ?? fallback;
}

return fallback;
}

/// <inheritdoc/>
public async Task<LongRunningOperation> GetCompletionOperationStatus(string instanceId, string operationId)
{
var fallback = new LongRunningOperation
{
OperationId = operationId,
Status = OperationStatus.Failed
};

var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);

_logger.LogInformation(
"Created Http client {ClientName} with timeout {Timeout} seconds.",
_downstreamHttpClientName,
(int)client.Timeout.TotalSeconds);

var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/status");

_logger.LogInformation(
"Http client {ClientName} returned a response with status code {HttpStatusCode}.",
_downstreamHttpClientName,
responseMessage.StatusCode);

if (responseMessage.IsSuccessStatusCode)
{
var responseContent = await responseMessage.Content.ReadAsStringAsync();
var longRunningOperationResponse = JsonSerializer.Deserialize<LongRunningOperation>(responseContent);

return longRunningOperationResponse ?? fallback;
}

return fallback;
}

/// <inheritdoc/>
public async Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId)
{
var fallback = new CompletionResponse
{
OperationId = operationId,
Completion = "A problem on my side prevented me from responding.",
UserPrompt = string.Empty,
PromptTokens = 0,
CompletionTokens = 0,
UserPromptEmbedding = [0f]
};

var client = await _httpClientFactoryService.CreateClient(_downstreamHttpClientName, _callContext.CurrentUserIdentity);

_logger.LogInformation(
"Created Http client {ClientName} with timeout {Timeout} seconds.",
_downstreamHttpClientName,
(int)client.Timeout.TotalSeconds);

var responseMessage = await client.GetAsync($"instances/{instanceId}/async-completions/{operationId}/result");

_logger.LogInformation(
"Http client {ClientName} returned a response with status code {HttpStatusCode}.",
_downstreamHttpClientName,
responseMessage.StatusCode);

if (responseMessage.IsSuccessStatusCode)
{
var responseContent = await responseMessage.Content.ReadAsStringAsync();
var completionResponse = JsonSerializer.Deserialize<CompletionResponse>(responseContent);

return completionResponse ?? fallback;
}

return fallback;
}

}
}
64 changes: 57 additions & 7 deletions src/dotnet/Gatekeeper/Services/GatekeeperService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,67 @@ public async Task<CompletionResponse> GetCompletion(string instanceId, Completio
}

/// <inheritdoc/>
public async Task<LongRunningOperation> StartCompletionOperation(string instanceId, CompletionRequest completionRequest) =>
// TODO: Need to call State API to start the operation.
throw new NotImplementedException();
public async Task<LongRunningOperation> StartCompletionOperation(string instanceId, CompletionRequest completionRequest)
{
if (completionRequest.GatekeeperOptions != null && completionRequest.GatekeeperOptions.Length > 0)
{
_gatekeeperServiceSettings.EnableAzureContentSafety = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafety);
_gatekeeperServiceSettings.EnableMicrosoftPresidio = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.MicrosoftPresidio);
_gatekeeperServiceSettings.EnableLakeraGuard = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.LakeraGuard);
_gatekeeperServiceSettings.EnableEnkryptGuardrails = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.EnkryptGuardrails);
_gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield = completionRequest.GatekeeperOptions.Any(x => x == GatekeeperOptionNames.AzureContentSafetyPromptShield);
}

if (_gatekeeperServiceSettings.EnableLakeraGuard)
{
var promptInjectionResult = await _lakeraGuardService.DetectPromptInjection(completionRequest.UserPrompt!);

if (!string.IsNullOrWhiteSpace(promptInjectionResult))
return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
}

if (_gatekeeperServiceSettings.EnableEnkryptGuardrails)
{
var promptInjectionResult = await _enkryptGuardrailsService.DetectPromptInjection(completionRequest.UserPrompt!);

if (!string.IsNullOrWhiteSpace(promptInjectionResult))
return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
}

if (_gatekeeperServiceSettings.EnableAzureContentSafetyPromptShield)
{
var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!);

if (!string.IsNullOrWhiteSpace(promptInjectionResult))
return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = promptInjectionResult, Status = OperationStatus.Failed };
}

if (_gatekeeperServiceSettings.EnableAzureContentSafety)
{
var contentSafetyResult = await _contentSafetyService.AnalyzeText(completionRequest.UserPrompt!);

if (!contentSafetyResult.Safe)
return new LongRunningOperation() { OperationId = completionRequest.OperationId, StatusMessage = contentSafetyResult.Reason, Status = OperationStatus.Failed };
}

var response = await _orchestrationAPIService.StartCompletionOperation(instanceId, completionRequest);

return response;
}

/// <inheritdoc/>
public Task<LongRunningOperation> GetCompletionOperationStatus(string instanceId, string operationId) => throw new NotImplementedException();
public async Task<LongRunningOperation> GetCompletionOperationStatus(string instanceId, string operationId) =>
await _orchestrationAPIService.GetCompletionOperationStatus(instanceId, operationId);

/// <inheritdoc/>
public async Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId) =>
// TODO: Need to call State API to get the operation.
throw new NotImplementedException();
public async Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId)
{
var completionResponse = await _orchestrationAPIService.GetCompletionOperationResult(instanceId, operationId);

if (_gatekeeperServiceSettings.EnableMicrosoftPresidio)
completionResponse.Completion = await _gatekeeperIntegrationAPIService.AnonymizeText(completionResponse.Completion);

return completionResponse;
}
}
}

0 comments on commit 62c2e2f

Please sign in to comment.