Skip to content

Commit

Permalink
Feature/82/invalidate-jwt (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirAgassi authored Dec 7, 2024
2 parents e49459d + bb5b197 commit 8c5a785
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 154 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- +goose Up
-- +goose StatementBegin
CREATE EXTENSION IF NOT EXISTS pgcrypto;

ALTER TABLE users ADD COLUMN token_salt BYTEA;

-- backfill existing users with random salt
UPDATE users SET token_salt = gen_random_bytes(32);

-- make token_salt non-nullable and unique
ALTER TABLE users ALTER COLUMN token_salt SET NOT NULL;
ALTER TABLE users ADD CONSTRAINT users_token_salt_key UNIQUE (token_salt);
-- +goose StatementEnd

-- +goose Down
-- +goose StatementBegin
ALTER TABLE users DROP COLUMN token_salt;

-- +goose StatementEnd
13 changes: 11 additions & 2 deletions backend/.sqlc/queries/users.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
INSERT INTO users (
email,
password_hash,
role
role,
token_salt
) VALUES (
$1, $2, $3
$1, $2, $3, gen_random_bytes(32)
) RETURNING *;

-- name: GetUserByEmail :one
Expand All @@ -18,3 +19,11 @@ WHERE id = $1 LIMIT 1;
-- name: UpdateUserEmailVerifiedStatus :exec
UPDATE users SET email_verified = $1
WHERE id = $2;

-- name: UpdateUserTokenSalt :exec
UPDATE users SET token_salt = gen_random_bytes(32)
WHERE id = $1;

-- name: GetUserTokenSalt :one
SELECT token_salt FROM users
WHERE id = $1;
1 change: 1 addition & 0 deletions backend/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 31 additions & 5 deletions backend/db/users.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/net v0.28.0 // indirect
Expand Down
5 changes: 5 additions & 0 deletions backend/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
Expand Down
16 changes: 8 additions & 8 deletions backend/internal/jwt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ const (
)

// Generates JWT tokens for the given user. Returns the access token, refresh token and error (nil if no error)
func Generate(userID string, role db.UserRole) (string, string, error) {
accessToken, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute))
func GenerateWithSalt(userID string, role db.UserRole, salt []byte) (string, string, error) {
accessToken, err := generateTokenWithSalt(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute), salt)
if err != nil {
return "", "", err
}

refreshToken, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour))
refreshToken, err := generateTokenWithSalt(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour), salt)
if err != nil {
return "", "", err
}
Expand All @@ -40,7 +40,6 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e
Email: email,
TokenType: VERIFY_EMAIL_TOKEN_TYPE,
RegisteredClaims: golangJWT.RegisteredClaims{
// expire in 1 week
ExpiresAt: golangJWT.NewNumericDate(exp),
IssuedAt: golangJWT.NewNumericDate(time.Now()),
ID: id,
Expand All @@ -51,19 +50,20 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e
return token.SignedString([]byte(os.Getenv("JWT_SECRET_VERIFY_EMAIL")))
}

