Skip to content

Commit

Permalink
test: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
james-d-elliott committed Feb 15, 2024
1 parent 1f87b31 commit 760314f
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 1 deletion.
4 changes: 4 additions & 0 deletions handler/oauth2/introspector_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ func IsJWTProfileAccessToken(token *jwt.Token) bool {
ok bool
)

if token == nil {
return false
}

if raw, ok = token.Header[string(jwt.JWTHeaderType)]; !ok {
return false
}
Expand Down
78 changes: 78 additions & 0 deletions handler/oauth2/introspector_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ func TestIntrospectJWT(t *testing.T) {
return token
},
},
{
description: "should fail bad typ",
token: func() string {
jwt := jwtValidCase(fosite.AccessToken)

s := jwt.Session.(*JWTSession)

s.JWTHeader.Extra["typ"] = "JWT"

jwt.Session = s

token, _, err := strat.GenerateAccessToken(context.Background(), jwt)
assert.NoError(t, err)
return token
},
expectErr: fosite.ErrRequestUnauthorized,
},
} {
t.Run(fmt.Sprintf("case=%d:%v", k, c.description), func(t *testing.T) {
if c.scopes == nil {
Expand Down Expand Up @@ -142,3 +159,64 @@ func BenchmarkIntrospectJWT(b *testing.B) {

assert.NoError(b, err)
}

func TestIsJWTProfileAccessToken(t *testing.T) {
testCases := []struct {
name string
have *jwt.Token
expected bool
}{
{
"ShouldPassTypATJWT",
&jwt.Token{
Header: map[string]interface{}{
"typ": "at+jwt",
},
},
true,
},
{
"ShouldPassTypApplicationATJWT",
&jwt.Token{
Header: map[string]interface{}{
"typ": "application/at+jwt",
},
},
true,
},
{
"ShouldFailJWT",
&jwt.Token{
Header: map[string]interface{}{
"typ": "JWT",
},
},
false,
},
{
"ShouldFailNoValue",
&jwt.Token{
Header: map[string]interface{}{},
},
false,
},
{
"ShouldFailNilValue",
&jwt.Token{
Header: nil,
},
false,
},
{
"ShouldFailNilInput",
nil,
false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, IsJWTProfileAccessToken(tc.have))
})
}
}
9 changes: 8 additions & 1 deletion handler/oauth2/strategy_jwt_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ func (j *JWTSession) GetJWTClaims() jwt.JWTClaimsContainer {

func (j *JWTSession) GetJWTHeader() *jwt.Headers {
if j.JWTHeader == nil {
j.JWTHeader = &jwt.Headers{}
j.JWTHeader = &jwt.Headers{
Extra: map[string]interface{}{
"typ": "at+jwt",
},
}
} else if j.JWTHeader.Extra["typ"] == nil {
j.JWTHeader.Extra["typ"] = "at+jwt"
}

return j.JWTHeader
}

Expand Down
46 changes: 46 additions & 0 deletions handler/oauth2/strategy_jwt_session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oauth2

import (
"github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"
"testing"
)

func TestJWTSession_GetJWTHeader(t *testing.T) {
testCases := []struct {
name string
have *JWTSession
expected string
}{
{
"ShouldReturnDefaultTyp",
&JWTSession{},
"at+jwt",
},
{
"ShouldReturnConfiguredATJWTTyp",
&JWTSession{JWTHeader: &jwt.Headers{Extra: map[string]interface{}{
"typ": "at+jwt",
}}},
"at+jwt",
},
{
"ShouldReturnConfiguredJWTTyp",
&JWTSession{JWTHeader: &jwt.Headers{Extra: map[string]interface{}{
"typ": "JWT",
}}},
"JWT",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
header := tc.have.GetJWTHeader()

assert.Equal(t, tc.expected, header.Get("typ"))
})
}
}
10 changes: 10 additions & 0 deletions handler/oauth2/strategy_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ func TestAccessToken(t *testing.T) {
require.Len(t, parts, 3, "%s - %v", token, parts)
assert.Equal(t, parts[2], signature)

rawHeader, err := base64.RawURLEncoding.DecodeString(parts[0])
require.NoError(t, err)

var header map[string]interface{}
require.NoError(t, json.Unmarshal(rawHeader, &header))

typ, ok := header["typ"]
assert.True(t, ok)
assert.Equal(t, "at+jwt", typ)

rawPayload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)
var payload map[string]interface{}
Expand Down

0 comments on commit 760314f

Please sign in to comment.