Skip to content

Commit

Permalink
Merge pull request #22 from aws/fix/RolesAnywhere-3788
Browse files Browse the repository at this point in the history
RolesAnywhere-3788: Include token TTL header in token API response
  • Loading branch information
13ajay authored Jan 17, 2023
2 parents 0d7625b + 67cfd99 commit 6f08015
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION=1.0.3
VERSION=1.0.4

release:
go build -buildmode=pie -ldflags "-X 'main.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper cmd/aws_signing_helper/main.go
48 changes: 48 additions & 0 deletions aws_signing_helper/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ type RefreshableCred struct {
AccessKeyId string
SecretAccessKey string
Token string
Code string
Type string
Expiration time.Time
LastUpdated time.Time
}

type Endpoint struct {
Expand All @@ -50,6 +53,9 @@ const DEFAULT_TOKEN_TTL_SECONDS = "21600"

const X_FORWARDED_FOR_HEADER = "X-Forwarded-For"

const REFRESHABLE_CRED_TYPE = "AWS-HMAC"
const REFRESHABLE_CRED_CODE = "Success"

const MAX_TOKENS = 256

var mutex sync.Mutex
Expand Down Expand Up @@ -120,6 +126,27 @@ func CheckValidToken(w http.ResponseWriter, r *http.Request) error {
return nil
}

// Helper function that finds a token's TTL in seconds
func FindTokenTTLSeconds(r *http.Request) (string, error) {
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
if token == "" {
msg := "no token provided"
return "", errors.New(msg)
}

mutex.Lock()
expiration, ok := tokenMap[token]
mutex.Unlock()
if ok {
tokenTTLFloat := expiration.Sub(time.Now()).Seconds()
tokenTTLInt64 := int64(tokenTTLFloat)
return strconv.FormatInt(tokenTTLInt64, 10), nil
} else {
msg := "invalid token provided"
return "", errors.New(msg)
}
}

func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) {
// Handles PUT requests to /latest/api/token/
putTokenHandler := func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -158,6 +185,7 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
expirationTime := time.Now().Add(time.Second * time.Duration(tokenTTL))
InsertToken(token, expirationTime)

w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTLStr)
io.WriteString(w, token) // nosemgrep
}

Expand All @@ -173,6 +201,12 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
return
}

tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
io.WriteString(w, roleName) // nosemgrep
}

Expand All @@ -195,7 +229,11 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
cred.Token = credentialProcessOutput.SessionToken
cred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
cred.Code = REFRESHABLE_CRED_CODE
cred.LastUpdated = time.Now()
cred.Type = REFRESHABLE_CRED_TYPE
err := json.NewEncoder(w).Encode(cred)

if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "failed to encode credentials")
Expand All @@ -209,6 +247,13 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
return
}
}

tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
}

return putTokenHandler, getRoleNameHandler, getCredentialsHandler
Expand All @@ -228,6 +273,9 @@ func Serve(port int, credentialsOptions CredentialsOpts) {
refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
refreshableCred.Token = credentialProcessOutput.SessionToken
refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
refreshableCred.Code = REFRESHABLE_CRED_CODE
refreshableCred.LastUpdated = time.Now()
refreshableCred.Type = REFRESHABLE_CRED_TYPE
endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred}
endpoint.Server = &http.Server{}
roleResourceParts := strings.Split(roleArn.Resource, "/")
Expand Down

0 comments on commit 6f08015

Please sign in to comment.