From 410b8acdd659fc4c929fe57a9e9dba4c76da305d Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Thu, 3 Oct 2024 17:46:16 +0200 Subject: [PATCH] fix: refactor mfa validation into functions (#1780) ## What kind of change does this PR introduce? We make the following changes: - Move all checks for factor validation into a single function. This ensures that we can re-use it across methods so we don't miss checks - Removes un-used functions - For MFA (TOTP) changes the duplicate friendly name from one that is done at the database level to one that is done at the application level. This aligns with MFA (Phone) --- internal/api/mfa.go | 114 +++++++++++++++++--------------------- internal/api/mfa_test.go | 8 +-- internal/models/factor.go | 4 -- 3 files changed, 53 insertions(+), 73 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 357eea34ce..d6158e9642 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -15,6 +15,7 @@ import ( "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/metering" @@ -69,9 +70,45 @@ const ( QRCodeGenerationErrorMessage = "Error generating QR Code" ) +func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { + if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + return err + } + if err := db.Load(user, "Factors"); err != nil { + return err + } + factorCount := len(user.Factors) + numVerifiedFactors := 0 + + for _, factor := range user.Factors { + if factor.FriendlyName == newFactorName { + return unprocessableEntityError( + ErrorCodeMFAFactorNameConflict, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), + ) + } + if factor.IsVerified() { + numVerifiedFactors++ + } + } + + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() { + return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") + } + + return nil +} + func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { ctx := r.Context() - config := a.config user := getUser(ctx) session := getSession(ctx) db := a.db.WithContext(ctx) @@ -83,37 +120,18 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * if err != nil { return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } - factors := user.Factors - factorCount := len(factors) - numVerifiedFactors := 0 - if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { - return err - } var factorsToDelete []models.Factor for _, factor := range user.Factors { - switch { - case factor.FriendlyName == params.FriendlyName: - return unprocessableEntityError( - ErrorCodeMFAFactorNameConflict, - fmt.Sprintf("A factor with the friendly name %q for this user already exists", factor.FriendlyName), - ) - - case factor.IsPhoneFactor(): - if factor.Phone.String() == phone { - if factor.IsVerified() { - return unprocessableEntityError( - ErrorCodeMFAVerifiedFactorExists, - "A verified phone factor already exists, unenroll the existing factor to continue", - ) - } else if factor.IsUnverified() { - factorsToDelete = append(factorsToDelete, factor) - } - + if factor.IsPhoneFactor() && factor.Phone.String() == phone { + if factor.IsVerified() { + return unprocessableEntityError( + ErrorCodeMFAVerifiedFactorExists, + "A verified phone factor already exists, unenroll the existing factor to continue", + ) + } else if factor.IsUnverified() { + factorsToDelete = append(factorsToDelete, factor) } - - case factor.IsVerified(): - numVerifiedFactors++ } } @@ -121,17 +139,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) } - if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err } - if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") - } factor := models.NewPhoneFactor(user, phone, params.FriendlyName) err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { @@ -173,31 +184,10 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E issuer = params.Issuer } - factors := user.Factors - - factorCount := len(factors) - numVerifiedFactors := 0 - if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil { return err } - for _, factor := range factors { - if factor.IsVerified() { - numVerifiedFactors += 1 - } - } - - if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") - } var factor *models.Factor var buf bytes.Buffer var key *otp.Key @@ -225,13 +215,9 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { - pgErr := utilities.NewPostgresError(terr) - if pgErr.IsUniqueConstraintViolated() { - return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) - } return terr - } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, }); terr != nil { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 557b0ab139..87767f0850 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -290,9 +290,7 @@ func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { err := json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) - // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) - require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) + require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict) } func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() { @@ -369,7 +367,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { var w *httptest.ResponseRecorder token := accessTokenResp.Token for i := 0; i < numFactors; i++ { - w = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK) } enrollResp := EnrollFactorResponse{} @@ -379,7 +377,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { _ = performChallengeFlow(ts, enrollResp.ID, token) // Enroll another Factor (Factor 3) - _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + _ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK) require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) require.Equal(ts.T(), 3, len(ts.TestUser.Factors)) } diff --git a/internal/models/factor.go b/internal/models/factor.go index 24cda188f7..7309653307 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -269,10 +269,6 @@ func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error { return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String()) } -func (f *Factor) IsOwnedBy(user *User) bool { - return f.UserID == user.ID -} - func (f *Factor) IsVerified() bool { return f.Status == FactorStateVerified.String() }