Skip to content

Commit

Permalink
TT-13271, fix for token metadata not being cached (#6689)
Browse files Browse the repository at this point in the history
### **User description**
Currently when the token is cached into redis we are not cacheing the
extra metadata so only the initial request gets populated with the
requiered fields.
<details open>
<summary><a href="https://tyktech.atlassian.net/browse/TT-13271"
title="TT-13271" target="_blank">TT-13271</a></summary>
  <br />
  <table>
    <tr>
      <th>Summary</th>
      <td>Add support for custom OAuth server response fields</td>
    </tr>
    <tr>
      <th>Type</th>
      <td>
<img alt="Story"
src="https://tyktech.atlassian.net/rest/api/2/universal_avatar/view/type/issuetype/avatar/10315?size=medium"
/>
        Story
      </td>
    </tr>
    <tr>
      <th>Status</th>
      <td>Ready for Testing</td>
    </tr>
    <tr>
      <th>Points</th>
      <td>N/A</td>
    </tr>
    <tr>
      <th>Labels</th>
      <td>-</td>
    </tr>
  </table>
</details>
<!--
  do not remove this marker as it will break jira-lint's functionality.
  added_by_jira_lint
-->

---

<!-- Provide a general summary of your changes in the Title above -->

## Description

<!-- Describe your changes in detail -->

## Related Issue

<!-- This project only accepts pull requests related to open issues. -->
<!-- If suggesting a new feature or change, please discuss it in an
issue first. -->
<!-- If fixing a bug, there should be an issue describing it with steps
to reproduce. -->
<!-- OSS: Please link to the issue here. Tyk: please create/link the
JIRA ticket. -->

## Motivation and Context

<!-- Why is this change required? What problem does it solve? -->

## How This Has Been Tested

<!-- Please describe in detail how you tested your changes -->
<!-- Include details of your testing environment, and the tests -->
<!-- you ran to see how your change affects other areas of the code,
etc. -->
<!-- This information is helpful for reviewers and QA. -->

## Screenshots (if appropriate)

## Types of changes

<!-- What types of changes does your code introduce? Put an `x` in all
the boxes that apply: -->

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Refactoring or add test (improvements in base code or adds test
coverage to functionality)

## Checklist

<!-- Go over all the following points, and put an `x` in all the boxes
that apply -->
<!-- If there are no documentation updates required, mark the item as
checked. -->
<!-- Raise up any additional concerns not covered by the checklist. -->

- [ ] I ensured that the documentation is up to date
- [ ] I explained why this PR updates go.mod in detail with reasoning
why it's required
- [ ] I would like a code coverage CI quality gate exception and have
explained why


___

### **PR Type**
enhancement, bug fix


___

### **Description**
- Introduced a new `TokenData` struct to encapsulate token and extra
metadata, improving the structure and readability of the code.
- Added functions `createTokenDataBytes` and `unmarshalTokenData` to
handle JSON marshaling and unmarshaling of token data, enhancing the
robustness of data handling.
- Modified the caching mechanism to store both token and metadata,
ensuring that metadata is preserved and accessible when retrieving
tokens from the cache.
- Improved the handling of extra metadata by using a map, allowing for
more flexible and efficient metadata management.



___



### **Changes walkthrough** 📝
<table><thead><tr><th></th><th align="left">Relevant
files</th></tr></thead><tbody><tr><td><strong>Enhancement</strong></td><td><table>
<tr>
  <td>
    <details>
<summary><strong>mw_oauth2_auth.go</strong><dd><code>Improve token
caching and metadata handling in OAuth2</code>&nbsp; &nbsp; &nbsp;
&nbsp; </dd></summary>
<hr>

gateway/mw_oauth2_auth.go

<li>Introduced a new <code>TokenData</code> struct to handle token and
metadata.<br> <li> Added functions to marshal and unmarshal token
data.<br> <li> Modified caching logic to store and retrieve token data
with metadata.<br> <li> Enhanced metadata handling by using a map for
extra metadata.<br>


</details>


  </td>
<td><a
href="https://github.com/TykTechnologies/tyk/pull/6689/files#diff-a90347c3ad28f06a7bd1c5554ce63448774cb486cf4e9961af2323423ce8209d">+85/-17</a>&nbsp;
</td>

</tr>                    
</table></td></tr></tr></tbody></table>

___

> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull
request to receive relevant information
  • Loading branch information
andrei-tyk authored Nov 1, 2024
1 parent 4a14e3a commit 85e8a94
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 16 deletions.
84 changes: 68 additions & 16 deletions gateway/mw_oauth2_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -57,13 +58,18 @@ type upstreamOAuthPasswordCache struct {
func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generatePasswordOAuthCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}

Expand All @@ -73,10 +79,15 @@ func (cache *upstreamOAuthPasswordCache) getToken(r *http.Request, OAuthSpec *Up
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.PasswordAuthentication.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

Expand Down Expand Up @@ -271,16 +282,26 @@ func generateClientCredentialsCacheKey(config apidef.UpstreamOAuth, apiId string
return hex.EncodeToString(hash.Sum(nil))
}

type TokenData struct {
Token string `json:"token"`
ExtraMetadata map[string]interface{} `json:"extra_metadata"`
}

func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generateClientCredentialsCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
tokenData, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenString)
if tokenData != "" {
tokenContents, err := unmarshalTokenData(tokenData)
if err != nil {
return "", err
}
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), tokenContents.Token)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, tokenContents.ExtraMetadata)
return decryptedToken, nil
}

