From f70be2c5e8e010c5d662d2073be920e3b547a47a Mon Sep 17 00:00:00 2001 From: Junjie Gao Date: Thu, 28 Mar 2024 13:32:34 +0800 Subject: [PATCH] feat: add credential_type plugin config Signed-off-by: Junjie Gao --- .gitignore | 1 + .../KeyVault/CredentialsTests.cs | 52 ++++++++++++++ .../KeyVault/KeyVaultClientTests.cs | 27 ++++---- .../Command/DescribeKey.cs | 5 +- .../Command/GenerateSignature.cs | 5 +- .../KeyVault/Credentials.cs | 67 +++++++++++++++++++ .../KeyVault/KeyVaultClient.cs | 17 ++--- 7 files changed, 151 insertions(+), 23 deletions(-) create mode 100644 Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs create mode 100644 Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs diff --git a/.gitignore b/.gitignore index c8322492..81267f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ coverage.txt .vscode/ .idea/ .devcontainer +.mono/ # binary output bin/ diff --git a/Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs b/Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs new file mode 100644 index 00000000..8da746e2 --- /dev/null +++ b/Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs @@ -0,0 +1,52 @@ +using Xunit; +using Azure.Core; +using System.Collections.Generic; +using Notation.Plugin.Protocol; + +namespace Notation.Plugin.AzureKeyVault.Credential.Tests +{ + public class CredentialsTests + { + [Theory] + [InlineData("default")] + [InlineData("environment")] + [InlineData("workloadidentity")] + [InlineData("managedidentity")] + [InlineData("azurecli")] + public void GetCredentials_WithValidCredentialType_ReturnsExpectedCredential(string credentialType) + { + // Act + var result = Credentials.GetCredentials(credentialType); + + // Assert + Assert.IsAssignableFrom(result); + } + + [Fact] + public void GetCredentials_WithInvalidCredentialType_ThrowsValidationException() + { + // Arrange + string invalidCredentialType = "invalid"; + + // Act & Assert + var ex = Assert.Throws(() => Credentials.GetCredentials(invalidCredentialType)); + Assert.Equal($"Invalid credential key: {invalidCredentialType}", ex.Message); + } + + [Fact] + public void GetCredentials_WithPluginConfig_ReturnsExpectedCredential() + { + // Arrange + var pluginConfig = new Dictionary + { + { "credential_type", "default" } + }; + + // Act + var result = Credentials.GetCredentials(pluginConfig); + + // Assert + Assert.IsAssignableFrom(result); + } + } +} \ No newline at end of file diff --git a/Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs b/Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs index da3abd42..3bf2d12d 100644 --- a/Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs +++ b/Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs @@ -11,6 +11,7 @@ using Azure.Security.KeyVault.Keys.Cryptography; using Azure.Security.KeyVault.Secrets; using Moq; +using Notation.Plugin.AzureKeyVault.Credential; using Notation.Plugin.Protocol; using Xunit; @@ -23,7 +24,7 @@ public void TestConstructorWithKeyId() { string keyId = "https://myvault.vault.azure.net/keys/my-key/123"; - KeyVaultClient keyVaultClient = new KeyVaultClient(keyId); + KeyVaultClient keyVaultClient = new KeyVaultClient(keyId, Credentials.GetCredentials("default")); Assert.Equal("my-key", keyVaultClient.Name); Assert.Equal("123", keyVaultClient.Version); @@ -37,7 +38,7 @@ public void TestConstructorWithKeyVaultUrlNameVersion() string name = "my-key"; string version = "123"; - KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version); + KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version, Credentials.GetCredentials("default")); Assert.Equal(name, keyVaultClient.Name); Assert.Equal(version, keyVaultClient.Version); @@ -51,32 +52,32 @@ public void TestConstructorWithKeyVaultUrlNameVersion() [InlineData("http://myvault.vault.azure.net/keys/my-key/123")] public void TestConstructorWithInvalidKeyId(string invalidKeyId) { - Assert.Throws(() => new KeyVaultClient(invalidKeyId)); + Assert.Throws(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default"))); } [Theory] [InlineData("")] public void TestConstructorWithEmptyKeyId(string invalidKeyId) { - Assert.Throws(() => new KeyVaultClient(invalidKeyId)); + Assert.Throws(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default"))); } private class TestableKeyVaultClient : KeyVaultClient { - public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CryptographyClient cryptoClient) - : base(keyVaultUrl, name, version) + public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CryptographyClient cryptoClient, TokenCredential credenital) + : base(keyVaultUrl, name, version, credenital) { this._cryptoClient = new Lazy(() => cryptoClient); } - public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CertificateClient certificateClient) - : base(keyVaultUrl, name, version) + public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CertificateClient certificateClient, TokenCredential credenital) + : base(keyVaultUrl, name, version, credenital) { this._certificateClient = new Lazy(() => certificateClient); } - public TestableKeyVaultClient(string keyVaultUrl, string name, string version, SecretClient secretClient) - : base(keyVaultUrl, name, version) + public TestableKeyVaultClient(string keyVaultUrl, string name, string version, SecretClient secretClient, TokenCredential credenital) + : base(keyVaultUrl, name, version, credenital) { this._secretClient = new Lazy(() => secretClient); } @@ -88,7 +89,7 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(SignResult signResult) mockCryptoClient.Setup(c => c.SignDataAsync(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(signResult); - return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object); + return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object, Credentials.GetCredentials("default")); } private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate certificate) @@ -97,7 +98,7 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate ce mockCertificateClient.Setup(c => c.GetCertificateVersionAsync(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(Response.FromValue(certificate, new Mock().Object)); - return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object); + return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object, Credentials.GetCredentials("default")); } private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultSecret secret) @@ -105,7 +106,7 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultSecret secret) var mockSecretClient = new Mock(new Uri("https://fake.vault.azure.net/secrets/fake-secret/123"), new Mock().Object); mockSecretClient.Setup(c => c.GetSecretAsync(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(Response.FromValue(secret, new Mock().Object)); - return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object); + return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object, Credentials.GetCredentials("default")); } [Fact] diff --git a/Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs b/Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs index 6e4f714d..4a4b1f49 100644 --- a/Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs +++ b/Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs @@ -1,5 +1,6 @@ using System.Text.Json; using Notation.Plugin.AzureKeyVault.Client; +using Notation.Plugin.AzureKeyVault.Credential; using Notation.Plugin.Protocol; namespace Notation.Plugin.AzureKeyVault.Command @@ -25,7 +26,9 @@ public DescribeKey(string inputJson) throw new ValidationException(invalidInputError); } this._request = request; - this._keyVaultClient = new KeyVaultClient(request.KeyId); + this._keyVaultClient = new KeyVaultClient( + id: request.KeyId, + credential: Credentials.GetCredentials(request.PluginConfig)); } /// diff --git a/Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs b/Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs index da1320c9..67f06ea4 100644 --- a/Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs +++ b/Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs @@ -1,5 +1,6 @@ using System.Security.Cryptography.X509Certificates; using System.Text.Json; +using Notation.Plugin.AzureKeyVault.Credential; using Notation.Plugin.AzureKeyVault.Certificate; using Notation.Plugin.AzureKeyVault.Client; using Notation.Plugin.Protocol; @@ -26,7 +27,9 @@ public GenerateSignature(string inputJson) throw new ValidationException("Invalid input"); } this._request = request; - this._keyVaultClient = new KeyVaultClient(request.KeyId); + this._keyVaultClient = new KeyVaultClient( + id: request.KeyId, + credential: Credentials.GetCredentials(request.PluginConfig)); } /// diff --git a/Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs b/Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs new file mode 100644 index 00000000..b564aacc --- /dev/null +++ b/Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs @@ -0,0 +1,67 @@ +using Azure.Core; +using Azure.Identity; +using Notation.Plugin.Protocol; + +namespace Notation.Plugin.AzureKeyVault.Credential +{ + public class Credentials + { + /// + /// Credential type key name in plugin config. + /// + public const string CredentialTypeKey = "credential_type"; + /// + /// Default credential name. + /// + public const string DefaultCredentialName = "default"; + /// + /// Environment credential name. + /// + public const string EnvironmentCredentialName = "environment"; + /// + /// Workload identity credential name. + /// + public const string WorkloadIdentityCredentialName = "workloadidentity"; + /// + /// Managed identity credential name. + /// + public const string ManagedIdentityCredentialName = "managedidentity"; + /// + /// Azure CLI credential name. + /// + public const string AzureCliCredentialName = "azurecli"; + + /// + /// Get the credential based on the credential type. + /// + public static TokenCredential GetCredentials(string credentialType) + { + credentialType = credentialType.ToLower(); + switch (credentialType) + { + case DefaultCredentialName: + return new DefaultAzureCredential(); + case EnvironmentCredentialName: + return new EnvironmentCredential(); + case WorkloadIdentityCredentialName: + return new WorkloadIdentityCredential(); + case ManagedIdentityCredentialName: + return new ManagedIdentityCredential(); + case AzureCliCredentialName: + return new AzureCliCredential(); + default: + throw new ValidationException($"Invalid credential key: {credentialType}"); + } + } + + /// + /// Get the credential based on the plugin config. + /// + public static TokenCredential GetCredentials(Dictionary? pluginConfig) + { + var credentialName = pluginConfig?.GetValueOrDefault(CredentialTypeKey, DefaultCredentialName) ?? + DefaultCredentialName; + return GetCredentials(credentialName); + } + } +} \ No newline at end of file diff --git a/Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs b/Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs index c40344a8..6fade0cf 100644 --- a/Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs +++ b/Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs @@ -1,7 +1,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; -using Azure.Identity; +using Azure.Core; using Azure.Security.KeyVault.Certificates; using Azure.Security.KeyVault.Keys.Cryptography; using Azure.Security.KeyVault.Secrets; @@ -62,7 +62,7 @@ private record KeyVaultMetadata(string KeyVaultUrl, string Name, string Version) /// Constructor to create AzureKeyVault object from keyVaultUrl, name /// and version. /// - public KeyVaultClient(string keyVaultUrl, string name, string version) + public KeyVaultClient(string keyVaultUrl, string name, string version, TokenCredential credential) { if (string.IsNullOrEmpty(keyVaultUrl)) { @@ -84,7 +84,6 @@ public KeyVaultClient(string keyVaultUrl, string name, string version) this._keyId = $"{keyVaultUrl}/keys/{name}/{version}"; // initialize credential and lazy clients - var credential = new DefaultAzureCredential(); this._certificateClient = new Lazy(() => new CertificateClient(new Uri(keyVaultUrl), credential)); this._cryptoClient = new Lazy(() => new CryptographyClient(new Uri(_keyId), credential)); this._secretClient = new Lazy(() => new SecretClient(new Uri(keyVaultUrl), credential)); @@ -93,18 +92,20 @@ public KeyVaultClient(string keyVaultUrl, string name, string version) /// /// Constructor to create AzureKeyVault object from key identifier or /// certificate identifier. - /// + /// /// /// Key identifier or certificate identifier. (e.g. https://.vault.azure.net/keys//) /// - /// - public KeyVaultClient(string id) : this(ParseId(id)) { } + /// + /// TokenCredential object to authenticate with Azure Key Vault. + /// + public KeyVaultClient(string id, TokenCredential credential) : this(ParseId(id), credential) { } /// /// A helper constructor to create KeyVaultClient from KeyVaultMetadata. /// - private KeyVaultClient(KeyVaultMetadata metadata) - : this(metadata.KeyVaultUrl, metadata.Name, metadata.Version) { } + private KeyVaultClient(KeyVaultMetadata metadata, TokenCredential credential) + : this(metadata.KeyVaultUrl, metadata.Name, metadata.Version, credential) { } /// /// A helper function to parse key identifier or certificate identifier