Skip to content

Commit

Permalink
Merge pull request #65 from rgooch/presigner
Browse files Browse the repository at this point in the history
Add AWS STS presigned-URL support packages.
  • Loading branch information
rgooch authored Aug 28, 2022
2 parents 2c1ca4b + ca35e34 commit 69ff3a6
Show file tree
Hide file tree
Showing 9 changed files with 649 additions and 0 deletions.
57 changes: 57 additions & 0 deletions cmd/presignauth-test/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"flag"
"fmt"
"os"

"github.com/Cloud-Foundations/Dominator/lib/log/cmdlogger"
"github.com/Cloud-Foundations/golib/pkg/awsutil/presignauth/caller"
"github.com/Cloud-Foundations/golib/pkg/awsutil/presignauth/presigner"
)

func printUsage() {
w := flag.CommandLine.Output()
fmt.Fprintln(w, "Usage: presignauth-test [flags...]")
fmt.Fprintln(w, "Common flags:")
flag.PrintDefaults()
}

func doMain() error {
flag.Usage = printUsage
flag.Parse()
logger := cmdlogger.New()
presignerClient, err := presigner.New(presigner.Params{
Logger: logger,
})
if err != nil {
return err
}
callerClient, err := caller.New(caller.Params{
Logger: logger,
})
if err != nil {
return err
}
logger.Printf("ARN: %s\n", presignerClient.GetCallerARN())
presignedReq, err := presignerClient.PresignGetCallerIdentity(nil)
if err != nil {
return err
}
logger.Printf("Method: %s, URL: %s\n",
presignedReq.Method, presignedReq.URL)
verifiedArn, err := callerClient.GetCallerIdentity(nil, presignedReq.Method,
presignedReq.URL)
if err != nil {
return err
}
logger.Printf("Verified ARN: %s\n", verifiedArn)
return nil
}

