diff --git a/internal/domains/accounts.go b/internal/domains/accounts.go index 3607c7bd5..5a77ed616 100644 --- a/internal/domains/accounts.go +++ b/internal/domains/accounts.go @@ -16,36 +16,35 @@ type AccountsDomain struct { deps *dependencies.Dependencies } -type JWTClaim struct { - jwt.RegisteredClaims - - Account *model.Account -} - -func (d *AccountsDomain) CheckToken(ctx context.Context, userJWT string) (*model.Account, error) { - token, err := jwt.ParseWithClaims(userJWT, &JWTClaim{}, func(token *jwt.Token) (interface{}, error) { +func (d *AccountsDomain) ParseToken(userJWT string) (*model.JWTClaim, error) { + token, err := jwt.ParseWithClaims(userJWT, &model.JWTClaim{}, func(token *jwt.Token) (interface{}, error) { // Validate algorithm if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } - return d.deps.Config.Http.SecretKey, nil }) if err != nil { return nil, errors.Wrap(err, "error parsing token") } - if claims, ok := token.Claims.(*JWTClaim); ok && token.Valid { - if claims.Account.ID > 0 { - return claims.Account, nil - } - if err != nil { - return nil, err - } + if claims, ok := token.Claims.(*model.JWTClaim); ok && token.Valid { + return claims, nil + } + + return nil, fmt.Errorf("error obtaining user from JWT claims") +} +func (d *AccountsDomain) CheckToken(ctx context.Context, userJWT string) (*model.Account, error) { + claims, err := d.ParseToken(userJWT) + if err != nil { + return nil, fmt.Errorf("error parsing token: %w", err) + } + + if claims.Account.ID > 0 { return claims.Account, nil } - return nil, fmt.Errorf("error obtaining user from JWT claims") + return nil, fmt.Errorf("error obtaining user from JWT claims: %w", err) } func (d *AccountsDomain) GetAccountFromCredentials(ctx context.Context, username, password string) (*model.Account, error) { @@ -62,6 +61,10 @@ func (d *AccountsDomain) GetAccountFromCredentials(ctx context.Context, username } func (d *AccountsDomain) CreateTokenForAccount(account *model.Account, expiration time.Time) (string, error) { + if account == nil { + return "", fmt.Errorf("account is nil") + } + claims := jwt.MapClaims{ "account": account.ToDTO(), "exp": expiration.UTC().Unix(), diff --git a/internal/domains/accounts_test.go b/internal/domains/accounts_test.go new file mode 100644 index 000000000..a32bde483 --- /dev/null +++ b/internal/domains/accounts_test.go @@ -0,0 +1,145 @@ +package domains_test + +import ( + "context" + "testing" + "time" + + "github.com/go-shiori/shiori/internal/domains" + "github.com/go-shiori/shiori/internal/model" + "github.com/go-shiori/shiori/internal/testutil" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestAccountsDomainParseToken(t *testing.T) { + ctx := context.TODO() + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) + domain := domains.NewAccountsDomain(deps) + + t.Run("valid token", func(t *testing.T) { + // Create a valid token + token, err := domain.CreateTokenForAccount( + testutil.GetValidAccount(), + time.Now().Add(time.Hour*1), + ) + require.NoError(t, err) + + claims, err := domain.ParseToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, 99, claims.Account.ID) + }) + + t.Run("invalid token", func(t *testing.T) { + claims, err := domain.ParseToken("invalid-token") + require.Error(t, err) + require.Nil(t, claims) + }) +} + +func TestAccountsDomainCheckToken(t *testing.T) { + ctx := context.TODO() + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) + domain := domains.NewAccountsDomain(deps) + + t.Run("valid token", func(t *testing.T) { + // Create a valid token + token, err := domain.CreateTokenForAccount( + testutil.GetValidAccount(), + time.Now().Add(time.Hour*1), + ) + require.NoError(t, err) + + acc, err := domain.CheckToken(ctx, token) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, 99, acc.ID) + }) + + t.Run("expired token", func(t *testing.T) { + // Create an expired token + token, err := domain.CreateTokenForAccount( + testutil.GetValidAccount(), + time.Now().Add(time.Hour*-1), + ) + require.NoError(t, err) + + acc, err := domain.CheckToken(ctx, token) + require.Error(t, err) + require.Nil(t, acc) + }) +} + +func TestAccountsDomainGetAccountFromCredentials(t *testing.T) { + ctx := context.TODO() + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) + domain := domains.NewAccountsDomain(deps) + + require.NoError(t, deps.Database.SaveAccount(ctx, model.Account{ + Username: "test", + Password: "test", + })) + + t.Run("valid credentials", func(t *testing.T) { + acc, err := domain.GetAccountFromCredentials(ctx, "test", "test") + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, "test", acc.Username) + }) + + t.Run("invalid credentials", func(t *testing.T) { + acc, err := domain.GetAccountFromCredentials(ctx, "test", "invalid") + require.Error(t, err) + require.Nil(t, acc) + }) + + t.Run("invalid username", func(t *testing.T) { + acc, err := domain.GetAccountFromCredentials(ctx, "nope", "invalid") + require.Error(t, err) + require.Nil(t, acc) + }) + +} + +func TestAccountsDomainCreateTokenForAccount(t *testing.T) { + ctx := context.TODO() + logger := logrus.New() + _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) + domain := domains.NewAccountsDomain(deps) + + t.Run("valid account", func(t *testing.T) { + token, err := domain.CreateTokenForAccount( + testutil.GetValidAccount(), + time.Now().Add(time.Hour*1), + ) + require.NoError(t, err) + require.NotEmpty(t, token) + }) + + t.Run("nil account", func(t *testing.T) { + token, err := domain.CreateTokenForAccount( + nil, + time.Now().Add(time.Hour*1), + ) + require.Error(t, err) + require.Empty(t, token) + }) + + t.Run("token expiration is valid", func(t *testing.T) { + expiration := time.Now().Add(time.Hour * 9) + token, err := domain.CreateTokenForAccount( + testutil.GetValidAccount(), + expiration, + ) + require.NoError(t, err) + require.NotEmpty(t, token) + claims, err := domain.ParseToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, expiration.Unix(), claims.ExpiresAt.Time.Unix()) + }) +} diff --git a/internal/http/middleware/auth.go b/internal/http/middleware/auth.go index cb3da436f..ee646d59e 100644 --- a/internal/http/middleware/auth.go +++ b/internal/http/middleware/auth.go @@ -23,6 +23,7 @@ func AuthMiddleware(deps *dependencies.Dependencies) gin.HandlerFunc { account, err := deps.Domains.Auth.CheckToken(c, token) if err != nil { + deps.Log.WithError(err).Error("Failed to check token") return } diff --git a/internal/http/middleware/auth_test.go b/internal/http/middleware/auth_test.go index 3b5d3d5ff..fbc9a2e8e 100644 --- a/internal/http/middleware/auth_test.go +++ b/internal/http/middleware/auth_test.go @@ -61,8 +61,8 @@ func TestAuthMiddleware(t *testing.T) { }) t.Run("test authorization header", func(t *testing.T) { - account := model.Account{Username: "shiori"} - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + account := testutil.GetValidAccount() + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -74,8 +74,8 @@ func TestAuthMiddleware(t *testing.T) { }) t.Run("test authorization cookie", func(t *testing.T) { - account := model.Account{Username: "shiori"} - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + account := testutil.GetValidAccount() + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) diff --git a/internal/http/routes/api/v1/auth.go b/internal/http/routes/api/v1/auth.go index 7aa8f2de1..3cfd57493 100644 --- a/internal/http/routes/api/v1/auth.go +++ b/internal/http/routes/api/v1/auth.go @@ -23,7 +23,7 @@ func (r *AuthAPIRoutes) Setup(group *gin.RouterGroup) model.Routes { group.GET("/me", r.meHandler) group.POST("/login", r.loginHandler) group.POST("/refresh", r.refreshHandler) - group.PATCH("/account", r.settingsHandler) + group.PATCH("/account", r.updateHandler) return r } @@ -81,18 +81,20 @@ func (r *AuthAPIRoutes) loginHandler(c *gin.Context) { return } - expiration := time.Now().Add(time.Hour) + expiration := time.Hour if payload.RememberMe { - expiration = time.Now().Add(time.Hour * 24 * 30) + expiration = time.Hour * 24 * 30 } - token, err := r.deps.Domains.Auth.CreateTokenForAccount(account, expiration) + expirationTime := time.Now().Add(expiration) + + token, err := r.deps.Domains.Auth.CreateTokenForAccount(account, expirationTime) if err != nil { response.SendInternalServerError(c) return } - sessionID, err := r.legacyLoginHandler(*account, time.Hour*24*30) + sessionID, err := r.legacyLoginHandler(*account, expiration) if err != nil { r.logger.WithError(err).Error("failed execute legacy login handler") response.SendInternalServerError(c) @@ -102,7 +104,7 @@ func (r *AuthAPIRoutes) loginHandler(c *gin.Context) { response.Send(c, http.StatusOK, loginResponseMessage{ Token: token, SessionID: sessionID, - Expiration: expiration.Unix(), + Expiration: expirationTime.Unix(), }) } @@ -154,7 +156,7 @@ func (r *AuthAPIRoutes) meHandler(c *gin.Context) { response.Send(c, http.StatusOK, ctx.GetAccount()) } -// settingsHandler godoc +// updateHandler godoc // // @Summary Perform actions on the currently logged-in user. // @Tags Auth @@ -164,7 +166,7 @@ func (r *AuthAPIRoutes) meHandler(c *gin.Context) { // @Success 200 {object} model.Account // @Failure 403 {object} nil "Token not provided/invalid" // @Router /api/v1/auth/account [patch] -func (r *AuthAPIRoutes) settingsHandler(c *gin.Context) { +func (r *AuthAPIRoutes) updateHandler(c *gin.Context) { ctx := context.NewContextFromGin(c) if !ctx.UserIsLogged() { response.SendError(c, http.StatusForbidden, nil) @@ -175,6 +177,10 @@ func (r *AuthAPIRoutes) settingsHandler(c *gin.Context) { } account := ctx.GetAccount() + if account == nil { + response.SendError(c, http.StatusUnauthorized, nil) + return + } account.Config = payload.Config err := r.deps.Database.SaveAccountSettings(c, *account) diff --git a/internal/http/routes/api/v1/auth_test.go b/internal/http/routes/api/v1/auth_test.go index 0bc0a6fd8..49686af84 100644 --- a/internal/http/routes/api/v1/auth_test.go +++ b/internal/http/routes/api/v1/auth_test.go @@ -82,14 +82,11 @@ func TestAccountsRoute(t *testing.T) { router.Setup(g.Group("/")) // Create an account manually to test - account := model.Account{ - Username: "shiori", - Password: "gopher", - Owner: true, - } - require.NoError(t, deps.Database.SaveAccount(ctx, account)) + account := testutil.GetValidAccount() + account.Owner = true + require.NoError(t, deps.Database.SaveAccount(ctx, *account)) - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) req := httptest.NewRequest("GET", "/me", nil) @@ -175,9 +172,7 @@ func TestRefreshHandler(t *testing.T) { }) t.Run("token valid", func(t *testing.T) { - token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{ - Username: "shiori", - }, time.Now().Add(time.Minute)) + token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute)) require.NoError(t, err) w := testutil.PerformRequest(g, "POST", "/refresh", testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token)) @@ -186,7 +181,7 @@ func TestRefreshHandler(t *testing.T) { }) } -func TestSettingsHandler(t *testing.T) { +func TestUpdateHandler(t *testing.T) { logger := logrus.New() ctx := context.TODO() g := testutil.NewGin() @@ -196,10 +191,18 @@ func TestSettingsHandler(t *testing.T) { g.Use(middleware.AuthMiddleware(deps)) router.Setup(g.Group("/")) + require.NoError(t, deps.Database.SaveAccount(ctx, model.Account{ + Username: "shiori", + Password: "gopher", + })) + + t.Run("invalid token", func(t *testing.T) { + w := testutil.PerformRequest(g, "PATCH", "/account") + require.Equal(t, http.StatusForbidden, w.Code) + }) + t.Run("token valid", func(t *testing.T) { - token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{ - Username: "shiori", - }, time.Now().Add(time.Minute)) + token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute)) require.NoError(t, err) type settingRequestPayload struct { @@ -222,9 +225,7 @@ func TestSettingsHandler(t *testing.T) { }) t.Run("config not valid", func(t *testing.T) { - token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{ - Username: "shiori", - }, time.Now().Add(time.Minute)) + token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute)) require.NoError(t, err) w := testutil.PerformRequest(g, "PATCH", "/account", testutil.WithBody("notValidConfig"), testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token)) diff --git a/internal/http/routes/api/v1/bookmarks_test.go b/internal/http/routes/api/v1/bookmarks_test.go index 7af902565..5a1ffa74d 100644 --- a/internal/http/routes/api/v1/bookmarks_test.go +++ b/internal/http/routes/api/v1/bookmarks_test.go @@ -26,13 +26,9 @@ func TestUpdateBookmarkCache(t *testing.T) { router := NewBookmarksAPIRoutes(logger, deps) router.Setup(g.Group("/")) - account := model.Account{ - Username: "test", - Password: "test", - Owner: false, - } - require.NoError(t, deps.Database.SaveAccount(ctx, account)) - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + account := testutil.GetValidAccount() + require.NoError(t, deps.Database.SaveAccount(ctx, *account)) + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) t.Run("require authentication", func(t *testing.T) { @@ -58,13 +54,9 @@ func TestReadableeBookmarkContent(t *testing.T) { router := NewBookmarksAPIRoutes(logger, deps) router.Setup(g.Group("/")) - account := model.Account{ - Username: "test", - Password: "test", - Owner: false, - } - require.NoError(t, deps.Database.SaveAccount(ctx, account)) - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + account := testutil.GetValidAccount() + require.NoError(t, deps.Database.SaveAccount(ctx, *account)) + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) bookmark := testutil.GetValidBookmark() diff --git a/internal/http/routes/api/v1/tags_test.go b/internal/http/routes/api/v1/tags_test.go index 54300ad42..7f3ed6834 100644 --- a/internal/http/routes/api/v1/tags_test.go +++ b/internal/http/routes/api/v1/tags_test.go @@ -23,13 +23,10 @@ func TestTagList(t *testing.T) { _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) g.Use(middleware.AuthMiddleware(deps)) - account := model.Account{ - Username: "test", - Password: "test", - Owner: true, - } - require.NoError(t, deps.Database.SaveAccount(ctx, account)) - token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) + account := testutil.GetValidAccount() + account.Owner = true + require.NoError(t, deps.Database.SaveAccount(ctx, *account)) + token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute)) require.NoError(t, err) bookmark := testutil.GetValidBookmark() @@ -73,12 +70,9 @@ func TestTagCreate(t *testing.T) { _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) g.Use(middleware.AuthMiddleware(deps)) - account := model.Account{ - Username: "test", - Password: "test", - Owner: true, - } - require.NoError(t, deps.Database.SaveAccount(ctx, account)) + account := testutil.GetValidAccount() + account.Owner = true + require.NoError(t, deps.Database.SaveAccount(ctx, *account)) // token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute)) // require.NoError(t, err) diff --git a/internal/model/account.go b/internal/model/account.go index 065477db5..5c58fca5f 100644 --- a/internal/model/account.go +++ b/internal/model/account.go @@ -4,6 +4,8 @@ import ( "database/sql/driver" "encoding/json" "fmt" + + "github.com/golang-jwt/jwt/v5" ) // Account is the database model for account. @@ -61,3 +63,9 @@ type AccountDTO struct { Owner bool `json:"owner"` Config UserConfig `json:"config"` } + +type JWTClaim struct { + jwt.RegisteredClaims + + Account *Account +} diff --git a/internal/model/domains.go b/internal/model/domains.go index ce4e0532b..8b8fefa2c 100644 --- a/internal/model/domains.go +++ b/internal/model/domains.go @@ -18,6 +18,7 @@ type BookmarksDomain interface { } type AccountsDomain interface { + ParseToken(userJWT string) (*JWTClaim, error) CheckToken(ctx context.Context, userJWT string) (*Account, error) GetAccountFromCredentials(ctx context.Context, username, password string) (*Account, error) CreateTokenForAccount(account *Account, expiration time.Time) (string, error) diff --git a/internal/testutil/shiori.go b/internal/testutil/shiori.go index 43ceb24d9..737266dc4 100644 --- a/internal/testutil/shiori.go +++ b/internal/testutil/shiori.go @@ -54,3 +54,14 @@ func GetValidBookmark() *model.BookmarkDTO { Title: "Shiori repository", } } + +// GetValidAccount returns a valid account for testing +// It includes an ID to properly use the account when testing authentication methods +// without interacting with the database. +func GetValidAccount() *model.Account { + return &model.Account{ + ID: 99, + Username: "test", + Password: "test", + } +}