Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TT-13271, fix for token metadata not being cached #6689

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions gateway/mw_oauth2_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -57,13 +58,18 @@ type upstreamOAuthPasswordCache struct {
func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}

Expand All @@ -73,10 +79,15 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

Expand Down Expand Up @@ -271,16 +282,26 @@ func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string
return hex.EncodeToString(hash.Sum(nil))
}

type TokenData struct {
Token string `json:"token"`
ExtraMetadata map[string]interface{} `json:"extra_metadata"`
}

func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}

Expand All @@ -290,24 +311,55 @@ func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAut
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

return token.AccessToken, nil
}

func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) {
func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) {
td := TokenData{
Token: encryptedToken,
ExtraMetadata: buildMetadataMap(token, extraMetadataKeys),
}
return json.Marshal(td)
}

func unmarshalTokenData(tokenData string) (TokenData, error) {
var tokenContents TokenData
err := json.Unmarshal([]byte(tokenData), &tokenContents)
if err != nil {
return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err)
}
return tokenContents, nil
}

func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} {
metadataMap := make(map[string]interface{})
for _, key := range extraMetadataKeys {
if val := token.Extra(key); val != "" && val != nil {
metadataMap[key] = val
}
}
return metadataMap
}

func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) {
contextDataObject := ctxGetData(r)
if contextDataObject == nil {
contextDataObject = make(map[string]interface{})
}
for _, key := range keyList {
val := token.Extra(key)
if val != "" {
if val, ok := token[key]; ok && val != "" {
contextDataObject[key] = val
}
}
Expand All @@ -318,13 +370,13 @@ func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, e
const maxRetries = 10
const retryDelay = 100 * time.Millisecond

var token string
var tokenData string
var err error

for i := 0; i < maxRetries; i++ {
token, err = cache.GetKey(cacheKey)
tokenData, err = cache.GetKey(cacheKey)
if err == nil {
return token, nil
return tokenData, nil
}

lockKey := cacheKey + ":lock"
Expand Down
138 changes: 138 additions & 0 deletions gateway/mw_oauth2_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"golang.org/x/oauth2"

"github.com/stretchr/testify/assert"

Expand All @@ -19,7 +22,13 @@ func TestUpstreamOauth2(t *testing.T) {
tst := StartTest(nil)
t.Cleanup(tst.Close)

var requestCount int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if requestCount > 0 {
assert.Fail(t, "Unexpected request received.")
}
requestCount++
if r.URL.String() != "/token" {
assert.Fail(t, "authenticate client request URL = %q; want %q", r.URL, "/token")
}
Expand Down Expand Up @@ -90,6 +99,23 @@ func TestUpstreamOauth2(t *testing.T) {
return true
},
},
{
Path: "/upstream-oauth-distributed/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, header.Authorization)
assert.NotEmpty(t, resp.Headers[header.Authorization])
assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization])

return true
},
},
}...)

}
Expand All @@ -98,8 +124,13 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
tst := StartTest(nil)
t.Cleanup(tst.Close)

var requestCount int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if requestCount > 0 {
assert.Fail(t, "Unexpected request received.")
}
requestCount++
expected := "/token"
if r.URL.String() != expected {
assert.Fail(t, "URL = %q; want %q", r.URL, expected)
Expand Down Expand Up @@ -174,5 +205,112 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
return true
},
},
{
Path: "/upstream-oauth-password/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, header.Authorization)
assert.NotEmpty(t, resp.Headers[header.Authorization])
assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization])

return true
},
},
}...)
}

func TestSetExtraMetadata(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://tykxample.com", nil)

keyList := []string{"key1", "key2"}
token := map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "value3",
}

setExtraMetadata(req, keyList, token)

contextData := ctxGetData(req)

assert.Equal(t, "value1", contextData["key1"])
assert.Equal(t, "value2", contextData["key2"])
assert.NotContains(t, contextData, "key3")
}

func TestBuildMetadataMap(t *testing.T) {
token := &oauth2.Token{
AccessToken: "tyk_upstream_oauth_access_token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
token = token.WithExtra(map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "",
})
extraMetadataKeys := []string{"key1", "key2", "key3", "key4"}

metadataMap := buildMetadataMap(token, extraMetadataKeys)

assert.Equal(t, "value1", metadataMap["key1"])
assert.Equal(t, "value2", metadataMap["key2"])
assert.NotContains(t, metadataMap, "key3")
assert.NotContains(t, metadataMap, "key4")
}

func TestCreateTokenDataBytes(t *testing.T) {
token := &oauth2.Token{
AccessToken: "tyk_upstream_oauth_access_token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
token = token.WithExtra(map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "",
})

extraMetadataKeys := []string{"key1", "key2", "key3", "key4"}

encryptedToken := "encrypted_tyk_upstream_oauth_access_token"
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, extraMetadataKeys)

assert.NoError(t, err)

var tokenData TokenData
err = json.Unmarshal(tokenDataBytes, &tokenData)
assert.NoError(t, err)

assert.Equal(t, encryptedToken, tokenData.Token)
assert.Equal(t, "value1", tokenData.ExtraMetadata["key1"])
assert.Equal(t, "value2", tokenData.ExtraMetadata["key2"])
assert.NotContains(t, tokenData.ExtraMetadata, "key3")
assert.NotContains(t, tokenData.ExtraMetadata, "key4")
}

func TestUnmarshalTokenData(t *testing.T) {
tokenData := TokenData{
Token: "tyk_upstream_oauth_access_token",
ExtraMetadata: map[string]interface{}{
"key1": "value1",
"key2": "value2",
},
}

tokenDataBytes, err := json.Marshal(tokenData)
assert.NoError(t, err)

result, err := unmarshalTokenData(string(tokenDataBytes))

assert.NoError(t, err)

assert.Equal(t, tokenData.Token, result.Token)
assert.Equal(t, tokenData.ExtraMetadata, result.ExtraMetadata)
}
Loading