Skip to content

Commit

Permalink
Merge pull request #1311 from solliancenet/aa-gatekeeper-api
Browse files Browse the repository at this point in the history
  • Loading branch information
ciprianjichici authored Aug 1, 2024
2 parents c52f0c2 + 64bda2f commit 827b071
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,11 @@ public static partial class AppConfigurationKeyFilters
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker:*";

/// <summary>
/// Filter for the configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API.
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI:*";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,11 @@ public static partial class AppConfigurationKeySections
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker";

/// <summary>
/// Configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API.
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI";
}
}
11 changes: 11 additions & 0 deletions src/dotnet/Common/Constants/Configuration/AppConfigurationKeys.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,13 @@ public static class AppConfigurationKeys
/// </summary>
public const string FoundationaLLM_Events_Profiles_VectorizationWorker =
"FoundationaLLM:Events:Profiles:VectorizationWorker";

/// <summary>
/// The app configuration key for the FoundationaLLM:Events:Profiles:GatekeeperAPI setting.
/// <para>Value description:<br/>The settings used by the Gatekeeper API to process Azure Event Grid events.</para>
/// </summary>
public const string FoundationaLLM_Events_Profiles_GatekeeperAPI =
"FoundationaLLM:Events:Profiles:GatekeeperAPI";

#endregion

Expand All @@ -1105,5 +1112,9 @@ public static class AppConfigurationKeys
#region FoundationaLLM:Events:Profiles:VectorizationWorker

#endregion

#region FoundationaLLM:Events:Profiles:GatekeeperAPI

#endregion
}
}
16 changes: 16 additions & 0 deletions src/dotnet/Common/Constants/Data/AppConfiguration.json
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,14 @@
"value": "${env:FOUNDATIONALLM_VECTORIZATION_WORKER_EVENT_GRID_PROFILE}",
"content_type": "application/json",
"first_version": "0.8.0"
},
{
"name": "GatekeeperAPI",
"description": "The settings used by the Gatekeeper API to process Azure Event Grid events.",
"secret": "",
"value": "${env:FOUNDATIONALLM_GATEKEEPER_API_EVENT_GRID_PROFILE}",
"content_type": "application/json",
"first_version": "0.8.0"
}
]
},
Expand Down Expand Up @@ -1522,5 +1530,13 @@
"description": "Configuration section used to identify the settings for the events infrastructure used by the Vectorization Worker service."
},
"configuration_keys": []
},
{
"namespace": "Events:Profiles:GatekeeperAPI",
"dependency_injection_key": null,
"configuration_section": {
"description": "Configuration section used to identify the settings for the events infrastructure used by the Gatekeeper API."
},
"configuration_keys": []
}
]
1 change: 0 additions & 1 deletion src/dotnet/Gatekeeper/Gatekeeper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageReference Include="Azure.AI.ContentSafety" Version="1.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
54 changes: 54 additions & 0 deletions src/dotnet/Gatekeeper/Models/ContentSafety/AnalyzeTextResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Text.Json.Serialization;

