Skip to content

Commit

Permalink
Merge pull request #678 from solliancenet/cj-fix-aks-managed-identity…
Browse files Browse the repository at this point in the history
…-auth

Fix AKS managed identity auth
  • Loading branch information
ciprianjichici authored Feb 27, 2024
2 parents 6a0aa7b + 34a80c1 commit 5bb59c0
Show file tree
Hide file tree
Showing 20 changed files with 57 additions and 38 deletions.
4 changes: 3 additions & 1 deletion src/dotnet/AgentFactoryAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static void Main(string[] args)
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIs);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_AgentFactory);
Expand All @@ -61,6 +61,8 @@ public static void Main(string[] args)
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add services to the container.
// Add the OpenTelemetry telemetry service and send telemetry data to Azure Monitor.
builder.Services.AddOpenTelemetry().UseAzureMonitor(options =>
Expand Down
24 changes: 10 additions & 14 deletions src/dotnet/Common/Authentication/DefaultAuthentication.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Azure.Core;
using Azure.Identity;
using FoundationaLLM.Common.Constants;

namespace FoundationaLLM.Common.Authentication
{
Expand All @@ -9,21 +10,16 @@ namespace FoundationaLLM.Common.Authentication
public static class DefaultAuthentication
{
/// <summary>
/// The default Azure credential to use for authentication.
/// Indicates whether the environment we run in is production or not.
/// </summary>
public static TokenCredential GetAzureCredential(bool development = false) => new DefaultAzureCredential(new DefaultAzureCredentialOptions
{
ExcludeAzureDeveloperCliCredential = true,
ExcludeAzurePowerShellCredential = true,
ExcludeEnvironmentCredential = true,
ExcludeInteractiveBrowserCredential = true,
ExcludeSharedTokenCacheCredential = true,
ExcludeVisualStudioCodeCredential = true,
ExcludeVisualStudioCredential = true,
ExcludeWorkloadIdentityCredential = true,
public static bool Production { get; set; }

ExcludeAzureCliCredential = !development,
ExcludeManagedIdentityCredential = development
});
/// <summary>
/// The default Azure credential to use for authentication.
/// </summary>
public static TokenCredential GetAzureCredential() =>
Production
? new ManagedIdentityCredential(Environment.GetEnvironmentVariable(EnvironmentVariables.AzureClientId))
: new AzureCliCredential();
}
}
5 changes: 5 additions & 0 deletions src/dotnet/Common/Constants/EnvironmentVariables.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ namespace FoundationaLLM.Common.Constants
/// </summary>
public static class EnvironmentVariables
{
/// <summary>
/// The client id of the user assigned managed identity.
/// </summary>
public const string AzureClientId = "AZURE_CLIENT_ID";

/// <summary>
/// The Azure Container App or Azure Kubernetes Service hostname.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ namespace FoundationaLLM.Common.Services.Azure
/// <summary>
/// Provides services to interact with the Azure Resource Manager (ARM) infrastructure.
/// </summary>
/// <param name="environment">The <see cref="IHostEnvironment"/> providing details about the environment.</param>
/// <param name="logger">The logger used for logging.</param>
public class AzureResourceManagerService(
IHostEnvironment environment,
ILogger<AzureResourceManagerService> logger) : IAzureResourceManagerService
{
private readonly ArmClient _armClient = new(DefaultAuthentication.GetAzureCredential(
environment.IsDevelopment()));
private readonly ArmClient _armClient = new(DefaultAuthentication.GetAzureCredential());
private readonly ILogger<AzureResourceManagerService> _logger = logger;

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Azure;
using Azure.Identity;
using Azure.Messaging.EventGrid.Namespaces;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Interfaces;
Expand Down Expand Up @@ -331,7 +332,7 @@ private void ValidateAPIKey(string? value)
try
{
ValidateEndpoint(_settings.Endpoint);
client = new EventGridClient(new Uri(_settings.Endpoint!), new DefaultAzureCredential());
client = new EventGridClient(new Uri(_settings.Endpoint!), DefaultAuthentication.GetAzureCredential());
}
catch (Exception ex)
{
Expand Down
3 changes: 2 additions & 1 deletion src/dotnet/Common/Services/Storage/BlobStorageService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Azure.Storage.Blobs;
using Azure.Storage.Blobs.Models;
using Azure.Storage.Blobs.Specialized;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Extensions;
using FoundationaLLM.Common.Interfaces;
Expand Down Expand Up @@ -169,6 +170,6 @@ protected override void CreateClientFromConnectionString(string connectionString
protected override void CreateClientFromIdentity(string accountName) =>
_blobServiceClient = new BlobServiceClient(
new Uri($"https://{accountName}.dfs.core.windows.net"),
new DefaultAzureCredential());
DefaultAuthentication.GetAzureCredential());
}
}
3 changes: 2 additions & 1 deletion src/dotnet/Common/Services/Storage/DataLakeStorageService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Azure.Identity;
using Azure.Storage;
using Azure.Storage.Files.DataLake;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.Configuration.Storage;
Expand Down Expand Up @@ -112,6 +113,6 @@ protected override void CreateClientFromConnectionString(string connectionString
protected override void CreateClientFromIdentity(string accountName) =>
_dataLakeClient = new DataLakeServiceClient(
new Uri($"https://{accountName}.dfs.core.windows.net"),
new DefaultAzureCredential());
DefaultAuthentication.GetAzureCredential());
}
}
3 changes: 1 addition & 2 deletions src/dotnet/Configuration/Services/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ public static void AddConfigurationResourceProvider(this IHostApplicationBuilder
{
var keyVaultUri = builder.Configuration[AppConfigurationKeys.FoundationaLLM_Configuration_KeyVaultURI];
clientBuilder.AddSecretClient(new Uri(keyVaultUri!))
.WithCredential(DefaultAuthentication.GetAzureCredential(
builder.Environment.IsDevelopment()));
.WithCredential(DefaultAuthentication.GetAzureCredential());
clientBuilder.AddConfigurationClient(
builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
});
Expand Down
6 changes: 4 additions & 2 deletions src/dotnet/CoreAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static void Main(string[] args)
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIs);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_CosmosDB);
Expand All @@ -58,7 +58,9 @@ public static void Main(string[] args)
});
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);


