diff --git a/src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs b/src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs index eced31d7d8..b6a02d7c56 100644 --- a/src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs +++ b/src/dotnet/Gatekeeper/Services/AzureContentSafetyService.cs @@ -1,6 +1,5 @@ using Azure; using Azure.AI.ContentSafety; -using FoundationaLLM.Common.Authentication; using FoundationaLLM.Common.Constants; using FoundationaLLM.Common.Interfaces; using FoundationaLLM.Gatekeeper.Core.Interfaces; @@ -11,7 +10,6 @@ using System.Text; using System.Text.Json; - namespace FoundationaLLM.Gatekeeper.Core.Services { /// @@ -21,7 +19,6 @@ public class AzureContentSafetyService : IContentSafetyService { private readonly ICallContext _callContext; private readonly IHttpClientFactoryService _httpClientFactoryService; - private readonly ContentSafetyClient _client; private readonly AzureContentSafetySettings _settings; private readonly ILogger _logger; @@ -43,52 +40,60 @@ public AzureContentSafetyService( _httpClientFactoryService = httpClientFactoryService; _settings = options.Value; _logger = logger; - - //TODO: Use IHttpClientFactoryService - //_client = new ContentSafetyClient(new Uri(_settings.APIUrl), DefaultAuthentication.AzureCredential); } /// public async Task AnalyzeText(string content) { - var request = new AnalyzeTextOptions(content); + var client = await _httpClientFactoryService.CreateClient(HttpClientNames.AzureContentSafety, _callContext.CurrentUserIdentity); - Response response; + Response? results = null; try { - response = await _client.AnalyzeTextAsync(request); + var response = await client.PostAsync("/contentsafety/text:analyze?api-version=2023-10-01", + new StringContent(JsonSerializer.Serialize(new AnalyzeTextOptions(content)), + Encoding.UTF8, "application/json")); + + if (response.IsSuccessStatusCode) + { + var responseContent = await response.Content.ReadAsStringAsync(); + results = JsonSerializer.Deserialize>(responseContent); + } } catch (RequestFailedException ex) { _logger.LogError(ex, $"Analyze prompt text failed with status code: {ex.Status}, error code: {ex.ErrorCode}, message: {ex.Message}"); - return new AnalyzeTextFilterResult { Safe = false, Reason = "The content safety service was unable to validate the prompt text due to an internal error." }; + results = null; } + if (results == null) + return new AnalyzeTextFilterResult { Safe = false, Reason = "The content safety service was unable to validate the prompt text due to an internal error." }; + var safe = true; var reason = "The prompt text did not pass the content safety filter. Reason:"; - var hateSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0; + var hateSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Hate)?.Severity ?? 0; if (hateSeverity > _settings.HateSeverity) { reason += $" hate"; safe = false; } - var violenceSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0; + var violenceSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Violence)?.Severity ?? 0; if (violenceSeverity > _settings.ViolenceSeverity) { reason += $" violence"; safe = false; } - var selfHarmSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0; + var selfHarmSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.SelfHarm)?.Severity ?? 0; if (selfHarmSeverity > _settings.SelfHarmSeverity) { reason += $" self-harm"; safe = false; } - var sexualSeverity = response.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0; + var sexualSeverity = results.Value.CategoriesAnalysis.FirstOrDefault(a => a.Category == TextCategory.Sexual)?.Severity ?? 0; if (sexualSeverity > _settings.SexualSeverity) { reason += $" sexual"; @@ -121,6 +126,7 @@ public async Task AnalyzeText(string content) return "The prompt text did not pass the safety filter. Reason: Prompt injection or jailbreak detected."; } } + return null; } }