From e7f555566718741cd683a22952881c34412aa177 Mon Sep 17 00:00:00 2001 From: Yassine Bounekhla Date: Tue, 20 Aug 2024 14:32:43 -0400 Subject: [PATCH] migrate uploader to aws-sdk-go-v2 --- lib/events/s3sessions/s3handler.go | 194 +++++++++--------- .../s3sessions/s3handler_thirdparty_test.go | 27 ++- lib/events/s3sessions/s3stream.go | 116 ++++++----- lib/observability/metrics/s3/api.go | 155 -------------- lib/service/service.go | 3 +- 5 files changed, 183 insertions(+), 312 deletions(-) delete mode 100644 lib/observability/metrics/s3/api.go diff --git a/lib/events/s3sessions/s3handler.go b/lib/events/s3sessions/s3handler.go index a83d30b890de9..3c4d878a1d10d 100644 --- a/lib/events/s3sessions/s3handler.go +++ b/lib/events/s3sessions/s3handler.go @@ -20,8 +20,10 @@ package s3sessions import ( "context" + "crypto/tls" "fmt" "io" + "net/http" "net/url" "path" "sort" @@ -29,13 +31,11 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - awssession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/aws/aws-sdk-go/service/s3/s3manager" - "github.com/aws/aws-sdk-go/service/s3/s3manager/s3manageriface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" @@ -43,7 +43,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" - s3metrics "github.com/gravitational/teleport/lib/observability/metrics/s3" + awsmetrics "github.com/gravitational/teleport/lib/observability/metrics/aws" "github.com/gravitational/teleport/lib/session" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) @@ -77,10 +77,10 @@ type Config struct { Endpoint string // ACL is the canned ACL to send to S3 ACL string - // Session is an optional existing AWS client session - Session *awssession.Session - // Credentials if supplied are used in tests or with External Audit Storage. - Credentials *credentials.Credentials + // AWSConfig is an optional existing AWS client configuration + AWSConfig *aws.Config + // CredentialsProvider if supplied is used in tests or with External Audit Storage. + CredentialsProvider aws.CredentialsProvider // SSEKMSKey specifies the optional custom CMK used for KMS SSE. SSEKMSKey string @@ -156,38 +156,40 @@ func (s *Config) CheckAndSetDefaults() error { if s.Bucket == "" { return trace.BadParameter("missing parameter Bucket") } - if s.Session == nil { - awsConfig := aws.Config{ - UseFIPSEndpoint: events.FIPSProtoStateToAWSState(s.UseFIPSEndpoint), - } - if s.Region != "" { - awsConfig.Region = aws.String(s.Region) - } - if s.Endpoint != "" { - awsConfig.Endpoint = aws.String(s.Endpoint) - awsConfig.S3ForcePathStyle = aws.Bool(true) + + if s.AWSConfig == nil { + var err error + opts := []func(*config.LoadOptions) error{ + config.WithRegion(s.Region), } + if s.Insecure { - awsConfig.DisableSSL = aws.Bool(s.Insecure) - } - if s.Credentials != nil { - awsConfig.Credentials = s.Credentials + opts = append(opts, config.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + })) + } else { + hc, err := defaults.HTTPClient() + if err != nil { + return trace.Wrap(err) + } + + opts = append(opts, config.WithHTTPClient(hc)) } - hc, err := defaults.HTTPClient() - if err != nil { - return trace.Wrap(err) + + if s.CredentialsProvider != nil { + opts = append(opts, config.WithCredentialsProvider(s.CredentialsProvider)) } - awsConfig.HTTPClient = hc - sess, err := awssession.NewSessionWithOptions(awssession.Options{ - SharedConfigState: awssession.SharedConfigEnable, - Config: awsConfig, - }) + opts = append(opts, config.WithAPIOptions(awsmetrics.MetricsMiddleware())) + + awsConfig, err := config.LoadDefaultConfig(context.Background(), opts...) if err != nil { return trace.Wrap(err) } - s.Session = sess + s.AWSConfig = &awsConfig } return nil } @@ -198,20 +200,15 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) { return nil, trace.Wrap(err) } - client, err := s3metrics.NewAPIMetrics(s3.New(cfg.Session)) - if err != nil { - return nil, trace.Wrap(err) - } - - uploader, err := s3metrics.NewUploadAPIMetrics(s3manager.NewUploader(cfg.Session)) - if err != nil { - return nil, trace.Wrap(err) - } + // Create S3 client with custom options + client := s3.NewFromConfig(*cfg.AWSConfig, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.UsePathStyle = true + } + }) - downloader, err := s3metrics.NewDownloadAPIMetrics(s3manager.NewDownloader(cfg.Session)) - if err != nil { - return nil, trace.Wrap(err) - } + uploader := manager.NewUploader(client) + downloader := manager.NewDownloader(client) h := &Handler{ Entry: log.WithFields(log.Fields{ @@ -222,6 +219,7 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) { downloader: downloader, client: client, } + start := time.Now() h.Infof("Setting up bucket %q, sessions path %q in region %q.", h.Bucket, h.Path, h.Region) if err := h.ensureBucket(ctx); err != nil { @@ -237,9 +235,9 @@ type Handler struct { Config // Entry is a logging entry *log.Entry - uploader s3manageriface.UploaderAPI - downloader s3manageriface.DownloaderAPI - client s3iface.S3API + uploader *manager.Uploader + downloader *manager.Downloader + client *s3.Client } // Close releases connection and resources associated with log if any @@ -250,25 +248,23 @@ func (h *Handler) Close() error { // Upload uploads object to S3 bucket, reads the contents of the object from reader // and returns the target S3 bucket path in case of successful upload. func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Reader) (string, error) { - var err error path := h.path(sessionID) - uploadInput := &s3manager.UploadInput{ + uploadInput := &s3.PutObjectInput{ Bucket: aws.String(h.Bucket), Key: aws.String(path), Body: reader, } if !h.Config.DisableServerSideEncryption { - uploadInput.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms) - + uploadInput.ServerSideEncryption = awstypes.ServerSideEncryptionAwsKms if h.Config.SSEKMSKey != "" { uploadInput.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey) } } if h.Config.ACL != "" { - uploadInput.ACL = aws.String(h.Config.ACL) + uploadInput.ACL = awstypes.ObjectCannedACL(h.Config.ACL) } - _, err = h.uploader.UploadWithContext(ctx, uploadInput) + _, err := h.uploader.Upload(ctx, uploadInput) if err != nil { return "", awsutils.ConvertS3Error(err) } @@ -288,7 +284,7 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io. h.Debugf("Downloading %v/%v [%v].", h.Bucket, h.path(sessionID), versionID) - written, err := h.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{ + _, err = h.downloader.Download(ctx, writer, &s3.GetObjectInput{ Bucket: aws.String(h.Bucket), Key: aws.String(h.path(sessionID)), VersionId: aws.String(versionID), @@ -296,9 +292,6 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io. if err != nil { return awsutils.ConvertS3Error(err) } - if written == 0 { - return trace.NotFound("recording for %v is not found", sessionID) - } return nil } @@ -315,24 +308,24 @@ type versionID struct { func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix string) (string, error) { var versions []versionID - // Get all versions of this object. - err := h.client.ListObjectVersionsPagesWithContext(ctx, &s3.ListObjectVersionsInput{ + paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{ Bucket: aws.String(bucket), Prefix: aws.String(prefix), - }, func(page *s3.ListObjectVersionsOutput, lastPage bool) bool { + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return "", awsutils.ConvertS3Error(err) + } for _, v := range page.Versions { versions = append(versions, versionID{ - ID: *v.VersionId, + ID: aws.ToString(v.VersionId), Timestamp: *v.LastModified, }) } - - // Returning false stops iteration, stop iteration upon last page. - return !lastPage - }) - if err != nil { - return "", awsutils.ConvertS3Error(err) } + if len(versions) == 0 { return "", trace.NotFound("%v/%v not found", bucket, prefix) } @@ -347,23 +340,28 @@ func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix st // delete bucket deletes bucket and all it's contents and is used in tests func (h *Handler) deleteBucket(ctx context.Context) error { // first, list and delete all the objects in the bucket - out, err := h.client.ListObjectVersionsWithContext(ctx, &s3.ListObjectVersionsInput{ + paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{ Bucket: aws.String(h.Bucket), }) - if err != nil { - return awsutils.ConvertS3Error(err) - } - for _, ver := range out.Versions { - _, err := h.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(h.Bucket), - Key: ver.Key, - VersionId: ver.VersionId, - }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return awsutils.ConvertS3Error(err) } + for _, ver := range page.Versions { + _, err := h.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(h.Bucket), + Key: ver.Key, + VersionId: ver.VersionId, + }) + if err != nil { + return awsutils.ConvertS3Error(err) + } + } } - _, err = h.client.DeleteBucketWithContext(ctx, &s3.DeleteBucketInput{ + + _, err := h.client.DeleteBucket(ctx, &s3.DeleteBucketInput{ Bucket: aws.String(h.Bucket), }) return awsutils.ConvertS3Error(err) @@ -382,7 +380,7 @@ func (h *Handler) fromPath(p string) session.ID { // ensureBucket makes sure bucket exists, and if it does not, creates it func (h *Handler) ensureBucket(ctx context.Context) error { - _, err := h.client.HeadBucketWithContext(ctx, &s3.HeadBucketInput{ + _, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: aws.String(h.Bucket), }) err = awsutils.ConvertS3Error(err) @@ -396,26 +394,26 @@ func (h *Handler) ensureBucket(ctx context.Context) error { } input := &s3.CreateBucketInput{ Bucket: aws.String(h.Bucket), - ACL: aws.String("private"), + ACL: awstypes.BucketCannedACLPrivate, } - _, err = h.client.CreateBucketWithContext(ctx, input) + _, err = h.client.CreateBucket(ctx, input) err = awsutils.ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket))) if err != nil { if !trace.IsAlreadyExists(err) { return trace.Wrap(err) } + // if this client has not created the bucket, don't reconfigure it return nil } // Turn on versioning. - ver := &s3.PutBucketVersioningInput{ + _, err = h.client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{ Bucket: aws.String(h.Bucket), - VersioningConfiguration: &s3.VersioningConfiguration{ - Status: aws.String("Enabled"), + VersioningConfiguration: &awstypes.VersioningConfiguration{ + Status: awstypes.BucketVersioningStatusEnabled, }, - } - _, err = h.client.PutBucketVersioningWithContext(ctx, ver) + }) err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) @@ -423,17 +421,19 @@ func (h *Handler) ensureBucket(ctx context.Context) error { // Turn on server-side encryption for the bucket. if !h.DisableServerSideEncryption { - _, err = h.client.PutBucketEncryptionWithContext(ctx, &s3.PutBucketEncryptionInput{ + _, err = h.client.PutBucketEncryption(ctx, &s3.PutBucketEncryptionInput{ Bucket: aws.String(h.Bucket), - ServerSideEncryptionConfiguration: &s3.ServerSideEncryptionConfiguration{ - Rules: []*s3.ServerSideEncryptionRule{{ - ApplyServerSideEncryptionByDefault: &s3.ServerSideEncryptionByDefault{ - SSEAlgorithm: aws.String(s3.ServerSideEncryptionAwsKms), + ServerSideEncryptionConfiguration: &awstypes.ServerSideEncryptionConfiguration{ + Rules: []awstypes.ServerSideEncryptionRule{ + { + ApplyServerSideEncryptionByDefault: &awstypes.ServerSideEncryptionByDefault{ + SSEAlgorithm: awstypes.ServerSideEncryptionAwsKms, + }, }, - }}, + }, }, }) - err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) + err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set encryption state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) } diff --git a/lib/events/s3sessions/s3handler_thirdparty_test.go b/lib/events/s3sessions/s3handler_thirdparty_test.go index b379446974bec..7cbecb1e92ca8 100644 --- a/lib/events/s3sessions/s3handler_thirdparty_test.go +++ b/lib/events/s3sessions/s3handler_thirdparty_test.go @@ -24,7 +24,9 @@ import ( "net/http/httptest" "testing" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/johannesboyne/gofakes3" @@ -41,12 +43,31 @@ func TestThirdpartyStreams(t *testing.T) { backend := s3mem.New(s3mem.WithTimeSource(timeSource)) faker := gofakes3.New(backend, gofakes3.WithLogger(gofakes3.GlobalLog())) server := httptest.NewServer(faker.Server()) + defer server.Close() + + bucketName := fmt.Sprintf("teleport-test-%v", uuid.New().String()) + + config := aws.Config{ + Credentials: credentials.NewStaticCredentialsProvider("YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", ""), + Region: "us-west-1", + BaseEndpoint: aws.String(server.URL), + } + + s3Client := s3.NewFromConfig(config, func(o *s3.Options) { + o.UsePathStyle = true + }) + + // Create the bucket. + _, err := s3Client.CreateBucket(context.Background(), &s3.CreateBucketInput{ + Bucket: aws.String(bucketName), + }) + require.NoError(t, err) handler, err := NewHandler(context.Background(), Config{ - Credentials: credentials.NewStaticCredentials("YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", ""), + AWSConfig: &config, Region: "us-west-1", Path: "/test/", - Bucket: fmt.Sprintf("teleport-test-%v", uuid.New().String()), + Bucket: bucketName, Endpoint: server.URL, DisableServerSideEncryption: true, }) diff --git a/lib/events/s3sessions/s3stream.go b/lib/events/s3sessions/s3stream.go index ec04d7ae9d761..3ae8cbdea8f87 100644 --- a/lib/events/s3sessions/s3stream.go +++ b/lib/events/s3sessions/s3stream.go @@ -20,6 +20,8 @@ package s3sessions import ( "context" + "crypto/md5" + "encoding/base64" "fmt" "io" "net/url" @@ -27,14 +29,14 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gravitational/trace" "github.com/sirupsen/logrus" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -49,34 +51,34 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even Key: aws.String(h.path(sessionID)), } if !h.Config.DisableServerSideEncryption { - input.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms) + input.ServerSideEncryption = types.ServerSideEncryptionAwsKms if h.Config.SSEKMSKey != "" { input.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey) } } if h.Config.ACL != "" { - input.ACL = aws.String(h.Config.ACL) + input.ACL = types.ObjectCannedACL(h.Config.ACL) } - resp, err := h.client.CreateMultipartUploadWithContext(ctx, input) + resp, err := h.client.CreateMultipartUpload(ctx, input) if err != nil { return nil, trace.Wrap(awsutils.ConvertS3Error(err), "CreateMultiPartUpload session(%v)", sessionID) } h.WithFields(logrus.Fields{ - "upload": aws.StringValue(resp.UploadId), + "upload": aws.ToString(resp.UploadId), "session": sessionID, - "key": aws.StringValue(resp.Key), + "key": aws.ToString(resp.Key), }).Infof("Created upload in %v", time.Since(start)) - return &events.StreamUpload{SessionID: sessionID, ID: aws.StringValue(resp.UploadId)}, nil + return &events.StreamUpload{SessionID: sessionID, ID: aws.ToString(resp.UploadId)}, nil } // UploadPart uploads part func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { // This upload exceeded maximum number of supported parts, error now. - if partNumber > s3manager.MaxUploadParts { + if partNumber > int64(s3manager.MaxUploadParts) { return nil, trace.LimitExceeded( "exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit", s3manager.MaxUploadParts) } @@ -89,16 +91,30 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa "key": uploadKey, }) + // Calculate the content MD5 hash to be included in the request. This is required for S3 buckets with Object Lock enabled. + hash := md5.New() + if _, err := io.Copy(hash, partBody); err != nil { + return nil, trace.Wrap(err, "failed to calculate content MD5 hash") + } + md5sum := base64.StdEncoding.EncodeToString(hash.Sum(nil)) + + // Reset the partBody reader to the beginning before passing it the params. + // This is necessary because after calculating the md5 hash the partBody reader will have been moved to the end of the data. + if _, err := partBody.Seek(0, io.SeekStart); err != nil { + return nil, trace.Wrap(err, "failed to reset part body reader to beginning") + } + params := &s3.UploadPartInput{ Bucket: aws.String(h.Bucket), UploadId: aws.String(upload.ID), Key: aws.String(uploadKey), Body: partBody, - PartNumber: aws.Int64(partNumber), + PartNumber: aws.Int32(int32(partNumber)), + ContentMD5: aws.String(md5sum), } log.Debugf("Uploading part %v", partNumber) - resp, err := h.client.UploadPartWithContext(ctx, params) + resp, err := h.client.UploadPart(ctx, params) if err != nil { return nil, trace.Wrap(awsutils.ConvertS3Error(err), "UploadPart(upload %v) part(%v) session(%v)", upload.ID, partNumber, upload.SessionID) @@ -111,7 +127,7 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa // the part we just uploaded, however. log.Infof("Uploaded part %v in %v", partNumber, time.Since(start)) return &events.StreamPart{ - ETag: aws.StringValue(resp.ETag), + ETag: aws.ToString(resp.ETag), Number: partNumber, LastModified: time.Now(), }, nil @@ -130,7 +146,7 @@ func (h *Handler) abortUpload(ctx context.Context, upload events.StreamUpload) e UploadId: aws.String(upload.ID), } log.Debug("Aborting upload") - _, err := h.client.AbortMultipartUploadWithContext(ctx, req) + _, err := h.client.AbortMultipartUpload(ctx, req) if err != nil { return awsutils.ConvertS3Error(err) } @@ -166,11 +182,11 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload return parts[i].Number < parts[j].Number }) - completedParts := make([]*s3.CompletedPart, len(parts)) + completedParts := make([]types.CompletedPart, len(parts)) for i := range parts { - completedParts[i] = &s3.CompletedPart{ + completedParts[i] = types.CompletedPart{ ETag: aws.String(parts[i].ETag), - PartNumber: aws.Int64(parts[i].Number), + PartNumber: aws.Int32(int32(parts[i].Number)), } } @@ -179,9 +195,9 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload Bucket: aws.String(h.Bucket), Key: aws.String(uploadKey), UploadId: aws.String(upload.ID), - MultipartUpload: &s3.CompletedMultipartUpload{Parts: completedParts}, + MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } - _, err := h.client.CompleteMultipartUploadWithContext(ctx, params) + _, err := h.client.CompleteMultipartUpload(ctx, params) if err != nil { return trace.Wrap(awsutils.ConvertS3Error(err), "CompleteMultipartUpload(upload %v) session(%v)", upload.ID, upload.SessionID) @@ -202,29 +218,27 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] log.Debug("Listing parts for upload") var parts []events.StreamPart - var partNumberMarker *int64 - for i := 0; i < defaults.MaxIterationLimit; i++ { - re, err := h.client.ListPartsWithContext(ctx, &s3.ListPartsInput{ - Bucket: aws.String(h.Bucket), - Key: aws.String(uploadKey), - UploadId: aws.String(upload.ID), - PartNumberMarker: partNumberMarker, - }) + + paginator := s3.NewListPartsPaginator(h.client, &s3.ListPartsInput{ + Bucket: aws.String(h.Bucket), + Key: aws.String(uploadKey), + UploadId: aws.String(upload.ID), + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return nil, awsutils.ConvertS3Error(err) } - for _, part := range re.Parts { + for _, part := range page.Parts { parts = append(parts, events.StreamPart{ - Number: aws.Int64Value(part.PartNumber), - ETag: aws.StringValue(part.ETag), - LastModified: aws.TimeValue(part.LastModified), + Number: int64(aws.ToInt32(part.PartNumber)), + ETag: aws.ToString(part.ETag), + LastModified: aws.ToTime(part.LastModified), }) } - if !aws.BoolValue(re.IsTruncated) { - break - } - partNumberMarker = re.NextPartNumberMarker } + // Parts must be sorted in PartNumber order. sort.Slice(parts, func(i, j int) bool { return parts[i].Number < parts[j].Number @@ -240,31 +254,23 @@ func (h *Handler) ListUploads(ctx context.Context) ([]events.StreamUpload, error prefix = &trimmed } var uploads []events.StreamUpload - var keyMarker *string - var uploadIDMarker *string - for i := 0; i < defaults.MaxIterationLimit; i++ { - input := &s3.ListMultipartUploadsInput{ - Bucket: aws.String(h.Bucket), - Prefix: prefix, - KeyMarker: keyMarker, - UploadIdMarker: uploadIDMarker, - } - re, err := h.client.ListMultipartUploadsWithContext(ctx, input) + paginator := s3.NewListMultipartUploadsPaginator(h.client, &s3.ListMultipartUploadsInput{ + Bucket: aws.String(h.Bucket), + Prefix: prefix, + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return nil, awsutils.ConvertS3Error(err) } - for _, upload := range re.Uploads { + for _, upload := range page.Uploads { uploads = append(uploads, events.StreamUpload{ - ID: aws.StringValue(upload.UploadId), - SessionID: h.fromPath(aws.StringValue(upload.Key)), - Initiated: aws.TimeValue(upload.Initiated), + ID: aws.ToString(upload.UploadId), + SessionID: h.fromPath(aws.ToString(upload.Key)), + Initiated: aws.ToTime(upload.Initiated), }) } - if !aws.BoolValue(re.IsTruncated) { - break - } - keyMarker = re.NextKeyMarker - uploadIDMarker = re.NextUploadIdMarker } sort.Slice(uploads, func(i, j int) bool { diff --git a/lib/observability/metrics/s3/api.go b/lib/observability/metrics/s3/api.go deleted file mode 100644 index 43231ab7c321a..0000000000000 --- a/lib/observability/metrics/s3/api.go +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package s3 - -import ( - "context" - "time" - - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/observability/metrics" -) - -type APIMetrics struct { - s3iface.S3API -} - -func NewAPIMetrics(api s3iface.S3API) (*APIMetrics, error) { - if err := metrics.RegisterPrometheusCollectors(s3Collectors...); err != nil { - return nil, trace.Wrap(err) - } - - return &APIMetrics{S3API: api}, nil -} - -func (m *APIMetrics) ListObjectVersionsPagesWithContext(ctx context.Context, input *s3.ListObjectVersionsInput, f func(*s3.ListObjectVersionsOutput, bool) bool, opts ...request.Option) error { - start := time.Now() - err := m.S3API.ListObjectVersionsPagesWithContext(ctx, input, f, opts...) - - recordMetrics("list_object_versions_pages", err, time.Since(start).Seconds()) - return err -} - -func (m *APIMetrics) ListObjectVersionsWithContext(ctx context.Context, input *s3.ListObjectVersionsInput, opts ...request.Option) (*s3.ListObjectVersionsOutput, error) { - start := time.Now() - output, err := m.S3API.ListObjectVersionsWithContext(ctx, input, opts...) - - recordMetrics("list_object_versions", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) DeleteObjectWithContext(ctx context.Context, input *s3.DeleteObjectInput, opts ...request.Option) (*s3.DeleteObjectOutput, error) { - start := time.Now() - output, err := m.S3API.DeleteObjectWithContext(ctx, input, opts...) - - recordMetrics("delete_object", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) DeleteBucketWithContext(ctx context.Context, input *s3.DeleteBucketInput, opts ...request.Option) (*s3.DeleteBucketOutput, error) { - start := time.Now() - output, err := m.S3API.DeleteBucketWithContext(ctx, input, opts...) - - recordMetrics("delete_bucket", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) HeadBucketWithContext(ctx context.Context, input *s3.HeadBucketInput, opts ...request.Option) (*s3.HeadBucketOutput, error) { - start := time.Now() - output, err := m.S3API.HeadBucketWithContext(ctx, input, opts...) - - recordMetrics("head_bucket", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) CreateBucketWithContext(ctx context.Context, input *s3.CreateBucketInput, opts ...request.Option) (*s3.CreateBucketOutput, error) { - start := time.Now() - output, err := m.S3API.CreateBucketWithContext(ctx, input, opts...) - - recordMetrics("create_bucket", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) PutBucketVersioningWithContext(ctx context.Context, input *s3.PutBucketVersioningInput, opts ...request.Option) (*s3.PutBucketVersioningOutput, error) { - start := time.Now() - output, err := m.S3API.PutBucketVersioningWithContext(ctx, input, opts...) - - recordMetrics("put_bucket_versioning", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) PutBucketEncryptionWithContext(ctx context.Context, input *s3.PutBucketEncryptionInput, opts ...request.Option) (*s3.PutBucketEncryptionOutput, error) { - start := time.Now() - output, err := m.S3API.PutBucketEncryptionWithContext(ctx, input, opts...) - - recordMetrics("put_bucket_encryption", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) CreateMultipartUploadWithContext(ctx context.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) { - start := time.Now() - output, err := m.S3API.CreateMultipartUploadWithContext(ctx, input, opts...) - - recordMetrics("create_multipart_upload", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) UploadPartWithContext(ctx context.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) { - start := time.Now() - output, err := m.S3API.UploadPartWithContext(ctx, input, opts...) - - recordMetrics("upload_part", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) AbortMultipartUploadWithContext(ctx context.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) { - start := time.Now() - output, err := m.S3API.AbortMultipartUploadWithContext(ctx, input, opts...) - - recordMetrics("abort_multipart_upload", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) CompleteMultipartUploadWithContext(ctx context.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) { - start := time.Now() - output, err := m.S3API.CompleteMultipartUploadWithContext(ctx, input, opts...) - - recordMetrics("complete_multipart_upload", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) ListPartsWithContext(ctx context.Context, input *s3.ListPartsInput, opts ...request.Option) (*s3.ListPartsOutput, error) { - start := time.Now() - output, err := m.S3API.ListPartsWithContext(ctx, input, opts...) - - recordMetrics("list_parts", err, time.Since(start).Seconds()) - return output, err -} - -func (m *APIMetrics) ListMultipartUploadsWithContext(ctx context.Context, input *s3.ListMultipartUploadsInput, opts ...request.Option) (*s3.ListMultipartUploadsOutput, error) { - start := time.Now() - output, err := m.S3API.ListMultipartUploadsWithContext(ctx, input, opts...) - - recordMetrics("list_multipart_uploads", err, time.Since(start).Seconds()) - return output, err -} diff --git a/lib/service/service.go b/lib/service/service.go index 656b5a59056cc..76810b3970294 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -48,7 +48,6 @@ import ( "testing" "time" - awscredentials "github.com/aws/aws-sdk-go/aws/credentials" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/google/renameio/v2" "github.com/google/uuid" @@ -1783,7 +1782,7 @@ func initAuthUploadHandler(ctx context.Context, auditConfig types.ClusterAuditCo UseFIPSEndpoint: auditConfig.GetUseFIPSEndpoint(), } if externalAuditStorage.IsUsed() { - config.Credentials = awscredentials.NewCredentials(externalAuditStorage.CredentialsProviderSDKV1()) + config.CredentialsProvider = externalAuditStorage.CredentialsProvider() } if err := config.SetFromURL(uri, auditConfig.Region()); err != nil { return nil, trace.Wrap(err)