// Private helper method to generate a token.
func generateToken(userID string, role db.UserRole, tokenType string, exp time.Time) (string, error) {
// Private helper method to generate a token with user's salt
func generateTokenWithSalt(userID string, role db.UserRole, tokenType string, exp time.Time, salt []byte) (string, error) {
claims := JWTClaims{
UserID: userID,
Role: role,
TokenType: tokenType,
RegisteredClaims: golangJWT.RegisteredClaims{
// expire in 1 week
ExpiresAt: golangJWT.NewNumericDate(exp),
IssuedAt: golangJWT.NewNumericDate(time.Now()),
},
}

token := golangJWT.NewWithClaims(golangJWT.SigningMethodHS256, claims)
return token.SignedString([]byte(os.Getenv("JWT_SECRET")))
// combine base secret with user's salt
secret := append([]byte(os.Getenv("JWT_SECRET")), salt...)
return token.SignedString(secret)
}
114 changes: 62 additions & 52 deletions backend/internal/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"KonferCA/SPUR/db"

golangJWT "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)

Expand All @@ -18,69 +17,90 @@ func TestJWT(t *testing.T) {

userID := "some-user-id"
role := db.UserRole("user")
exp := time.Now().Add(5 * time.Minute)
salt := []byte("test-salt")

t.Run("generate access token", func(t *testing.T) {
token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp)
t.Run("token salt invalidation", func(t *testing.T) {
// Generate initial salt
initialSalt := []byte("initial-salt")

// Generate tokens with initial salt
accessToken, refreshToken, err := GenerateWithSalt(userID, role, initialSalt)
assert.Nil(t, err)
assert.NotEmpty(t, token)
claims, err := VerifyToken(token)
assert.NotEmpty(t, accessToken)
assert.NotEmpty(t, refreshToken)

// Verify tokens work with initial salt
claims, err := VerifyTokenWithSalt(accessToken, initialSalt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp))
})

t.Run("generate refresh token", func(t *testing.T) {
token, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, exp)
// Change salt (simulating token invalidation)
newSalt := []byte("new-salt")

// Old tokens should fail verification with new salt
_, err = VerifyTokenWithSalt(accessToken, newSalt)
assert.NotNil(t, err, "Token should be invalid with new salt")

// Generate new tokens with new salt
newAccessToken, newRefreshToken, err := GenerateWithSalt(userID, role, newSalt)
assert.Nil(t, err)
assert.NotEmpty(t, token)
claims, err := VerifyToken(token)
assert.NotEmpty(t, newAccessToken)
assert.NotEmpty(t, newRefreshToken)

// New tokens should work with new salt
claims, err = VerifyTokenWithSalt(newAccessToken, newSalt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp))
})

t.Run("generate both refresh and access token", func(t *testing.T) {
a, r, err := Generate(userID, role)
t.Run("two-step verification", func(t *testing.T) {
salt := []byte("test-salt")

// Generate a token
accessToken, _, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, r)
claims, err := VerifyToken(a)

// Step 1: Parse claims without verification
unverifiedClaims, err := ParseUnverifiedClaims(accessToken)
assert.Nil(t, err)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
claims, err = VerifyToken(r)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
assert.Equal(t, userID, unverifiedClaims.UserID)

// Step 2: Verify with salt
verifiedClaims, err := VerifyTokenWithSalt(accessToken, salt)
assert.Nil(t, err)
assert.Equal(t, userID, verifiedClaims.UserID)

// Try to verify with wrong salt
wrongSalt := []byte("wrong-salt")
_, err = VerifyTokenWithSalt(accessToken, wrongSalt)
assert.NotNil(t, err, "Token should be invalid with wrong salt")
})

t.Run("deny token with wrong signature", func(t *testing.T) {
a, r, err := Generate(userID, role)
t.Run("generate access token", func(t *testing.T) {
accessToken, _, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
// change secret
os.Setenv("JWT_SECRET", "changed")
_, err = VerifyToken(a)
assert.NotNil(t, err)
// restore error to nil
err = nil
// test the other token
_, err = VerifyToken(r)
assert.NotNil(t, err)
// restore secret
os.Setenv("JWT_SECRET", "secret")
assert.NotEmpty(t, accessToken)
claims, err := VerifyTokenWithSalt(accessToken, salt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
})

t.Run("deny expired token", func(t *testing.T) {
exp = time.Now().Add(-1 * 5 * time.Minute)
token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp)
t.Run("generate refresh token", func(t *testing.T) {
_, refreshToken, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
_, err = VerifyToken(token)
assert.NotNil(t, err)
assert.NotEmpty(t, refreshToken)
claims, err := VerifyTokenWithSalt(refreshToken, salt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
})

t.Run("generate verify email token", func(t *testing.T) {
t.Run("verify email token", func(t *testing.T) {
email := "[email protected]"
id := "some-id"
exp := time.Now().Add(time.Second * 5)
Expand All @@ -102,14 +122,4 @@ func TestJWT(t *testing.T) {
_, err = VerifyEmailToken(token)
assert.NotNil(t, err)
})

t.Run("deny expired verify email token", func(t *testing.T) {
email := "[email protected]"
id := "some-id"
exp := time.Now().Add(-1 * 5 * time.Second)
token, err := GenerateVerifyEmailToken(email, id, exp)
assert.Nil(t, err)
_, err = VerifyEmailToken(token)
assert.NotNil(t, err)
})
}
Loading

0 comments on commit 8c5a785

Please sign in to comment.