Skip to content

Commit

Permalink
feat: add credential_type plugin config
Browse files Browse the repository at this point in the history
Signed-off-by: Junjie Gao <[email protected]>
  • Loading branch information
JeyJeyGao committed Mar 28, 2024
1 parent 62f1a4c commit f70be2c
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ coverage.txt
.vscode/
.idea/
.devcontainer
.mono/

# binary output
bin/
Expand Down
52 changes: 52 additions & 0 deletions Notation.Plugin.AzureKeyVault.Tests/KeyVault/CredentialsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using Xunit;
using Azure.Core;
using System.Collections.Generic;
using Notation.Plugin.Protocol;

namespace Notation.Plugin.AzureKeyVault.Credential.Tests
{
public class CredentialsTests
{
[Theory]
[InlineData("default")]
[InlineData("environment")]
[InlineData("workloadidentity")]
[InlineData("managedidentity")]
[InlineData("azurecli")]
public void GetCredentials_WithValidCredentialType_ReturnsExpectedCredential(string credentialType)
{
// Act
var result = Credentials.GetCredentials(credentialType);

// Assert
Assert.IsAssignableFrom<TokenCredential>(result);
}

[Fact]
public void GetCredentials_WithInvalidCredentialType_ThrowsValidationException()
{
// Arrange
string invalidCredentialType = "invalid";

// Act & Assert
var ex = Assert.Throws<ValidationException>(() => Credentials.GetCredentials(invalidCredentialType));
Assert.Equal($"Invalid credential key: {invalidCredentialType}", ex.Message);
}

[Fact]
public void GetCredentials_WithPluginConfig_ReturnsExpectedCredential()
{
// Arrange
var pluginConfig = new Dictionary<string, string>
{
{ "credential_type", "default" }
};

// Act
var result = Credentials.GetCredentials(pluginConfig);

// Assert
Assert.IsAssignableFrom<TokenCredential>(result);
}
}
}
27 changes: 14 additions & 13 deletions Notation.Plugin.AzureKeyVault.Tests/KeyVault/KeyVaultClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Azure.Security.KeyVault.Keys.Cryptography;
using Azure.Security.KeyVault.Secrets;
using Moq;
using Notation.Plugin.AzureKeyVault.Credential;
using Notation.Plugin.Protocol;
using Xunit;

Expand All @@ -23,7 +24,7 @@ public void TestConstructorWithKeyId()
{
string keyId = "https://myvault.vault.azure.net/keys/my-key/123";

KeyVaultClient keyVaultClient = new KeyVaultClient(keyId);
KeyVaultClient keyVaultClient = new KeyVaultClient(keyId, Credentials.GetCredentials("default"));

Assert.Equal("my-key", keyVaultClient.Name);
Assert.Equal("123", keyVaultClient.Version);
Expand All @@ -37,7 +38,7 @@ public void TestConstructorWithKeyVaultUrlNameVersion()
string name = "my-key";
string version = "123";

KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version);
KeyVaultClient keyVaultClient = new KeyVaultClient(keyVaultUrl, name, version, Credentials.GetCredentials("default"));

Assert.Equal(name, keyVaultClient.Name);
Assert.Equal(version, keyVaultClient.Version);
Expand All @@ -51,32 +52,32 @@ public void TestConstructorWithKeyVaultUrlNameVersion()
[InlineData("http://myvault.vault.azure.net/keys/my-key/123")]
public void TestConstructorWithInvalidKeyId(string invalidKeyId)
{
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId));
Assert.Throws<ValidationException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default")));
}

[Theory]
[InlineData("")]
public void TestConstructorWithEmptyKeyId(string invalidKeyId)
{
Assert.Throws<ArgumentNullException>(() => new KeyVaultClient(invalidKeyId));
Assert.Throws<ArgumentNullException>(() => new KeyVaultClient(invalidKeyId, Credentials.GetCredentials("default")));
}

private class TestableKeyVaultClient : KeyVaultClient
{
public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CryptographyClient cryptoClient)
: base(keyVaultUrl, name, version)
public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CryptographyClient cryptoClient, TokenCredential credenital)
: base(keyVaultUrl, name, version, credenital)
{
this._cryptoClient = new Lazy<CryptographyClient>(() => cryptoClient);
}

