diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index bc683edbdf1..6808a39ab45 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "net/http" "strings" @@ -57,13 +58,18 @@ type upstreamOAuthPasswordCache struct { func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) + tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) if err != nil { return "", err } - if tokenString != "" { - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString) + if tokenData != "" { + tokenContents, err := unmarshalTokenData(tokenData) + if err != nil { + return "", err + } + decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata) return decryptedToken, nil } @@ -73,10 +79,15 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up } encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, token) + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) + if err != nil { + return "", err + } + metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap) ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil { + if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { return "", err } @@ -271,16 +282,26 @@ func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string return hex.EncodeToString(hash.Sum(nil)) } +type TokenData struct { + Token string `json:"token"` + ExtraMetadata map[string]interface{} `json:"extra_metadata"` +} + func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) { cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID) - tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) + tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster) if err != nil { return "", err } - if tokenString != "" { - decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString) + if tokenData != "" { + tokenContents, err := unmarshalTokenData(tokenData) + if err != nil { + return "", err + } + decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata) return decryptedToken, nil } @@ -290,24 +311,55 @@ func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAut } encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken) - setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, token) + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) + if err != nil { + return "", err + } + metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata) + setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap) ttl := time.Until(token.Expiry) - if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil { + if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil { return "", err } return token.AccessToken, nil } -func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) { +func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) { + td := TokenData{ + Token: encryptedToken, + ExtraMetadata: buildMetadataMap(token, extraMetadataKeys), + } + return json.Marshal(td) +} + +func unmarshalTokenData(tokenData string) (TokenData, error) { + var tokenContents TokenData + err := json.Unmarshal([]byte(tokenData), &tokenContents) + if err != nil { + return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err) + } + return tokenContents, nil +} + +func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} { + metadataMap := make(map[string]interface{}) + for _, key := range extraMetadataKeys { + if val := token.Extra(key); val != "" && val != nil { + metadataMap[key] = val + } + } + return metadataMap +} + +func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) { contextDataObject := ctxGetData(r) if contextDataObject == nil { contextDataObject = make(map[string]interface{}) } for _, key := range keyList { - val := token.Extra(key) - if val != "" { + if val, ok := token[key]; ok && val != "" { contextDataObject[key] = val } } @@ -318,13 +370,13 @@ func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, e const maxRetries = 10 const retryDelay = 100 * time.Millisecond - var token string + var tokenData string var err error for i := 0; i < maxRetries; i++ { - token, err = cache.GetKey(cacheKey) + tokenData, err = cache.GetKey(cacheKey) if err == nil { - return token, nil + return tokenData, nil } lockKey := cacheKey + ":lock" diff --git a/gateway/mw_oauth2_auth_test.go b/gateway/mw_oauth2_auth_test.go index 736a6561f64..8f26b67039b 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/gateway/mw_oauth2_auth_test.go @@ -6,6 +6,9 @@ import ( "net/http" "net/http/httptest" "testing" + "time" + + "golang.org/x/oauth2" "github.com/stretchr/testify/assert" @@ -19,7 +22,13 @@ func TestUpstreamOauth2(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) + var requestCount int ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + if requestCount > 0 { + assert.Fail(t, "Unexpected request received.") + } + requestCount++ if r.URL.String() != "/token" { assert.Fail(t, "authenticate client request URL = %q; want %q", r.URL, "/token") } @@ -90,6 +99,23 @@ func TestUpstreamOauth2(t *testing.T) { return true }, }, + { + Path: "/upstream-oauth-distributed/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, header.Authorization) + assert.NotEmpty(t, resp.Headers[header.Authorization]) + assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization]) + + return true + }, + }, }...) } @@ -98,8 +124,13 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { tst := StartTest(nil) t.Cleanup(tst.Close) + var requestCount int ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() + if requestCount > 0 { + assert.Fail(t, "Unexpected request received.") + } + requestCount++ expected := "/token" if r.URL.String() != expected { assert.Fail(t, "URL = %q; want %q", r.URL, expected) @@ -174,5 +205,112 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { return true }, }, + { + Path: "/upstream-oauth-password/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, header.Authorization) + assert.NotEmpty(t, resp.Headers[header.Authorization]) + assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization]) + + return true + }, + }, }...) } + +func TestSetExtraMetadata(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://tykxample.com", nil) + + keyList := []string{"key1", "key2"} + token := map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + setExtraMetadata(req, keyList, token) + + contextData := ctxGetData(req) + + assert.Equal(t, "value1", contextData["key1"]) + assert.Equal(t, "value2", contextData["key2"]) + assert.NotContains(t, contextData, "key3") +} + +func TestBuildMetadataMap(t *testing.T) { + token := &oauth2.Token{ + AccessToken: "tyk_upstream_oauth_access_token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + } + token = token.WithExtra(map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "", + }) + extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} + + metadataMap := buildMetadataMap(token, extraMetadataKeys) + + assert.Equal(t, "value1", metadataMap["key1"]) + assert.Equal(t, "value2", metadataMap["key2"]) + assert.NotContains(t, metadataMap, "key3") + assert.NotContains(t, metadataMap, "key4") +} + +func TestCreateTokenDataBytes(t *testing.T) { + token := &oauth2.Token{ + AccessToken: "tyk_upstream_oauth_access_token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + } + token = token.WithExtra(map[string]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "", + }) + + extraMetadataKeys := []string{"key1", "key2", "key3", "key4"} + + encryptedToken := "encrypted_tyk_upstream_oauth_access_token" + tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, extraMetadataKeys) + + assert.NoError(t, err) + + var tokenData TokenData + err = json.Unmarshal(tokenDataBytes, &tokenData) + assert.NoError(t, err) + + assert.Equal(t, encryptedToken, tokenData.Token) + assert.Equal(t, "value1", tokenData.ExtraMetadata["key1"]) + assert.Equal(t, "value2", tokenData.ExtraMetadata["key2"]) + assert.NotContains(t, tokenData.ExtraMetadata, "key3") + assert.NotContains(t, tokenData.ExtraMetadata, "key4") +} + +func TestUnmarshalTokenData(t *testing.T) { + tokenData := TokenData{ + Token: "tyk_upstream_oauth_access_token", + ExtraMetadata: map[string]interface{}{ + "key1": "value1", + "key2": "value2", + }, + } + + tokenDataBytes, err := json.Marshal(tokenData) + assert.NoError(t, err) + + result, err := unmarshalTokenData(string(tokenDataBytes)) + + assert.NoError(t, err) + + assert.Equal(t, tokenData.Token, result.Token) + assert.Equal(t, tokenData.ExtraMetadata, result.ExtraMetadata) +}