Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add Audience for Certificate auth to work with Skills #6794

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public CertificateAppCredentials(CertificateAppCredentialsOptions options)
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, string appId, string channelAuthTenant = null, HttpClient customHttpClient = null, ILogger logger = null)
: this(clientCertificate, false, appId, channelAuthTenant, customHttpClient, logger)
: this(clientCertificate, appId, channelAuthTenant, string.Empty, false, customHttpClient, logger)
{
}

Expand All @@ -62,7 +62,22 @@ public CertificateAppCredentials(X509Certificate2 clientCertificate, string appI
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, bool sendX5c, string appId, string channelAuthTenant = null, HttpClient customHttpClient = null, ILogger logger = null)
: base(channelAuthTenant, customHttpClient, logger)
: this(clientCertificate, appId, channelAuthTenant, string.Empty, sendX5c, customHttpClient, logger)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="CertificateAppCredentials"/> class.
/// </summary>
/// <param name="clientCertificate">Client certificate to be presented for authentication.</param>
/// <param name="appId">Microsoft application Id related to the certifiacte.</param>
/// <param name="channelAuthTenant">Optional. The oauth token tenant.</param>
/// <param name="oAuthScope">Optional. The scope for the token.</param>
/// <param name="sendX5c">Optional. This parameter, if true, enables application developers to achieve easy certificates roll-over in Azure AD: setting this parameter to true will send the public certificate to Azure AD along with the token request, so that Azure AD can use it to validate the subject name based on a trusted issuer policy. </param>
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, string appId, string channelAuthTenant = null, string oAuthScope = null, bool sendX5c = false, HttpClient customHttpClient = null, ILogger logger = null)
: base(channelAuthTenant, customHttpClient, logger, oAuthScope: oAuthScope)
{
if (clientCertificate == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Net.Http;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand All @@ -16,8 +17,13 @@ namespace Microsoft.Bot.Connector.Authentication
/// </summary>
public class CertificateServiceClientCredentialsFactory : ServiceClientCredentialsFactory
{
private readonly CertificateAppCredentials _certificateAppCredentials;
private readonly X509Certificate2 _certificate;
private readonly string _appId;
private readonly string _tenantId;
private readonly bool _sendX5c;
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<string, CertificateAppCredentials> _certificateAppCredentialsByAudience = new ConcurrentDictionary<string, CertificateAppCredentials>();

/// <summary>
/// Initializes a new instance of the <see cref="CertificateServiceClientCredentialsFactory"/> class.
Expand All @@ -44,16 +50,12 @@ public CertificateServiceClientCredentialsFactory(
throw new ArgumentNullException(nameof(appId));
}

_certificate = certificate ?? throw new ArgumentNullException(nameof(certificate));
_appId = appId;

// Instance must be reused otherwise it will cause throttling on AAD.
_certificateAppCredentials = new CertificateAppCredentials(
certificate ?? throw new ArgumentNullException(nameof(certificate)),
sendX5c,
appId,
tenantId,
httpClient,
logger);
_tenantId = tenantId;
_sendX5c = sendX5c;
_httpClient = httpClient;
_logger = logger;
}

/// <inheritdoc />
Expand All @@ -78,7 +80,20 @@ public override Task<ServiceClientCredentials> CreateCredentialsAsync(
throw new InvalidOperationException("Invalid Managed ID.");
}

return Task.FromResult<ServiceClientCredentials>(_certificateAppCredentials);
// Instance must be reused per audience, otherwise it will cause throttling on AAD.
var certificateAppCredentials = _certificateAppCredentialsByAudience.GetOrAdd(audience, (audience) =>
{
return new CertificateAppCredentials(
_certificate,
_appId,
_tenantId,
audience,
_sendX5c,
_httpClient,
_logger);
});

return Task.FromResult<ServiceClientCredentials>(certificateAppCredentials);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class CertificateServiceClientCredentialsFactoryTests
private const string TestAppId = nameof(TestAppId);
private const string TestTenantId = nameof(TestTenantId);
private const string TestAudience = nameof(TestAudience);
private const string LoginEndpoint = "https://login.microsoftonline.com";
private readonly Mock<ILogger> logger = new Mock<ILogger>();
private readonly Mock<X509Certificate2> certificate = new Mock<X509Certificate2>();

Expand Down Expand Up @@ -68,19 +69,38 @@ public async void CanCreateCredentials()
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

var credentials = await factory.CreateCredentialsAsync(
TestAppId, TestAudience, "https://login.microsoftonline.com", true, CancellationToken.None);
TestAppId, TestAudience, LoginEndpoint, true, CancellationToken.None);

Assert.NotNull(credentials);
Assert.IsType<CertificateAppCredentials>(credentials);
}

[Fact]
public async void ShouldCreateUniqueCredentialsByAudience()
{
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

var credentials1 = await factory.CreateCredentialsAsync(
TestAppId, string.Empty, LoginEndpoint, true, CancellationToken.None);
var credentials2 = await factory.CreateCredentialsAsync(
TestAppId, TestAudience, LoginEndpoint, true, CancellationToken.None);
var credentials3 = await factory.CreateCredentialsAsync(
TestAppId, Guid.NewGuid().ToString(), LoginEndpoint, true, CancellationToken.None);
var credentials4 = await factory.CreateCredentialsAsync(
TestAppId, string.Empty, LoginEndpoint, true, CancellationToken.None);

Assert.NotEqual(credentials1, credentials2);
Assert.NotEqual(credentials1, credentials3);
Assert.Equal(credentials1, credentials4);
}

[Fact]
public void CannotCreateCredentialsWithInvalidAppId()
{
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

Assert.ThrowsAsync<InvalidOperationException>(() => factory.CreateCredentialsAsync(
"InvalidAppId", TestAudience, "https://login.microsoftonline.com", true, CancellationToken.None));
"InvalidAppId", TestAudience, LoginEndpoint, true, CancellationToken.None));
}
}
}
Loading