public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CertificateClient certificateClient)
: base(keyVaultUrl, name, version)
public TestableKeyVaultClient(string keyVaultUrl, string name, string version, CertificateClient certificateClient, TokenCredential credenital)
: base(keyVaultUrl, name, version, credenital)
{
this._certificateClient = new Lazy<CertificateClient>(() => certificateClient);
}

public TestableKeyVaultClient(string keyVaultUrl, string name, string version, SecretClient secretClient)
: base(keyVaultUrl, name, version)
public TestableKeyVaultClient(string keyVaultUrl, string name, string version, SecretClient secretClient, TokenCredential credenital)
: base(keyVaultUrl, name, version, credenital)
{
this._secretClient = new Lazy<SecretClient>(() => secretClient);
}
Expand All @@ -88,7 +89,7 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(SignResult signResult)
mockCryptoClient.Setup(c => c.SignDataAsync(It.IsAny<SignatureAlgorithm>(), It.IsAny<byte[]>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(signResult);

return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object);
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-key", "123", mockCryptoClient.Object, Credentials.GetCredentials("default"));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate certificate)
Expand All @@ -97,15 +98,15 @@ private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultCertificate ce
mockCertificateClient.Setup(c => c.GetCertificateVersionAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(Response.FromValue(certificate, new Mock<Response>().Object));

return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object);
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockCertificateClient.Object, Credentials.GetCredentials("default"));
}

private TestableKeyVaultClient CreateMockedKeyVaultClient(KeyVaultSecret secret)
{
var mockSecretClient = new Mock<SecretClient>(new Uri("https://fake.vault.azure.net/secrets/fake-secret/123"), new Mock<TokenCredential>().Object);
mockSecretClient.Setup(c => c.GetSecretAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(Response.FromValue(secret, new Mock<Response>().Object));
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object);
return new TestableKeyVaultClient("https://fake.vault.azure.net", "fake-certificate", "123", mockSecretClient.Object, Credentials.GetCredentials("default"));
}

[Fact]
Expand Down
5 changes: 4 additions & 1 deletion Notation.Plugin.AzureKeyVault/Command/DescribeKey.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text.Json;
using Notation.Plugin.AzureKeyVault.Client;
using Notation.Plugin.AzureKeyVault.Credential;
using Notation.Plugin.Protocol;

namespace Notation.Plugin.AzureKeyVault.Command
Expand All @@ -25,7 +26,9 @@ public DescribeKey(string inputJson)
throw new ValidationException(invalidInputError);
}
this._request = request;
this._keyVaultClient = new KeyVaultClient(request.KeyId);
this._keyVaultClient = new KeyVaultClient(
id: request.KeyId,
credential: Credentials.GetCredentials(request.PluginConfig));
}

/// <summary>
Expand Down
5 changes: 4 additions & 1 deletion Notation.Plugin.AzureKeyVault/Command/GenerateSignature.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Security.Cryptography.X509Certificates;
using System.Text.Json;
using Notation.Plugin.AzureKeyVault.Credential;
using Notation.Plugin.AzureKeyVault.Certificate;
using Notation.Plugin.AzureKeyVault.Client;
using Notation.Plugin.Protocol;
Expand All @@ -26,7 +27,9 @@ public GenerateSignature(string inputJson)
throw new ValidationException("Invalid input");
}
this._request = request;
this._keyVaultClient = new KeyVaultClient(request.KeyId);
this._keyVaultClient = new KeyVaultClient(
id: request.KeyId,
credential: Credentials.GetCredentials(request.PluginConfig));
}

/// <summary>
Expand Down
67 changes: 67 additions & 0 deletions Notation.Plugin.AzureKeyVault/KeyVault/Credentials.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using Azure.Core;
using Azure.Identity;
using Notation.Plugin.Protocol;

