Skip to content

Commit

Permalink
fix: refactor mfa validation into functions (#1780)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

We make the following changes:
- Move all checks for factor validation into a single function. This
ensures that we can re-use it across methods so we don't miss checks
- Removes un-used functions
- For MFA (TOTP) changes the duplicate friendly name from one that is
done at the database level to one that is done at the application level.
This aligns with MFA (Phone)
  • Loading branch information
J0 authored Oct 3, 2024
1 parent 819dabb commit 410b8ac
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 73 deletions.
114 changes: 50 additions & 64 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/supabase/auth/internal/api/sms_provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/hooks"
"github.com/supabase/auth/internal/metering"
Expand Down Expand Up @@ -69,9 +70,45 @@ const (
QRCodeGenerationErrorMessage = "Error generating QR Code"
)

func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error {
if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil {
return err
}
if err := db.Load(user, "Factors"); err != nil {
return err
}
factorCount := len(user.Factors)
numVerifiedFactors := 0

for _, factor := range user.Factors {
if factor.FriendlyName == newFactorName {
return unprocessableEntityError(
ErrorCodeMFAFactorNameConflict,
fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName),
)
}
if factor.IsVerified() {
numVerifiedFactors++
}
}

if factorCount >= int(config.MFA.MaxEnrolledFactors) {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors >= config.MFA.MaxVerifiedFactors {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() {
return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor")
}

return nil
}

func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error {
ctx := r.Context()
config := a.config
user := getUser(ctx)
session := getSession(ctx)
db := a.db.WithContext(ctx)
Expand All @@ -83,55 +120,29 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *
if err != nil {
return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
}
factors := user.Factors

factorCount := len(factors)
numVerifiedFactors := 0
if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil {
return err
}
var factorsToDelete []models.Factor
for _, factor := range user.Factors {
switch {
case factor.FriendlyName == params.FriendlyName:
return unprocessableEntityError(
ErrorCodeMFAFactorNameConflict,
fmt.Sprintf("A factor with the friendly name %q for this user already exists", factor.FriendlyName),
)

case factor.IsPhoneFactor():
if factor.Phone.String() == phone {
if factor.IsVerified() {
return unprocessableEntityError(
ErrorCodeMFAVerifiedFactorExists,
"A verified phone factor already exists, unenroll the existing factor to continue",
)
} else if factor.IsUnverified() {
factorsToDelete = append(factorsToDelete, factor)
}

if factor.IsPhoneFactor() && factor.Phone.String() == phone {
if factor.IsVerified() {
return unprocessableEntityError(
ErrorCodeMFAVerifiedFactorExists,
"A verified phone factor already exists, unenroll the existing factor to continue",
)
} else if factor.IsUnverified() {
factorsToDelete = append(factorsToDelete, factor)
}

case factor.IsVerified():
numVerifiedFactors++
}
}

if err := db.Destroy(&factorsToDelete); err != nil {
return internalServerError("Database error deleting unverified phone factors").WithInternalError(err)
}

if factorCount >= int(config.MFA.MaxEnrolledFactors) {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors >= config.MFA.MaxVerifiedFactors {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil {
return err
}

if numVerifiedFactors > 0 && !session.IsAAL2() {
return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor")
}
factor := models.NewPhoneFactor(user, phone, params.FriendlyName)
err = db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(factor); terr != nil {
Expand Down Expand Up @@ -173,31 +184,10 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E
issuer = params.Issuer
}

factors := user.Factors

factorCount := len(factors)
numVerifiedFactors := 0
if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil {
if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil {
return err
}

for _, factor := range factors {
if factor.IsVerified() {
numVerifiedFactors += 1
}
}

if factorCount >= int(config.MFA.MaxEnrolledFactors) {
return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors >= config.MFA.MaxVerifiedFactors {
return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors > 0 && !session.IsAAL2() {
return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor")
}
var factor *models.Factor
var buf bytes.Buffer
var key *otp.Key
Expand Down Expand Up @@ -225,13 +215,9 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E

err = db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(factor); terr != nil {
pgErr := utilities.NewPostgresError(terr)
if pgErr.IsUniqueConstraintViolated() {
return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName))
}
return terr

}

if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{
"factor_id": factor.ID,
}); terr != nil {
Expand Down
8 changes: 3 additions & 5 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() {
err := json.NewDecoder(response.Body).Decode(&errorResponse)
require.NoError(ts.T(), err)

// Convert the response body to a string and check for the expected error message
expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName)
require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage)
require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict)
}

func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() {
Expand Down Expand Up @@ -369,7 +367,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
var w *httptest.ResponseRecorder
token := accessTokenResp.Token
for i := 0; i < numFactors; i++ {
w = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK)
w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK)
}

enrollResp := EnrollFactorResponse{}
Expand All @@ -379,7 +377,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
_ = performChallengeFlow(ts, enrollResp.ID, token)

// Enroll another Factor (Factor 3)
_ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK)
_ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK)
require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID))
require.Equal(ts.T(), 3, len(ts.TestUser.Factors))
}
Expand Down
4 changes: 0 additions & 4 deletions internal/models/factor.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,6 @@ func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error {
return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String())
}

func (f *Factor) IsOwnedBy(user *User) bool {
return f.UserID == user.ID
}

func (f *Factor) IsVerified() bool {
return f.Status == FactorStateVerified.String()
}
Expand Down

0 comments on commit 410b8ac

Please sign in to comment.