Skip to content

Commit

Permalink
feat: configurable email and sms rate limiting (#1800)
Browse files Browse the repository at this point in the history
Adds two new configuration values for rate limiting the sending of
emails and sms messages:

 - GOTRUE_RATE_LIMIT_EMAIL_SENT
 - GOTRUE_RATE_LIMIT_SMS_SENT

It is implemented with a simple rate limiter that resets a counter at a
regular interval. The first intervals start time is set when the counter
is initialized. It will be reset when the server is restarted, but
preserved when the config is reloaded.

Syntax examples:
```
1.5       # Allow 1.5 events over 1 hour (legacy format)
100       # Allow 100 events over 1 hour (1h is default)
100/1h    # Allow 100 events over 1 hour (explicit duration)
100/24h   # Allow 100 events over 24 hours
100/72h   # Allow 100 events over 72 hours (use hours for days)
10/30m    # Allow 10  events over 30 minutes
3/10s     # Allow 3   events over 10 seconds
```

Syntax in ABNF to express the format as value:
```
value = count / rate
count = 1*DIGIT ["." 1*DIGIT]
rate = 1*DIGIT "/" ival
ival = ival-sec / ival-min / ival-hr
ival-sec = 1*DIGIT "s"
ival-min = 1*DIGIT "s"
ival-hr = 1*DIGIT "h"
```

This change was a continuation of
#1746 adapted to support the recent
preservation of rate limiters across server reloads.

---------

Co-authored-by: Chris Stockton <[email protected]>
Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent 8cc2f0e commit 5e94047
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 267 deletions.
20 changes: 7 additions & 13 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts)
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
r.With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns
limitSignups := api.limiterOpts.Signups
Expand All @@ -165,24 +164,20 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if _, err := api.limitHandler(limitSignups)(w, r); err != nil {
return err
}
// apply shared rate limiting on email / phone
if _, err := sharedLimiter(w, r); err != nil {
return err
}
return api.Signup(w, r)
})
})
r.With(api.limitHandler(api.limiterOpts.Recover)).
With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)
With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(api.limiterOpts.Resend)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)
With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(api.limiterOpts.MagicLink)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)
With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(api.limiterOpts.Otp)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)
With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)
Expand All @@ -200,8 +195,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(api.limitHandler(api.limiterOpts.User)).
With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(api.limiterOpts.User)).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand Down
19 changes: 0 additions & 19 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"net/url"

"github.com/didip/tollbooth/v5/limiter"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/models"
)
Expand Down Expand Up @@ -32,7 +31,6 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
sharedLimiterKey = contextKey("shared_limiter")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -243,20 +241,3 @@ func getExternalHost(ctx context.Context) *url.URL {
}
return obj.(*url.URL)
}

type SharedLimiter struct {
EmailLimiter *limiter.Limiter
PhoneLimiter *limiter.Limiter
}

func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context {
return context.WithValue(ctx, sharedLimiterKey, limiter)
}

func getLimiter(ctx context.Context) *SharedLimiter {
obj := ctx.Value(sharedLimiterKey)
if obj == nil {
return nil
}
return obj.(*SharedLimiter)
}
17 changes: 7 additions & 10 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strings"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"
mail "github.com/supabase/auth/internal/mailer"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -578,15 +577,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
externalURL := getExternalHost(ctx)

// apply rate limiting before the email is sent out
if limiter := getLimiter(ctx); limiter != nil {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}
if ok := a.limiterOpts.Email.Allow(); !ok {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}

if config.Hook.SendEmail.Enabled {
Expand Down
21 changes: 0 additions & 21 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,6 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
shouldRateLimitEmail := config.External.Email.Enabled && !config.Mailer.Autoconfirm
shouldRateLimitPhone := config.External.Phone.Enabled && !config.Sms.Autoconfirm

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: limiterOptions.Email,
PhoneLimiter: limiterOptions.Phone,
})
}
}

return c, nil
}
}

