From 67cfd997dbeb8b0b7378c3c96235e0326751e5a5 Mon Sep 17 00:00:00 2001 From: Ajay Gupta Date: Tue, 17 Jan 2023 10:20:51 -0500 Subject: [PATCH] RolesAnywhere-3788: Include token TTL header in token API response --- Makefile | 2 +- aws_signing_helper/serve.go | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bac14ad..9307e1b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION=1.0.3 +VERSION=1.0.4 release: go build -buildmode=pie -ldflags "-X 'main.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper cmd/aws_signing_helper/main.go diff --git a/aws_signing_helper/serve.go b/aws_signing_helper/serve.go index e17b420..b0b892f 100644 --- a/aws_signing_helper/serve.go +++ b/aws_signing_helper/serve.go @@ -28,7 +28,10 @@ type RefreshableCred struct { AccessKeyId string SecretAccessKey string Token string + Code string + Type string Expiration time.Time + LastUpdated time.Time } type Endpoint struct { @@ -50,6 +53,9 @@ const DEFAULT_TOKEN_TTL_SECONDS = "21600" const X_FORWARDED_FOR_HEADER = "X-Forwarded-For" +const REFRESHABLE_CRED_TYPE = "AWS-HMAC" +const REFRESHABLE_CRED_CODE = "Success" + const MAX_TOKENS = 256 var mutex sync.Mutex @@ -120,6 +126,27 @@ func CheckValidToken(w http.ResponseWriter, r *http.Request) error { return nil } +// Helper function that finds a token's TTL in seconds +func FindTokenTTLSeconds(r *http.Request) (string, error) { + token := r.Header.Get(EC2_METADATA_TOKEN_HEADER) + if token == "" { + msg := "no token provided" + return "", errors.New(msg) + } + + mutex.Lock() + expiration, ok := tokenMap[token] + mutex.Unlock() + if ok { + tokenTTLFloat := expiration.Sub(time.Now()).Seconds() + tokenTTLInt64 := int64(tokenTTLFloat) + return strconv.FormatInt(tokenTTLInt64, 10), nil + } else { + msg := "invalid token provided" + return "", errors.New(msg) + } +} + func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) { // Handles PUT requests to /latest/api/token/ putTokenHandler := func(w http.ResponseWriter, r *http.Request) { @@ -158,6 +185,7 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials expirationTime := time.Now().Add(time.Second * time.Duration(tokenTTL)) InsertToken(token, expirationTime) + w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTLStr) io.WriteString(w, token) // nosemgrep } @@ -173,6 +201,12 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials return } + tokenTTL, err := FindTokenTTLSeconds(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL) io.WriteString(w, roleName) // nosemgrep } @@ -195,7 +229,11 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey cred.Token = credentialProcessOutput.SessionToken cred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration) + cred.Code = REFRESHABLE_CRED_CODE + cred.LastUpdated = time.Now() + cred.Type = REFRESHABLE_CRED_TYPE err := json.NewEncoder(w).Encode(cred) + if err != nil { w.WriteHeader(http.StatusInternalServerError) io.WriteString(w, "failed to encode credentials") @@ -209,6 +247,13 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials return } } + + tokenTTL, err := FindTokenTTLSeconds(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL) } return putTokenHandler, getRoleNameHandler, getCredentialsHandler @@ -228,6 +273,9 @@ func Serve(port int, credentialsOptions CredentialsOpts) { refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey refreshableCred.Token = credentialProcessOutput.SessionToken refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration) + refreshableCred.Code = REFRESHABLE_CRED_CODE + refreshableCred.LastUpdated = time.Now() + refreshableCred.Type = REFRESHABLE_CRED_TYPE endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred} endpoint.Server = &http.Server{} roleResourceParts := strings.Split(roleArn.Resource, "/")