Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gatekeeper API stabilization #1311

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,11 @@ public static partial class AppConfigurationKeyFilters
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker:*";

/// <summary>
/// Filter for the configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API.
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI:*";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,11 @@ public static partial class AppConfigurationKeySections
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker";

/// <summary>
/// Configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API.
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,13 @@ public static class AppConfigurationKeys
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker";

/// <summary>
/// The app configuration key for the FoundationaLLM:Events:Profiles:GatekeeperAPI setting.
/// <para>Value description:<br/>The settings used by the Gatekeeper API to process Azure Event Grid events.</para>
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI";

#endregion

Expand All @@ -1105,5 +1112,9 @@ public static class AppConfigurationKeys
#region FoundationaLLM:Events:Profiles:VectorizationWorker

#endregion

#region FoundationaLLM:Events:Profiles:GatekeeperAPI

#endregion
}
}
16 changes: 16 additions & 0 deletions src/dotnet/Common/Constants/Data/AppConfiguration.json
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,14 @@
"value": "${env:FOUNDATIONALLM_VECTORIZATION_WORKER_EVENT_GRID_PROFILE}",
"content_type": "application/json",
"first_version": "0.8.0"
},
{
"name": "GatekeeperAPI",
"description": "The settings used by the Gatekeeper API to process Azure Event Grid events.",
"secret": "",
"value": "${env:FOUNDATIONALLM_GATEKEEPER_API_EVENT_GRID_PROFILE}",
"content_type": "application/json",
"first_version": "0.8.0"
}
]
},
Expand Down Expand Up @@ -1522,5 +1530,13 @@
"description": "Configuration section used to identify the settings for the events infrastructure used by the Vectorization Worker service."
},
"configuration_keys": []
},
{
"namespace": "Events:Profiles:GatekeeperAPI",
"dependency_injection_key": null,
"configuration_section": {
"description": "Configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API."
},
"configuration_keys": []
}
]
1 change: 0 additions & 1 deletion src/dotnet/Gatekeeper/Gatekeeper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageReference Include="Azure.AI.ContentSafety" Version="1.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
54 changes: 54 additions & 0 deletions src/dotnet/Gatekeeper/Models/ContentSafety/AnalyzeTextResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Text.Json.Serialization;

