Skip to content

Commit

Permalink
migrate uploader to aws-sdk-go-v2 (#44728)
Browse files Browse the repository at this point in the history
  • Loading branch information
rudream committed Aug 21, 2024
1 parent 56cd382 commit b49b907
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 312 deletions.
194 changes: 97 additions & 97 deletions lib/events/s3sessions/s3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ package s3sessions

import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"net/url"
"path"
"sort"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go/service/s3/s3manager/s3manageriface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
s3metrics "github.com/gravitational/teleport/lib/observability/metrics/s3"
awsmetrics "github.com/gravitational/teleport/lib/observability/metrics/aws"
"github.com/gravitational/teleport/lib/session"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)
Expand Down Expand Up @@ -77,10 +77,10 @@ type Config struct {
Endpoint string
// ACL is the canned ACL to send to S3
ACL string
// Session is an optional existing AWS client session
Session *awssession.Session
// Credentials if supplied are used in tests or with External Audit Storage.
Credentials *credentials.Credentials
// AWSConfig is an optional existing AWS client configuration
AWSConfig *aws.Config
// CredentialsProvider if supplied is used in tests or with External Audit Storage.
CredentialsProvider aws.CredentialsProvider
// SSEKMSKey specifies the optional custom CMK used for KMS SSE.
SSEKMSKey string

Expand Down Expand Up @@ -156,38 +156,40 @@ func (s *Config) CheckAndSetDefaults() error {
if s.Bucket == "" {
return trace.BadParameter("missing parameter Bucket")
}
if s.Session == nil {
awsConfig := aws.Config{
UseFIPSEndpoint: events.FIPSProtoStateToAWSState(s.UseFIPSEndpoint),
}
if s.Region != "" {
awsConfig.Region = aws.String(s.Region)
}
if s.Endpoint != "" {
awsConfig.Endpoint = aws.String(s.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(true)

if s.AWSConfig == nil {
var err error
opts := []func(*config.LoadOptions) error{
config.WithRegion(s.Region),
}

if s.Insecure {
awsConfig.DisableSSL = aws.Bool(s.Insecure)
}
if s.Credentials != nil {
awsConfig.Credentials = s.Credentials
opts = append(opts, config.WithHTTPClient(&http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}))
} else {
hc, err := defaults.HTTPClient()
if err != nil {
return trace.Wrap(err)
}

opts = append(opts, config.WithHTTPClient(hc))
}
hc, err := defaults.HTTPClient()
if err != nil {
return trace.Wrap(err)

if s.CredentialsProvider != nil {
opts = append(opts, config.WithCredentialsProvider(s.CredentialsProvider))
}
awsConfig.HTTPClient = hc

sess, err := awssession.NewSessionWithOptions(awssession.Options{
SharedConfigState: awssession.SharedConfigEnable,
Config: awsConfig,
})
opts = append(opts, config.WithAPIOptions(awsmetrics.MetricsMiddleware()))

awsConfig, err := config.LoadDefaultConfig(context.Background(), opts...)
if err != nil {
return trace.Wrap(err)
}

s.Session = sess
s.AWSConfig = &awsConfig
}
return nil
}
Expand All @@ -198,20 +200,15 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) {
return nil, trace.Wrap(err)
}

client, err := s3metrics.NewAPIMetrics(s3.New(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}

uploader, err := s3metrics.NewUploadAPIMetrics(s3manager.NewUploader(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}
// Create S3 client with custom options
client := s3.NewFromConfig(*cfg.AWSConfig, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.UsePathStyle = true
}
})

downloader, err := s3metrics.NewDownloadAPIMetrics(s3manager.NewDownloader(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}
uploader := manager.NewUploader(client)
downloader := manager.NewDownloader(client)

h := &Handler{
Entry: log.WithFields(log.Fields{
Expand All @@ -222,6 +219,7 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) {
downloader: downloader,
client: client,
}

start := time.Now()
h.Infof("Setting up bucket %q, sessions path %q in region %q.", h.Bucket, h.Path, h.Region)
if err := h.ensureBucket(ctx); err != nil {
Expand All @@ -237,9 +235,9 @@ type Handler struct {
Config
// Entry is a logging entry
*log.Entry
uploader s3manageriface.UploaderAPI
downloader s3manageriface.DownloaderAPI
client s3iface.S3API
uploader *manager.Uploader
downloader *manager.Downloader
client *s3.Client
}

// Close releases connection and resources associated with log if any
Expand All @@ -250,25 +248,23 @@ func (h *Handler) Close() error {
// Upload uploads object to S3 bucket, reads the contents of the object from reader
// and returns the target S3 bucket path in case of successful upload.
func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Reader) (string, error) {
var err error
path := h.path(sessionID)

uploadInput := &s3manager.UploadInput{
uploadInput := &s3.PutObjectInput{
Bucket: aws.String(h.Bucket),
Key: aws.String(path),
Body: reader,
}
if !h.Config.DisableServerSideEncryption {
uploadInput.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms)

uploadInput.ServerSideEncryption = awstypes.ServerSideEncryptionAwsKms
if h.Config.SSEKMSKey != "" {
uploadInput.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey)
}
}
if h.Config.ACL != "" {
uploadInput.ACL = aws.String(h.Config.ACL)
uploadInput.ACL = awstypes.ObjectCannedACL(h.Config.ACL)
}
_, err = h.uploader.UploadWithContext(ctx, uploadInput)
_, err := h.uploader.Upload(ctx, uploadInput)
if err != nil {
return "", awsutils.ConvertS3Error(err)
}
Expand All @@ -288,17 +284,14 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io.

h.Debugf("Downloading %v/%v [%v].", h.Bucket, h.path(sessionID), versionID)

written, err := h.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
_, err = h.downloader.Download(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(h.Bucket),
Key: aws.String(h.path(sessionID)),
VersionId: aws.String(versionID),
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
if written == 0 {
return trace.NotFound("recording for %v is not found", sessionID)
}
return nil
}

Expand All @@ -315,24 +308,24 @@ type versionID struct {
func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix string) (string, error) {
var versions []versionID

// Get all versions of this object.
err := h.client.ListObjectVersionsPagesWithContext(ctx, &s3.ListObjectVersionsInput{
paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
}, func(page *s3.ListObjectVersionsOutput, lastPage bool) bool {
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return "", awsutils.ConvertS3Error(err)
}
for _, v := range page.Versions {
versions = append(versions, versionID{
ID: *v.VersionId,
ID: aws.ToString(v.VersionId),
Timestamp: *v.LastModified,
})
}

// Returning false stops iteration, stop iteration upon last page.
return !lastPage
})
if err != nil {
return "", awsutils.ConvertS3Error(err)
}

if len(versions) == 0 {
return "", trace.NotFound("%v/%v not found", bucket, prefix)
}
Expand All @@ -347,23 +340,28 @@ func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix st
// delete bucket deletes bucket and all it's contents and is used in tests
func (h *Handler) deleteBucket(ctx context.Context) error {
// first, list and delete all the objects in the bucket
out, err := h.client.ListObjectVersionsWithContext(ctx, &s3.ListObjectVersionsInput{
paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{
Bucket: aws.String(h.Bucket),
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
for _, ver := range out.Versions {
_, err := h.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(h.Bucket),
Key: ver.Key,
VersionId: ver.VersionId,
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return awsutils.ConvertS3Error(err)
}
for _, ver := range page.Versions {
_, err := h.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(h.Bucket),
Key: ver.Key,
VersionId: ver.VersionId,
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
}
}
_, err = h.client.DeleteBucketWithContext(ctx, &s3.DeleteBucketInput{

_, err := h.client.DeleteBucket(ctx, &s3.DeleteBucketInput{
Bucket: aws.String(h.Bucket),
})
return awsutils.ConvertS3Error(err)
Expand All @@ -382,7 +380,7 @@ func (h *Handler) fromPath(p string) session.ID {

// ensureBucket makes sure bucket exists, and if it does not, creates it
func (h *Handler) ensureBucket(ctx context.Context) error {
_, err := h.client.HeadBucketWithContext(ctx, &s3.HeadBucketInput{
_, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(h.Bucket),
})
err = awsutils.ConvertS3Error(err)
Expand All @@ -396,44 +394,46 @@ func (h *Handler) ensureBucket(ctx context.Context) error {
}
input := &s3.CreateBucketInput{
Bucket: aws.String(h.Bucket),
ACL: aws.String("private"),
ACL: awstypes.BucketCannedACLPrivate,
}
_, err = h.client.CreateBucketWithContext(ctx, input)
_, err = h.client.CreateBucket(ctx, input)
err = awsutils.ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket)))
if err != nil {
if !trace.IsAlreadyExists(err) {
return trace.Wrap(err)
}

// if this client has not created the bucket, don't reconfigure it
return nil
}

// Turn on versioning.
ver := &s3.PutBucketVersioningInput{
_, err = h.client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{
Bucket: aws.String(h.Bucket),
VersioningConfiguration: &s3.VersioningConfiguration{
Status: aws.String("Enabled"),
VersioningConfiguration: &awstypes.VersioningConfiguration{
Status: awstypes.BucketVersioningStatusEnabled,
},
}
_, err = h.client.PutBucketVersioningWithContext(ctx, ver)
})
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket))
if err != nil {
return trace.Wrap(err)
}

// Turn on server-side encryption for the bucket.
if !h.DisableServerSideEncryption {
_, err = h.client.PutBucketEncryptionWithContext(ctx, &s3.PutBucketEncryptionInput{
_, err = h.client.PutBucketEncryption(ctx, &s3.PutBucketEncryptionInput{
Bucket: aws.String(h.Bucket),
ServerSideEncryptionConfiguration: &s3.ServerSideEncryptionConfiguration{
Rules: []*s3.ServerSideEncryptionRule{{
ApplyServerSideEncryptionByDefault: &s3.ServerSideEncryptionByDefault{
SSEAlgorithm: aws.String(s3.ServerSideEncryptionAwsKms),
ServerSideEncryptionConfiguration: &awstypes.ServerSideEncryptionConfiguration{
Rules: []awstypes.ServerSideEncryptionRule{
{
ApplyServerSideEncryptionByDefault: &awstypes.ServerSideEncryptionByDefault{
SSEAlgorithm: awstypes.ServerSideEncryptionAwsKms,
},
},
}},
},
},
})
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket))
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set encryption state for bucket %q", h.Bucket))
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit b49b907

Please sign in to comment.