Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/82/invalidate-jwt #91

Merged
merged 2 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- +goose Up
-- +goose StatementBegin
CREATE EXTENSION IF NOT EXISTS pgcrypto;

ALTER TABLE users ADD COLUMN token_salt BYTEA;
AmirAgassi marked this conversation as resolved.
Show resolved Hide resolved

-- backfill existing users with random salt
UPDATE users SET token_salt = gen_random_bytes(32);

-- make token_salt non-nullable
ALTER TABLE users ALTER COLUMN token_salt SET NOT NULL;
-- +goose StatementEnd

-- +goose Down
-- +goose StatementBegin
ALTER TABLE users DROP COLUMN token_salt;

-- +goose StatementEnd
13 changes: 11 additions & 2 deletions backend/.sqlc/queries/users.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
INSERT INTO users (
email,
password_hash,
role
role,
token_salt
) VALUES (
$1, $2, $3
$1, $2, $3, gen_random_bytes(32)
) RETURNING *;

-- name: GetUserByEmail :one
Expand All @@ -18,3 +19,11 @@ WHERE id = $1 LIMIT 1;
-- name: UpdateUserEmailVerifiedStatus :exec
UPDATE users SET email_verified = $1
WHERE id = $2;

-- name: UpdateUserTokenSalt :exec
UPDATE users SET token_salt = gen_random_bytes(32)
WHERE id = $1;

-- name: GetUserTokenSalt :one
SELECT token_salt FROM users
WHERE id = $1;
1 change: 1 addition & 0 deletions backend/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 31 additions & 5 deletions backend/db/users.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/net v0.28.0 // indirect
Expand Down
5 changes: 5 additions & 0 deletions backend/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
Expand Down
16 changes: 8 additions & 8 deletions backend/internal/jwt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ const (
)

// Generates JWT tokens for the given user. Returns the access token, refresh token and error (nil if no error)
func Generate(userID string, role db.UserRole) (string, string, error) {
accessToken, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute))
func GenerateWithSalt(userID string, role db.UserRole, salt []byte) (string, string, error) {
accessToken, err := generateTokenWithSalt(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute), salt)
if err != nil {
return "", "", err
}

refreshToken, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour))
refreshToken, err := generateTokenWithSalt(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour), salt)
if err != nil {
return "", "", err
}
Expand All @@ -40,7 +40,6 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e
Email: email,
TokenType: VERIFY_EMAIL_TOKEN_TYPE,
RegisteredClaims: golangJWT.RegisteredClaims{
// expire in 1 week
ExpiresAt: golangJWT.NewNumericDate(exp),
IssuedAt: golangJWT.NewNumericDate(time.Now()),
ID: id,
Expand All @@ -51,19 +50,20 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e
return token.SignedString([]byte(os.Getenv("JWT_SECRET_VERIFY_EMAIL")))
}

// Private helper method to generate a token.
func generateToken(userID string, role db.UserRole, tokenType string, exp time.Time) (string, error) {
// Private helper method to generate a token with user's salt
func generateTokenWithSalt(userID string, role db.UserRole, tokenType string, exp time.Time, salt []byte) (string, error) {
claims := JWTClaims{
UserID: userID,
Role: role,
TokenType: tokenType,
RegisteredClaims: golangJWT.RegisteredClaims{
// expire in 1 week
ExpiresAt: golangJWT.NewNumericDate(exp),
IssuedAt: golangJWT.NewNumericDate(time.Now()),
},
}

token := golangJWT.NewWithClaims(golangJWT.SigningMethodHS256, claims)
return token.SignedString([]byte(os.Getenv("JWT_SECRET")))
// combine base secret with user's salt
secret := append([]byte(os.Getenv("JWT_SECRET")), salt...)
return token.SignedString(secret)
}
114 changes: 62 additions & 52 deletions backend/internal/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"KonferCA/SPUR/db"

golangJWT "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)

Expand All @@ -18,69 +17,90 @@ func TestJWT(t *testing.T) {

userID := "some-user-id"
role := db.UserRole("user")
exp := time.Now().Add(5 * time.Minute)
salt := []byte("test-salt")

t.Run("generate access token", func(t *testing.T) {
token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp)
t.Run("token salt invalidation", func(t *testing.T) {
// Generate initial salt
initialSalt := []byte("initial-salt")

// Generate tokens with initial salt
accessToken, refreshToken, err := GenerateWithSalt(userID, role, initialSalt)
assert.Nil(t, err)
assert.NotEmpty(t, token)
claims, err := VerifyToken(token)
assert.NotEmpty(t, accessToken)
assert.NotEmpty(t, refreshToken)

// Verify tokens work with initial salt
claims, err := VerifyTokenWithSalt(accessToken, initialSalt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp))
})

t.Run("generate refresh token", func(t *testing.T) {
token, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, exp)
// Change salt (simulating token invalidation)
newSalt := []byte("new-salt")

// Old tokens should fail verification with new salt
_, err = VerifyTokenWithSalt(accessToken, newSalt)
assert.NotNil(t, err, "Token should be invalid with new salt")

// Generate new tokens with new salt
newAccessToken, newRefreshToken, err := GenerateWithSalt(userID, role, newSalt)
assert.Nil(t, err)
assert.NotEmpty(t, token)
claims, err := VerifyToken(token)
assert.NotEmpty(t, newAccessToken)
assert.NotEmpty(t, newRefreshToken)

// New tokens should work with new salt
claims, err = VerifyTokenWithSalt(newAccessToken, newSalt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp))
})