func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) {
t, err := a.extractBearerToken(req)
if err != nil || t == "" {
Expand Down
174 changes: 0 additions & 174 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,52 +185,6 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() {
}
}

func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
// Set up rate limit config for this test
ts.Config.RateLimitEmailSent = 5
ts.Config.RateLimitSmsSent = 5
ts.Config.External.Phone.Enabled = true

cases := []struct {
desc string
expectedErrorMsg string
requestBody map[string]interface{}
}{
{
desc: "Email rate limit exceeded",
expectedErrorMsg: "429: Email rate limit exceeded",
requestBody: map[string]interface{}{
"email": "[email protected]",
},
},
{
desc: "SMS rate limit exceeded",
expectedErrorMsg: "429: SMS rate limit exceeded",
requestBody: map[string]interface{}{
"phone": "+1233456789",
},
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

ctx, err := limiter(w, req)
require.NoError(ts.T(), err)

// check that shared limiter is set in the request context
sharedLimiter := getLimiter(ctx)
require.NotNil(ts.T(), sharedLimiter)
})
}
}

func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
Expand Down Expand Up @@ -388,134 +342,6 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() {
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
}

func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
// setup config for shared limiter and ip-based limiter to work
ts.Config.RateLimitHeader = "X-Rate-Limit"
ts.Config.External.Email.Enabled = true
ts.Config.External.Phone.Enabled = true
ts.Config.Mailer.Autoconfirm = false
ts.Config.Sms.Autoconfirm = false

ipBasedLimiter := func(max float64) *limiter.Limiter {
return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})
}

okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := getLimiter(r.Context())
if limiter != nil {
var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}
err := retrieveRequestParams(r, &requestBody)
require.NoError(ts.T(), err)

if requestBody.Email != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverEmailSendRateLimit,
Message: "Email rate limit exceeded",
})
}
}
if requestBody.Phone != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverSMSSendRateLimit,
Message: "SMS rate limit exceeded",
})
}
}
}
w.WriteHeader(http.StatusOK)
})

cases := []struct {
desc string
sharedLimiterConfig *conf.GlobalConfiguration
ipBasedLimiterConfig float64
body map[string]interface{}
expectedErrorCode string
}{
{
desc: "Exceed ip-based rate limit before shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 10,
RateLimitSmsSent: 10,
},
ipBasedLimiterConfig: 1,
body: map[string]interface{}{
"email": "[email protected]",
},
expectedErrorCode: ErrorCodeOverRequestRateLimit,
},
{
desc: "Exceed email shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"email": "[email protected]",
},
expectedErrorCode: ErrorCodeOverEmailSendRateLimit,
},
{
desc: "Exceed sms shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"phone": "123456789",
},
expectedErrorCode: ErrorCodeOverSMSSendRateLimit,
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
for i := 0; i < int(threshold); i++ {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
}

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

// check if the rate limit is exceeded with the expected error code
w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedErrorCode, data["error_code"])
})
}
}

func (ts *MiddlewareTestSuite) TestIsValidAuthorizedEmail() {
ts.API.config.External.Email.AuthorizedAddresses = []string{"[email protected]"}

Expand Down
17 changes: 5 additions & 12 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ type Option interface {
}

type LimiterOptions struct {
Email *limiter.Limiter
Phone *limiter.Limiter
Email *RateLimiter
Phone *RateLimiter

Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Recover *limiter.Limiter
Expand All @@ -35,16 +36,8 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

o.Email = newRateLimiter(gc.RateLimitEmailSent)
o.Phone = newRateLimiter(gc.RateLimitSmsSent)
o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down
9 changes: 2 additions & 7 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"text/template"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"

"github.com/pkg/errors"
Expand Down Expand Up @@ -45,7 +44,6 @@ func formatPhoneNumber(phone string) string {

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) {
ctx := r.Context()
config := a.config

var token *string
Expand Down Expand Up @@ -89,11 +87,8 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
// not using test OTPs
if otp == "" {
// apply rate limiting before the sms is sent out
limiter := getLimiter(ctx)
if limiter != nil {
if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
if ok := a.limiterOpts.Phone.Allow(); !ok {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
otp, err = crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
Expand Down
Loading

0 comments on commit 5e94047

Please sign in to comment.