func main() {
if err := doMain(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
14 changes: 14 additions & 0 deletions pkg/awsutil/presignauth/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package presignauth

import (
"github.com/aws/aws-sdk-go-v2/aws/arn"
)

// NormaliseARN will normalise an AWS IAM ARN (i.e. an ARN returned from
// sts:GetCallerIdentity), returning the actual role ARN, rather than an ARN
// showing how the credentials were obtained (such as by assuming the role).
// This mirrors the way AWS policy documents are written. The ARN will have the
// form: arn:aws:iam::$AccountId:role/$RoleName
func NormaliseARN(input arn.ARN) (arn.ARN, error) {
return normaliseARN(input)
}
51 changes: 51 additions & 0 deletions pkg/awsutil/presignauth/caller/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package caller

import (
"context"
"net/http"
"net/url"
"sync"
"time"

"github.com/Cloud-Foundations/golib/pkg/log"

"github.com/aws/aws-sdk-go-v2/aws/arn"
)

type Params struct {
// Optional parameters.
HttpClient *http.Client
Logger log.DebugLogger
urlValidator func(presignedUrl string) (*url.URL, error)
}

type cacheEntry struct {
expires time.Time
normalisedArn arn.ARN
}

type Caller interface {
GetCallerIdentity(ctx context.Context, presignedMethod string,
presignedUrl string) (arn.ARN, error)
}

type callerT struct {
params Params
mutex sync.Mutex // Protect everything below.
cache map[string]cacheEntry // Key: presigned URL.
}

// Interface checks.
var _ Caller = (*callerT)(nil)

// New will create a caller for AWS STS presigned request URLs.
func New(params Params) (Caller, error) {
return newCaller(params)
}

// GetCallerIdentity will verify if the specified URL is a valid AWS STS
// presigned URL and if so will return the corresponding caller identity.
func (c *callerT) GetCallerIdentity(ctx context.Context, presignedMethod string,
presignedUrl string) (arn.ARN, error) {
return c.getCallerIdentity(ctx, presignedMethod, presignedUrl)
}
160 changes: 160 additions & 0 deletions pkg/awsutil/presignauth/caller/impl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package caller

import (
"context"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"

"github.com/Cloud-Foundations/golib/pkg/awsutil/presignauth"
"github.com/Cloud-Foundations/golib/pkg/log/nulllogger"

"github.com/aws/aws-sdk-go-v2/aws/arn"
)

const (
presignedUrlLifetime = 15 * time.Minute
)

type getCallerIdentityResult struct {
Arn string
}

type getCallerIdentityResponse struct {
GetCallerIdentityResult getCallerIdentityResult
}

func newCaller(params Params) (*callerT, error) {
if params.HttpClient == nil {
params.HttpClient = http.DefaultClient
}
if params.Logger == nil {
params.Logger = nulllogger.New()
}
if params.urlValidator == nil {
params.urlValidator = validateStsPresignedUrl
}
caller := &callerT{
params: params,
cache: make(map[string]cacheEntry),
}
go caller.cleanupLoop()
return caller, nil
}

func validateStsPresignedUrl(presignedUrl string) (*url.URL, error) {
parsedPresignedUrl, err := url.Parse(presignedUrl)
if err != nil {
return nil, err
}
if parsedPresignedUrl.Scheme != "https" {
return nil, fmt.Errorf("invalid scheme: %s", parsedPresignedUrl.Scheme)
}
if parsedPresignedUrl.Path != "/" {
return nil, fmt.Errorf("invalid path: %s", parsedPresignedUrl.Path)
}
if !strings.HasPrefix(parsedPresignedUrl.RawQuery,
"Action=GetCallerIdentity&") {
return nil,
fmt.Errorf("invalid action: %s", parsedPresignedUrl.RawQuery)
}
splitHost := strings.Split(parsedPresignedUrl.Host, ".")
if len(splitHost) != 4 ||
splitHost[0] != "sts" ||
splitHost[2] != "amazonaws" ||
splitHost[3] != "com" {
return nil, fmt.Errorf("malformed presigned URL host")
}
return parsedPresignedUrl, nil
}

func (c *callerT) cleanupLoop() {
for {
time.Sleep(c.cleanupOnce())
}
}

func (c *callerT) cleanupOnce() time.Duration {
c.mutex.Lock()
defer c.mutex.Unlock()
nextExpiration := time.Minute
for presignedUrl, entry := range c.cache {
if expiration := time.Until(entry.expires); expiration <= 0 {
delete(c.cache, presignedUrl)
} else if expiration < nextExpiration {
nextExpiration = expiration
}
}
return nextExpiration
}

func (c *callerT) getCallerIdentity(ctx context.Context, presignedMethod string,
presignedUrl string) (arn.ARN, error) {
if cv := c.getCallerIdentityCached(presignedUrl); cv != nil {
return *cv, nil
}
validatedUrl, err := c.params.urlValidator(presignedUrl)
if err != nil {
return arn.ARN{}, err
}
presignedUrl = validatedUrl.String()
var validateReq *http.Request
if ctx == nil {
validateReq, err = http.NewRequest(presignedMethod, presignedUrl, nil)
} else {
validateReq, err = http.NewRequestWithContext(ctx, presignedMethod,
presignedUrl, nil)
}
if err != nil {
return arn.ARN{}, err
}
validateResp, err := c.params.HttpClient.Do(validateReq)
if err != nil {
return arn.ARN{}, err
}
defer validateResp.Body.Close()
if validateResp.StatusCode != http.StatusOK {
return arn.ARN{}, fmt.Errorf("verification request failed")
}
body, err := ioutil.ReadAll(validateResp.Body)
if err != nil {
return arn.ARN{}, err
}
var callerIdentity getCallerIdentityResponse
if err := xml.Unmarshal(body, &callerIdentity); err != nil {
return arn.ARN{}, err
}
parsedArn, err := arn.Parse(callerIdentity.GetCallerIdentityResult.Arn)
if err != nil {
return arn.ARN{}, err
}
normalisedArn, err := presignauth.NormaliseARN(parsedArn)
if err != nil {
return arn.ARN{}, err
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.cache[presignedUrl] = cacheEntry{
expires: time.Now().Add(presignedUrlLifetime),
normalisedArn: normalisedArn,
}
return normalisedArn, nil
}

func (c *callerT) getCallerIdentityCached(presignedUrl string) *arn.ARN {
c.mutex.Lock()
defer c.mutex.Unlock()
entry, ok := c.cache[presignedUrl]
if !ok {
return nil
}
if time.Since(entry.expires) >= 0 {
delete(c.cache, presignedUrl)
return nil
}
return &entry.normalisedArn
}
89 changes: 89 additions & 0 deletions pkg/awsutil/presignauth/caller/impl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package caller

import (
"fmt"
"net"
"net/http"
"net/url"
"testing"
)

const (
awsTestArn = "arn:aws:iam::accountid:role/TestMonkey"
awsPresignedUrlBadAction = "https://sts.a-region.amazonaws.com/?Action=BecomeRoot&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=cred&X-Amz-Security-Token=token&X-Amz-SignedHeaders=host&X-Amz-Signature=sig"
awsPresignedUrlBadDomain = "https://sts.a-region.hackerz.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=cred&X-Amz-Security-Token=token&X-Amz-SignedHeaders=host&X-Amz-Signature=sig"
awsPresignedUrlGood = "https://sts.a-region.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=cred&X-Amz-Security-Token=token&X-Amz-SignedHeaders=host&X-Amz-Signature=sig"
awsCallerIdentityResponse = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:sts::accountid:assumed-role/TestMonkey/tester</Arn>
<UserId>useridstuff:tester</UserId>
<Account>accountid</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>some-uuid</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>
`
)

var serverCount uint

type testAwsGetCallerIdentityType struct{}

func testValidatePresignedUrl(presignedUrl string) (*url.URL, error) {
return url.Parse(presignedUrl)
}

func (testAwsGetCallerIdentityType) ServeHTTP(w http.ResponseWriter,
r *http.Request) {
serverCount++
w.Write([]byte(awsCallerIdentityResponse))
}

func TestAwsPresignedUrlValidation(t *testing.T) {
if _, err := validateStsPresignedUrl(awsPresignedUrlBadAction); err == nil {
t.Errorf("no error with bad action URL: %s", awsPresignedUrlBadAction)
}
if _, err := validateStsPresignedUrl(awsPresignedUrlBadDomain); err == nil {
t.Errorf("no error with bad domain URL: %s", awsPresignedUrlBadDomain)
}
if _, err := validateStsPresignedUrl(awsPresignedUrlGood); err != nil {
t.Error("valid URL does not validate")
}
}

func TestAwsGetCallerIdentity(t *testing.T) {
client, err := New(Params{urlValidator: testValidatePresignedUrl})
if err != nil {
t.Fatal(err)
}
listener, err := net.Listen("tcp", "localhost:")
if err != nil {
t.Fatal(err)
}
go func() {
err := http.Serve(listener, &testAwsGetCallerIdentityType{})
if err != nil {
t.Fatal(err)
}
}()
testUrl := fmt.Sprintf("http://%s/", listener.Addr().String())
callerArn, err := client.GetCallerIdentity(nil, "GET", testUrl)
if err != nil {
t.Fatal(err)
}
if callerArn.String() != awsTestArn {
t.Errorf("expected: %s but got: %s", awsTestArn, callerArn)
}
// Check again to see if caching works.
callerArn, err = client.GetCallerIdentity(nil, "GET", testUrl)
if err != nil {
t.Fatal(err)
}
if callerArn.String() != awsTestArn {
t.Errorf("expected: %s but got: %s", awsTestArn, callerArn)
}
if serverCount != 1 {
t.Errorf("serverCount expected: 1 but got: %d", serverCount)
}
}
Loading

0 comments on commit 69ff3a6

Please sign in to comment.