Expand All @@ -290,24 +311,55 @@ func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAut
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw.GetConfig().Secret), token.AccessToken)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, token)
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
if err != nil {
return "", err
}
metadataMap := buildMetadataMap(token, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata)
setExtraMetadata(r, OAuthSpec.Spec.UpstreamAuth.OAuth.ClientCredentials.ExtraMetadata, metadataMap)

ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
if err := setTokenInCache(cacheKey, string(tokenDataBytes), ttl, &cache.RedisCluster); err != nil {
return "", err
}

return token.AccessToken, nil
}

func setExtraMetadata(r *http.Request, keyList []string, token *oauth2.Token) {
func createTokenDataBytes(encryptedToken string, token *oauth2.Token, extraMetadataKeys []string) ([]byte, error) {
td := TokenData{
Token: encryptedToken,
ExtraMetadata: buildMetadataMap(token, extraMetadataKeys),
}
return json.Marshal(td)
}

func unmarshalTokenData(tokenData string) (TokenData, error) {
var tokenContents TokenData
err := json.Unmarshal([]byte(tokenData), &tokenContents)
if err != nil {
return TokenData{}, fmt.Errorf("failed to unmarshal token data: %w", err)
}
return tokenContents, nil
}

func buildMetadataMap(token *oauth2.Token, extraMetadataKeys []string) map[string]interface{} {
metadataMap := make(map[string]interface{})
for _, key := range extraMetadataKeys {
if val := token.Extra(key); val != "" && val != nil {
metadataMap[key] = val
}
}
return metadataMap
}

func setExtraMetadata(r *http.Request, keyList []string, token map[string]interface{}) {
contextDataObject := ctxGetData(r)
if contextDataObject == nil {
contextDataObject = make(map[string]interface{})
}
for _, key := range keyList {
val := token.Extra(key)
if val != "" {
if val, ok := token[key]; ok && val != "" {
contextDataObject[key] = val
}
}
Expand All @@ -318,13 +370,13 @@ func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, e
const maxRetries = 10
const retryDelay = 100 * time.Millisecond

var token string
var tokenData string
var err error

for i := 0; i < maxRetries; i++ {
token, err = cache.GetKey(cacheKey)
tokenData, err = cache.GetKey(cacheKey)
if err == nil {
return token, nil
return tokenData, nil
}

lockKey := cacheKey + ":lock"
Expand Down
138 changes: 138 additions & 0 deletions gateway/mw_oauth2_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"golang.org/x/oauth2"

"github.com/stretchr/testify/assert"

Expand All @@ -19,7 +22,13 @@ func TestUpstreamOauth2(t *testing.T) {
tst := StartTest(nil)
t.Cleanup(tst.Close)

var requestCount int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if requestCount > 0 {
assert.Fail(t, "Unexpected request received.")
}
requestCount++
if r.URL.String() != "/token" {
assert.Fail(t, "authenticate client request URL = %q; want %q", r.URL, "/token")
}
Expand Down Expand Up @@ -90,6 +99,23 @@ func TestUpstreamOauth2(t *testing.T) {
return true
},
},
{
Path: "/upstream-oauth-distributed/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, header.Authorization)
assert.NotEmpty(t, resp.Headers[header.Authorization])
assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization])

return true
},
},
}...)

}
Expand All @@ -98,8 +124,13 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
tst := StartTest(nil)
t.Cleanup(tst.Close)

