Skip to content

Commit

Permalink
keystore: add support for custom kms tags (#48132)
Browse files Browse the repository at this point in the history
* keystore: add support for custom kms tags

* remove utils.IsEmpty and use pointer for aws kms config

* test tags len

Co-authored-by: Nic Klaassen <[email protected]>

* add yaml struct tags

Co-authored-by: Nic Klaassen <[email protected]>

* Update lib/auth/keystore/aws_kms_test.go

Co-authored-by: Nic Klaassen <[email protected]>

* keep TeleportCluster tag if not specified

* add delete used key test cases with tagging

* fix unit tests

* remove unused tag var

* Update lib/auth/keystore/aws_kms.go

Co-authored-by: Nic Klaassen <[email protected]>

---------

Co-authored-by: Nic Klaassen <[email protected]>
  • Loading branch information
dboslee and nklaassen committed Nov 11, 2024
1 parent a893432 commit 0058967
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 95 deletions.
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled {
return nil, fmt.Errorf("Google Cloud KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise)
}
} else if cfg.KeyStoreConfig.AWSKMS != (servicecfg.AWSKMSConfig{}) {
} else if cfg.KeyStoreConfig.AWSKMS != nil {
if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled {
return nil, fmt.Errorf("AWS KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise)
}
Expand Down
40 changes: 27 additions & 13 deletions lib/auth/keystore/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ type CloudClientProvider interface {

type awsKMSKeystore struct {
kms kmsiface.KMSAPI
clusterName types.ClusterName
awsAccount string
awsRegion string
multiRegionEnabled bool
tags map[string]string
clock clockwork.Clock
logger *slog.Logger
}
Expand All @@ -89,14 +89,23 @@ func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts *
if err != nil {
return nil, trace.Wrap(err)
}

tags := cfg.Tags
if tags == nil {
tags = make(map[string]string, 1)
}
if _, ok := tags[clusterTagKey]; !ok {
tags[clusterTagKey] = opts.ClusterName.GetClusterName()
}

clock := opts.clockworkOverride
if clock == nil {
clock = clockwork.NewRealClock()
}
return &awsKMSKeystore{
clusterName: opts.ClusterName,
awsAccount: cfg.AWSAccount,
awsRegion: cfg.AWSRegion,
tags: tags,
multiRegionEnabled: cfg.MultiRegion.Enabled,
kms: kmsClient,
clock: clock,
Expand All @@ -117,16 +126,18 @@ func (a *awsKMSKeystore) keyTypeDescription() string {
// generateRSA creates a new RSA private key and returns its identifier and a crypto.Signer. The returned
// identifier can be passed to getSigner later to get an equivalent crypto.Signer.
func (a *awsKMSKeystore) generateRSA(ctx context.Context, _ ...rsaKeyOption) ([]byte, crypto.Signer, error) {
tags := make([]*kms.Tag, 0, len(a.tags))
for k, v := range a.tags {
tags = append(tags, &kms.Tag{
TagKey: aws.String(k),
TagValue: aws.String(v),
})
}
output, err := a.kms.CreateKey(&kms.CreateKeyInput{
Description: aws.String("Teleport CA key"),
KeySpec: aws.String("RSA_2048"),
KeyUsage: aws.String("SIGN_VERIFY"),
Tags: []*kms.Tag{
{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String(a.clusterName.GetClusterName()),
},
},
Tags: tags,
MultiRegion: aws.Bool(a.multiRegionEnabled),
})
if err != nil {
Expand Down Expand Up @@ -351,11 +362,14 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by
}
return trace.Wrap(err, "failed to fetch tags for AWS KMS key %q", keyARN)
}
if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool {
return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == a.clusterName.GetClusterName()
}) {
// This key was not created by this Teleport cluster, never delete it.
return nil

// All tags must match for this key to be considered for deletion.
for k, v := range a.tags {
if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool {
return aws.StringValue(tag.TagKey) == k && aws.StringValue(tag.TagValue) == v
}) {
return nil
}
}

// Check if this key is not enabled or was created in the past 5 minutes.
Expand Down
213 changes: 140 additions & 73 deletions lib/auth/keystore/aws_kms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,80 +56,113 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

const pageSize int = 4
fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize)
cfg := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
for _, tc := range []struct {
name string
tags map[string]string
}{
{
name: "delete keys with default tags",
},
}
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"})
require.NoError(t, err)
opts := &Options{
ClusterName: clusterName,
HostUUID: "uuid",
CloudClients: &cloud.TestCloudClients{
KMS: fakeKMS,
STS: &fakeAWSSTSClient{
account: "123456789012",
{
name: "delete keys with custom tags",
tags: map[string]string{
"test-key-1": "test-value-1",
},
},
clockworkOverride: clock,
}
keyStore, err := NewManager(ctx, &cfg, opts)
require.NoError(t, err)
{
name: "delete keys with override cluster tag",
tags: map[string]string{
"TeleportCluster": "test-cluster-2",
},
},
} {
t.Run(tc.name, func(t *testing.T) {
const pageSize int = 4
fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize)
cfg := servicecfg.KeystoreConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
Tags: tc.tags,
},
}
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"})
require.NoError(t, err)
opts := &Options{
ClusterName: clusterName,
HostUUID: "uuid",
CloudClients: &cloud.TestCloudClients{
KMS: fakeKMS,
STS: &fakeAWSSTSClient{
account: "123456789012",
},
},
clockworkOverride: clock,
}
keyStore, err := NewManager(ctx, &cfg, opts)
require.NoError(t, err)

totalKeys := pageSize * 3
for i := 0; i < totalKeys; i++ {
_, err := keyStore.NewSSHKeyPair(ctx)
require.NoError(t, err)
}
var otherTags []*kms.Tag
for k, v := range keyStore.backendForNewKeys.(*awsKMSKeystore).tags {
if k != clusterTagKey {
otherTags = append(otherTags, &kms.Tag{
TagKey: aws.String(k),
TagValue: aws.String(v),
})
}
}

// Newly created keys should not be deleted.
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, "Enabled", key.state)
}
totalKeys := pageSize * 3
for i := 0; i < totalKeys; i++ {
_, err := keyStore.NewSSHKeyPair(ctx)
require.NoError(t, err, trace.DebugReport(err))
}

// Keys created more than 5 minutes ago should be deleted.
clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, "PendingDeletion", key.state)
}
// Newly created keys should not be deleted.
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kms.KeyStateEnabled, key.state)
}