namespace FoundationaLLM.Gatekeeper.Core.Models.ContentSafety
{
/// <summary>
/// Text analysis results.
/// </summary>
public class AnalyzeTextResult
{
/// <summary> Analysis result for categories. </summary>
[JsonPropertyName("categoriesAnalysis")]
public required List<TextCategoryResult> CategoriesAnalysis { get; set; }
}

/// <summary>
/// Text category results.
/// </summary>
public class TextCategoryResult
{
/// <summary> The text analysis category. </summary>
[JsonPropertyName("category")]
public required string Category { get; set; }

/// <summary> The value increases with the severity of the input content. The value of this field is determined by the output type specified in the request. The output type could be ‘FourSeverityLevels’ or ‘EightSeverity Levels’, and the output value can be 0, 2, 4, 6 or 0, 1, 2, 3, 4, 5, 6, or 7. </summary>
[JsonPropertyName("severity")]
public int? Severity { get; set; }
}

/// <summary>
/// Text category constants.
/// </summary>
public static class TextCategory
{
/// <summary>
/// Hate.
/// </summary>
public const string Hate = "Hate";

/// <summary>
/// Violence.
/// </summary>
public const string Violence = "Violence";

/// <summary>
/// SelfHarm.
/// </summary>
public const string SelfHarm = "SelfHarm";

/// <summary>
/// Sexual.
/// </summary>
public const string Sexual = "Sexual";
}
}
15 changes: 7 additions & 8 deletions src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Azure;
using Azure.AI.ContentSafety;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Gatekeeper.Core.Interfaces;
Expand Down Expand Up @@ -47,17 +46,17 @@ public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
{
var client = await _httpClientFactoryService.CreateClient(HttpClientNames.AzureContentSafety, _callContext.CurrentUserIdentity);

Response<AnalyzeTextResult>? results = null;
AnalyzeTextResult? results = null;
try
{
var response = await client.PostAsync("/contentsafety/text:analyze?api-version=2023-10-01",
new StringContent(JsonSerializer.Serialize(new AnalyzeTextOptions(content)),
new StringContent(JsonSerializer.Serialize(new { text = content }),
Encoding.UTF8, "application/json"));

if (response.IsSuccessStatusCode)
{
var responseContent = await response.Content.ReadAsStringAsync();
results = JsonSerializer.Deserialize<Response<AnalyzeTextResult>>(responseContent);
results = JsonSerializer.Deserialize<AnalyzeTextResult>(responseContent);
}
}
catch (RequestFailedException ex)
Expand All @@ -72,28 +71,28 @@ public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
var safe = true;
var reason = "The prompt text did not pass the content safety filter. Reason:";

var hateSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
var hateSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
if (hateSeverity > _settings.HateSeverity)
{
reason += $" hate";
safe = false;
}

var violenceSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
var violenceSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
if (violenceSeverity > _settings.ViolenceSeverity)
{
reason += $" violence";
safe = false;
}

var selfHarmSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
var selfHarmSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
if (selfHarmSeverity > _settings.SelfHarmSeverity)
{
reason += $" self-harm";
safe = false;
}

var sexualSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
var sexualSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
if (sexualSeverity > _settings.SexualSeverity)
{
reason += $" sexual";
Expand Down
4 changes: 2 additions & 2 deletions src/dotnet/Gatekeeper/Services/GatekeeperService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public async Task<CompletionResponse> GetCompletion(string instanceId, Completio
var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!);

if (!string.IsNullOrWhiteSpace(promptInjectionResult))
return new CompletionResponse() { OperationId=completionRequest.OperationId, Completion = promptInjectionResult };
return new CompletionResponse() { OperationId = completionRequest.OperationId, Completion = promptInjectionResult };
}

if (_gatekeeperServiceSettings.EnableAzureContentSafety)
Expand Down Expand Up @@ -103,6 +103,6 @@ public async Task<LongRunningOperation> StartCompletionOperation(string instance
public async Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId) =>
// TODO: Need to call State API to get the operation.
throw new NotImplementedException();

}
}
1 change: 1 addition & 0 deletions src/dotnet/GatekeeperAPI/GatekeeperAPI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

<ItemGroup>
<ProjectReference Include="..\Authorization\Authorization.csproj" />
<ProjectReference Include="..\Configuration\Configuration.csproj" />
<ProjectReference Include="..\Gatekeeper\Gatekeeper.csproj" />
</ItemGroup>

Expand Down
18 changes: 18 additions & 0 deletions src/dotnet/GatekeeperAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
using FoundationaLLM.Common.Models.Configuration.Instance;
using FoundationaLLM.Common.Models.Context;
using FoundationaLLM.Common.OpenAPI;
using FoundationaLLM.Common.Services.Azure;
using FoundationaLLM.Common.Services.Security;
using FoundationaLLM.Common.Validation;
using FoundationaLLM.Gatekeeper.Core.Interfaces;
using FoundationaLLM.Gatekeeper.Core.Models.ConfigurationOptions;
using FoundationaLLM.Gatekeeper.Core.Services;
Expand Down Expand Up @@ -46,6 +48,10 @@ public static void Main(string[] args)
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Instance);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Configuration);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints_GatekeeperAPI_Configuration);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints_GatekeeperAPI_Essentials);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_ResourceProviders_Configuration_Storage);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Events_Profiles_GatekeeperAPI);

//TODO: Replace this with a more granular approach that would only bring in the configuration namespaces that are actually needed.
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints);
Expand Down Expand Up @@ -73,6 +79,18 @@ public static void Main(string[] args)
builder.Services.AddOptions<InstanceSettings>()
.Bind(builder.Configuration.GetSection(AppConfigurationKeySections.FoundationaLLM_Instance));

// Add Azure ARM services
builder.Services.AddAzureResourceManager();

// Add event services
builder.Services.AddAzureEventGridEvents(
builder.Configuration,
AppConfigurationKeySections.FoundationaLLM_Events_Profiles_GatekeeperAPI);

// Add resource providers
builder.Services.AddSingleton<IResourceValidatorFactory, ResourceValidatorFactory>();
builder.AddConfigurationResourceProvider();

// Register the downstream services and HTTP clients.
builder.AddHttpClientFactoryService();
builder.AddDownstreamAPIService(HttpClientNames.GatekeeperIntegrationAPI);
Expand Down
Loading