var requestCount int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if requestCount > 0 {
assert.Fail(t, "Unexpected request received.")
}
requestCount++
expected := "/token"
if r.URL.String() != expected {
assert.Fail(t, "URL = %q; want %q", r.URL, expected)
Expand Down Expand Up @@ -174,5 +205,112 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
return true
},
},
{
Path: "/upstream-oauth-password/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, header.Authorization)
assert.NotEmpty(t, resp.Headers[header.Authorization])
assert.Equal(t, "Bearer 90d64460d14870c08c81352a05dedd3465940a7c", resp.Headers[header.Authorization])

return true
},
},
}...)
}

func TestSetExtraMetadata(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://tykxample.com", nil)

keyList := []string{"key1", "key2"}
token := map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "value3",
}

setExtraMetadata(req, keyList, token)

contextData := ctxGetData(req)

assert.Equal(t, "value1", contextData["key1"])
assert.Equal(t, "value2", contextData["key2"])
assert.NotContains(t, contextData, "key3")
}

func TestBuildMetadataMap(t *testing.T) {
token := &oauth2.Token{
AccessToken: "tyk_upstream_oauth_access_token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
token = token.WithExtra(map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "",
})
extraMetadataKeys := []string{"key1", "key2", "key3", "key4"}

metadataMap := buildMetadataMap(token, extraMetadataKeys)

assert.Equal(t, "value1", metadataMap["key1"])
assert.Equal(t, "value2", metadataMap["key2"])
assert.NotContains(t, metadataMap, "key3")
assert.NotContains(t, metadataMap, "key4")
}

func TestCreateTokenDataBytes(t *testing.T) {
token := &oauth2.Token{
AccessToken: "tyk_upstream_oauth_access_token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
token = token.WithExtra(map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "",
})

extraMetadataKeys := []string{"key1", "key2", "key3", "key4"}

encryptedToken := "encrypted_tyk_upstream_oauth_access_token"
tokenDataBytes, err := createTokenDataBytes(encryptedToken, token, extraMetadataKeys)

assert.NoError(t, err)

var tokenData TokenData
err = json.Unmarshal(tokenDataBytes, &tokenData)
assert.NoError(t, err)

assert.Equal(t, encryptedToken, tokenData.Token)
assert.Equal(t, "value1", tokenData.ExtraMetadata["key1"])
assert.Equal(t, "value2", tokenData.ExtraMetadata["key2"])
assert.NotContains(t, tokenData.ExtraMetadata, "key3")
assert.NotContains(t, tokenData.ExtraMetadata, "key4")
}

func TestUnmarshalTokenData(t *testing.T) {
tokenData := TokenData{
Token: "tyk_upstream_oauth_access_token",
ExtraMetadata: map[string]interface{}{
"key1": "value1",
"key2": "value2",
},
}

tokenDataBytes, err := json.Marshal(tokenData)
assert.NoError(t, err)

result, err := unmarshalTokenData(string(tokenDataBytes))

assert.NoError(t, err)

assert.Equal(t, tokenData.Token, result.Token)
assert.Equal(t, tokenData.ExtraMetadata, result.ExtraMetadata)
}

0 comments on commit 85e8a94

Please sign in to comment.