DefaultAuthentication.Production = builder.Environment.IsProduction();

var allowAllCorsOrigins = "AllowAllOrigins";
builder.Services.AddCors(policyBuilder =>
{
Expand Down
3 changes: 2 additions & 1 deletion src/dotnet/CoreWorker/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.Identity;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Constants;
using FoundationaLLM.Core.Interfaces;
using FoundationaLLM.Core.Models.Configuration;
Expand All @@ -16,7 +17,7 @@

options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_CoreWorker);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_CosmosDB);
Expand Down
4 changes: 3 additions & 1 deletion src/dotnet/GatekeeperAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static void Main(string[] args)
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIs);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Refinement);
Expand All @@ -55,6 +55,8 @@ public static void Main(string[] args)
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add services to the container.
// Add the OpenTelemetry telemetry service and send telemetry data to Azure Monitor.
builder.Services.AddOpenTelemetry().UseAzureMonitor(options =>
Expand Down
4 changes: 3 additions & 1 deletion src/dotnet/ManagementAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static async Task Main(string[] args)
builder.Configuration.AddAzureAppConfiguration(options =>
{
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options => { options.SetCredential(new DefaultAzureCredential()); });
options.ConfigureKeyVault(options => { options.SetCredential(DefaultAuthentication.GetAzureCredential()); });
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Instance);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIs);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_CosmosDB);
Expand All @@ -67,6 +67,8 @@ public static async Task Main(string[] args)
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add the Configuration resource provider
builder.AddConfigurationResourceProvider();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.Identity;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.TextEmbedding;
Expand Down Expand Up @@ -77,7 +78,7 @@ private AzureAISearchMemoryStore CreateMemoryStoreFromAPIKey(string endpoint, st
/// <param name="endpoint">The endpoint of the Azure AI Search deployment.</param>
/// <returns>The <see cref="Kernel"/> instance.</returns>
private AzureAISearchMemoryStore CreateMemoryStoreFromIdentity(string endpoint) =>
new(endpoint, new DefaultAzureCredential());
new(endpoint, DefaultAuthentication.GetAzureCredential());

private void ValidateEndpoint(string? value)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.Identity;
using FoundationaLLM.Common.Authentication;
using FoundationaLLM.Common.Exceptions;
using FoundationaLLM.Common.Interfaces;
using FoundationaLLM.Common.Models.TextEmbedding;
Expand Down Expand Up @@ -76,7 +77,7 @@ private Kernel CreateKernelFromAPIKey(string deploymentName, string endpoint, st
private Kernel CreateKernelFromIdentity(string deploymentName, string endpoint)
{
var builder = Kernel.CreateBuilder();
builder.AddAzureOpenAITextEmbeddingGeneration(deploymentName, endpoint, new DefaultAzureCredential());
builder.AddAzureOpenAITextEmbeddingGeneration(deploymentName, endpoint, DefaultAuthentication.GetAzureCredential());
return builder.Build();
}

Expand Down
4 changes: 3 additions & 1 deletion src/dotnet/SemanticKernelAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static void Main(string[] args)
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_APIs);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_DurableSystemPrompt);
Expand All @@ -53,6 +53,8 @@ public static void Main(string[] args)
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add services to the container.
//builder.Services.AddApplicationInsightsTelemetry();
builder.Services.AddAuthorization();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using FoundationaLLM.Vectorization.Models.Resources;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

