From 3104fc67760f1e24973a57b0b6a510b9070d5575 Mon Sep 17 00:00:00 2001 From: Yassine Bounekhla Date: Mon, 29 Jul 2024 04:19:36 -0400 Subject: [PATCH] migrate to aws-sdk-go-v2 --- lib/events/api.go | 6 +- lib/events/azsessions/azsessions.go | 10 +- lib/events/eventstest/uploader.go | 16 +- lib/events/filesessions/filestream.go | 18 +- lib/events/filesessions/filestream_test.go | 24 +-- lib/events/gcssessions/gcsstream.go | 8 +- lib/events/s3sessions/s3handler.go | 194 +++++++++--------- .../s3sessions/s3handler_thirdparty_test.go | 8 +- lib/events/s3sessions/s3stream.go | 102 ++++----- lib/events/stream.go | 10 +- .../externalauditstorage/error_counter.go | 4 +- lib/observability/metrics/s3/api.go | 70 +++---- lib/service/service.go | 3 +- 13 files changed, 221 insertions(+), 252 deletions(-) diff --git a/lib/events/api.go b/lib/events/api.go index 8ac05d415ddde..4b1fe8e7ab267 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -874,7 +874,7 @@ type Streamer interface { // StreamPart represents uploaded stream part type StreamPart struct { // Number is a part number - Number int64 + Number int32 // ETag is a part e-tag ETag string } @@ -914,9 +914,9 @@ type MultipartUploader interface { CompleteUpload(ctx context.Context, upload StreamUpload, parts []StreamPart) error // ReserveUploadPart reserves an upload part. Reserve is used to identify // upload errors beforehand. - ReserveUploadPart(ctx context.Context, upload StreamUpload, partNumber int64) error + ReserveUploadPart(ctx context.Context, upload StreamUpload, partNumber int32) error // UploadPart uploads part and returns the part - UploadPart(ctx context.Context, upload StreamUpload, partNumber int64, partBody io.ReadSeeker) (*StreamPart, error) + UploadPart(ctx context.Context, upload StreamUpload, partNumber int32, partBody io.ReadSeeker) (*StreamPart, error) // ListParts returns all uploaded parts for the completed upload in sorted order ListParts(ctx context.Context, upload StreamUpload) ([]StreamPart, error) // ListUploads lists uploads that have been initiated but not completed with diff --git a/lib/events/azsessions/azsessions.go b/lib/events/azsessions/azsessions.go index 85aa13d39084d..19bd75bda2c3a 100644 --- a/lib/events/azsessions/azsessions.go +++ b/lib/events/azsessions/azsessions.go @@ -85,7 +85,7 @@ func partPrefix(upload events.StreamUpload) string { } // partName returns the name of the blob for a specific part in an upload. -func partName(upload events.StreamUpload, partNumber int64) string { +func partName(upload events.StreamUpload, partNumber int32) string { return fmt.Sprintf("%v%v", partPrefix(upload), partNumber) } @@ -252,7 +252,7 @@ func (h *Handler) uploadMarkerBlob(upload events.StreamUpload) *blockblob.Client // partBlob returns a BlockBlobClient for the blob of the part of the specified // upload, with the given part number. -func (h *Handler) partBlob(upload events.StreamUpload, partNumber int64) *blockblob.Client { +func (h *Handler) partBlob(upload events.StreamUpload, partNumber int32) *blockblob.Client { return h.inprogress.NewBlockBlobClient(partName(upload, partNumber)) } @@ -440,12 +440,12 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload } // ReserveUploadPart implements [events.MultipartUploader]. -func (*Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (*Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { return nil } // UploadPart implements [events.MultipartUploader]. -func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { +func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { partBlob := h.partBlob(upload, partNumber) // our parts are just over 5 MiB (events.MinUploadPartSizeBytes) so we can @@ -493,7 +493,7 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] continue } - parts = append(parts, events.StreamPart{Number: partNumber}) + parts = append(parts, events.StreamPart{Number: int32(partNumber)}) } } diff --git a/lib/events/eventstest/uploader.go b/lib/events/eventstest/uploader.go index 63a30cd242684..b48adccf64270 100644 --- a/lib/events/eventstest/uploader.go +++ b/lib/events/eventstest/uploader.go @@ -63,7 +63,7 @@ type MemoryUpload struct { // id is the upload ID id string // parts is the upload parts - parts map[int64][]byte + parts map[int32][]byte // sessionID is the session ID associated with the upload sessionID session.ID //completed specifies upload as completed @@ -105,7 +105,7 @@ func (m *MemoryUploader) CreateUpload(ctx context.Context, sessionID session.ID) m.uploads[upload.ID] = &MemoryUpload{ id: upload.ID, sessionID: sessionID, - parts: make(map[int64][]byte), + parts: make(map[int32][]byte), Initiated: upload.Initiated, } return upload, nil @@ -124,7 +124,7 @@ func (m *MemoryUploader) CompleteUpload(ctx context.Context, upload events.Strea } // verify that all parts have been uploaded var result []byte - partsSet := make(map[int64]bool, len(parts)) + partsSet := make(map[int32]bool, len(parts)) for _, part := range parts { partsSet[part.Number] = true data, ok := up.parts[part.Number] @@ -146,7 +146,7 @@ func (m *MemoryUploader) CompleteUpload(ctx context.Context, upload events.Strea } // UploadPart uploads part and returns the part -func (m *MemoryUploader) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { +func (m *MemoryUploader) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { data, err := io.ReadAll(partBody) if err != nil { return nil, trace.Wrap(err) @@ -190,7 +190,7 @@ func (m *MemoryUploader) GetParts(uploadID string) ([][]byte, error) { return nil, trace.NotFound("upload %q is not found", uploadID) } - partNumbers := make([]int64, 0, len(up.parts)) + partNumbers := make([]int32, 0, len(up.parts)) sortedParts := make([][]byte, 0, len(up.parts)) for partNumber := range up.parts { partNumbers = append(partNumbers, partNumber) @@ -222,7 +222,7 @@ func (m *MemoryUploader) ListParts(ctx context.Context, upload events.StreamUplo return nil, trace.NotFound("upload %v is not found", upload.ID) } - partNumbers := make([]int64, 0, len(up.parts)) + partNumbers := make([]int32, 0, len(up.parts)) sortedParts := make([]events.StreamPart, 0, len(up.parts)) for partNumber := range up.parts { partNumbers = append(partNumbers, partNumber) @@ -278,7 +278,7 @@ func (m *MemoryUploader) GetUploadMetadata(sid session.ID) events.UploadMetadata } // ReserveUploadPart reserves an upload part. -func (m *MemoryUploader) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (m *MemoryUploader) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { return nil } @@ -307,7 +307,7 @@ func (m *MockUploader) CreateUpload(ctx context.Context, sessionID session.ID) ( }, nil } -func (m *MockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int64) error { +func (m *MockUploader) ReserveUploadPart(_ context.Context, _ events.StreamUpload, _ int32) error { return m.ReserveUploadPartError } diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index c9c5ff5ccd855..768ee7cd3f49f 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -105,7 +105,7 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even } // UploadPart uploads part -func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { +func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { if err := checkUpload(upload); err != nil { return nil, trace.Wrap(err) } @@ -339,7 +339,7 @@ func (h *Handler) GetUploadMetadata(s session.ID) events.UploadMetadata { } // ReserveUploadPart reserves an upload part. -func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { file, partPath, err := h.openReservationPart(upload, partNumber) if err != nil { return trace.ConvertSystemError(err) @@ -361,7 +361,7 @@ func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpl } // openReservationPart opens a reservation upload part file. -func (h *Handler) openReservationPart(upload events.StreamUpload, partNumber int64) (*os.File, string, error) { +func (h *Handler) openReservationPart(upload events.StreamUpload, partNumber int32) (*os.File, string, error) { partPath := h.reservationPath(upload, partNumber) file, err := GetOpenFileFunc()(partPath, os.O_RDWR|os.O_CREATE, 0o600) if err != nil { @@ -383,23 +383,23 @@ func (h *Handler) uploadPath(upload events.StreamUpload) string { return filepath.Join(h.uploadRootPath(upload), string(upload.SessionID)) } -func (h *Handler) partPath(upload events.StreamUpload, partNumber int64) string { +func (h *Handler) partPath(upload events.StreamUpload, partNumber int32) string { return filepath.Join(h.uploadPath(upload), partFileName(partNumber)) } -func (h *Handler) reservationPath(upload events.StreamUpload, partNumber int64) string { +func (h *Handler) reservationPath(upload events.StreamUpload, partNumber int32) string { return filepath.Join(h.uploadPath(upload), reservationFileName(partNumber)) } -func partFileName(partNumber int64) string { +func partFileName(partNumber int32) string { return fmt.Sprintf("%v%v", partNumber, partExt) } -func reservationFileName(partNumber int64) string { +func reservationFileName(partNumber int32) string { return fmt.Sprintf("%v%v", partNumber, reservationExt) } -func partFromFileName(fileName string) (int64, error) { +func partFromFileName(fileName string) (int32, error) { base := filepath.Base(fileName) if filepath.Ext(base) != partExt { return -1, trace.BadParameter("expected extension %v, got %v", partExt, base) @@ -409,7 +409,7 @@ func partFromFileName(fileName string) (int64, error) { if err != nil { return -1, trace.Wrap(err) } - return partNumber, nil + return int32(partNumber), nil } // checkUpload checks that upload IDs are valid diff --git a/lib/events/filesessions/filestream_test.go b/lib/events/filesessions/filestream_test.go index 7e409e34920e3..61e036282b34e 100644 --- a/lib/events/filesessions/filestream_test.go +++ b/lib/events/filesessions/filestream_test.go @@ -33,7 +33,7 @@ import ( func TestReserveUploadPart(t *testing.T) { ctx := context.Background() - partNumber := int64(1) + partNumber := int32(1) dir := t.TempDir() handler, err := NewHandler(Config{ @@ -54,7 +54,7 @@ func TestReserveUploadPart(t *testing.T) { func TestUploadPart(t *testing.T) { ctx := context.Background() - partNumber := int64(1) + partNumber := int32(1) dir := t.TempDir() expectedContent := []byte("upload part contents") @@ -89,7 +89,7 @@ func TestCompleteUpload(t *testing.T) { ctx := context.Background() // Create some upload parts using reserve + write. - createPart := func(t *testing.T, handler *Handler, upload *events.StreamUpload, partNumber int64, content []byte) events.StreamPart { + createPart := func(t *testing.T, handler *Handler, upload *events.StreamUpload, partNumber int32, content []byte) events.StreamPart { err := handler.ReserveUploadPart(ctx, *upload, partNumber) require.NoError(t, err) @@ -111,27 +111,27 @@ func TestCompleteUpload(t *testing.T) { desc: "PartsWithContent", expectedContent: []byte("helloworld"), partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) { - createPart(t, handler, upload, int64(1), []byte("hello")) - createPart(t, handler, upload, int64(2), []byte("world")) + createPart(t, handler, upload, int32(1), []byte("hello")) + createPart(t, handler, upload, int32(2), []byte("world")) }, }, { desc: "ReservationParts", expectedContent: []byte("helloworldwithreservation"), partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) { - createPart(t, handler, upload, int64(1), []byte{}) - createPart(t, handler, upload, int64(2), []byte("hello")) - createPart(t, handler, upload, int64(3), []byte("world")) - createPart(t, handler, upload, int64(4), []byte{}) - createPart(t, handler, upload, int64(5), []byte("withreservation")) + createPart(t, handler, upload, int32(1), []byte{}) + createPart(t, handler, upload, int32(2), []byte("hello")) + createPart(t, handler, upload, int32(3), []byte("world")) + createPart(t, handler, upload, int32(4), []byte{}) + createPart(t, handler, upload, int32(5), []byte("withreservation")) }, }, { desc: "OnlyReservation", expectedContent: []byte{}, partsFunc: func(t *testing.T, handler *Handler, upload *events.StreamUpload) { - createPart(t, handler, upload, int64(1), []byte{}) - createPart(t, handler, upload, int64(2), []byte{}) + createPart(t, handler, upload, int32(1), []byte{}) + createPart(t, handler, upload, int32(2), []byte{}) }, }, } { diff --git a/lib/events/gcssessions/gcsstream.go b/lib/events/gcssessions/gcsstream.go index f18487fed85e9..0c947ba12378b 100644 --- a/lib/events/gcssessions/gcsstream.go +++ b/lib/events/gcssessions/gcsstream.go @@ -80,7 +80,7 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even } // UploadPart uploads part -func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { +func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { if err := upload.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -292,7 +292,7 @@ func (h *Handler) GetUploadMetadata(s session.ID) events.UploadMetadata { } // ReserveUploadPart reserves an upload part. -func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { return nil } @@ -341,7 +341,7 @@ func (h *Handler) partsPrefix(upload events.StreamUpload) string { } // partPath is "path/parts//.part" -func (h *Handler) partPath(upload events.StreamUpload, partNumber int64) string { +func (h *Handler) partPath(upload events.StreamUpload, partNumber int32) string { return path.Join(h.partsPrefix(upload), fmt.Sprintf("%v%v", partNumber, partExt)) } @@ -398,5 +398,5 @@ func partFromPath(uploadPath string) (*events.StreamPart, error) { if err != nil { return nil, trace.Wrap(err) } - return &events.StreamPart{Number: partNumber}, nil + return &events.StreamPart{Number: int32(partNumber)}, nil } diff --git a/lib/events/s3sessions/s3handler.go b/lib/events/s3sessions/s3handler.go index a83d30b890de9..fc702e7c5fa94 100644 --- a/lib/events/s3sessions/s3handler.go +++ b/lib/events/s3sessions/s3handler.go @@ -20,8 +20,10 @@ package s3sessions import ( "context" + "crypto/tls" "fmt" "io" + "net/http" "net/url" "path" "sort" @@ -29,13 +31,11 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - awssession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/aws/aws-sdk-go/service/s3/s3manager" - "github.com/aws/aws-sdk-go/service/s3/s3manager/s3manageriface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" @@ -77,10 +77,10 @@ type Config struct { Endpoint string // ACL is the canned ACL to send to S3 ACL string - // Session is an optional existing AWS client session - Session *awssession.Session - // Credentials if supplied are used in tests or with External Audit Storage. - Credentials *credentials.Credentials + // AWSConfig is an optional existing AWS client configuration + AWSConfig aws.Config + // CredentialsProvider if supplied is used in tests or with External Audit Storage. + CredentialsProvider aws.CredentialsProvider // SSEKMSKey specifies the optional custom CMK used for KMS SSE. SSEKMSKey string @@ -156,38 +156,37 @@ func (s *Config) CheckAndSetDefaults() error { if s.Bucket == "" { return trace.BadParameter("missing parameter Bucket") } - if s.Session == nil { - awsConfig := aws.Config{ - UseFIPSEndpoint: events.FIPSProtoStateToAWSState(s.UseFIPSEndpoint), - } - if s.Region != "" { - awsConfig.Region = aws.String(s.Region) - } - if s.Endpoint != "" { - awsConfig.Endpoint = aws.String(s.Endpoint) - awsConfig.S3ForcePathStyle = aws.Bool(true) + + if s.AWSConfig.Region == "" { + var err error + opts := []func(*config.LoadOptions) error{ + config.WithRegion(s.Region), } + if s.Insecure { - awsConfig.DisableSSL = aws.Bool(s.Insecure) - } - if s.Credentials != nil { - awsConfig.Credentials = s.Credentials + opts = append(opts, config.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + })) + } else { + hc, err := defaults.HTTPClient() + if err != nil { + return trace.Wrap(err) + } + + opts = append(opts, config.WithHTTPClient(hc)) } - hc, err := defaults.HTTPClient() - if err != nil { - return trace.Wrap(err) + + if s.CredentialsProvider != nil { + opts = append(opts, config.WithCredentialsProvider(s.CredentialsProvider)) } - awsConfig.HTTPClient = hc - sess, err := awssession.NewSessionWithOptions(awssession.Options{ - SharedConfigState: awssession.SharedConfigEnable, - Config: awsConfig, - }) + s.AWSConfig, err = config.LoadDefaultConfig(context.Background(), opts...) if err != nil { return trace.Wrap(err) } - s.Session = sess } return nil } @@ -198,20 +197,20 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) { return nil, trace.Wrap(err) } - client, err := s3metrics.NewAPIMetrics(s3.New(cfg.Session)) - if err != nil { - return nil, trace.Wrap(err) - } + // Create S3 client with custom options + s3Client := s3.NewFromConfig(cfg.AWSConfig, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.UsePathStyle = true + } + }) - uploader, err := s3metrics.NewUploadAPIMetrics(s3manager.NewUploader(cfg.Session)) + client, err := s3metrics.NewAPIMetrics(s3Client) if err != nil { return nil, trace.Wrap(err) } - downloader, err := s3metrics.NewDownloadAPIMetrics(s3manager.NewDownloader(cfg.Session)) - if err != nil { - return nil, trace.Wrap(err) - } + uploader := manager.NewUploader(s3Client) + downloader := manager.NewDownloader(s3Client) h := &Handler{ Entry: log.WithFields(log.Fields{ @@ -222,6 +221,7 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) { downloader: downloader, client: client, } + start := time.Now() h.Infof("Setting up bucket %q, sessions path %q in region %q.", h.Bucket, h.Path, h.Region) if err := h.ensureBucket(ctx); err != nil { @@ -237,9 +237,9 @@ type Handler struct { Config // Entry is a logging entry *log.Entry - uploader s3manageriface.UploaderAPI - downloader s3manageriface.DownloaderAPI - client s3iface.S3API + uploader *manager.Uploader + downloader *manager.Downloader + client *s3metrics.APIMetrics } // Close releases connection and resources associated with log if any @@ -250,25 +250,23 @@ func (h *Handler) Close() error { // Upload uploads object to S3 bucket, reads the contents of the object from reader // and returns the target S3 bucket path in case of successful upload. func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Reader) (string, error) { - var err error path := h.path(sessionID) - uploadInput := &s3manager.UploadInput{ + uploadInput := &s3.PutObjectInput{ Bucket: aws.String(h.Bucket), Key: aws.String(path), Body: reader, } if !h.Config.DisableServerSideEncryption { - uploadInput.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms) - + uploadInput.ServerSideEncryption = awstypes.ServerSideEncryptionAwsKms if h.Config.SSEKMSKey != "" { uploadInput.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey) } } if h.Config.ACL != "" { - uploadInput.ACL = aws.String(h.Config.ACL) + uploadInput.ACL = awstypes.ObjectCannedACL(h.Config.ACL) } - _, err = h.uploader.UploadWithContext(ctx, uploadInput) + _, err := h.uploader.Upload(ctx, uploadInput) if err != nil { return "", awsutils.ConvertS3Error(err) } @@ -278,9 +276,6 @@ func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Re // Download downloads recorded session from S3 bucket and writes the results // into writer return trace.NotFound error is object is not found. func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io.WriterAt) error { - // Get the oldest version of this object. This has to be done because S3 - // allows overwriting objects in a bucket. To prevent corruption of recording - // data, get all versions and always return the first. versionID, err := h.getOldestVersion(ctx, h.Bucket, h.path(sessionID)) if err != nil { return trace.Wrap(err) @@ -288,7 +283,7 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io. h.Debugf("Downloading %v/%v [%v].", h.Bucket, h.path(sessionID), versionID) - written, err := h.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{ + _, err = h.downloader.Download(ctx, writer, &s3.GetObjectInput{ Bucket: aws.String(h.Bucket), Key: aws.String(h.path(sessionID)), VersionId: aws.String(versionID), @@ -296,9 +291,6 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io. if err != nil { return awsutils.ConvertS3Error(err) } - if written == 0 { - return trace.NotFound("recording for %v is not found", sessionID) - } return nil } @@ -315,29 +307,28 @@ type versionID struct { func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix string) (string, error) { var versions []versionID - // Get all versions of this object. - err := h.client.ListObjectVersionsPagesWithContext(ctx, &s3.ListObjectVersionsInput{ + paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{ Bucket: aws.String(bucket), Prefix: aws.String(prefix), - }, func(page *s3.ListObjectVersionsOutput, lastPage bool) bool { + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return "", awsutils.ConvertS3Error(err) + } for _, v := range page.Versions { versions = append(versions, versionID{ - ID: *v.VersionId, + ID: aws.ToString(v.VersionId), Timestamp: *v.LastModified, }) } - - // Returning false stops iteration, stop iteration upon last page. - return !lastPage - }) - if err != nil { - return "", awsutils.ConvertS3Error(err) } + if len(versions) == 0 { return "", trace.NotFound("%v/%v not found", bucket, prefix) } - // Sort the versions slice so the first entry is the oldest and return it. sort.Slice(versions, func(i int, j int) bool { return versions[i].Timestamp.Before(versions[j].Timestamp) }) @@ -346,24 +337,28 @@ func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix st // delete bucket deletes bucket and all it's contents and is used in tests func (h *Handler) deleteBucket(ctx context.Context) error { - // first, list and delete all the objects in the bucket - out, err := h.client.ListObjectVersionsWithContext(ctx, &s3.ListObjectVersionsInput{ + paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{ Bucket: aws.String(h.Bucket), }) - if err != nil { - return awsutils.ConvertS3Error(err) - } - for _, ver := range out.Versions { - _, err := h.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(h.Bucket), - Key: ver.Key, - VersionId: ver.VersionId, - }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return awsutils.ConvertS3Error(err) } + for _, ver := range page.Versions { + _, err := h.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(h.Bucket), + Key: ver.Key, + VersionId: ver.VersionId, + }) + if err != nil { + return awsutils.ConvertS3Error(err) + } + } } - _, err = h.client.DeleteBucketWithContext(ctx, &s3.DeleteBucketInput{ + + _, err := h.client.DeleteBucket(ctx, &s3.DeleteBucketInput{ Bucket: aws.String(h.Bucket), }) return awsutils.ConvertS3Error(err) @@ -382,11 +377,10 @@ func (h *Handler) fromPath(p string) session.ID { // ensureBucket makes sure bucket exists, and if it does not, creates it func (h *Handler) ensureBucket(ctx context.Context) error { - _, err := h.client.HeadBucketWithContext(ctx, &s3.HeadBucketInput{ + _, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: aws.String(h.Bucket), }) err = awsutils.ConvertS3Error(err) - // assumes that bucket is administered by other entity if err == nil { return nil } @@ -396,44 +390,42 @@ func (h *Handler) ensureBucket(ctx context.Context) error { } input := &s3.CreateBucketInput{ Bucket: aws.String(h.Bucket), - ACL: aws.String("private"), + ACL: awstypes.BucketCannedACLPrivate, } - _, err = h.client.CreateBucketWithContext(ctx, input) + _, err = h.client.CreateBucket(ctx, input) err = awsutils.ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket))) if err != nil { if !trace.IsAlreadyExists(err) { return trace.Wrap(err) } - // if this client has not created the bucket, don't reconfigure it return nil } - // Turn on versioning. - ver := &s3.PutBucketVersioningInput{ + _, err = h.client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{ Bucket: aws.String(h.Bucket), - VersioningConfiguration: &s3.VersioningConfiguration{ - Status: aws.String("Enabled"), + VersioningConfiguration: &awstypes.VersioningConfiguration{ + Status: awstypes.BucketVersioningStatusEnabled, }, - } - _, err = h.client.PutBucketVersioningWithContext(ctx, ver) + }) err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) } - // Turn on server-side encryption for the bucket. if !h.DisableServerSideEncryption { - _, err = h.client.PutBucketEncryptionWithContext(ctx, &s3.PutBucketEncryptionInput{ + _, err = h.client.PutBucketEncryption(ctx, &s3.PutBucketEncryptionInput{ Bucket: aws.String(h.Bucket), - ServerSideEncryptionConfiguration: &s3.ServerSideEncryptionConfiguration{ - Rules: []*s3.ServerSideEncryptionRule{{ - ApplyServerSideEncryptionByDefault: &s3.ServerSideEncryptionByDefault{ - SSEAlgorithm: aws.String(s3.ServerSideEncryptionAwsKms), + ServerSideEncryptionConfiguration: &awstypes.ServerSideEncryptionConfiguration{ + Rules: []awstypes.ServerSideEncryptionRule{ + { + ApplyServerSideEncryptionByDefault: &awstypes.ServerSideEncryptionByDefault{ + SSEAlgorithm: awstypes.ServerSideEncryptionAwsKms, + }, }, - }}, + }, }, }) - err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) + err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set encryption state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) } diff --git a/lib/events/s3sessions/s3handler_thirdparty_test.go b/lib/events/s3sessions/s3handler_thirdparty_test.go index b379446974bec..a5a15d77cdb8a 100644 --- a/lib/events/s3sessions/s3handler_thirdparty_test.go +++ b/lib/events/s3sessions/s3handler_thirdparty_test.go @@ -1,6 +1,6 @@ /* * Teleport - * Copyright (C) 2023 Gravitational, Inc. +* Copyright (C) 2023 Gravitational, Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -14,7 +14,7 @@ * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . - */ +*/ package s3sessions @@ -24,7 +24,7 @@ import ( "net/http/httptest" "testing" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/johannesboyne/gofakes3" @@ -43,7 +43,7 @@ func TestThirdpartyStreams(t *testing.T) { server := httptest.NewServer(faker.Server()) handler, err := NewHandler(context.Background(), Config{ - Credentials: credentials.NewStaticCredentials("YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", ""), + CredentialsProvider: credentials.NewStaticCredentialsProvider("YOUR-ACCESSKEYID", "YOUR-SECRETACCESSKEY", ""), Region: "us-west-1", Path: "/test/", Bucket: fmt.Sprintf("teleport-test-%v", uuid.New().String()), diff --git a/lib/events/s3sessions/s3stream.go b/lib/events/s3sessions/s3stream.go index c855ca564180f..5ab03ca1ab864 100644 --- a/lib/events/s3sessions/s3stream.go +++ b/lib/events/s3sessions/s3stream.go @@ -27,14 +27,14 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gravitational/trace" "github.com/sirupsen/logrus" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -49,33 +49,32 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even Key: aws.String(h.path(sessionID)), } if !h.Config.DisableServerSideEncryption { - input.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms) + input.ServerSideEncryption = types.ServerSideEncryptionAwsKms if h.Config.SSEKMSKey != "" { input.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey) } } if h.Config.ACL != "" { - input.ACL = aws.String(h.Config.ACL) + input.ACL = types.ObjectCannedACL(h.Config.ACL) } - resp, err := h.client.CreateMultipartUploadWithContext(ctx, input) + resp, err := h.client.CreateMultipartUpload(ctx, input) if err != nil { return nil, trace.Wrap(awsutils.ConvertS3Error(err), "CreateMultiPartUpload session(%v)", sessionID) } h.WithFields(logrus.Fields{ - "upload": aws.StringValue(resp.UploadId), + "upload": aws.ToString(resp.UploadId), "session": sessionID, - "key": aws.StringValue(resp.Key), + "key": aws.ToString(resp.Key), }).Infof("Created upload in %v", time.Since(start)) - return &events.StreamUpload{SessionID: sessionID, ID: aws.StringValue(resp.UploadId)}, nil + return &events.StreamUpload{SessionID: sessionID, ID: aws.ToString(resp.UploadId)}, nil } // UploadPart uploads part -func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { - // This upload exceeded maximum number of supported parts, error now. +func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { if partNumber > s3manager.MaxUploadParts { return nil, trace.LimitExceeded( "exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit", s3manager.MaxUploadParts) @@ -94,18 +93,18 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa UploadId: aws.String(upload.ID), Key: aws.String(uploadKey), Body: partBody, - PartNumber: aws.Int64(partNumber), + PartNumber: &partNumber, } log.Debugf("Uploading part %v", partNumber) - resp, err := h.client.UploadPartWithContext(ctx, params) + resp, err := h.client.UploadPart(ctx, params) if err != nil { return nil, trace.Wrap(awsutils.ConvertS3Error(err), "UploadPart(upload %v) part(%v) session(%v)", upload.ID, partNumber, upload.SessionID) } log.Infof("Uploaded part %v in %v", partNumber, time.Since(start)) - return &events.StreamPart{ETag: aws.StringValue(resp.ETag), Number: partNumber}, nil + return &events.StreamPart{ETag: aws.ToString(resp.ETag), Number: partNumber}, nil } func (h *Handler) abortUpload(ctx context.Context, upload events.StreamUpload) error { @@ -121,7 +120,7 @@ func (h *Handler) abortUpload(ctx context.Context, upload events.StreamUpload) e UploadId: aws.String(upload.ID), } log.Debug("Aborting upload") - _, err := h.client.AbortMultipartUploadWithContext(ctx, req) + _, err := h.client.AbortMultipartUpload(ctx, req) if err != nil { return awsutils.ConvertS3Error(err) } @@ -157,11 +156,11 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload return parts[i].Number < parts[j].Number }) - completedParts := make([]*s3.CompletedPart, len(parts)) + completedParts := make([]types.CompletedPart, len(parts)) for i := range parts { - completedParts[i] = &s3.CompletedPart{ + completedParts[i] = types.CompletedPart{ ETag: aws.String(parts[i].ETag), - PartNumber: aws.Int64(parts[i].Number), + PartNumber: aws.Int32(int32(parts[i].Number)), } } @@ -170,9 +169,9 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload Bucket: aws.String(h.Bucket), Key: aws.String(uploadKey), UploadId: aws.String(upload.ID), - MultipartUpload: &s3.CompletedMultipartUpload{Parts: completedParts}, + MultipartUpload: &types.CompletedMultipartUpload{Parts: completedParts}, } - _, err := h.client.CompleteMultipartUploadWithContext(ctx, params) + _, err := h.client.CompleteMultipartUpload(ctx, params) if err != nil { return trace.Wrap(awsutils.ConvertS3Error(err), "CompleteMultipartUpload(upload %v) session(%v)", upload.ID, upload.SessionID) @@ -193,29 +192,26 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] log.Debug("Listing parts for upload") var parts []events.StreamPart - var partNumberMarker *int64 - for i := 0; i < defaults.MaxIterationLimit; i++ { - re, err := h.client.ListPartsWithContext(ctx, &s3.ListPartsInput{ - Bucket: aws.String(h.Bucket), - Key: aws.String(uploadKey), - UploadId: aws.String(upload.ID), - PartNumberMarker: partNumberMarker, - }) + + paginator := s3.NewListPartsPaginator(h.client, &s3.ListPartsInput{ + Bucket: aws.String(h.Bucket), + Key: aws.String(uploadKey), + UploadId: aws.String(upload.ID), + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return nil, awsutils.ConvertS3Error(err) } - for _, part := range re.Parts { - + for _, part := range page.Parts { parts = append(parts, events.StreamPart{ - Number: aws.Int64Value(part.PartNumber), - ETag: aws.StringValue(part.ETag), + Number: *part.PartNumber, + ETag: aws.ToString(part.ETag), }) } - if !aws.BoolValue(re.IsTruncated) { - break - } - partNumberMarker = re.NextPartNumberMarker } + // Parts must be sorted in PartNumber order. sort.Slice(parts, func(i, j int) bool { return parts[i].Number < parts[j].Number @@ -231,31 +227,23 @@ func (h *Handler) ListUploads(ctx context.Context) ([]events.StreamUpload, error prefix = &trimmed } var uploads []events.StreamUpload - var keyMarker *string - var uploadIDMarker *string - for i := 0; i < defaults.MaxIterationLimit; i++ { - input := &s3.ListMultipartUploadsInput{ - Bucket: aws.String(h.Bucket), - Prefix: prefix, - KeyMarker: keyMarker, - UploadIdMarker: uploadIDMarker, - } - re, err := h.client.ListMultipartUploadsWithContext(ctx, input) + paginator := s3.NewListMultipartUploadsPaginator(h.client, &s3.ListMultipartUploadsInput{ + Bucket: aws.String(h.Bucket), + Prefix: prefix, + }) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) if err != nil { return nil, awsutils.ConvertS3Error(err) } - for _, upload := range re.Uploads { + for _, upload := range page.Uploads { uploads = append(uploads, events.StreamUpload{ - ID: aws.StringValue(upload.UploadId), - SessionID: h.fromPath(aws.StringValue(upload.Key)), - Initiated: aws.TimeValue(upload.Initiated), + ID: aws.ToString(upload.UploadId), + SessionID: h.fromPath(aws.ToString(upload.Key)), + Initiated: aws.ToTime(upload.Initiated), }) } - if !aws.BoolValue(re.IsTruncated) { - break - } - keyMarker = re.NextKeyMarker - uploadIDMarker = re.NextUploadIdMarker } sort.Slice(uploads, func(i, j int) bool { @@ -281,6 +269,6 @@ func (h *Handler) GetUploadMetadata(sessionID session.ID) events.UploadMetadata } // ReserveUploadPart reserves an upload part. -func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { return nil } diff --git a/lib/events/stream.go b/lib/events/stream.go index 6c743010d9173..09fdde1c5310b 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -278,7 +278,7 @@ func NewProtoStream(cfg ProtoStreamConfig) (*ProtoStream, error) { writer := &sliceWriter{ proto: stream, - activeUploads: make(map[int64]*activeUpload), + activeUploads: make(map[int32]*activeUpload), completedUploadsC: make(chan *activeUpload, cfg.ConcurrentUploads), semUploads: make(chan struct{}, cfg.ConcurrentUploads), lastPartNumber: 0, @@ -477,9 +477,9 @@ type sliceWriter struct { // current is the current slice being written to current *slice // lastPartNumber is the last assigned part number - lastPartNumber int64 + lastPartNumber int32 // activeUploads tracks active uploads - activeUploads map[int64]*activeUpload + activeUploads map[int32]*activeUpload // completedUploadsC receives uploads that have been completed completedUploadsC chan *activeUpload // semUploads controls concurrent uploads that are in flight @@ -698,7 +698,7 @@ func (w *sliceWriter) completeStream() { // startUpload acquires upload semaphore and starts upload, returns error // only if there is a critical error -func (w *sliceWriter) startUpload(partNumber int64, slice *slice) (*activeUpload, error) { +func (w *sliceWriter) startUpload(partNumber int32, slice *slice) (*activeUpload, error) { // acquire semaphore limiting concurrent uploads select { case w.semUploads <- struct{}{}: @@ -794,7 +794,7 @@ type activeUpload struct { mtx sync.RWMutex start time.Time end time.Time - partNumber int64 + partNumber int32 part *StreamPart err error lastEventIndex int64 diff --git a/lib/integrations/externalauditstorage/error_counter.go b/lib/integrations/externalauditstorage/error_counter.go index fc2e229d647f8..6607d3f3ebc8f 100644 --- a/lib/integrations/externalauditstorage/error_counter.go +++ b/lib/integrations/externalauditstorage/error_counter.go @@ -370,14 +370,14 @@ func (c *ErrorCountingSessionHandler) CompleteUpload(ctx context.Context, upload } // ReserveUploadPart calls [c.wrapped.ReserveUploadPart] and counts the error or success. -func (c *ErrorCountingSessionHandler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64) error { +func (c *ErrorCountingSessionHandler) ReserveUploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32) error { err := c.wrapped.ReserveUploadPart(ctx, upload, partNumber) c.uploads.observe(err) return err } // UploadPart calls [c.wrapped.UploadPart] and counts the error or success. -func (c *ErrorCountingSessionHandler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int64, partBody io.ReadSeeker) (*events.StreamPart, error) { +func (c *ErrorCountingSessionHandler) UploadPart(ctx context.Context, upload events.StreamUpload, partNumber int32, partBody io.ReadSeeker) (*events.StreamPart, error) { part, err := c.wrapped.UploadPart(ctx, upload, partNumber, partBody) c.uploads.observe(err) return part, err diff --git a/lib/observability/metrics/s3/api.go b/lib/observability/metrics/s3/api.go index 43231ab7c321a..b4c1a5a08e2c7 100644 --- a/lib/observability/metrics/s3/api.go +++ b/lib/observability/metrics/s3/api.go @@ -22,133 +22,123 @@ import ( "context" "time" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/gravitational/trace" "github.com/gravitational/teleport/lib/observability/metrics" ) type APIMetrics struct { - s3iface.S3API + client *s3.Client } -func NewAPIMetrics(api s3iface.S3API) (*APIMetrics, error) { +func NewAPIMetrics(client *s3.Client) (*APIMetrics, error) { if err := metrics.RegisterPrometheusCollectors(s3Collectors...); err != nil { return nil, trace.Wrap(err) } - return &APIMetrics{S3API: api}, nil + return &APIMetrics{client: client}, nil } -func (m *APIMetrics) ListObjectVersionsPagesWithContext(ctx context.Context, input *s3.ListObjectVersionsInput, f func(*s3.ListObjectVersionsOutput, bool) bool, opts ...request.Option) error { +func (m *APIMetrics) ListObjectVersions(ctx context.Context, params *s3.ListObjectVersionsInput, optFns ...func(*s3.Options)) (*s3.ListObjectVersionsOutput, error) { start := time.Now() - err := m.S3API.ListObjectVersionsPagesWithContext(ctx, input, f, opts...) - - recordMetrics("list_object_versions_pages", err, time.Since(start).Seconds()) - return err -} - -func (m *APIMetrics) ListObjectVersionsWithContext(ctx context.Context, input *s3.ListObjectVersionsInput, opts ...request.Option) (*s3.ListObjectVersionsOutput, error) { - start := time.Now() - output, err := m.S3API.ListObjectVersionsWithContext(ctx, input, opts...) + output, err := m.client.ListObjectVersions(ctx, params, optFns...) recordMetrics("list_object_versions", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) DeleteObjectWithContext(ctx context.Context, input *s3.DeleteObjectInput, opts ...request.Option) (*s3.DeleteObjectOutput, error) { +func (m *APIMetrics) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { start := time.Now() - output, err := m.S3API.DeleteObjectWithContext(ctx, input, opts...) + output, err := m.client.DeleteObject(ctx, params, optFns...) recordMetrics("delete_object", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) DeleteBucketWithContext(ctx context.Context, input *s3.DeleteBucketInput, opts ...request.Option) (*s3.DeleteBucketOutput, error) { +func (m *APIMetrics) DeleteBucket(ctx context.Context, params *s3.DeleteBucketInput, optFns ...func(*s3.Options)) (*s3.DeleteBucketOutput, error) { start := time.Now() - output, err := m.S3API.DeleteBucketWithContext(ctx, input, opts...) + output, err := m.client.DeleteBucket(ctx, params, optFns...) recordMetrics("delete_bucket", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) HeadBucketWithContext(ctx context.Context, input *s3.HeadBucketInput, opts ...request.Option) (*s3.HeadBucketOutput, error) { +func (m *APIMetrics) HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) { start := time.Now() - output, err := m.S3API.HeadBucketWithContext(ctx, input, opts...) + output, err := m.client.HeadBucket(ctx, params, optFns...) recordMetrics("head_bucket", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) CreateBucketWithContext(ctx context.Context, input *s3.CreateBucketInput, opts ...request.Option) (*s3.CreateBucketOutput, error) { +func (m *APIMetrics) CreateBucket(ctx context.Context, params *s3.CreateBucketInput, optFns ...func(*s3.Options)) (*s3.CreateBucketOutput, error) { start := time.Now() - output, err := m.S3API.CreateBucketWithContext(ctx, input, opts...) + output, err := m.client.CreateBucket(ctx, params, optFns...) recordMetrics("create_bucket", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) PutBucketVersioningWithContext(ctx context.Context, input *s3.PutBucketVersioningInput, opts ...request.Option) (*s3.PutBucketVersioningOutput, error) { +func (m *APIMetrics) PutBucketVersioning(ctx context.Context, params *s3.PutBucketVersioningInput, optFns ...func(*s3.Options)) (*s3.PutBucketVersioningOutput, error) { start := time.Now() - output, err := m.S3API.PutBucketVersioningWithContext(ctx, input, opts...) + output, err := m.client.PutBucketVersioning(ctx, params, optFns...) recordMetrics("put_bucket_versioning", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) PutBucketEncryptionWithContext(ctx context.Context, input *s3.PutBucketEncryptionInput, opts ...request.Option) (*s3.PutBucketEncryptionOutput, error) { +func (m *APIMetrics) PutBucketEncryption(ctx context.Context, params *s3.PutBucketEncryptionInput, optFns ...func(*s3.Options)) (*s3.PutBucketEncryptionOutput, error) { start := time.Now() - output, err := m.S3API.PutBucketEncryptionWithContext(ctx, input, opts...) + output, err := m.client.PutBucketEncryption(ctx, params, optFns...) recordMetrics("put_bucket_encryption", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) CreateMultipartUploadWithContext(ctx context.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) { +func (m *APIMetrics) CreateMultipartUpload(ctx context.Context, params *s3.CreateMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) { start := time.Now() - output, err := m.S3API.CreateMultipartUploadWithContext(ctx, input, opts...) + output, err := m.client.CreateMultipartUpload(ctx, params, optFns...) recordMetrics("create_multipart_upload", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) UploadPartWithContext(ctx context.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) { +func (m *APIMetrics) UploadPart(ctx context.Context, params *s3.UploadPartInput, optFns ...func(*s3.Options)) (*s3.UploadPartOutput, error) { start := time.Now() - output, err := m.S3API.UploadPartWithContext(ctx, input, opts...) + output, err := m.client.UploadPart(ctx, params, optFns...) recordMetrics("upload_part", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) AbortMultipartUploadWithContext(ctx context.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) { +func (m *APIMetrics) AbortMultipartUpload(ctx context.Context, params *s3.AbortMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) { start := time.Now() - output, err := m.S3API.AbortMultipartUploadWithContext(ctx, input, opts...) + output, err := m.client.AbortMultipartUpload(ctx, params, optFns...) recordMetrics("abort_multipart_upload", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) CompleteMultipartUploadWithContext(ctx context.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) { +func (m *APIMetrics) CompleteMultipartUpload(ctx context.Context, params *s3.CompleteMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) { start := time.Now() - output, err := m.S3API.CompleteMultipartUploadWithContext(ctx, input, opts...) + output, err := m.client.CompleteMultipartUpload(ctx, params, optFns...) recordMetrics("complete_multipart_upload", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) ListPartsWithContext(ctx context.Context, input *s3.ListPartsInput, opts ...request.Option) (*s3.ListPartsOutput, error) { +func (m *APIMetrics) ListParts(ctx context.Context, params *s3.ListPartsInput, optFns ...func(*s3.Options)) (*s3.ListPartsOutput, error) { start := time.Now() - output, err := m.S3API.ListPartsWithContext(ctx, input, opts...) + output, err := m.client.ListParts(ctx, params, optFns...) recordMetrics("list_parts", err, time.Since(start).Seconds()) return output, err } -func (m *APIMetrics) ListMultipartUploadsWithContext(ctx context.Context, input *s3.ListMultipartUploadsInput, opts ...request.Option) (*s3.ListMultipartUploadsOutput, error) { +func (m *APIMetrics) ListMultipartUploads(ctx context.Context, params *s3.ListMultipartUploadsInput, optFns ...func(*s3.Options)) (*s3.ListMultipartUploadsOutput, error) { start := time.Now() - output, err := m.S3API.ListMultipartUploadsWithContext(ctx, input, opts...) + output, err := m.client.ListMultipartUploads(ctx, params, optFns...) recordMetrics("list_multipart_uploads", err, time.Since(start).Seconds()) return output, err diff --git a/lib/service/service.go b/lib/service/service.go index 0ed957687395b..aae621a206b4a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -48,7 +48,6 @@ import ( "testing" "time" - awscredentials "github.com/aws/aws-sdk-go/aws/credentials" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/google/renameio/v2" "github.com/google/uuid" @@ -1783,7 +1782,7 @@ func initAuthUploadHandler(ctx context.Context, auditConfig types.ClusterAuditCo UseFIPSEndpoint: auditConfig.GetUseFIPSEndpoint(), } if externalAuditStorage.IsUsed() { - config.Credentials = awscredentials.NewCredentials(externalAuditStorage.CredentialsProviderSDKV1()) + config.CredentialsProvider = externalAuditStorage.CredentialsProvider() } if err := config.SetFromURL(uri, auditConfig.Region()); err != nil { return nil, trace.Wrap(err)