Skip to content

Commit

Permalink
Add JWT middleware test
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirAgassi committed Dec 17, 2024
1 parent acce0f7 commit db86c48
Showing 1 changed file with 136 additions and 131 deletions.
267 changes: 136 additions & 131 deletions backend/internal/tests/jwt_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,134 +1,139 @@
package tests

// import (
// "context"
// "fmt"
// "net/http"
// "net/http/httptest"
// "os"
// "testing"
//
// "KonferCA/SPUR/db"
// "KonferCA/SPUR/internal/jwt"
// "github.com/google/uuid"
// "github.com/jackc/pgx/v5/pgxpool"
// "github.com/labstack/echo/v4"
// "github.com/stretchr/testify/assert"
// )
//
// func TestProtectAPIForAccessToken(t *testing.T) {
// // setup test environment
// os.Setenv("JWT_SECRET", "secret")
//
// // Connect to test database
// ctx := context.Background()
// dbURL := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
// "postgres",
// "postgres",
// "localhost",
// "5432",
// "postgres",
// "disable",
// )
//
// dbPool, err := pgxpool.New(ctx, dbURL)
// if err != nil {
// t.Fatalf("failed to connect to database: %v", err)
// }
// defer dbPool.Close()
//
// // Clean up any existing test user
// _, err = dbPool.Exec(ctx, "DELETE FROM users WHERE email = $1", "[email protected]")
// if err != nil {
// t.Fatalf("failed to clean up test user: %v", err)
// }
//
// // Create a test user directly in the database
// userID := uuid.New().String()
// _, err = dbPool.Exec(ctx, `
// INSERT INTO users (id, email, password_hash, first_name, last_name, role, token_salt)
// VALUES ($1, $2, $3, $4, $5, 'startup_owner', gen_random_bytes(32))
// `, userID, "[email protected]", "hashedpassword", "Test", "User")
// if err != nil {
// t.Fatalf("failed to create test user: %v", err)
// }
//
// // Create Echo instance with the DB connection
// e := echo.New()
// queries := db.New(dbPool)
// middleware := ProtectAPI(jwt.ACCESS_TOKEN_TYPE, queries)
// e.Use(middleware)
//
// e.GET("/protected", func(c echo.Context) error {
// return c.String(http.StatusOK, "protected resource")
// })
//
// // Get test user data from the database
// user, err := queries.GetUserByEmail(ctx, "[email protected]")
// if err != nil {
// t.Fatalf("failed to get test user: %v", err)
// }
//
// // Get the user's salt
// var salt []byte
// err = dbPool.QueryRow(ctx, "SELECT token_salt FROM users WHERE id = $1", user.ID).Scan(&salt)
// if err != nil {
// t.Fatalf("failed to get user salt: %v", err)
// }
//
// // generate valid tokens using the actual salt
// accessToken, refreshToken, err := jwt.GenerateWithSalt(user.ID, user.Role, salt)
// assert.Nil(t, err)
//
// tests := []struct {
// name string
// expectedCode int
// token string
// }{
// {
// name: "Accept access token",
// expectedCode: http.StatusOK,
// token: accessToken,
// },
// {
// name: "Reject refresh token",
// expectedCode: http.StatusUnauthorized,
// token: refreshToken,
// },
// {
// name: "Reject invalid token format",
// expectedCode: http.StatusUnauthorized,
// token: "invalid-token",
// },
// {
// name: "Reject empty token",
// expectedCode: http.StatusUnauthorized,
// token: "",
// },
// {
// name: "Reject token with invalid signature",
// expectedCode: http.StatusUnauthorized,
// token: accessToken + "tampered",
// },
// }
//
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// req := httptest.NewRequest(http.MethodGet, "/protected", nil)
// rec := httptest.NewRecorder()
// if test.token != "" {
// req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", test.token))
// }
// e.ServeHTTP(rec, req)
// assert.Equal(t, test.expectedCode, rec.Code)
// })
// }
//
// // Clean up test user after test
// _, err = dbPool.Exec(ctx, "DELETE FROM users WHERE email = $1", "[email protected]")
// if err != nil {
// t.Fatalf("failed to clean up test user: %v", err)
// }
// }
//
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"

"KonferCA/SPUR/db"
"KonferCA/SPUR/internal/jwt"
"KonferCA/SPUR/internal/middleware"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestJWTMiddleware(t *testing.T) {
// setup test environment
setupEnv()
os.Setenv("JWT_SECRET", "secret")

// Connect to test database
ctx := context.Background()
dbURL := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_NAME"),
os.Getenv("DB_SSLMODE"),
)

dbPool, err := pgxpool.New(ctx, dbURL)
if err != nil {
t.Fatalf("failed to connect to database: %v", err)
}
defer dbPool.Close()

// Clean up any existing test user
_, err = dbPool.Exec(ctx, "DELETE FROM users WHERE email = $1", "[email protected]")
if err != nil {
t.Fatalf("failed to clean up test user: %v", err)
}

// Create a test user directly in the database
userID := uuid.New()
_, err = dbPool.Exec(ctx, `
INSERT INTO users (
id,
email,
password_hash,
role,
email_verified,
token_salt,
first_name,
last_name
)
VALUES ($1, $2, $3, $4, $5, gen_random_bytes(32), $6, $7)
`, userID, "[email protected]", "hashedpassword", db.UserRoleStartupOwner, true, "Test", "User")
if err != nil {
t.Fatalf("failed to create test user: %v", err)
}

// Create Echo instance with the middleware
e := echo.New()
middlewareConfig := middleware.AuthConfig{
AcceptTokenType: jwt.ACCESS_TOKEN_TYPE,
AcceptUserRoles: []db.UserRole{db.UserRoleStartupOwner},
}
e.Use(middleware.Auth(middlewareConfig, dbPool))

e.GET("/protected", func(c echo.Context) error {
return c.String(http.StatusOK, "protected resource")
})

// Get user's salt from database
var salt []byte
err = dbPool.QueryRow(ctx, "SELECT token_salt FROM users WHERE id = $1", userID).Scan(&salt)
if err != nil {
t.Fatalf("failed to get user salt: %v", err)
}

// generate valid tokens using the actual salt
accessToken, refreshToken, err := jwt.GenerateWithSalt(userID.String(), db.UserRoleStartupOwner, salt)
assert.Nil(t, err)

tests := []struct {
name string
expectedCode int
token string
}{
{
name: "Accept access token",
expectedCode: http.StatusOK,
token: accessToken,
},
{
name: "Reject refresh token",
expectedCode: http.StatusUnauthorized,
token: refreshToken,
},
{
name: "Reject invalid token format",
expectedCode: http.StatusUnauthorized,
token: "invalid-token",
},
{
name: "Reject empty token",
expectedCode: http.StatusUnauthorized,
token: "",
},
{
name: "Reject token with invalid signature",
expectedCode: http.StatusUnauthorized,
token: accessToken + "tampered",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rec := httptest.NewRecorder()
if test.token != "" {
req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", test.token))
}
e.ServeHTTP(rec, req)
assert.Equal(t, test.expectedCode, rec.Code)
})
}

// Clean up test user after test
_, err = dbPool.Exec(ctx, "DELETE FROM users WHERE email = $1", "[email protected]")
if err != nil {
t.Fatalf("failed to clean up test user: %v", err)
}
}

0 comments on commit db86c48

Please sign in to comment.