Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breakfix: revert bad feature/rds/auth api release #2925

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/658fc1c50afd443fa803916695e3583e.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "658fc1c5-0afd-443f-a803-916695e3583e",
"type": "bugfix",
"description": "**BREAKFIX**: Revert bad API release.",
"modules": [
"feature/rds/auth"
]
}
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

## Module Highlights
* `github.com/aws/aws-sdk-go-v2/feature/rds/auth`: [v1.5.0](feature/rds/auth/CHANGELOG.md#v150-2024-12-032)
* **Feature**: feat: Add Xanadu Auth Token Generator
* No change notes available for this release.
* `github.com/aws/aws-sdk-go-v2/service/athena`: [v1.49.0](service/athena/CHANGELOG.md#v1490-2024-12-032)
* **Feature**: Add FEDERATED type to CreateDataCatalog. This creates Athena Data Catalog, AWS Lambda connector, and AWS Glue connection. Create/DeleteDataCatalog returns DataCatalog. Add Status, ConnectionType, and Error to DataCatalog and DataCatalogSummary. Add DeleteCatalogOnly to delete Athena Catalog only.
* `github.com/aws/aws-sdk-go-v2/service/bedrock`: [v1.24.0](service/bedrock/CHANGELOG.md#v1240-2024-12-032)
Expand Down
2 changes: 1 addition & 1 deletion feature/rds/auth/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# v1.5.0 (2024-12-03.2)

* **Feature**: feat: Add Xanadu Auth Token Generator
* No change notes available for this release.

# v1.4.25 (2024-12-02)

Expand Down
82 changes: 11 additions & 71 deletions feature/rds/auth/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,26 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)

const (
rdsAuthTokenID = "rds-db"
rdsClusterTokenID = "dsql"
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
userAction = "DbConnect"
adminUserAction = "DbConnectAdmin"
signingID = "rds-db"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not going to remove this? why keep feature/rds/auth? if its going to be empty?

emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
)

// BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken
type BuildAuthTokenOptions struct {
ExpiresIn time.Duration
}
type BuildAuthTokenOptions struct{}

// BuildAuthToken will return an authorization token used as the password for a DB
// connection.
//
// * endpoint - Endpoint consists of the hostname and port needed to connect to the DB. <host>:<port>
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
// * region - Region is the location of where the DB is
// * dbUser - User account within the database to sign in with
// * creds - Credentials to be signed with
Expand All @@ -57,64 +50,12 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

values := url.Values{
"Action": []string{"connect"},
"DBUser": []string{dbUser},
}

return generateAuthToken(ctx, endpoint, region, values, rdsAuthTokenID, creds, optFns...)
}

// GenerateDbConnectAuthToken will return an authorization token as the password for a
// DB connection.
//
// This is the regular user variant, see [GenerateDBConnectSuperUserAuthToken] for the superuser variant
//
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
// * region - Region is the location of where the DB is
// * creds - Credentials to be signed with
func GenerateDbConnectAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
values := url.Values{
"Action": []string{userAction},
}
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
}

// GenerateDBConnectSuperUserAuthToken will return an authorization token as the password for a
// DB connection.
//
// This is the superuser user variant, see [GenerateDBConnectSuperUserAuthToken] for the regular user variant
//
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
// * region - Region is the location of where the DB is
// * creds - Credentials to be signed with
func GenerateDBConnectSuperUserAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
values := url.Values{
"Action": []string{adminUserAction},
}
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
}

// All generate token functions are presigned URLs behind the scenes with the scheme stripped.
// This function abstracts generating this for all use cases
func generateAuthToken(ctx context.Context, endpoint, region string, values url.Values, signingID string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
if len(region) == 0 {
return "", fmt.Errorf("region is required")
}
if len(endpoint) == 0 {
return "", fmt.Errorf("endpoint is required")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
fn(&o)
}

if o.ExpiresIn == 0 {
o.ExpiresIn = 15 * time.Minute
}

if creds == nil {
return "", fmt.Errorf("credetials provider must not ne nil")
}
Expand All @@ -128,25 +69,24 @@ func generateAuthToken(ctx context.Context, endpoint, region string, values url.
if err != nil {
return "", err
}
values := req.URL.Query()
values.Set("Action", "connect")
values.Set("DBUser", dbUser)
req.URL.RawQuery = values.Encode()

signer := v4.NewSigner()

credentials, err := creds.Retrieve(ctx)
if err != nil {
return "", err
}

expires := o.ExpiresIn
// if creds expire before expiresIn, set that as the expiration time
if credentials.CanExpire && !credentials.Expires.IsZero() {
credsExpireIn := credentials.Expires.Sub(sdk.NowTime())
expires = min(o.ExpiresIn, credsExpireIn)
}
// Expire Time: 15 minute
query := req.URL.Query()
query.Set("X-Amz-Expires", strconv.Itoa(int(expires.Seconds())))
query.Set("X-Amz-Expires", "900")
req.URL.RawQuery = query.Encode()

signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, sdk.NowTime().UTC())
signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC())
if err != nil {
return "", err
}
Expand Down
148 changes: 2 additions & 146 deletions feature/rds/auth/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ package auth_test