// Insert a key created by a different Teleport cluster, it should not be
// deleted by the keystore.
output, err := fakeKMS.CreateKey(&kms.CreateKeyInput{
Tags: []*kms.Tag{
&kms.Tag{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String("other-cluster"),
},
},
})
require.NoError(t, err)
otherClusterKeyARN := aws.StringValue(output.KeyMetadata.Arn)
// Keys created more than 5 minutes ago should be deleted.
clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kms.KeyStatePendingDeletion, key.state)
}

clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
if key.arn == otherClusterKeyARN {
assert.Equal(t, "Enabled", key.state)
} else {
assert.Equal(t, "PendingDeletion", key.state)
}
// Insert a key created by a different Teleport cluster, it should not be
// deleted by the keystore.
output, err := fakeKMS.CreateKey(&kms.CreateKeyInput{
KeySpec: aws.String(kms.KeySpecEccNistP256),
Tags: append(otherTags, &kms.Tag{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String("other-cluster"),
}),
})
require.NoError(t, err)
otherClusterKeyARN := aws.StringValue(output.KeyMetadata.Arn)

clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
if key.arn == otherClusterKeyARN {
assert.Equal(t, kms.KeyStateEnabled, key.state)
} else {
assert.Equal(t, kms.KeyStatePendingDeletion, key.state)
}
}
})
}
}

func TestAWSKMS_WrongAccount(t *testing.T) {
clock := clockwork.NewFakeClock()
cfg := &servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "111111111111",
AWSRegion: "us-west-2",
},
Expand Down Expand Up @@ -161,7 +194,7 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) {
pageLimit: 1000,
}
cfg := &servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "111111111111",
AWSRegion: "us-west-2",
},
Expand Down Expand Up @@ -214,13 +247,13 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) {
require.Error(t, err)
}