namespace Notation.Plugin.AzureKeyVault.Credential
{
public class Credentials
{
/// <summary>
/// Credential type key name in plugin config.
/// </summary>
public const string CredentialTypeKey = "credential_type";
/// <summary>
/// Default credential name.
/// </summary>
public const string DefaultCredentialName = "default";
/// <summary>
/// Environment credential name.
/// </summary>
public const string EnvironmentCredentialName = "environment";
/// <summary>
/// Workload identity credential name.
/// </summary>
public const string WorkloadIdentityCredentialName = "workloadidentity";
/// <summary>
/// Managed identity credential name.
/// </summary>
public const string ManagedIdentityCredentialName = "managedidentity";
/// <summary>
/// Azure CLI credential name.
/// </summary>
public const string AzureCliCredentialName = "azurecli";

/// <summary>
/// Get the credential based on the credential type.
/// </summary>
public static TokenCredential GetCredentials(string credentialType)
{
credentialType = credentialType.ToLower();
switch (credentialType)
{
case DefaultCredentialName:
return new DefaultAzureCredential();
case EnvironmentCredentialName:
return new EnvironmentCredential();
case WorkloadIdentityCredentialName:
return new WorkloadIdentityCredential();
case ManagedIdentityCredentialName:
return new ManagedIdentityCredential();
case AzureCliCredentialName:
return new AzureCliCredential();
default:
throw new ValidationException($"Invalid credential key: {credentialType}");
}
}

/// <summary>
/// Get the credential based on the plugin config.
/// </summary>
public static TokenCredential GetCredentials(Dictionary<string, string>? pluginConfig)
{
var credentialName = pluginConfig?.GetValueOrDefault(CredentialTypeKey, DefaultCredentialName) ??
DefaultCredentialName;
return GetCredentials(credentialName);
}
}
}
17 changes: 9 additions & 8 deletions Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using Azure.Identity;
using Azure.Core;
using Azure.Security.KeyVault.Certificates;
using Azure.Security.KeyVault.Keys.Cryptography;
using Azure.Security.KeyVault.Secrets;
Expand Down Expand Up @@ -62,7 +62,7 @@ private record KeyVaultMetadata(string KeyVaultUrl, string Name, string Version)
/// Constructor to create AzureKeyVault object from keyVaultUrl, name
/// and version.
/// </summary>
public KeyVaultClient(string keyVaultUrl, string name, string version)
public KeyVaultClient(string keyVaultUrl, string name, string version, TokenCredential credential)
{
if (string.IsNullOrEmpty(keyVaultUrl))
{
Expand All @@ -84,7 +84,6 @@ public KeyVaultClient(string keyVaultUrl, string name, string version)
this._keyId = $"{keyVaultUrl}/keys/{name}/{version}";

// initialize credential and lazy clients
var credential = new DefaultAzureCredential();
this._certificateClient = new Lazy<CertificateClient>(() => new CertificateClient(new Uri(keyVaultUrl), credential));
this._cryptoClient = new Lazy<CryptographyClient>(() => new CryptographyClient(new Uri(_keyId), credential));
this._secretClient = new Lazy<SecretClient>(() => new SecretClient(new Uri(keyVaultUrl), credential));
Expand All @@ -93,18 +92,20 @@ public KeyVaultClient(string keyVaultUrl, string name, string version)
/// <summary>
/// Constructor to create AzureKeyVault object from key identifier or
/// certificate identifier.
///
/// </summary>
/// <param name="id">
/// Key identifier or certificate identifier. (e.g. https://<vaultname>.vault.azure.net/keys/<name>/<version>)
/// </param>
/// </summary>
public KeyVaultClient(string id) : this(ParseId(id)) { }
/// <param name="credential">
/// TokenCredential object to authenticate with Azure Key Vault.
/// </param>
public KeyVaultClient(string id, TokenCredential credential) : this(ParseId(id), credential) { }

/// <summary>
/// A helper constructor to create KeyVaultClient from KeyVaultMetadata.
/// </summary>
private KeyVaultClient(KeyVaultMetadata metadata)
: this(metadata.KeyVaultUrl, metadata.Name, metadata.Version) { }
private KeyVaultClient(KeyVaultMetadata metadata, TokenCredential credential)
: this(metadata.KeyVaultUrl, metadata.Name, metadata.Version, credential) { }

/// <summary>
/// A helper function to parse key identifier or certificate identifier
Expand Down

0 comments on commit f70be2c

Please sign in to comment.