import (
"context"
"net/url"
"regexp"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)

func TestBuildAuthToken(t *testing.T) {
Expand Down Expand Up @@ -70,155 +67,14 @@ func TestBuildAuthToken(t *testing.T) {
}
}

type dbAuthTestCase struct {
endpoint string
region string
expires time.Duration
credsExpireIn time.Duration
expectedHost string
expectedQueryParams []string
expectedError string
}

type tokenGenFunc func(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error)

func TestGenerateDbConnectAuthToken(t *testing.T) {
cases := map[string]dbAuthTestCase{
"no region": {
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedError: "no region",
},
"no endpoint": {
region: "us-west-2",
expectedError: "port",
},
"endpoint with scheme": {
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint without scheme": {
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint without port": {
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint with region and expires": {
endpoint: "peccy.dsql.us-east-1.on.aws",
region: "us-east-1",
expires: time.Second * 450,
expectedHost: "peccy.dsql.us-east-1.on.aws",
expectedQueryParams: []string{
"Action=DbConnect",
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
"X-Amz-Date=20240827T000000Z",
"X-Amz-Expires=450"},
},
"pick credential expires when less than expires": {
endpoint: "peccy.dsql.us-east-1.on.aws",
region: "us-east-1",
credsExpireIn: time.Second * 100,
expires: time.Second * 450,
expectedHost: "peccy.dsql.us-east-1.on.aws",
expectedQueryParams: []string{
"Action=DbConnect",
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
"X-Amz-Date=20240827T000000Z",
"X-Amz-Expires=100"},
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "akid", SecretKey: "secret", expiresIn: c.credsExpireIn}
defer withTempGlobalTime(time.Date(2024, time.August, 27, 0, 0, 0, 0, time.UTC))()
optFns := func(options *auth.BuildAuthTokenOptions) {}
if c.expires != 0 {
optFns = func(options *auth.BuildAuthTokenOptions) {
options.ExpiresIn = c.expires
}
}
verifyTestCase(auth.GenerateDbConnectAuthToken, c, creds, optFns, t)

// Update the test case to use Superuser variant
updated := []string{}
for _, part := range c.expectedQueryParams {
if part == "Action=DbConnect" {
part = "Action=DbConnectAdmin"
}
updated = append(updated, part)
}
c.expectedQueryParams = updated

verifyTestCase(auth.GenerateDBConnectSuperUserAuthToken, c, creds, optFns, t)
}
}

func verifyTestCase(f tokenGenFunc, c dbAuthTestCase, creds aws.CredentialsProvider, optFns func(options *auth.BuildAuthTokenOptions), t *testing.T) {
token, err := f(context.Background(), c.endpoint, c.region, creds, optFns)
isErrorExpected := len(c.expectedError) > 0
if err != nil && !isErrorExpected {
t.Fatalf("expect no err, got: %v", err)
} else if err == nil && isErrorExpected {
t.Fatalf("Expected error %v got none", c.expectedError)
}
// adding a scheme so we can parse it back as a URL. This is because comparing
// just direct string comparison was failing since "Action=DbConnect" is a substring or
// "Action=DBConnectSuperuser"
parsed, err := url.Parse("http://" + token)
if err != nil {
t.Fatalf("Couldn't parse the token %v to URL after adding a scheme, got: %v", token, err)
}
if parsed.Host != c.expectedHost {
t.Errorf("expect host %v, got %v", c.expectedHost, parsed.Host)
}

q := parsed.Query()
queryValuePair := map[string]any{}
for k, v := range q {
pair := k + "=" + v[0]
queryValuePair[pair] = struct{}{}
}

for _, part := range c.expectedQueryParams {
if _, ok := queryValuePair[part]; !ok {
t.Errorf("expect part %s to be present at token %s", part, token)
}
}
if token != "" && c.expires == 0 {
if !strings.Contains(token, "X-Amz-Expires=900") {
t.Errorf("expect token to contain default X-Amz-Expires value of 900, got %v", token)
}
}
}

type staticCredentials struct {
AccessKey, SecretKey, Session string
expiresIn time.Duration
}

func (s *staticCredentials) Retrieve(ctx context.Context) (aws.Credentials, error) {
c := aws.Credentials{
return aws.Credentials{
AccessKeyID: s.AccessKey,
SecretAccessKey: s.SecretKey,
SessionToken: s.Session,
}
if s.expiresIn != 0 {
c.CanExpire = true
c.Expires = sdk.NowTime().Add(s.expiresIn)
}
return c, nil
}

func withTempGlobalTime(t time.Time) func() {
sdk.NowTime = func() time.Time { return t }
return func() { sdk.NowTime = time.Now }
}, nil
}
Loading