diff --git a/backend/.sqlc/migrations/20241219000000_add_token_salt_to_users.sql b/backend/.sqlc/migrations/20241219000000_add_token_salt_to_users.sql new file mode 100644 index 00000000..012f7e25 --- /dev/null +++ b/backend/.sqlc/migrations/20241219000000_add_token_salt_to_users.sql @@ -0,0 +1,19 @@ +-- +goose Up +-- +goose StatementBegin +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +ALTER TABLE users ADD COLUMN token_salt BYTEA; + +-- backfill existing users with random salt +UPDATE users SET token_salt = gen_random_bytes(32); + +-- make token_salt non-nullable and unique +ALTER TABLE users ALTER COLUMN token_salt SET NOT NULL; +ALTER TABLE users ADD CONSTRAINT users_token_salt_key UNIQUE (token_salt); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE users DROP COLUMN token_salt; + +-- +goose StatementEnd \ No newline at end of file diff --git a/backend/.sqlc/queries/users.sql b/backend/.sqlc/queries/users.sql index cbbdc978..329c1992 100644 --- a/backend/.sqlc/queries/users.sql +++ b/backend/.sqlc/queries/users.sql @@ -2,9 +2,10 @@ INSERT INTO users ( email, password_hash, - role + role, + token_salt ) VALUES ( - $1, $2, $3 + $1, $2, $3, gen_random_bytes(32) ) RETURNING *; -- name: GetUserByEmail :one @@ -18,3 +19,11 @@ WHERE id = $1 LIMIT 1; -- name: UpdateUserEmailVerifiedStatus :exec UPDATE users SET email_verified = $1 WHERE id = $2; + +-- name: UpdateUserTokenSalt :exec +UPDATE users SET token_salt = gen_random_bytes(32) +WHERE id = $1; + +-- name: GetUserTokenSalt :one +SELECT token_salt FROM users +WHERE id = $1; diff --git a/backend/db/models.go b/backend/db/models.go index 6d619168..0e317fbc 100644 --- a/backend/db/models.go +++ b/backend/db/models.go @@ -235,6 +235,7 @@ type User struct { UpdatedAt pgtype.Timestamp Role UserRole EmailVerified bool + TokenSalt []byte } type VerifyEmailToken struct { diff --git a/backend/db/users.sql.go b/backend/db/users.sql.go index 29858623..535fdbb2 100644 --- a/backend/db/users.sql.go +++ b/backend/db/users.sql.go @@ -13,10 +13,11 @@ const createUser = `-- name: CreateUser :one INSERT INTO users ( email, password_hash, - role + role, + token_salt ) VALUES ( - $1, $2, $3 -) RETURNING id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified + $1, $2, $3, gen_random_bytes(32) +) RETURNING id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified, token_salt ` type CreateUserParams struct { @@ -39,12 +40,13 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.UpdatedAt, &i.Role, &i.EmailVerified, + &i.TokenSalt, ) return i, err } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified FROM users +SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified, token_salt FROM users WHERE email = $1 LIMIT 1 ` @@ -62,12 +64,13 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.UpdatedAt, &i.Role, &i.EmailVerified, + &i.TokenSalt, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified FROM users +SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role, email_verified, token_salt FROM users WHERE id = $1 LIMIT 1 ` @@ -85,10 +88,23 @@ func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) { &i.UpdatedAt, &i.Role, &i.EmailVerified, + &i.TokenSalt, ) return i, err } +const getUserTokenSalt = `-- name: GetUserTokenSalt :one +SELECT token_salt FROM users +WHERE id = $1 +` + +func (q *Queries) GetUserTokenSalt(ctx context.Context, id string) ([]byte, error) { + row := q.db.QueryRow(ctx, getUserTokenSalt, id) + var token_salt []byte + err := row.Scan(&token_salt) + return token_salt, err +} + const updateUserEmailVerifiedStatus = `-- name: UpdateUserEmailVerifiedStatus :exec UPDATE users SET email_verified = $1 WHERE id = $2 @@ -103,3 +119,13 @@ func (q *Queries) UpdateUserEmailVerifiedStatus(ctx context.Context, arg UpdateU _, err := q.db.Exec(ctx, updateUserEmailVerifiedStatus, arg.EmailVerified, arg.ID) return err } + +const updateUserTokenSalt = `-- name: UpdateUserTokenSalt :exec +UPDATE users SET token_salt = gen_random_bytes(32) +WHERE id = $1 +` + +func (q *Queries) UpdateUserTokenSalt(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, updateUserTokenSalt, id) + return err +} diff --git a/backend/go.mod b/backend/go.mod index 993307c1..6d74f78a 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -50,6 +50,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect golang.org/x/net v0.28.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index df3a635f..01c655cb 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -97,8 +97,13 @@ github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/backend/internal/jwt/generate.go b/backend/internal/jwt/generate.go index d4976166..91e475ae 100644 --- a/backend/internal/jwt/generate.go +++ b/backend/internal/jwt/generate.go @@ -16,13 +16,13 @@ const ( ) // Generates JWT tokens for the given user. Returns the access token, refresh token and error (nil if no error) -func Generate(userID string, role db.UserRole) (string, string, error) { - accessToken, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute)) +func GenerateWithSalt(userID string, role db.UserRole, salt []byte) (string, string, error) { + accessToken, err := generateTokenWithSalt(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute), salt) if err != nil { return "", "", err } - refreshToken, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour)) + refreshToken, err := generateTokenWithSalt(userID, role, REFRESH_TOKEN_TYPE, time.Now().Add(24*7*time.Hour), salt) if err != nil { return "", "", err } @@ -40,7 +40,6 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e Email: email, TokenType: VERIFY_EMAIL_TOKEN_TYPE, RegisteredClaims: golangJWT.RegisteredClaims{ - // expire in 1 week ExpiresAt: golangJWT.NewNumericDate(exp), IssuedAt: golangJWT.NewNumericDate(time.Now()), ID: id, @@ -51,19 +50,20 @@ func GenerateVerifyEmailToken(email string, id string, exp time.Time) (string, e return token.SignedString([]byte(os.Getenv("JWT_SECRET_VERIFY_EMAIL"))) } -// Private helper method to generate a token. -func generateToken(userID string, role db.UserRole, tokenType string, exp time.Time) (string, error) { +// Private helper method to generate a token with user's salt +func generateTokenWithSalt(userID string, role db.UserRole, tokenType string, exp time.Time, salt []byte) (string, error) { claims := JWTClaims{ UserID: userID, Role: role, TokenType: tokenType, RegisteredClaims: golangJWT.RegisteredClaims{ - // expire in 1 week ExpiresAt: golangJWT.NewNumericDate(exp), IssuedAt: golangJWT.NewNumericDate(time.Now()), }, } token := golangJWT.NewWithClaims(golangJWT.SigningMethodHS256, claims) - return token.SignedString([]byte(os.Getenv("JWT_SECRET"))) + // combine base secret with user's salt + secret := append([]byte(os.Getenv("JWT_SECRET")), salt...) + return token.SignedString(secret) } diff --git a/backend/internal/jwt/jwt_test.go b/backend/internal/jwt/jwt_test.go index ae1664db..5b5a2630 100644 --- a/backend/internal/jwt/jwt_test.go +++ b/backend/internal/jwt/jwt_test.go @@ -7,7 +7,6 @@ import ( "KonferCA/SPUR/db" - golangJWT "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" ) @@ -18,69 +17,90 @@ func TestJWT(t *testing.T) { userID := "some-user-id" role := db.UserRole("user") - exp := time.Now().Add(5 * time.Minute) + salt := []byte("test-salt") - t.Run("generate access token", func(t *testing.T) { - token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp) + t.Run("token salt invalidation", func(t *testing.T) { + // Generate initial salt + initialSalt := []byte("initial-salt") + + // Generate tokens with initial salt + accessToken, refreshToken, err := GenerateWithSalt(userID, role, initialSalt) assert.Nil(t, err) - assert.NotEmpty(t, token) - claims, err := VerifyToken(token) + assert.NotEmpty(t, accessToken) + assert.NotEmpty(t, refreshToken) + + // Verify tokens work with initial salt + claims, err := VerifyTokenWithSalt(accessToken, initialSalt) assert.Nil(t, err) assert.Equal(t, claims.UserID, userID) assert.Equal(t, claims.Role, role) assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE) - assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp)) - }) - t.Run("generate refresh token", func(t *testing.T) { - token, err := generateToken(userID, role, REFRESH_TOKEN_TYPE, exp) + // Change salt (simulating token invalidation) + newSalt := []byte("new-salt") + + // Old tokens should fail verification with new salt + _, err = VerifyTokenWithSalt(accessToken, newSalt) + assert.NotNil(t, err, "Token should be invalid with new salt") + + // Generate new tokens with new salt + newAccessToken, newRefreshToken, err := GenerateWithSalt(userID, role, newSalt) assert.Nil(t, err) - assert.NotEmpty(t, token) - claims, err := VerifyToken(token) + assert.NotEmpty(t, newAccessToken) + assert.NotEmpty(t, newRefreshToken) + + // New tokens should work with new salt + claims, err = VerifyTokenWithSalt(newAccessToken, newSalt) assert.Nil(t, err) assert.Equal(t, claims.UserID, userID) - assert.Equal(t, claims.Role, role) - assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE) - assert.Equal(t, claims.RegisteredClaims.ExpiresAt, golangJWT.NewNumericDate(exp)) }) - t.Run("generate both refresh and access token", func(t *testing.T) { - a, r, err := Generate(userID, role) + t.Run("two-step verification", func(t *testing.T) { + salt := []byte("test-salt") + + // Generate a token + accessToken, _, err := GenerateWithSalt(userID, role, salt) assert.Nil(t, err) - assert.NotEmpty(t, a) - assert.NotEmpty(t, r) - claims, err := VerifyToken(a) + + // Step 1: Parse claims without verification + unverifiedClaims, err := ParseUnverifiedClaims(accessToken) assert.Nil(t, err) - assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE) - claims, err = VerifyToken(r) - assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE) + assert.Equal(t, userID, unverifiedClaims.UserID) + + // Step 2: Verify with salt + verifiedClaims, err := VerifyTokenWithSalt(accessToken, salt) + assert.Nil(t, err) + assert.Equal(t, userID, verifiedClaims.UserID) + + // Try to verify with wrong salt + wrongSalt := []byte("wrong-salt") + _, err = VerifyTokenWithSalt(accessToken, wrongSalt) + assert.NotNil(t, err, "Token should be invalid with wrong salt") }) - t.Run("deny token with wrong signature", func(t *testing.T) { - a, r, err := Generate(userID, role) + t.Run("generate access token", func(t *testing.T) { + accessToken, _, err := GenerateWithSalt(userID, role, salt) assert.Nil(t, err) - // change secret - os.Setenv("JWT_SECRET", "changed") - _, err = VerifyToken(a) - assert.NotNil(t, err) - // restore error to nil - err = nil - // test the other token - _, err = VerifyToken(r) - assert.NotNil(t, err) - // restore secret - os.Setenv("JWT_SECRET", "secret") + assert.NotEmpty(t, accessToken) + claims, err := VerifyTokenWithSalt(accessToken, salt) + assert.Nil(t, err) + assert.Equal(t, claims.UserID, userID) + assert.Equal(t, claims.Role, role) + assert.Equal(t, claims.TokenType, ACCESS_TOKEN_TYPE) }) - t.Run("deny expired token", func(t *testing.T) { - exp = time.Now().Add(-1 * 5 * time.Minute) - token, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, exp) + t.Run("generate refresh token", func(t *testing.T) { + _, refreshToken, err := GenerateWithSalt(userID, role, salt) assert.Nil(t, err) - _, err = VerifyToken(token) - assert.NotNil(t, err) + assert.NotEmpty(t, refreshToken) + claims, err := VerifyTokenWithSalt(refreshToken, salt) + assert.Nil(t, err) + assert.Equal(t, claims.UserID, userID) + assert.Equal(t, claims.Role, role) + assert.Equal(t, claims.TokenType, REFRESH_TOKEN_TYPE) }) - t.Run("generate verify email token", func(t *testing.T) { + t.Run("verify email token", func(t *testing.T) { email := "test@mail.com" id := "some-id" exp := time.Now().Add(time.Second * 5) @@ -102,14 +122,4 @@ func TestJWT(t *testing.T) { _, err = VerifyEmailToken(token) assert.NotNil(t, err) }) - - t.Run("deny expired verify email token", func(t *testing.T) { - email := "test@mail.com" - id := "some-id" - exp := time.Now().Add(-1 * 5 * time.Second) - token, err := GenerateVerifyEmailToken(email, id, exp) - assert.Nil(t, err) - _, err = VerifyEmailToken(token) - assert.NotNil(t, err) - }) } diff --git a/backend/internal/jwt/verify.go b/backend/internal/jwt/verify.go index ec43078a..95e013cc 100644 --- a/backend/internal/jwt/verify.go +++ b/backend/internal/jwt/verify.go @@ -7,6 +7,37 @@ import ( golangJWT "github.com/golang-jwt/jwt/v5" ) +// ParseUnverifiedClaims parses the token without verifying the signature +// to extract the claims. This is used as the first step in the two-step +// verification process. +func ParseUnverifiedClaims(token string) (*JWTClaims, error) { + // Create parser that skips claims validation + parser := golangJWT.NewParser(golangJWT.WithoutClaimsValidation()) + claims := &JWTClaims{} + _, _, err := parser.ParseUnverified(token, claims) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + return claims, nil +} + +// VerifyTokenWithSalt verifies the token using the user's salt +func VerifyTokenWithSalt(token string, salt []byte) (*JWTClaims, error) { + claims := &JWTClaims{} + _, err := golangJWT.ParseWithClaims(token, claims, func(t *golangJWT.Token) (interface{}, error) { + if _, ok := t.Method.(*golangJWT.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + // combine base secret with user's salt + secret := append([]byte(os.Getenv("JWT_SECRET")), salt...) + return secret, nil + }) + if err != nil { + return nil, err + } + return claims, nil +} + // Verifies the given token. If successful, then it will // return the JWTClaims of the token, otherwise an error is returned. func VerifyToken(token string) (*JWTClaims, error) { diff --git a/backend/internal/middleware/jwt.go b/backend/internal/middleware/jwt.go index 2ea676ad..3c40009d 100644 --- a/backend/internal/middleware/jwt.go +++ b/backend/internal/middleware/jwt.go @@ -4,6 +4,7 @@ import ( "net/http" "strings" + "KonferCA/SPUR/db" "KonferCA/SPUR/internal/jwt" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" @@ -13,7 +14,7 @@ const JWT_CLAIMS = "MIDDLEWARE_JWT_CLAIMS" // Middleware that validate the "Authorization" header for a Bearer token. // Matches the received token with the accepted token type. -func ProtectAPI(acceptTokenType string) echo.MiddlewareFunc { +func ProtectAPI(acceptTokenType string, queries *db.Queries) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { authorization := c.Request().Header.Get(echo.HeaderAuthorization) @@ -21,16 +22,34 @@ func ProtectAPI(acceptTokenType string) echo.MiddlewareFunc { if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { return echo.NewHTTPError(http.StatusUnauthorized, "Invalid authorization header. Only accept Bearer token.") } - claims, err := jwt.VerifyToken(parts[1]) + + // Step 1: Parse claims without verification to get userID + unverifiedClaims, err := jwt.ParseUnverifiedClaims(parts[1]) + if err != nil { + log.Error().Err(err).Msg("Failed to parse JWT claims") + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid token format") + } + + // Step 2: Get user's salt + salt, err := queries.GetUserTokenSalt(c.Request().Context(), unverifiedClaims.UserID) + if err != nil { + log.Error().Err(err).Msg("Failed to get user's token salt") + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid token") + } + + // Step 3: Verify token with salt + claims, err := jwt.VerifyTokenWithSalt(parts[1], salt) if err != nil { log.Error().Err(err).Msg("JWT verification error") - return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired token.") + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired token") } - // match token type + + // Step 4: Verify token type if acceptTokenType != claims.TokenType { - log.Error().Str("accept", acceptTokenType).Str("received", claims.TokenType).Msg("Invalid token type.") - return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired token.") + log.Error().Str("accept", acceptTokenType).Str("received", claims.TokenType).Msg("Invalid token type") + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid token type") } + c.Set(JWT_CLAIMS, claims) return next(c) } diff --git a/backend/internal/middleware/jwt_test.go b/backend/internal/middleware/jwt_test.go index 1bf44bd0..70239c89 100644 --- a/backend/internal/middleware/jwt_test.go +++ b/backend/internal/middleware/jwt_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -9,15 +10,55 @@ import ( "KonferCA/SPUR/db" "KonferCA/SPUR/internal/jwt" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) +// MockDBTX implements db.DBTX interface +type MockDBTX struct { + mock.Mock +} + +func (m *MockDBTX) Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) { + return pgconn.CommandTag{}, nil +} + +func (m *MockDBTX) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + return nil, nil +} + +func (m *MockDBTX) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + // For GetUserTokenSalt, we'll return a mock row that returns our test salt + return &mockRow{salt: []byte("test-salt")} +} + +// mockRow implements pgx.Row for our test +type mockRow struct { + salt []byte +} + +func (m *mockRow) Scan(dest ...interface{}) error { + if len(dest) > 0 { + if p, ok := dest[0].(*[]byte); ok { + *p = m.salt + return nil + } + } + return fmt.Errorf("unexpected scan") +} + func TestProtectAPIForAccessToken(t *testing.T) { os.Setenv("JWT_SECRET", "secret") e := echo.New() + mockDB := &MockDBTX{} + queries := db.New(mockDB) - e.Use(ProtectAPI(jwt.ACCESS_TOKEN_TYPE)) + // Create a middleware instance with the mock + middleware := ProtectAPI(jwt.ACCESS_TOKEN_TYPE, queries) + e.Use(middleware) e.GET("/protected", func(c echo.Context) error { return c.String(http.StatusOK, "protected resource") @@ -26,7 +67,8 @@ func TestProtectAPIForAccessToken(t *testing.T) { // generate valid tokens userID := "user-id" role := db.UserRole("user-role") - accessToken, refreshToken, err := jwt.Generate(userID, role) + salt := []byte("test-salt") + accessToken, refreshToken, err := jwt.GenerateWithSalt(userID, role, salt) assert.Nil(t, err) tests := []struct { @@ -55,48 +97,17 @@ func TestProtectAPIForAccessToken(t *testing.T) { assert.Equal(t, test.expectedCode, rec.Code) }) } - - // change jwt secret and generate new tokens - os.Setenv("JWT_SECRET", "another-secret") - accessToken, refreshToken, err = jwt.Generate(userID, role) - assert.Nil(t, err) - - // reset secret - os.Setenv("JWT_SECRET", "secret") - - tests = []struct { - name string - expectedCode int - token string - }{ - { - name: "Reject access token signed with wrong secret", - expectedCode: http.StatusUnauthorized, - token: accessToken, - }, - { - name: "Reject refresh token signed with wrong secret", - expectedCode: http.StatusUnauthorized, - token: refreshToken, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/protected", nil) - rec := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", test.token)) - e.ServeHTTP(rec, req) - assert.Equal(t, test.expectedCode, rec.Code) - }) - } } func TestProtectAPIForRefreshToken(t *testing.T) { os.Setenv("JWT_SECRET", "secret") e := echo.New() + mockDB := &MockDBTX{} + queries := db.New(mockDB) - e.Use(ProtectAPI(jwt.REFRESH_TOKEN_TYPE)) + // Create a middleware instance with the mock + middleware := ProtectAPI(jwt.REFRESH_TOKEN_TYPE, queries) + e.Use(middleware) e.GET("/protected", func(c echo.Context) error { return c.String(http.StatusOK, "protected resource") @@ -105,7 +116,8 @@ func TestProtectAPIForRefreshToken(t *testing.T) { // generate valid tokens userID := "user-id" role := db.UserRole("user-role") - accessToken, refreshToken, err := jwt.Generate(userID, role) + salt := []byte("test-salt") + accessToken, refreshToken, err := jwt.GenerateWithSalt(userID, role, salt) assert.Nil(t, err) tests := []struct { @@ -134,39 +146,4 @@ func TestProtectAPIForRefreshToken(t *testing.T) { assert.Equal(t, test.expectedCode, rec.Code) }) } - - // change jwt secret and generate new tokens - os.Setenv("JWT_SECRET", "another-secret") - accessToken, refreshToken, err = jwt.Generate(userID, role) - assert.Nil(t, err) - - // reset secret - os.Setenv("JWT_SECRET", "secret") - - tests = []struct { - name string - expectedCode int - token string - }{ - { - name: "Reject access token signed with wrong secret", - expectedCode: http.StatusUnauthorized, - token: accessToken, - }, - { - name: "Reject refresh token signed with wrong secret", - expectedCode: http.StatusUnauthorized, - token: refreshToken, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/protected", nil) - rec := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", test.token)) - e.ServeHTTP(rec, req) - assert.Equal(t, test.expectedCode, rec.Code) - }) - } } diff --git a/backend/internal/server/auth.go b/backend/internal/server/auth.go index 4313cb90..ab5ab081 100644 --- a/backend/internal/server/auth.go +++ b/backend/internal/server/auth.go @@ -59,7 +59,13 @@ func (s *Server) handleSignup(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "failed to create user") } - accessToken, refreshToken, err := jwt.Generate(user.ID, user.Role) + // Get user's token salt + salt, err := s.queries.GetUserTokenSalt(ctx, user.ID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user's token salt") + } + + accessToken, refreshToken, err := jwt.GenerateWithSalt(user.ID, user.Role, salt) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate token") } @@ -110,9 +116,6 @@ func (s *Server) handleSignup(c echo.Context) error { func (s *Server) handleSignin(c echo.Context) error { req, ok := c.Get(mw.REQUEST_BODY_KEY).(*SigninRequest) if !ok { - // no bueno... - // should never really reach this state since the validator should reject - // the request body if it is not a proper SigninRequest type return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } @@ -126,7 +129,13 @@ func (s *Server) handleSignin(c echo.Context) error { return echo.NewHTTPError(http.StatusUnauthorized, "invalid credentials") } - accessToken, refreshToken, err := jwt.Generate(user.ID, user.Role) + // Get user's token salt + salt, err := s.queries.GetUserTokenSalt(ctx, user.ID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user's token salt") + } + + accessToken, refreshToken, err := jwt.GenerateWithSalt(user.ID, user.Role, salt) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate token") } diff --git a/backend/internal/server/project_comment_test.go b/backend/internal/server/project_comment_test.go index 26e51bf9..8529981c 100644 --- a/backend/internal/server/project_comment_test.go +++ b/backend/internal/server/project_comment_test.go @@ -51,8 +51,8 @@ func TestProjectCommentEndpoints(t *testing.T) { // Create a test user directly in the database userID := uuid.New().String() _, err = s.DBPool.Exec(ctx, ` - INSERT INTO users (id, email, password_hash, first_name, last_name, role) - VALUES ($1, $2, $3, $4, $5, 'startup_owner') + INSERT INTO users (id, email, password_hash, first_name, last_name, role, token_salt) + VALUES ($1, $2, $3, $4, $5, 'startup_owner', gen_random_bytes(32)) `, userID, "test@example.com", "hashedpassword", "Test", "User") if err != nil { t.Fatalf("failed to create test user: %v", err)