namespace FoundationaLLM.Gatekeeper.Core.Models.ContentSafety
{
/// <summary>
/// Text analysis results.
/// </summary>
public class AnalyzeTextResult
{
/// <summary> Analysis result for categories. </summary>
[JsonPropertyName("categoriesAnalysis")]
public required List<TextCategoryResult> CategoriesAnalysis { get; set; }
}

/// <summary>
/// Text category results.
/// </summary>
public class TextCategoryResult
{
/// <summary> The text analysis category. </summary>
[JsonPropertyName("category")]
public required string Category { get; set; }

/// <summary> The value increases with the severity of the input content. The value of this field is determined by the output type specified in the request. The output type could be ‘FourSeverityLevels’ or ‘EightSeverity Levels’, and the output value can be 0, 2, 4, 6 or 0, 1, 2, 3, 4, 5, 6, or 7. </summary>
[JsonPropertyName("severity")]
public int? Severity { get; set; }
}

/// <summary>
/// Text category constants.
/// </summary>
public static class TextCategory
{
/// <summary>
/// Hate.
/// </summary>
public const string Hate = "Hate";

/// <summary>
/// Violence.
/// </summary>
public const string Violence = "Violence";

/// <summary>
/// SelfHarm.
/// </summary>
public const string SelfHarm = "SelfHarm";

/// <summary>
/// Sexual.
/// </summary>
public const string Sexual = "Sexual";
}
}
15 changes: 7 additions & 8 deletions src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Azure;
using Azure.AI.ContentSafety;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Gatekeeper.Core.Interfaces;
Expand Down Expand Up @@ -47,17 +46,17 @@ public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
{
var client = await _httpClientFactoryService.CreateClient(HttpClientNames.AzureContentSafety, _callContext.CurrentUserIdentity);

Response<AnalyzeTextResult>? results = null;
AnalyzeTextResult? results = null;
try
{
var response = await client.PostAsync("/contentsafety/text:analyze?api-version=2023-10-01",
new StringContent(JsonSerializer.Serialize(new AnalyzeTextOptions(content)),
new StringContent(JsonSerializer.Serialize(new { text = content }),
Encoding.UTF8, "application/json"));

if (response.IsSuccessStatusCode)
{
var responseContent = await response.Content.ReadAsStringAsync();
results = JsonSerializer.Deserialize<Response<AnalyzeTextResult>>(responseContent);
results = JsonSerializer.Deserialize<AnalyzeTextResult>(responseContent);
}
}
catch (RequestFailedException ex)
Expand All @@ -72,28 +71,28 @@ public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
var safe = true;
var reason = "The prompt text did not pass the content safety filter. Reason:";

var hateSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
var hateSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
if (hateSeverity > _settings.HateSeverity)
{
reason += $" hate";
safe = false;
}

var violenceSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
var violenceSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
if (violenceSeverity > _settings.ViolenceSeverity)
{
reason += $" violence";
safe = false;
}

var selfHarmSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
var selfHarmSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
if (selfHarmSeverity > _settings.SelfHarmSeverity)
{
reason += $" self-harm";
safe = false;
}

var sexualSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
var sexualSeverity = results.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
if (sexualSeverity > _settings.SexualSeverity)
{
reason += $" sexual";
Expand Down
4 changes: 2 additions & 2 deletions src/dotnet/Gatekeeper/Services/GatekeeperService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public async Task<CompletionResponse> GetCompletion(string instanceId, Completio
var promptInjectionResult = await _contentSafetyService.DetectPromptInjection(completionRequest.UserPrompt!);

if (!string.IsNullOrWhiteSpace(promptInjectionResult))
return new CompletionResponse() { OperationId=completionRequest.OperationId, Completion = promptInjectionResult };
return new CompletionResponse() { OperationId = completionRequest.OperationId, Completion = promptInjectionResult };
}

if (_gatekeeperServiceSettings.EnableAzureContentSafety)
Expand Down Expand Up @@ -103,6 +103,6 @@ public async Task<LongRunningOperation> StartCompletionOperation(string instance
public async Task<CompletionResponse> GetCompletionOperationResult(string instanceId, string operationId) =>
// TODO: Need to call State API to get the operation.
throw new NotImplementedException();

}
}
1 change: 1 addition & 0 deletions src/dotnet/GatekeeperAPI/GatekeeperAPI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

<ItemGroup>
<ProjectReference Include="..\Authorization\Authorization.csproj" />
<ProjectReference Include="..\Configuration\Configuration.csproj" />
<ProjectReference Include="..\Gatekeeper\Gatekeeper.csproj" />
</ItemGroup>

Expand Down
18 changes: 18 additions & 0 deletions src/dotnet/GatekeeperAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
using FoundationaLLM.Common.Models.Configuration.Instance;
using FoundationaLLM.Common.Models.Context;
using FoundationaLLM.Common.OpenAPI;
using FoundationaLLM.Common.Services.Azure;
using FoundationaLLM.Common.Services.Security;
using FoundationaLLM.Common.Validation;
using FoundationaLLM.Gatekeeper.Core.Interfaces;
using FoundationaLLM.Gatekeeper.Core.Models.ConfigurationOptions;
using FoundationaLLM.Gatekeeper.Core.Services;
Expand Down Expand Up @@ -46,6 +48,10 @@ public static void Main(string[] args)
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Instance);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Configuration);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints_GatekeeperAPI_Configuration);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints_GatekeeperAPI_Essentials);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_ResourceProviders_Configuration_Storage);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Events_Profiles_GatekeeperAPI);

//TODO: Replace this with a more granular approach that would only bring in the configuration namespaces that are actually needed.
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIEndpoints);
Expand Down Expand Up @@ -73,6 +79,18 @@ public static void Main(string[] args)
builder.Services.AddOptions<InstanceSettings>()
.Bind(builder.Configuration.GetSection(AppConfigurationKeySections.FoundationaLLM_Instance));

// Add Azure ARM services
builder.Services.AddAzureResourceManager();

// Add event services
builder.Services.AddAzureEventGridEvents(
builder.Configuration,
AppConfigurationKeySections.FoundationaLLM_Events_Profiles_GatekeeperAPI);

// Add resource providers
builder.Services.AddSingleton<IResourceValidatorFactory, ResourceValidatorFactory>();
builder.AddConfigurationResourceProvider();

// Register the downstream services and HTTP clients.
builder.AddHttpClientFactoryService();
builder.AddDownstreamAPIService(HttpClientNames.GatekeeperIntegrationAPI);
Expand Down

0 comments on commit 827b071

Please sign in to comment.