namespace FoundationaLLM.Vectorization.Services.ContentSources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System;
using PnP.Core.Model.SharePoint;
using FoundationaLLM.Common.Models.TextEmbedding;
using FoundationaLLM.Common.Authentication;

namespace FoundationaLLM.Vectorization.Services.ContentSources
{
Expand Down Expand Up @@ -90,11 +91,11 @@ private async Task<X509Certificate2> GetCertificate()
{
ValidateSettings();

var certificateClient = new CertificateClient(new Uri(_settings.KeyVaultURL!), new DefaultAzureCredential());
var certificateClient = new CertificateClient(new Uri(_settings.KeyVaultURL!), DefaultAuthentication.GetAzureCredential());
var certificateWithPolicy = await certificateClient.GetCertificateAsync(_settings.CertificateName);
var certificateIdentifier = new KeyVaultSecretIdentifier(certificateWithPolicy.Value.SecretId);

var secretClient = new SecretClient(new Uri(_settings.KeyVaultURL!), new DefaultAzureCredential());
var secretClient = new SecretClient(new Uri(_settings.KeyVaultURL!), DefaultAuthentication.GetAzureCredential());
var secret = await secretClient.GetSecretAsync(certificateIdentifier.Name, certificateIdentifier.Version);
var secretBytes = Convert.FromBase64String(secret.Value.Value);

Expand Down
5 changes: 3 additions & 2 deletions src/dotnet/VectorizationAPI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(DefaultAuthentication.GetAzureCredential(
builder.Environment.IsDevelopment()));
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Instance);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Vectorization);
Expand All @@ -45,6 +44,8 @@
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add the Configuration resource provider
builder.AddConfigurationResourceProvider();

Expand Down
4 changes: 3 additions & 1 deletion src/dotnet/VectorizationWorker/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
options.Connect(builder.Configuration[EnvironmentVariables.FoundationaLLM_AppConfig_ConnectionString]);
options.ConfigureKeyVault(options =>
{
options.SetCredential(new DefaultAzureCredential());
options.SetCredential(DefaultAuthentication.GetAzureCredential());
});
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Instance);
options.Select(AppConfigurationKeyFilters.FoundationaLLM_Vectorization);
Expand All @@ -46,6 +46,8 @@
if (builder.Environment.IsDevelopment())
builder.Configuration.AddJsonFile("appsettings.development.json", true, true);

DefaultAuthentication.Production = builder.Environment.IsProduction();

// Add the Configuration resource provider
builder.AddConfigurationResourceProvider();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Azure.Search.Documents.Indexes.Models;
using SemanticKernel.Tests.Models;
using FoundationaLLM.Common.Models.TextEmbedding;
using FoundationaLLM.Common.Authentication;

namespace FoundationaLLM.SemanticKernel.Tests.Services
{
Expand All @@ -23,7 +24,7 @@ public AzureAISearchIndexingServiceTests()
var endpoint = Environment.GetEnvironmentVariable("AzureAISearchIndexingServiceTestsSearchEndpoint") ?? "";
_searchIndexClient = new SearchIndexClient(
new Uri(endpoint),
new DefaultAzureCredential()
DefaultAuthentication.GetAzureCredential()
);
_indexingService = new AzureAISearchIndexingService(
Options.Create(
Expand Down

0 comments on commit 5bb59c0

Please sign in to comment.