From cf953d18ce70a556e36c76b90cf70c727e0296c7 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Mon, 25 Nov 2024 18:12:17 -0800 Subject: [PATCH] adapt AWSAccessMiddleware to AWS SDK v2 --- lib/srv/alpnproxy/aws_local_proxy.go | 12 ++- lib/srv/alpnproxy/aws_local_proxy_test.go | 125 +++++++++++++--------- 2 files changed, 84 insertions(+), 53 deletions(-) diff --git a/lib/srv/alpnproxy/aws_local_proxy.go b/lib/srv/alpnproxy/aws_local_proxy.go index 794fc6b9b78c9..56c1bbabe4dd5 100644 --- a/lib/srv/alpnproxy/aws_local_proxy.go +++ b/lib/srv/alpnproxy/aws_local_proxy.go @@ -22,6 +22,7 @@ import ( "net/http" "strings" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/service/sts" @@ -33,6 +34,7 @@ import ( appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) // AWSAccessMiddleware verifies the requests to AWS proxy are properly signed. @@ -42,6 +44,11 @@ type AWSAccessMiddleware struct { // AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification. AWSCredentials *credentials.Credentials + // AWSCredentialsV2Provider is an aws sdk v2 credential provider used by + // LocalProxy for request's signature verification if AWSCredentials is not + // specified. + AWSCredentialsV2Provider awsv2.CredentialsProvider + Log logrus.FieldLogger assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput] @@ -55,7 +62,10 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error { } if m.AWSCredentials == nil { - return trace.BadParameter("missing AWSCredentials") + if m.AWSCredentialsV2Provider == nil { + return trace.BadParameter("missing AWSCredentials") + } + m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider)) } return nil diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 59a044ec8a45c..30f39290d697b 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" @@ -42,60 +43,80 @@ func TestAWSAccessMiddleware(t *testing.T) { localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "") assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token") - stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) - - requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + tests := []struct { + name string + middleware *AWSAccessMiddleware + }{ + { + name: "v1", + middleware: &AWSAccessMiddleware{ + AWSCredentials: localProxyCred, + }, + }, + { + name: "v2", + middleware: &AWSAccessMiddleware{ + AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), + }, + }, + } - m := &AWSAccessMiddleware{ - AWSCredentials: localProxyCred, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + m := test.middleware + require.NoError(t, m.CheckAndSetDefaults()) + + stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) + v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) + + requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) + v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + + t.Run("request no authorization", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by unknown credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by local proxy credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies sts:AssumeRole output can be handled successfully. The + // credentials should be saved afterwards. + t.Run("handle sts:AssumeRole response", func(t *testing.T) { + response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + + // This is the same request as the "unknown credentials" test above. But at + // this point, the assumed role credentials should have been saved by the + // middleware so the request can be handled successfully now. + t.Run("request signed by assumed role", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies non sts:AssumeRole responses do not give errors. + t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { + response := getCallerIdentityResponse(t, assumedRoleARN) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + }) } - require.NoError(t, m.CheckAndSetDefaults()) - - t.Run("request no authorization", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by unknown credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by local proxy credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies sts:AssumeRole output can be handled successfully. The - // credentials should be saved afterwards. - t.Run("handle sts:AssumeRole response", func(t *testing.T) { - response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - - // This is the same request as the "unknown credentials" test above. But at - // this point, the assumed role credentials should have been saved by the - // middleware so the request can be handled successfully now. - t.Run("request signed by assumed role", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies non sts:AssumeRole responses do not give errors. - t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { - response := getCallerIdentityResponse(t, assumedRoleARN) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) } func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response {