From 48e627dc87fc32e14684aae8ad2d30a21cf598ec Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 1 Nov 2024 10:53:10 +0200 Subject: [PATCH 1/5] TT-13271, fix for token metadata not being cached --- gateway/mw_oauth2_auth.go | 102 +++++++++++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 17 deletions(-) diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index bc683edbdf1..a8e429b98fc 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "net/http" "strings" @@ -57,14 +58,21 @@ type upstreamOAuthPasswordCache struct { func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) + tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) if err != nil { return "", err } - if tokenString != "" { - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString) - return decryptedToken, nil + if tokenData != "" { + if tokenData != "" { + tokenContents, err := unmarshalTokenData(tokenData) + if err != nil { + return "", err + } + decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata) + return decryptedToken, nil + } } token, err := cache.obtainToken(r.Context(), OAuthSpec) @@ -73,10 +81,15 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up } encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, token) + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) + if err != nil { + return "", err + } + metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap) ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil { + if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { return "", err } @@ -271,16 +284,26 @@ func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string return hex.EncodeToString(hash.Sum(nil)) } +type TokenData struct { + Token string `json:"token"` + ExtraMetadata map[string]interface{} `json:"extra_metadata"` +} + func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) + tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) if err != nil { return "", err } - if tokenString != "" { - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString) + if tokenData != "" { + tokenContents, err := unmarshalTokenData(tokenData) + if err != nil { + return "", err + } + decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata) return decryptedToken, nil } @@ -290,24 +313,69 @@ func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAut } encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, token) + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) + if err != nil { + return "", err + } + metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap) ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil { + if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { return "", err } return token.AccessToken, nil } -func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) { +func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) { + td := TokenData{ + Token: encryptedToken, + ExtraMetadata: buildMetadataMap(token, extraMetadataKeys), + } + return json.Marshal(td) +} + +func unmarshalTokenData(tokenData string) (TokenData, error) { + var tokenContents TokenData + err := json.Unmarshal([]byte(tokenData), &tokenContents) + if err != nil { + return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err) + } + return tokenContents, nil +} + +func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} { + metadataMap := make(map[string]interface{}) + for _, key := range extraMetadataKeys { + if val := token.Extra(key); val != "" { + metadataMap[key] = val + } + } + return metadataMap +} + +//func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) { +// contextDataObject := ctxGetData(r) +// if contextDataObject == nil { +// contextDataObject = make(map[string]interface{}) +// } +// for _, key := range keyList { +// val := token.Extra(key) +// if val != "" { +// contextDataObject[key] = val +// } +// } +// ctxSetData(r, contextDataObject) +//} + +func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) { contextDataObject := ctxGetData(r) if contextDataObject == nil { contextDataObject = make(map[string]interface{}) } for _, key := range keyList { - val := token.Extra(key) - if val != "" { + if val, ok := token[key]; ok && val != "" { contextDataObject[key] = val } } @@ -318,13 +386,13 @@ func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, e const maxRetries = 10 const retryDelay = 100 * time.Millisecond - var token string + var tokenData string var err error for i := 0; i < maxRetries; i++ { - token, err = cache.GetKey(cacheKey) + tokenData, err = cache.GetKey(cacheKey) if err == nil { - return token, nil + return tokenData, nil } lockKey := cacheKey + ":lock" From c323e18424c37fedf90c41e5eded980c8d20fc1f Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 1 Nov 2024 10:54:35 +0200 Subject: [PATCH 2/5] TT-13271, PR cleanup --- gateway/mw_oauth2_auth.go | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index a8e429b98fc..8b3c4f46794 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -64,15 +64,13 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up } if tokenData != "" { - if tokenData != "" { - tokenContents, err := unmarshalTokenData(tokenData) - if err != nil { - return "", err - } - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata) - return decryptedToken, nil + tokenContents, err := unmarshalTokenData(tokenData) + if err != nil { + return "", err } + decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata) + return decryptedToken, nil } token, err := cache.obtainToken(r.Context(), OAuthSpec) @@ -355,20 +353,6 @@ func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[strin return metadataMap } -//func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) { -// contextDataObject := ctxGetData(r) -// if contextDataObject == nil { -// contextDataObject = make(map[string]interface{}) -// } -// for _, key := range keyList { -// val := token.Extra(key) -// if val != "" { -// contextDataObject[key] = val -// } -// } -// ctxSetData(r, contextDataObject) -//} - func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) { contextDataObject := ctxGetData(r) if contextDataObject == nil { From 0c8ee2c5ae905643a41d3e85bee0ca548ede08b5 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 1 Nov 2024 13:08:09 +0200 Subject: [PATCH 3/5] TT-13271, assert that response from cache is identical --- gateway/mw_oauth2_auth_test.go | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/gateway/mw_oauth2_auth_test.go b/gateway/mw_oauth2_auth_test.go index 736a6561f64..b35a0944b18 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/gateway/mw_oauth2_auth_test.go @@ -19,7 +19,13 @@ func TestUpstreamOauth2(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) + var requestCount int ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + if requestCount > 0 { + assert.Fail(t, "Unexpected request received.") + } + requestCount++ if r.URL.String() != "/token" { assert.Fail(t, "authenticate client request URL = %q; want %q", r.URL, "/token") } @@ -90,6 +96,23 @@ func TestUpstreamOauth2(t *testing.T) { return true }, }, + { + Path: "/upstream-oauth-distributed/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, header.Authorization) + assert.NotEmpty(t, resp.Headers[header.Authorization]) + assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization]) + + return true + }, + }, }...) } @@ -98,8 +121,13 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) + var requestCount int ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() + if requestCount > 0 { + assert.Fail(t, "Unexpected request received.") + } + requestCount++ expected := "/token" if r.URL.String() != expected { assert.Fail(t, "URL = %q; want %q", r.URL, expected) @@ -174,5 +202,22 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { return true }, }, + { + Path: "/upstream-oauth-password/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, header.Authorization) + assert.NotEmpty(t, resp.Headers[header.Authorization]) + assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization]) + + return true + }, + }, }...) } From 34f1259e41b19e87844acc3be4e55630ec8a7626 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 1 Nov 2024 13:21:25 +0200 Subject: [PATCH 4/5] TT-13271, added unit test and fixed a small bug --- gateway/mw_oauth2_auth.go | 2 +- gateway/mw_oauth2_auth_test.go | 92 ++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index 8b3c4f46794..6808a39ab45 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -346,7 +346,7 @@ func unmarshalTokenData(tokenData string) (TokenData, error) { func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} { metadataMap := make(map[string]interface{}) for _, key := range extraMetadataKeys { - if val := token.Extra(key); val != "" { + if val := token.Extra(key); val != "" && val != nil { metadataMap[key] = val } } diff --git a/gateway/mw_oauth2_auth_test.go b/gateway/mw_oauth2_auth_test.go index b35a0944b18..a92d63041de 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/gateway/mw_oauth2_auth_test.go @@ -2,10 +2,12 @@ package gateway import ( "encoding/json" + "golang.org/x/oauth2" "io" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" @@ -221,3 +223,93 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { }, }...) } + +func TestSetExtraMetadata(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://tykxample.com", nil) + + keyList := []string{"key1", "key2"} + token := map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + setExtraMetadata(req, keyList, token) + + contextData := ctxGetData(req) + + assert.Equal(t, "value1", contextData["key1"]) + assert.Equal(t, "value2", contextData["key2"]) + assert.NotContains(t, contextData, "key3") +} + +func TestBuildMetadataMap(t *testing.T) { + token := &oauth2.Token{ + AccessToken: "tyk_upstream_oauth_access_token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + } + token = token.WithExtra(map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "", + }) + extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} + + metadataMap := buildMetadataMap(token, extraMetadataKeys) + + assert.Equal(t, "value1", metadataMap["key1"]) + assert.Equal(t, "value2", metadataMap["key2"]) + assert.NotContains(t, metadataMap, "key3") + assert.NotContains(t, metadataMap, "key4") +} + +func TestCreateTokenDataBytes(t *testing.T) { + token := &oauth2.Token{ + AccessToken: "tyk_upstream_oauth_access_token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + } + token = token.WithExtra(map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "", + }) + + extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} + + encryptedToken := "encrypted_tyk_upstream_oauth_access_token" + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, extraMetadataKeys) + + assert.NoError(t, err) + + var tokenData TokenData + err = json.Unmarshal(tokenDataBytes, &tokenData) + assert.NoError(t, err) + + assert.Equal(t, encryptedToken, tokenData.Token) + assert.Equal(t, "value1", tokenData.ExtraMetadata["key1"]) + assert.Equal(t, "value2", tokenData.ExtraMetadata["key2"]) + assert.NotContains(t, tokenData.ExtraMetadata, "key3") + assert.NotContains(t, tokenData.ExtraMetadata, "key4") +} + +func TestUnmarshalTokenData(t *testing.T) { + tokenData := TokenData{ + Token: "tyk_upstream_oauth_access_token", + ExtraMetadata: map[string]interface{}{ + "key1": "value1", + "key2": "value2", + }, + } + + tokenDataBytes, err := json.Marshal(tokenData) + assert.NoError(t, err) + + result, err := unmarshalTokenData(string(tokenDataBytes)) + + assert.NoError(t, err) + + assert.Equal(t, tokenData.Token, result.Token) + assert.Equal(t, tokenData.ExtraMetadata, result.ExtraMetadata) +} From 7ebf8d0572105cb3f7fcf71057c92ec92b8ea2c3 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 1 Nov 2024 13:44:02 +0200 Subject: [PATCH 5/5] TT-13271, fixed linting --- gateway/mw_oauth2_auth_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gateway/mw_oauth2_auth_test.go b/gateway/mw_oauth2_auth_test.go index a92d63041de..8f26b67039b 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/gateway/mw_oauth2_auth_test.go @@ -2,13 +2,14 @@ package gateway import ( "encoding/json" - "golang.org/x/oauth2" "io" "net/http" "net/http/httptest" "testing" "time" + "golang.org/x/oauth2" + "github.com/stretchr/testify/assert" "github.com/TykTechnologies/tyk/apidef"