Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate S3 uploader to aws-sdk-go-v2 #44728

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 97 additions & 97 deletions lib/events/s3sessions/s3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ package s3sessions

import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"net/url"
"path"
"sort"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
awssession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go/service/s3/s3manager/s3manageriface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
s3metrics "github.com/gravitational/teleport/lib/observability/metrics/s3"
awsmetrics "github.com/gravitational/teleport/lib/observability/metrics/aws"
"github.com/gravitational/teleport/lib/session"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)
Expand Down Expand Up @@ -77,10 +77,10 @@ type Config struct {
Endpoint string
// ACL is the canned ACL to send to S3
ACL string
// Session is an optional existing AWS client session
Session *awssession.Session
// Credentials if supplied are used in tests or with External Audit Storage.
Credentials *credentials.Credentials
// AWSConfig is an optional existing AWS client configuration
AWSConfig *aws.Config
// CredentialsProvider if supplied is used in tests or with External Audit Storage.
CredentialsProvider aws.CredentialsProvider
// SSEKMSKey specifies the optional custom CMK used for KMS SSE.
SSEKMSKey string

Expand Down Expand Up @@ -156,38 +156,40 @@ func (s *Config) CheckAndSetDefaults() error {
if s.Bucket == "" {
return trace.BadParameter("missing parameter Bucket")
}
if s.Session == nil {
awsConfig := aws.Config{
UseFIPSEndpoint: events.FIPSProtoStateToAWSState(s.UseFIPSEndpoint),
}
if s.Region != "" {
awsConfig.Region = aws.String(s.Region)
}
if s.Endpoint != "" {
awsConfig.Endpoint = aws.String(s.Endpoint)
awsConfig.S3ForcePathStyle = aws.Bool(true)

if s.AWSConfig == nil {
var err error
opts := []func(*config.LoadOptions) error{
config.WithRegion(s.Region),
}

if s.Insecure {
awsConfig.DisableSSL = aws.Bool(s.Insecure)
}
if s.Credentials != nil {
awsConfig.Credentials = s.Credentials
opts = append(opts, config.WithHTTPClient(&http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}))
} else {
hc, err := defaults.HTTPClient()
if err != nil {
return trace.Wrap(err)
}

opts = append(opts, config.WithHTTPClient(hc))
}
hc, err := defaults.HTTPClient()
if err != nil {
return trace.Wrap(err)

if s.CredentialsProvider != nil {
opts = append(opts, config.WithCredentialsProvider(s.CredentialsProvider))
}
awsConfig.HTTPClient = hc

sess, err := awssession.NewSessionWithOptions(awssession.Options{
SharedConfigState: awssession.SharedConfigEnable,
Config: awsConfig,
})
opts = append(opts, config.WithAPIOptions(awsmetrics.MetricsMiddleware()))

awsConfig, err := config.LoadDefaultConfig(context.Background(), opts...)
if err != nil {
return trace.Wrap(err)
}

s.Session = sess
s.AWSConfig = &awsConfig
}
return nil
}
Expand All @@ -198,20 +200,15 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) {
return nil, trace.Wrap(err)
}

client, err := s3metrics.NewAPIMetrics(s3.New(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}

uploader, err := s3metrics.NewUploadAPIMetrics(s3manager.NewUploader(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}
// Create S3 client with custom options
client := s3.NewFromConfig(*cfg.AWSConfig, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.UsePathStyle = true
}
})

downloader, err := s3metrics.NewDownloadAPIMetrics(s3manager.NewDownloader(cfg.Session))
if err != nil {
return nil, trace.Wrap(err)
}
uploader := manager.NewUploader(client)
downloader := manager.NewDownloader(client)

h := &Handler{
Entry: log.WithFields(log.Fields{
Expand All @@ -222,6 +219,7 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) {
downloader: downloader,
client: client,
}

start := time.Now()
h.Infof("Setting up bucket %q, sessions path %q in region %q.", h.Bucket, h.Path, h.Region)
if err := h.ensureBucket(ctx); err != nil {
Expand All @@ -237,9 +235,9 @@ type Handler struct {
Config
// Entry is a logging entry
*log.Entry
uploader s3manageriface.UploaderAPI
downloader s3manageriface.DownloaderAPI
client s3iface.S3API
uploader *manager.Uploader
downloader *manager.Downloader
client *s3.Client
}

// Close releases connection and resources associated with log if any
Expand All @@ -250,25 +248,23 @@ func (h *Handler) Close() error {
// Upload uploads object to S3 bucket, reads the contents of the object from reader
// and returns the target S3 bucket path in case of successful upload.
func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Reader) (string, error) {
var err error
path := h.path(sessionID)

uploadInput := &s3manager.UploadInput{
uploadInput := &s3.PutObjectInput{
Bucket: aws.String(h.Bucket),
Key: aws.String(path),
Body: reader,
}
if !h.Config.DisableServerSideEncryption {
uploadInput.ServerSideEncryption = aws.String(s3.ServerSideEncryptionAwsKms)

uploadInput.ServerSideEncryption = awstypes.ServerSideEncryptionAwsKms
if h.Config.SSEKMSKey != "" {
uploadInput.SSEKMSKeyId = aws.String(h.Config.SSEKMSKey)
}
}
if h.Config.ACL != "" {
uploadInput.ACL = aws.String(h.Config.ACL)
uploadInput.ACL = awstypes.ObjectCannedACL(h.Config.ACL)
}
_, err = h.uploader.UploadWithContext(ctx, uploadInput)
_, err := h.uploader.Upload(ctx, uploadInput)
if err != nil {
return "", awsutils.ConvertS3Error(err)
}
Expand All @@ -288,17 +284,14 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io.

h.Debugf("Downloading %v/%v [%v].", h.Bucket, h.path(sessionID), versionID)

written, err := h.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
_, err = h.downloader.Download(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(h.Bucket),
Key: aws.String(h.path(sessionID)),
VersionId: aws.String(versionID),
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
if written == 0 {
return trace.NotFound("recording for %v is not found", sessionID)
}
return nil
}

Expand All @@ -315,24 +308,24 @@ type versionID struct {
func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix string) (string, error) {
var versions []versionID

// Get all versions of this object.
err := h.client.ListObjectVersionsPagesWithContext(ctx, &s3.ListObjectVersionsInput{
paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
}, func(page *s3.ListObjectVersionsOutput, lastPage bool) bool {
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return "", awsutils.ConvertS3Error(err)
}
for _, v := range page.Versions {
versions = append(versions, versionID{
ID: *v.VersionId,
ID: aws.ToString(v.VersionId),
Timestamp: *v.LastModified,
})
}

// Returning false stops iteration, stop iteration upon last page.
return !lastPage
})
if err != nil {
return "", awsutils.ConvertS3Error(err)
}

if len(versions) == 0 {
return "", trace.NotFound("%v/%v not found", bucket, prefix)
}
Expand All @@ -347,23 +340,28 @@ func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix st
// delete bucket deletes bucket and all it's contents and is used in tests
func (h *Handler) deleteBucket(ctx context.Context) error {
// first, list and delete all the objects in the bucket
out, err := h.client.ListObjectVersionsWithContext(ctx, &s3.ListObjectVersionsInput{
paginator := s3.NewListObjectVersionsPaginator(h.client, &s3.ListObjectVersionsInput{
Bucket: aws.String(h.Bucket),
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
for _, ver := range out.Versions {
_, err := h.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(h.Bucket),
Key: ver.Key,
VersionId: ver.VersionId,
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return awsutils.ConvertS3Error(err)
}
for _, ver := range page.Versions {
_, err := h.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(h.Bucket),
Key: ver.Key,
VersionId: ver.VersionId,
})
if err != nil {
return awsutils.ConvertS3Error(err)
}
}
}
_, err = h.client.DeleteBucketWithContext(ctx, &s3.DeleteBucketInput{

_, err := h.client.DeleteBucket(ctx, &s3.DeleteBucketInput{
Bucket: aws.String(h.Bucket),
})
return awsutils.ConvertS3Error(err)
Expand All @@ -382,7 +380,7 @@ func (h *Handler) fromPath(p string) session.ID {

// ensureBucket makes sure bucket exists, and if it does not, creates it
func (h *Handler) ensureBucket(ctx context.Context) error {
_, err := h.client.HeadBucketWithContext(ctx, &s3.HeadBucketInput{
_, err := h.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(h.Bucket),
})
err = awsutils.ConvertS3Error(err)
Expand All @@ -396,44 +394,46 @@ func (h *Handler) ensureBucket(ctx context.Context) error {
}
input := &s3.CreateBucketInput{
Bucket: aws.String(h.Bucket),
ACL: aws.String("private"),
ACL: awstypes.BucketCannedACLPrivate,
}
_, err = h.client.CreateBucketWithContext(ctx, input)
_, err = h.client.CreateBucket(ctx, input)
err = awsutils.ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket)))
if err != nil {
if !trace.IsAlreadyExists(err) {
return trace.Wrap(err)
}

// if this client has not created the bucket, don't reconfigure it
return nil
}

// Turn on versioning.
ver := &s3.PutBucketVersioningInput{
_, err = h.client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{
Bucket: aws.String(h.Bucket),
VersioningConfiguration: &s3.VersioningConfiguration{
Status: aws.String("Enabled"),
VersioningConfiguration: &awstypes.VersioningConfiguration{
Status: awstypes.BucketVersioningStatusEnabled,
},
}
_, err = h.client.PutBucketVersioningWithContext(ctx, ver)
})
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket))
if err != nil {
return trace.Wrap(err)
}

// Turn on server-side encryption for the bucket.
if !h.DisableServerSideEncryption {
_, err = h.client.PutBucketEncryptionWithContext(ctx, &s3.PutBucketEncryptionInput{
_, err = h.client.PutBucketEncryption(ctx, &s3.PutBucketEncryptionInput{
Bucket: aws.String(h.Bucket),
ServerSideEncryptionConfiguration: &s3.ServerSideEncryptionConfiguration{
Rules: []*s3.ServerSideEncryptionRule{{
ApplyServerSideEncryptionByDefault: &s3.ServerSideEncryptionByDefault{
SSEAlgorithm: aws.String(s3.ServerSideEncryptionAwsKms),
ServerSideEncryptionConfiguration: &awstypes.ServerSideEncryptionConfiguration{
Rules: []awstypes.ServerSideEncryptionRule{
{
ApplyServerSideEncryptionByDefault: &awstypes.ServerSideEncryptionByDefault{
SSEAlgorithm: awstypes.ServerSideEncryptionAwsKms,
},
},
}},
},
},
})
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket))
err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set encryption state for bucket %q", h.Bucket))
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading
Loading