From a1ec282d02706e074bc4986fd0412e5da3b9d00a Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Wed, 13 Jul 2022 10:37:06 -0700 Subject: [PATCH] Merge pull request from GHSA-qwrj-9hmp-gpxh * Fix claims verification for access tokens in external IdP setup Signed-off-by: Haytham Abuelfutuh * Add another test case for no signature Signed-off-by: Haytham Abuelfutuh --- auth/authzserver/authorize_test.go | 3 - auth/authzserver/resource_server.go | 15 ++- auth/authzserver/resource_server_test.go | 134 ++++++++++++++++++++--- 3 files changed, 129 insertions(+), 23 deletions(-) diff --git a/auth/authzserver/authorize_test.go b/auth/authzserver/authorize_test.go index d481283df..99a01d58a 100644 --- a/auth/authzserver/authorize_test.go +++ b/auth/authzserver/authorize_test.go @@ -65,9 +65,6 @@ func TestAuthEndpoint(t *testing.T) { }) } -// #nosec -const sampleIDToken = `eyJraWQiOiJaNmRtWl9UWGhkdXctalVCWjZ1RUV6dm5oLWpoTk8wWWhlbUI3cWFfTE9jIiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiIwMHVra2k0OHBzSDhMaWtZVjVkNiIsIm5hbWUiOiJIYXl0aGFtIEFidWVsZnV0dWgiLCJ2ZXIiOjEsImlzcyI6Imh0dHBzOi8vZGV2LTE0MTg2NDIyLm9rdGEuY29tL29hdXRoMi9hdXNrbmdubjd1QlZpUXE2YjVkNiIsImF1ZCI6IjBvYWtraGV0ZU5qQ01FUnN0NWQ2IiwiaWF0IjoxNjE4NDUzNjc5LCJleHAiOjE2MTg0NTcyNzksImp0aSI6IklELmE0YXpLdUphVFM2YzNTeHdpWWdTMHhPbTM2bVFnVlVVN0I4V2dEdk80dFkiLCJhbXIiOlsicHdkIl0sImlkcCI6IjBvYWtrbTFjaTFVZVBwTlUwNWQ2IiwicHJlZmVycmVkX3VzZXJuYW1lIjoiaGF5dGhhbUB1bmlvbi5haSIsImF1dGhfdGltZSI6MTYxODQ0NjI0NywiYXRfaGFzaCI6Ikg5Q0FweWlrQkpGYXJ4d1FUbnB6ZFEifQ.SJ3BTD_MFcrYvTnql181Ddeb_mOm81z_S7ZKQ6P8mMgWqn94LZ2nG8k8-_odaaNAAT-M1nAFKWqZAQGvliwS1_TsD8_j0cen5zYnGcz2Uu5fFlvoHwuPgy5JYYNOXkXYgPnIb3kNkgXKbkdjS9hdbMfvnPd9rr8v0yzqf0AQBnUe-cPrzY-ZJjvh80IWDZgSjoP244tTYppPkx8UtedJLJZ4tzB7aXlEyoRV-DpmOLfJkAmblRm4OsO1qjwmx3HSIy_T-0PANn-g4AS07rpoMYHRcqncdgcAsVfGxjyWiOg3kbymLqpGlkIZgzmev-TmpoDp0QkUVPOntuiB57GZ6g` - //func TestAuthCallbackEndpoint(t *testing.T) { // originalURL := "http://localhost:8088/oauth2/authorize?client_id=my-client&redirect_uri=http%3A%2F%2Flocalhost%3A3846%2Fcallback&response_type=code&scope=photos+openid+offline&state=some-random-state-foobar&nonce=some-random-nonce&code_challenge=p0v_UR0KrXl4--BpxM2BQa7qIW5k3k4WauBhjmkVQw8&code_challenge_method=S256" // req := httptest.NewRequest(http.MethodGet, originalURL, nil) diff --git a/auth/authzserver/resource_server.go b/auth/authzserver/resource_server.go index 78e895297..19f5dbfa3 100644 --- a/auth/authzserver/resource_server.go +++ b/auth/authzserver/resource_server.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + jwtgo "github.com/golang-jwt/jwt/v4" "io/ioutil" "mime" "net/http" @@ -28,17 +29,21 @@ type ResourceServer struct { } func (r ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (interfaces.IdentityContext, error) { - raw, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) + _, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) if err != nil { return nil, err } - claimsRaw := map[string]interface{}{} - if err = json.Unmarshal(raw, &claimsRaw); err != nil { - return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err) + t, _, err := jwtgo.NewParser().ParseUnverified(tokenStr, jwtgo.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %v", err) + } + + if err = t.Claims.Valid(); err != nil { + return nil, fmt.Errorf("failed to validate token: %v", err) } - return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), claimsRaw) + return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), t.Claims.(jwtgo.MapClaims)) } func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) { diff --git a/auth/authzserver/resource_server_test.go b/auth/authzserver/resource_server_test.go index 002924dfc..e300f0857 100644 --- a/auth/authzserver/resource_server_test.go +++ b/auth/authzserver/resource_server_test.go @@ -2,6 +2,8 @@ package authzserver import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "io" "net/http" @@ -10,6 +12,9 @@ import ( "reflect" "strings" "testing" + "time" + + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" @@ -21,7 +26,7 @@ import ( stdlibConfig "github.com/flyteorg/flytestdlib/config" ) -func newMockResourceServer(t testing.TB) ResourceServer { +func newMockResourceServer(t testing.TB, publicKey rsa.PublicKey) (resourceServer ResourceServer, closer func()) { ctx := context.Background() dummy := "" serverURL := &dummy @@ -29,12 +34,12 @@ func newMockResourceServer(t testing.TB) ResourceServer { if r.URL.Path == "/.well-known/oauth-authorization-server" { w.Header().Set("Content-Type", "application/json") _, err := io.WriteString(w, strings.ReplaceAll(`{ - "issuer": "https://dev-14186422.okta.com", + "issuer": "https://whatever.okta.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", - "jwks_uri": "URL/keys", + "jwks_uri": "{URL}/keys", "id_token_signing_alg_values_supported": ["RS256"] - }`, "URL", *serverURL)) + }`, "{URL}", *serverURL)) if !assert.NoError(t, err) { t.FailNow() @@ -43,6 +48,14 @@ func newMockResourceServer(t testing.TB) ResourceServer { return } else if r.URL.Path == "/keys" { keys := jwk.NewSet() + key := jwk.NewRSAPublicKey() + err := key.FromRaw(&publicKey) + if err != nil { + http.Error(w, err.Error(), 400) + return + } + + keys.Add(key) raw, err := json.Marshal(keys) if err != nil { http.Error(w, err.Error(), 400) @@ -55,36 +68,127 @@ func newMockResourceServer(t testing.TB) ResourceServer { if !assert.NoError(t, err) { t.FailNow() } + + return } http.NotFound(w, r) } s := httptest.NewServer(http.HandlerFunc(hf)) - defer s.Close() - *serverURL = s.URL http.DefaultClient = s.Client() r, err := NewOAuth2ResourceServer(ctx, authConfig.ExternalAuthorizationServer{ - BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)}, + BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)}, + AllowedAudience: []string{"https://localhost"}, }, stdlibConfig.URL{}) if !assert.NoError(t, err) { t.FailNow() } - return r -} - -func TestNewOAuth2ResourceServer(t *testing.T) { - newMockResourceServer(t) + return r, func() { + s.Close() + } } func TestResourceServer_ValidateAccessToken(t *testing.T) { - r := newMockResourceServer(t) - _, err := r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken) - assert.Error(t, err) + sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048) + if !assert.NoError(t, err) { + t.FailNow() + } + + r, closer := newMockResourceServer(t, sampleRSAKey.PublicKey) + defer closer() + + t.Run("No signature", func(t *testing.T) { + sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{ + Audience: r.allowedAudience, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "localhost", + Subject: "someone", + }).SignedString(sampleRSAKey) + if !assert.NoError(t, err) { + t.FailNow() + } + + parts := strings.Split(sampleIDToken, ".") + sampleIDToken = strings.Join(parts[:len(parts)-1], ".") + "." + + _, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken) + if !assert.Error(t, err) { + t.FailNow() + } + + assert.Contains(t, err.Error(), "failed to verify id token signature") + }) + + t.Run("Invalid signature", func(t *testing.T) { + sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048) + if !assert.NoError(t, err) { + t.FailNow() + } + + sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{ + Audience: r.allowedAudience, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "localhost", + Subject: "someone", + }).SignedString(sampleRSAKey) + if !assert.NoError(t, err) { + t.FailNow() + } + + _, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken) + if !assert.Error(t, err) { + t.FailNow() + } + + assert.Contains(t, err.Error(), "failed to verify id token signature") + }) + + t.Run("Invalid audience", func(t *testing.T) { + sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{ + Audience: []string{"https://hello world"}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "localhost", + Subject: "someone", + }).SignedString(sampleRSAKey) + if !assert.NoError(t, err) { + t.FailNow() + } + + _, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken) + if !assert.Error(t, err) { + t.FailNow() + } + + assert.Contains(t, err.Error(), "invalid audience") + }) + + t.Run("Expired token", func(t *testing.T) { + sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.StandardClaims{ + Audience: r.allowedAudience[0], + ExpiresAt: time.Now().Add(-time.Hour).Unix(), + IssuedAt: time.Now().Add(-2 * time.Hour).Unix(), + Issuer: "localhost", + Subject: "someone", + }).SignedString(sampleRSAKey) + if !assert.NoError(t, err) { + t.FailNow() + } + + _, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken) + if !assert.Error(t, err) { + t.FailNow() + } + + assert.Contains(t, err.Error(), "failed to validate token: Token is expired") + }) } func Test_doRequest(t *testing.T) {