Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ds 5114 update the aws secrets provider to aws sdk v2 #76

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions aws/aws_kms/aws_kms.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package aws_kms

import (
"context"
"encoding/base64"
"fmt"
"github.com/libopenstorage/secrets/aws/utils"
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/libopenstorage/secrets/aws/utils"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/libopenstorage/secrets"
sc "github.com/libopenstorage/secrets/aws/credentials"
"github.com/libopenstorage/secrets/pkg/store"
Expand All @@ -28,9 +29,8 @@ const (
)

type awsKmsSecrets struct {
client *kms.KMS
creds *credentials.Credentials
sess *session.Session
client *kms.Client
creds *aws.Credentials
cmk string
asc sc.AWSCredentials
ps store.PersistenceStore
Expand Down Expand Up @@ -84,15 +84,16 @@ func New(
if err != nil {
return nil, fmt.Errorf("Failed to get credentials: %v", err)
}
config := &aws.Config{
Credentials: creds,
Region: &region,
credProv, err := asc.GetCredentialsProvider()
config := aws.Config{
Credentials: credProv,
Region: region,
}
sess := session.New(config)
kmsClient := kms.New(sess)

kmsClient := kms.NewFromConfig(config)

return &awsKmsSecrets{
client: kmsClient,
sess: sess,
creds: creds,
cmk: cmk,
asc: asc,
Expand Down Expand Up @@ -139,10 +140,10 @@ func (a *awsKmsSecrets) GetSecret(
decodedCipherBlob = cipherBlob
}
input := &kms.DecryptInput{
EncryptionContext: getAWSKeyContext(keyContext),
EncryptionContext: keyContext,
CiphertextBlob: decodedCipherBlob,
}
output, err := a.client.Decrypt(input)
output, err := a.client.Decrypt(context.TODO(), input)
if err != nil {
return nil, secrets.NoVersion, err
}
Expand Down Expand Up @@ -203,11 +204,11 @@ func (a *awsKmsSecrets) PutSecret(
keySpec := "AES_256"
input := &kms.GenerateDataKeyInput{
KeyId: &a.cmk,
EncryptionContext: getAWSKeyContext(keyContext),
KeySpec: &keySpec,
EncryptionContext: keyContext,
KeySpec: types.DataKeySpec(keySpec),
}

output, err := a.client.GenerateDataKey(input)
output, err := a.client.GenerateDataKey(context.TODO(), input)
if err != nil {
return secrets.NoVersion, err
}
Expand Down
4 changes: 4 additions & 0 deletions aws/aws_kms/aws_kms_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package aws_kms
import (
"os"
"testing"
"time"

"github.com/libopenstorage/secrets"
"github.com/libopenstorage/secrets/aws/utils"
Expand Down Expand Up @@ -164,6 +165,9 @@ func (a *awsSecretTest) TestDeleteSecret(t *testing.T) error {
err := a.s.DeleteSecret(secretIdWithData, nil)
assert.NoError(t, err, "Expected DeleteSecret to succeed")

// Add a delay to allow time for deletion to propagate
time.Sleep(time.Second * 90)

// Get of a deleted key should fail
_, _, err = a.s.GetSecret(secretIdWithData, nil)
assert.EqualError(t, secrets.ErrInvalidSecretId, err.Error(), "Unexpected error on GetSecret after delete")
Expand Down
76 changes: 34 additions & 42 deletions aws/aws_secrets_manager/aws_scm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import (
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
"github.com/libopenstorage/secrets"
sc "github.com/libopenstorage/secrets/aws/credentials"
"github.com/libopenstorage/secrets/aws/utils"
Expand All @@ -26,7 +25,7 @@ const (

// AWSSecretsMgr is backend for secrets.SecretStore.
type AWSSecretsMgr struct {
scm *secretsmanager.SecretsManager
scm *secretsmanager.Client
}

// New creates new instance of AWSSecretsMgr with provided configuration.
Expand All @@ -39,7 +38,7 @@ func New(

awsConfig, ok := secretConfig[utils.AwsConfigKey]
if ok {
awsConfig, ok := awsConfig.(*aws.Config)
awsConfig, ok := awsConfig.(aws.Config)
if !ok {
return nil, utils.ErrAWSConfigWrongType
}
Expand All @@ -64,25 +63,22 @@ func New(
if err != nil {
return nil, fmt.Errorf("failed to create aws credentials instance: %v", err)
}
creds, err := asc.Get()
_, err = asc.Get()
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %v", err)
}
config := &aws.Config{
Credentials: creds,
Region: &region,
credProv, err := asc.GetCredentialsProvider()
config := aws.Config{
Credentials: credProv,
Region: region,
}

return NewFromAWSConfig(config)
}

// NewFromAWSConfig creates new instance of AWSSecretsMgr with provided AWS configuration (aws.Config).
func NewFromAWSConfig(config *aws.Config) (*AWSSecretsMgr, error) {
sess, err := session.NewSession(config)
if err != nil {
return nil, fmt.Errorf("failed to create a session: %v", err)
}
scm := secretsmanager.New(sess)
func NewFromAWSConfig(config aws.Config) (*AWSSecretsMgr, error) {
scm := secretsmanager.NewFromConfig(config)
return &AWSSecretsMgr{
scm: scm,
}, nil
Expand Down Expand Up @@ -175,17 +171,15 @@ func (a *AWSSecretsMgr) Rencrypt(
}

func (a *AWSSecretsMgr) get(secretID string) (map[string]interface{}, secrets.Version, error) {
secretValueOutput, err := a.scm.GetSecretValue(&secretsmanager.GetSecretValueInput{
secretValueOutput, err := a.scm.GetSecretValue(context.TODO(), &secretsmanager.GetSecretValueInput{
SecretId: aws.String(secretID),
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
if aerr.Code() == secretsmanager.ErrCodeResourceNotFoundException {
return nil, secrets.NoVersion, secrets.ErrInvalidSecretId
} else if aerr.Code() == secretsmanager.ErrCodeInvalidRequestException &&
strings.Contains(aerr.Error(), "marked for deletion") {
return nil, secrets.NoVersion, secrets.ErrInvalidSecretId
}
if _, ok := err.(*types.ResourceNotFoundException); ok {
return nil, secrets.NoVersion, secrets.ErrInvalidSecretId
} else if aerr, ok := err.(*types.InvalidRequestException); ok &&
strings.Contains(aerr.Error(), "marked for deletion") {
return nil, secrets.NoVersion, secrets.ErrInvalidSecretId
}
return nil, secrets.NoVersion, &secrets.ErrProviderInternal{Reason: err.Error(), Provider: Name}
}
Expand Down Expand Up @@ -214,12 +208,12 @@ func (a *AWSSecretsMgr) put(
return secrets.NoVersion, fmt.Errorf("failed to marshal secret data: %v", err)
}
// Check if there already exists a key.
_, err = a.scm.GetSecretValue(&secretsmanager.GetSecretValueInput{
_, err = a.scm.GetSecretValue(context.TODO(), &secretsmanager.GetSecretValueInput{
SecretId: aws.String(secretID),
})
if err == nil {
// Update the existing secret
secretValueOutput, putErr := a.scm.PutSecretValue(&secretsmanager.PutSecretValueInput{
secretValueOutput, putErr := a.scm.PutSecretValue(context.TODO(), &secretsmanager.PutSecretValueInput{
SecretId: aws.String(secretID),
SecretString: aws.String(string(secretBytes)),
})
Expand All @@ -231,21 +225,19 @@ func (a *AWSSecretsMgr) put(
}
return secrets.Version(*secretValueOutput.VersionId), nil
} else {
if aerr, ok := err.(awserr.Error); ok {
if aerr.Code() == secretsmanager.ErrCodeResourceNotFoundException {
// Create a new secret
secretValueOutput, createErr := a.scm.CreateSecret(&secretsmanager.CreateSecretInput{
SecretString: aws.String(string(secretBytes)),
Name: aws.String(secretID),
})
if createErr != nil {
return secrets.NoVersion, &secrets.ErrProviderInternal{Reason: createErr.Error(), Provider: Name}
}
if secretValueOutput.VersionId == nil {
return secrets.NoVersion, &secrets.ErrProviderInternal{Reason: "invalid version returned by aws", Provider: Name}
}
return secrets.Version(*secretValueOutput.VersionId), nil
} // return the aws error
if _, ok := err.(*types.ResourceNotFoundException); ok {
// Create a new secret
secretValueOutput, createErr := a.scm.CreateSecret(context.TODO(), &secretsmanager.CreateSecretInput{
SecretString: aws.String(string(secretBytes)),
Name: aws.String(secretID),
})
if createErr != nil {
return secrets.NoVersion, &secrets.ErrProviderInternal{Reason: createErr.Error(), Provider: Name}
}
if secretValueOutput.VersionId == nil {
return secrets.NoVersion, &secrets.ErrProviderInternal{Reason: "invalid version returned by aws", Provider: Name}
}
return secrets.Version(*secretValueOutput.VersionId), nil
} // return the non-aws error
}
// Gets, Puts & Creates have failed
Expand Down Expand Up @@ -278,7 +270,7 @@ func (a *AWSSecretsMgr) delete(
}
}

_, err := a.scm.DeleteSecret(deleteSecretInput)
_, err := a.scm.DeleteSecret(context.TODO(), deleteSecretInput)
if err != nil {
return &secrets.ErrProviderInternal{Reason: err.Error(), Provider: Name}
}
Expand Down
7 changes: 7 additions & 0 deletions aws/aws_secrets_manager/aws_scm_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package aws_secrets_manager
import (
"os"
"testing"
"time"

"github.com/libopenstorage/secrets"
"github.com/libopenstorage/secrets/aws/utils"
Expand Down Expand Up @@ -110,6 +111,9 @@ func (a *awsSecretTest) TestDeleteSecret(t *testing.T) error {
err := a.s.DeleteSecret(a.secretIdWithData, nil)
assert.NoError(t, err, "Expected DeleteSecret to succeed")

// Add a delay to allow time for deletion to propagate
time.Sleep(time.Second * 200)

// Get of a deleted key should fail
_, version, err := a.s.GetSecret(a.secretIdWithData, nil)
assert.EqualError(t, secrets.ErrInvalidSecretId, err.Error(), "Unexpected error on GetSecret after delete")
Expand All @@ -119,6 +123,9 @@ func (a *awsSecretTest) TestDeleteSecret(t *testing.T) error {
err = a.s.DeleteSecret(a.secretIdWithoutData, nil)
assert.NoError(t, err, "Expected DeleteSecret to succeed")

// Add a delay to allow time for deletion to propagate
time.Sleep(time.Second * 200)

// GetSecret using a secretId without data
_, version, err = a.s.GetSecret(a.secretIdWithoutData, nil)
assert.EqualError(t, secrets.ErrInvalidSecretId, err.Error(), "Unexpected error on GetSecret after delete")
Expand Down
84 changes: 45 additions & 39 deletions aws/credentials/credentials.go
Original file line number Diff line number Diff line change
@@ -1,69 +1,75 @@
package credentials

import (
"fmt"
"net/http"
"context"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
)

type AWSCredentials interface {
Get() (*credentials.Credentials, error)
Get() (*aws.Credentials, error)
GetCredentialsProvider() (aws.CredentialsProvider, error)
}

type awsCred struct {
creds *credentials.Credentials
creds *aws.Credentials
credsprovider aws.CredentialsProvider
}

func NewAWSCredentials(id, secret, token string, runningOnEc2 bool) (AWSCredentials, error) {
var creds *credentials.Credentials
sess, err := session.NewSession()
if err != nil {
return nil, fmt.Errorf("error creating new aws credentials: %w", err)
}
var creds aws.Credentials
var credsprovider aws.CredentialsProvider
var ctx context.Context
if id != "" && secret != "" {
creds = credentials.NewStaticCredentials(id, secret, token)
if _, err := creds.Get(); err != nil {
cfg, err := config.LoadDefaultConfig(ctx, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(id, secret, token)))
if err != nil {
return nil, err
}
} else if sess.Config.Credentials != nil {
// sess config loads credential automatically from environment variable
// this is used to prioritize loading aws web identity token whenever it's specified.
creds = sess.Config.Credentials
} else {
providers := []credentials.Provider{
&credentials.EnvProvider{},

creds, err = cfg.Credentials.Retrieve(context.Background())
if err != nil {
return nil, err
}
if runningOnEc2 {
client := http.Client{Timeout: time.Second * 10}
ec2RoleProvider := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(sess, &aws.Config{
HTTPClient: &client,
}),
}
providers = append(providers, ec2RoleProvider)

} else if runningOnEc2 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does AWS not allow chaining of providers anymore? Previously we were able to add EnvProvider and RoleProvider together. With this change we won't since each provider is in its own if condition.


ec2Provider := ec2rolecreds.New(func(o *ec2rolecreds.Options) {
o.Client = imds.New(imds.Options{
HTTPClient: http.NewBuildableClient().WithTimeout(10 * time.Second),
})
})

cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithCredentialsProvider(ec2Provider),
)
if err != nil {
return nil, err
}
providers = append(providers, &credentials.SharedCredentialsProvider{})
creds = credentials.NewChainCredentials(providers)
if _, err := creds.Get(); err != nil {

creds, err = cfg.Credentials.Retrieve(context.Background())
if err != nil {
return nil, err
}
}
return &awsCred{creds}, nil
return &awsCred{&creds, credsprovider}, nil
}

func (a *awsCred) Get() (*credentials.Credentials, error) {
if a.creds.IsExpired() {
arivankar-px marked this conversation as resolved.
Show resolved Hide resolved
func (a *awsCred) Get() (*aws.Credentials, error) {
if a.creds.Expired() {
// Refresh the credentials
_, err := a.creds.Get()
if err != nil {
if _, err := a.credsprovider.Retrieve(context.TODO()); err != nil {
return nil, err
}
}
return a.creds, nil
}

func (a *awsCred) GetCredentialsProvider() (aws.CredentialsProvider, error) {
return a.credsprovider, nil
}
Loading
Loading