diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 357eea34ce..d6158e9642 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -15,6 +15,7 @@ import ( "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/metering" @@ -69,9 +70,45 @@ const ( QRCodeGenerationErrorMessage = "Error generating QR Code" ) +func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { + if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + return err + } + if err := db.Load(user, "Factors"); err != nil { + return err + } + factorCount := len(user.Factors) + numVerifiedFactors := 0 + + for _, factor := range user.Factors { + if factor.FriendlyName == newFactorName { + return unprocessableEntityError( + ErrorCodeMFAFactorNameConflict, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), + ) + } + if factor.IsVerified() { + numVerifiedFactors++ + } + } + + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() { + return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") + } + + return nil +} + func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { ctx := r.Context() - config := a.config user := getUser(ctx) session := getSession(ctx) db := a.db.WithContext(ctx) @@ -83,37 +120,18 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * if err != nil { return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } - factors := user.Factors - factorCount := len(factors) - numVerifiedFactors := 0 - if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { - return err - } var factorsToDelete []models.Factor for _, factor := range user.Factors { - switch { - case factor.FriendlyName == params.FriendlyName: - return unprocessableEntityError( - ErrorCodeMFAFactorNameConflict, - fmt.Sprintf("A factor with the friendly name %q for this user already exists", factor.FriendlyName), - ) - - case factor.IsPhoneFactor(): - if factor.Phone.String() == phone { - if factor.IsVerified() { - return unprocessableEntityError( - ErrorCodeMFAVerifiedFactorExists, - "A verified phone factor already exists, unenroll the existing factor to continue", - ) - } else if factor.IsUnverified() { - factorsToDelete = append(factorsToDelete, factor) - } - + if factor.IsPhoneFactor() && factor.Phone.String() == phone { + if factor.IsVerified() { + return unprocessableEntityError( + ErrorCodeMFAVerifiedFactorExists, + "A verified phone factor already exists, unenroll the existing factor to continue", + ) + } else if factor.IsUnverified() { + factorsToDelete = append(factorsToDelete, factor) } - - case factor.IsVerified(): - numVerifiedFactors++ } } @@ -121,17 +139,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) } - if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err } - if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") - } factor := models.NewPhoneFactor(user, phone, params.FriendlyName) err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { @@ -173,31 +184,10 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E issuer = params.Issuer } - factors := user.Factors - - factorCount := len(factors) - numVerifiedFactors := 0 - if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil { return err } - for _, factor := range factors { - if factor.IsVerified() { - numVerifiedFactors += 1 - } - } - - if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") - } var factor *models.Factor var buf bytes.Buffer var key *otp.Key @@ -225,13 +215,9 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { - pgErr := utilities.NewPostgresError(terr) - if pgErr.IsUniqueConstraintViolated() { - return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) - } return terr - } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, }); terr != nil { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 557b0ab139..87767f0850 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -290,9 +290,7 @@ func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { err := json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) - // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) - require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) + require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict) } func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() { @@ -369,7 +367,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { var w *httptest.ResponseRecorder token := accessTokenResp.Token for i := 0; i < numFactors; i++ { - w = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK) } enrollResp := EnrollFactorResponse{} @@ -379,7 +377,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { _ = performChallengeFlow(ts, enrollResp.ID, token) // Enroll another Factor (Factor 3) - _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + _ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK) require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) require.Equal(ts.T(), 3, len(ts.TestUser.Factors)) } diff --git a/internal/models/factor.go b/internal/models/factor.go index 24cda188f7..7309653307 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -269,10 +269,6 @@ func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error { return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String()) } -func (f *Factor) IsOwnedBy(user *User) bool { - return f.UserID == user.ID -} - func (f *Factor) IsVerified() bool { return f.Status == FactorStateVerified.String() }