// TestMultiRegionKeys asserts that a keystore created with multi-region enabled
// correctly passes this argument to the AWS client. This gives very little real
// coverage since the AWS KMS service here is faked, but at least we know the
// keystore passed the bool to the client correctly. TestBackends and
// TestManager are both able to run with a real AWS KMS client and you can
// confirm the keys are really multi-region there.
func TestMultiRegionKeys(t *testing.T) {
// TestKeyAWSKeyCreationParameters asserts that an AWS keystore created with a
// variety of parameters correctly passes these parameters to the AWS client.
// This gives very little real coverage since the AWS KMS service here is faked,
// but at least we know the keystore passed the parameters to the client correctly.
// TestBackends and TestManager are both able to run with a real AWS KMS client
// and you can confirm the keys are configured correctly there.
func TestAWSKeyCreationParameters(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

Expand All @@ -240,15 +273,36 @@ func TestMultiRegionKeys(t *testing.T) {
clockworkOverride: clock,
}

for _, multiRegion := range []bool{false, true} {
t.Run(fmt.Sprint(multiRegion), func(t *testing.T) {
for _, tc := range []struct {
name string
multiRegion bool
tags map[string]string
}{
{
name: "multi-region enabled with default tags",
multiRegion: true,
},
{
name: "multi-region disabled with default tags",
multiRegion: false,
},
{
name: "multi region disabled with custom tags",
multiRegion: false,
tags: map[string]string{
"key": "value",
},
},
} {
t.Run(tc.name, func(t *testing.T) {
cfg := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
MultiRegion: struct{ Enabled bool }{
Enabled: multiRegion,
Enabled: tc.multiRegion,
},
Tags: tc.tags,
},
}
keyStore, err := NewManager(ctx, &cfg, opts)
Expand All @@ -260,11 +314,24 @@ func TestMultiRegionKeys(t *testing.T) {
keyID, err := parseAWSKMSKeyID(sshKeyPair.PrivateKey)
require.NoError(t, err)

if multiRegion {
if tc.multiRegion {
assert.Contains(t, keyID.arn, "mrk-")
} else {
assert.NotContains(t, keyID.arn, "mrk-")
}

tagsOut, err := fakeKMS.ListResourceTags(&kms.ListResourceTagsInput{KeyId: &keyID.arn})
require.NoError(t, err)
if len(tc.tags) == 0 {
tc.tags = map[string]string{
"TeleportCluster": clusterName.GetClusterName(),
}
}
require.Equal(t, len(tc.tags), len(tagsOut.Tags))
for _, tag := range tagsOut.Tags {
v := tc.tags[aws.StringValue(tag.TagKey)]
require.Equal(t, v, aws.StringValue(tag.TagValue))
}
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/keystore/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack {
if config, ok := awsKMSTestConfig(t); ok {
config.AWSKMS.MultiRegion.Enabled = multiRegion

backend, err := newAWSKMSKeystore(ctx, &config.AWSKMS, opts)
backend, err := newAWSKMSKeystore(ctx, config.AWSKMS, opts)
require.NoError(t, err)
name := "aws_kms"
if multiRegion {
Expand All @@ -552,15 +552,15 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack {

// Always test with fake AWS client.
fakeAWSKMSConfig := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
MultiRegion: struct{ Enabled bool }{
Enabled: multiRegion,
},
},
}
fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, &fakeAWSKMSConfig.AWSKMS, opts)
fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, fakeAWSKMSConfig.AWSKMS, opts)
require.NoError(t, err)
name := "fake_aws_kms"
if multiRegion {
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/keystore/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ func NewManager(ctx context.Context, cfg *servicecfg.KeystoreConfig, opts *Optio
}
backendForNewKeys = gcpBackend
usableSigningBackends = []backend{gcpBackend, softwareBackend}
case cfg.AWSKMS != (servicecfg.AWSKMSConfig{}):
awsBackend, err := newAWSKMSKeystore(ctx, &cfg.AWSKMS, opts)
case cfg.AWSKMS != nil:
awsBackend, err := newAWSKMSKeystore(ctx, cfg.AWSKMS, opts)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit 0058967

Please sign in to comment.