Skip to content

Commit

Permalink
use IHttpClientFactoryService for Azure Content Safety
Browse files Browse the repository at this point in the history
  • Loading branch information
alistar-andrei committed Jul 30, 2024
1 parent a1355ce commit 080a23a
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Azure;
using Azure.AI.ContentSafety;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Gatekeeper.Core.Interfaces;
Expand All @@ -11,7 +10,6 @@
using System.Text;
using System.Text.Json;


namespace FoundationaLLM.Gatekeeper.Core.Services
{
/// <summary>
Expand All @@ -21,7 +19,6 @@ public class AzureContentSafetyService : IContentSafetyService
{
private readonly ICallContext _callContext;
private readonly IHttpClientFactoryService _httpClientFactoryService;
private readonly ContentSafetyClient _client;
private readonly AzureContentSafetySettings _settings;
private readonly ILogger _logger;

Expand All @@ -43,52 +40,60 @@ public AzureContentSafetyService(
_httpClientFactoryService = httpClientFactoryService;
_settings = options.Value;
_logger = logger;

//TODO: Use IHttpClientFactoryService
//_client = new ContentSafetyClient(new Uri(_settings.APIUrl), DefaultAuthentication.AzureCredential);
}

/// <inheritdoc/>
public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
{
var request = new AnalyzeTextOptions(content);
var client = await _httpClientFactoryService.CreateClient(HttpClientNames.AzureContentSafety, _callContext.CurrentUserIdentity);

Response<AnalyzeTextResult> response;
Response<AnalyzeTextResult>? results = null;
try
{
response = await _client.AnalyzeTextAsync(request);
var response = await client.PostAsync("/contentsafety/text:analyze?api-version=2023-10-01",
new StringContent(JsonSerializer.Serialize(new AnalyzeTextOptions(content)),
Encoding.UTF8, "application/json"));

if (response.IsSuccessStatusCode)
{
var responseContent = await response.Content.ReadAsStringAsync();
results = JsonSerializer.Deserialize<Response<AnalyzeTextResult>>(responseContent);
}
}
catch (RequestFailedException ex)
{
_logger.LogError(ex, $"Analyze prompt text failed with status code: {ex.Status}, error code: {ex.ErrorCode}, message: {ex.Message}");
return new AnalyzeTextFilterResult { Safe = false, Reason = "The content safety service was unable to validate the prompt text due to an internal error." };
results = null;
}

if (results == null)
return new AnalyzeTextFilterResult { Safe = false, Reason = "The content safety service was unable to validate the prompt text due to an internal error." };

var safe = true;
var reason = "The prompt text did not pass the content safety filter. Reason:";

var hateSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
var hateSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0;
if (hateSeverity > _settings.HateSeverity)
{
reason += $" hate";
safe = false;
}

var violenceSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
var violenceSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0;
if (violenceSeverity > _settings.ViolenceSeverity)
{
reason += $" violence";
safe = false;
}

var selfHarmSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
var selfHarmSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0;
if (selfHarmSeverity > _settings.SelfHarmSeverity)
{
reason += $" self-harm";
safe = false;
}

var sexualSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
var sexualSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0;
if (sexualSeverity > _settings.SexualSeverity)
{
reason += $" sexual";
Expand Down Expand Up @@ -121,6 +126,7 @@ public async Task<AnalyzeTextFilterResult> AnalyzeText(string content)
return "The prompt text did not pass the safety filter. Reason: Prompt injection or jailbreak detected.";
}
}

return null;
}
}
Expand Down

0 comments on commit 080a23a

Please sign in to comment.