t.Run("generate both refresh and access token", func(t *testing.T) {
a, r, err := Generate(userID, role)
t.Run("two-step verification", func(t *testing.T) {
salt := []byte("test-salt")

// Generate a token
accessToken, _, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, r)
claims, err := VerifyToken(a)

// Step 1: Parse claims without verification
unverifiedClaims, err := ParseUnverifiedClaims(accessToken)
assert.Nil(t, err)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
claims, err = VerifyToken(r)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
assert.Equal(t, userID, unverifiedClaims.UserID)

// Step 2: Verify with salt
verifiedClaims, err := VerifyTokenWithSalt(accessToken, salt)
assert.Nil(t, err)
assert.Equal(t, userID, verifiedClaims.UserID)

// Try to verify with wrong salt
wrongSalt := []byte("wrong-salt")
_, err = VerifyTokenWithSalt(accessToken, wrongSalt)
assert.NotNil(t, err, "Token should be invalid with wrong salt")
})

t.Run("deny token with wrong signature", func(t *testing.T) {
a, r, err := Generate(userID, role)
t.Run("generate access token", func(t *testing.T) {
accessToken, _, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
// change secret
os.Setenv("JWT_SECRET", "changed")
_, err = VerifyToken(a)
assert.NotNil(t, err)
// restore error to nil
err = nil
// test the other token
_, err = VerifyToken(r)
assert.NotNil(t, err)
// restore secret
os.Setenv("JWT_SECRET", "secret")
assert.NotEmpty(t, accessToken)
claims, err := VerifyTokenWithSalt(accessToken, salt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE)
})

t.Run("deny expired token", func(t *testing.T) {
exp = time.Now().Add(-1 * 5 * time.Minute)
token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp)
t.Run("generate refresh token", func(t *testing.T) {
_, refreshToken, err := GenerateWithSalt(userID, role, salt)
assert.Nil(t, err)
_, err = VerifyToken(token)
assert.NotNil(t, err)
assert.NotEmpty(t, refreshToken)
claims, err := VerifyTokenWithSalt(refreshToken, salt)
assert.Nil(t, err)
assert.Equal(t, claims.UserID, userID)
assert.Equal(t, claims.Role, role)
assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE)
})

t.Run("generate verify email token", func(t *testing.T) {
t.Run("verify email token", func(t *testing.T) {
email := "[email protected]"
id := "some-id"
exp := time.Now().Add(time.Second * 5)
Expand All @@ -102,14 +122,4 @@ func TestJWT(t *testing.T) {
_, err = VerifyEmailToken(token)
assert.NotNil(t, err)
})

t.Run("deny expired verify email token", func(t *testing.T) {
email := "[email protected]"
id := "some-id"
exp := time.Now().Add(-1 * 5 * time.Second)
token, err := GenerateVerifyEmailToken(email, id, exp)
assert.Nil(t, err)
_, err = VerifyEmailToken(token)
assert.NotNil(t, err)
})
}
Loading
Loading