diff --git a/internal/api/api.go b/internal/api/api.go index 054167136e..287ae2995d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -128,6 +128,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Route("/", func(r *router) { r.Use(api.isValidExternalHost) + r.Use(api.isValidAuthorizedEmail) r.Get("/settings", api.Settings) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 6921392525..04458103b2 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -86,6 +86,9 @@ type RequestParams interface { struct { Email string `json:"email"` Phone string `json:"phone"` + } | + struct { + Email string `json:"email"` } } diff --git a/internal/api/mail.go b/internal/api/mail.go index 8bf69ba943..44f364453c 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -2,8 +2,6 @@ package api import ( "net/http" - "regexp" - "strings" "time" "github.com/didip/tollbooth/v5" @@ -550,8 +548,6 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models return nil } -var emailLabelPattern = regexp.MustCompile("[+][^@]+@") - func (a *API) validateEmail(email string) (string, error) { if email == "" { return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") @@ -563,21 +559,6 @@ func (a *API) validateEmail(email string) (string, error) { return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } - email = strings.ToLower(email) - - if len(a.config.External.Email.AuthorizedAddresses) > 0 { - // allow labelled emails when authorization rules are in place - normalized := emailLabelPattern.ReplaceAllString(email, "@") - - for _, authorizedAddress := range a.config.External.Email.AuthorizedAddresses { - if normalized == authorizedAddress { - return email, nil - } - } - - return "", badRequestError(ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", email) - } - return email, nil } diff --git a/internal/api/mail_test.go b/internal/api/mail_test.go index fd3de7c80c..90608a13ab 100644 --- a/internal/api/mail_test.go +++ b/internal/api/mail_test.go @@ -48,41 +48,6 @@ func (ts *MailTestSuite) SetupTest() { require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user") } -func (ts *MailTestSuite) TestValidateEmailAuthorizedAddresses() { - ts.Config.External.Email.AuthorizedAddresses = []string{"someone-a@example.com", "someone-b@example.com"} - defer func() { - ts.Config.External.Email.AuthorizedAddresses = nil - }() - - positiveExamples := []string{ - "someone-a@example.com", - "someone-b@example.com", - "someone-a+test-1@example.com", - "someone-b+test-2@example.com", - "someone-A@example.com", - "someone-B@example.com", - "someone-a@Example.com", - "someone-b@Example.com", - } - - negativeExamples := []string{ - "someone@example.com", - "s.omeone@example.com", - "someone-a+@example.com", - "someone+a@example.com", - } - - for _, example := range positiveExamples { - _, err := ts.API.validateEmail(example) - require.NoError(ts.T(), err) - } - - for _, example := range negativeExamples { - _, err := ts.API.validateEmail(example) - require.Error(ts.T(), err) - } -} - func (ts *MailTestSuite) TestGenerateLink() { // create admin jwt claims := &AccessTokenClaims{ diff --git a/internal/api/middleware.go b/internal/api/middleware.go index aa2c3e9ffa..3b56f59d1b 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strings" "sync" "time" @@ -170,6 +171,43 @@ func isIgnoreCaptchaRoute(req *http.Request) bool { return false } +var emailLabelPattern = regexp.MustCompile("[+][^@]+@") + +func (a *API) isValidAuthorizedEmail(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + + // skip checking for authorized email addresses if it's an admin request + if strings.HasPrefix(req.URL.Path, "/admin") || req.Method == http.MethodGet || req.Method == http.MethodDelete { + return ctx, nil + } + + var body struct { + Email string `json:"email"` + } + + if err := retrieveRequestParams(req, &body); err != nil { + // let downstream handlers handle the error + return ctx, nil + } + if body.Email == "" { + return ctx, nil + } + email := strings.ToLower(body.Email) + if len(a.config.External.Email.AuthorizedAddresses) > 0 { + // allow labelled emails when authorization rules are in place + normalized := emailLabelPattern.ReplaceAllString(email, "@") + + for _, authorizedAddress := range a.config.External.Email.AuthorizedAddresses { + if normalized == authorizedAddress { + return ctx, nil + } + } + + return ctx, badRequestError(ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", email) + } + return ctx, nil +} + func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() config := a.config diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 2d7a324935..7056d91ddb 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -515,3 +515,52 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { }) } } + +func (ts *MiddlewareTestSuite) TestIsValidAuthorizedEmail() { + ts.API.config.External.Email.AuthorizedAddresses = []string{"valid@example.com"} + + cases := []struct { + desc string + reqPath string + body map[string]interface{} + }{ + { + desc: "bypass check for admin endpoints", + reqPath: "/admin", + body: map[string]interface{}{ + "email": "test@example.com", + }, + }, + { + desc: "bypass check if no email in request body", + reqPath: "/signup", + body: map[string]interface{}{}, + }, + { + desc: "email not in authorized list", + reqPath: "/signup", + body: map[string]interface{}{ + "email": "invalid@example.com", + }, + }, + { + desc: "email in authorized list", + reqPath: "/signup", + body: map[string]interface{}{ + "email": "valid@example.com", + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + req := httptest.NewRequest(http.MethodPost, "http://localhost"+c.reqPath, &buffer) + w := httptest.NewRecorder() + if _, err := ts.API.isValidAuthorizedEmail(w, req); err != nil { + require.Equal(ts.T(), err.(*HTTPError).ErrorCode, ErrorCodeEmailAddressNotAuthorized) + } + }) + } +}