Skip to content

Commit

Permalink
Merge pull request #413 from 88labs/feat/add-global-s3dialer
Browse files Browse the repository at this point in the history
add global s3dialer
  • Loading branch information
tomtwinkle authored Jan 19, 2024
2 parents d703e41 + bb9d781 commit 7835df4
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 4 deletions.
57 changes: 53 additions & 4 deletions aws/awss3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,61 @@ import (
"context"
"encoding/gob"
"fmt"
"net"
"sync"

"github.com/88labs/go-utils/aws/awsconfig"

"github.com/88labs/go-utils/aws/ctxawslocal"
"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"

"github.com/88labs/go-utils/aws/awsconfig"
"github.com/88labs/go-utils/aws/awss3/options/global/s3dialer"
"github.com/88labs/go-utils/aws/ctxawslocal"
)

var (
// GlobalDialer Global http dialer settings for awss3 library
GlobalDialer *s3dialer.ConfGlobalDialer

customMu sync.Mutex
customEndpointClient *s3.Client
)

// GetClient
// Get s3 client for aws-sdk-go v2.
// Using ctxawslocal.WithContext, you can make requests for local mocks
func GetClient(ctx context.Context, region awsconfig.Region) (*s3.Client, error) {
if localProfile, ok := getLocalEndpoint(ctx); ok {
return getClientLocal(ctx, *localProfile)
customMu.Lock()
defer customMu.Unlock()
var err error
if customEndpointClient != nil {
return customEndpointClient, err
}
customEndpointClient, err = getClientLocal(ctx, *localProfile)
return customEndpointClient, err
}
awsHttpClient := awshttp.NewBuildableClient()
if GlobalDialer != nil {
awsHttpClient.WithDialerOptions(func(dialer *net.Dialer) {
if GlobalDialer.Timeout != 0 {
dialer.Timeout = GlobalDialer.Timeout
}
if GlobalDialer.Deadline != nil {
dialer.Deadline = *GlobalDialer.Deadline
}
if GlobalDialer.KeepAlive != 0 {
dialer.KeepAlive = GlobalDialer.KeepAlive
}
})
}
// S3 Client
awsCfg, err := awsConfig.LoadDefaultConfig(
ctx,
awsConfig.WithRegion(region.String()),
awsConfig.WithHTTPClient(awsHttpClient),
)
if err != nil {
return nil, fmt.Errorf("unable to load SDK config, %w", err)
Expand All @@ -34,7 +68,22 @@ func GetClient(ctx context.Context, region awsconfig.Region) (*s3.Client, error)
}

func getClientLocal(ctx context.Context, localProfile LocalProfile) (*s3.Client, error) {
awsHttpClient := awshttp.NewBuildableClient()
if GlobalDialer != nil {
awsHttpClient.WithDialerOptions(func(dialer *net.Dialer) {
if GlobalDialer.Timeout != 0 {
dialer.Timeout = GlobalDialer.Timeout
}
if GlobalDialer.Deadline != nil {
dialer.Deadline = *GlobalDialer.Deadline
}
if GlobalDialer.KeepAlive != 0 {
dialer.KeepAlive = GlobalDialer.KeepAlive
}
})
}
awsCfg, err := awsConfig.LoadDefaultConfig(ctx,
awsConfig.WithHTTPClient(awsHttpClient),
awsConfig.WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: localProfile.AccessKey,
Expand Down
63 changes: 63 additions & 0 deletions aws/awss3/options/global/global_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package global_test

import (
"bytes"
"context"
"fmt"
"testing"
"time"

"github.com/88labs/go-utils/ulid"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/assert"

"github.com/88labs/go-utils/aws/awsconfig"
"github.com/88labs/go-utils/aws/awss3"
"github.com/88labs/go-utils/aws/awss3/options/global/s3dialer"
"github.com/88labs/go-utils/aws/ctxawslocal"
)

const (
TestBucket = "test"
TestRegion = awsconfig.RegionTokyo
)

func TestGlobalOptionWithHeadObject(t *testing.T) {
ctx := ctxawslocal.WithContext(
context.Background(),
ctxawslocal.WithS3Endpoint("http://127.0.0.1:29000"), // use Minio
ctxawslocal.WithAccessKey("DUMMYACCESSKEYEXAMPLE"),
ctxawslocal.WithSecretAccessKey("DUMMYSECRETKEYEXAMPLE"),
)
s3Client, err := awss3.GetClient(ctx, TestRegion)
assert.NoError(t, err)

createFixture := func(fileSize int) awss3.Key {
key := fmt.Sprintf("awstest/%s.txt", ulid.MustNew())
uploader := manager.NewUploader(s3Client)
input := s3.PutObjectInput{
Body: bytes.NewReader(bytes.Repeat([]byte{1}, fileSize)),
Bucket: aws.String(TestBucket),
Key: aws.String(key),
Expires: aws.Time(time.Now().Add(10 * time.Minute)),
}
if _, err := uploader.Upload(ctx, &input); err != nil {
assert.NoError(t, err)
}
return awss3.Key(key)
}

t.Run("If the option is specified", func(t *testing.T) {
key := createFixture(100)
dialer := s3dialer.NewConfGlobalDialer()
dialer.WithTimeout(time.Second)
dialer.WithKeepAlive(2 * time.Second)
dialer.WithDeadline(time.Now().Add(time.Second))
awss3.GlobalDialer = dialer
res, err := awss3.HeadObject(ctx, TestRegion, TestBucket, key)
assert.NoError(t, err)
assert.Equal(t, aws.Int64(100), res.ContentLength)
})
}
50 changes: 50 additions & 0 deletions aws/awss3/options/global/s3dialer/s3dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package s3dialer

import "time"

type ConfGlobalDialer struct {
// Timeout is the maximum amount of time a dial will wait for
// a connect to complete. If Deadline is also set, it may fail
// earlier.
//
// The default is no timeout.
//
// When using TCP and dialing a host name with multiple IP
// addresses, the timeout may be divided between them.
//
// With or without a timeout, the operating system may impose
// its own earlier timeout. For instance, TCP timeouts are
// often around 3 minutes.
Timeout time.Duration

// Deadline is the absolute point in time after which dials
// will fail. If Timeout is set, it may fail earlier.
// Zero means no deadline, or dependent on the operating system
// as with the Timeout option.
Deadline *time.Time

// KeepAlive specifies the interval between keep-alive
// probes for an active network connection.
// If zero, keep-alive probes are sent with a default value
// (currently 15 seconds), if supported by the protocol and operating
// system. Network protocols or operating systems that do
// not support keep-alives ignore this field.
// If negative, keep-alive probes are disabled.
KeepAlive time.Duration
}

func NewConfGlobalDialer() *ConfGlobalDialer {
return &ConfGlobalDialer{}
}

func (c *ConfGlobalDialer) WithTimeout(timeout time.Duration) {
c.Timeout = timeout
}

func (c *ConfGlobalDialer) WithDeadline(deadline time.Time) {
c.Deadline = &deadline
}

func (c *ConfGlobalDialer) WithKeepAlive(keepAlive time.Duration) {
c.KeepAlive = keepAlive
}

0 comments on commit 7835df4

Please sign in to comment.