diff --git a/docs/middlewares.md b/docs/middlewares.md index 80d0dc1..aefb030 100644 --- a/docs/middlewares.md +++ b/docs/middlewares.md @@ -190,6 +190,43 @@ func main() { ``` +## global concurrent limiter + +全局的并发请求限制,可以用于控制应用的并发请求量。 + +**Example** +```go +package main + +import ( + "bytes" + "sync" + "time" + + "github.com/vicanso/elton" + "github.com/vicanso/elton/middleware" +) + +func main() { + + e := elton.New() + e.Use(middleware.NewGlobalConcurrentLimiter(middleware.GlobalConcurrentLimiterConfig{ + Max: 1000, + })) + + e.POST("/login", func(c *elton.Context) (err error) { + time.Sleep(3 * time.Second) + c.BodyBuffer = bytes.NewBufferString("hello world") + return + }) + + err := e.ListenAndServe(":3000") + if err != nil { + panic(err) + } +} +``` + ## concurrent limiter 并发请求限制,可以通过指定请求的参数,如IP、query的字段或者body等获取,限制同时并发性的提交请求,主要用于避免相同的请求多次提交。指定的Key分为以下几种: diff --git a/middleware/concurrent_limiter.go b/middleware/concurrent_limiter.go index a6f2777..e9b3826 100644 --- a/middleware/concurrent_limiter.go +++ b/middleware/concurrent_limiter.go @@ -26,6 +26,7 @@ import ( "errors" "net/http" "strings" + "sync/atomic" "github.com/tidwall/gjson" "github.com/vicanso/elton" @@ -39,6 +40,12 @@ var ( Message: "submit too frequently", Category: ErrConcurrentLimiterCategory, } + // ErrTooManyRequests too many request + ErrTooManyRequests = &hes.Error{ + StatusCode: http.StatusTooManyRequests, + Message: "Too Many Requests", + Category: ErrConcurrentLimiterCategory, + } ErrNotAllowEmpty = &hes.Error{ StatusCode: http.StatusBadRequest, Message: "empty value is not allowed", @@ -78,9 +85,14 @@ type ( Body bool IP bool } + // GlobalConcurrentLimiterConfig + GlobalConcurrentLimiterConfig struct { + Skipper elton.Skipper + Max uint32 + } ) -// New create a concurrent limiter middleware +// NewConcurrentLimiter create a concurrent limiter middleware func NewConcurrentLimiter(config ConcurrentLimiterConfig) elton.Handler { if config.Lock == nil { @@ -177,3 +189,17 @@ func NewConcurrentLimiter(config ConcurrentLimiterConfig) elton.Handler { return c.Next() } } + +// NewGlobalConcurrentLimiter create a new global concurrent limiter +func NewGlobalConcurrentLimiter(config GlobalConcurrentLimiterConfig) elton.Handler { + var count uint32 + return func(c *elton.Context) (err error) { + value := atomic.AddUint32(&count, 1) + defer atomic.AddUint32(&count, ^uint32(0)) + if value >= config.Max { + err = ErrTooManyRequests + return + } + return c.Next() + } +} diff --git a/middleware/concurrent_limiter_test.go b/middleware/concurrent_limiter_test.go index 8a9af16..3445a7e 100644 --- a/middleware/concurrent_limiter_test.go +++ b/middleware/concurrent_limiter_test.go @@ -136,3 +136,27 @@ func TestConcurrentLimiter(t *testing.T) { assert.Equal(ErrNotAllowEmpty, err) }) } + +func TestGlobalConcurrentLimiter(t *testing.T) { + assert := assert.New(t) + fn := NewGlobalConcurrentLimiter(GlobalConcurrentLimiterConfig{ + Max: 1, + }) + req := httptest.NewRequest("POST", "/users/login?type=1", nil) + resp := httptest.NewRecorder() + c := elton.NewContext(resp, req) + err := fn(c) + assert.Equal(ErrTooManyRequests, err) + + fn = NewGlobalConcurrentLimiter(GlobalConcurrentLimiterConfig{ + Max: 2, + }) + done := false + c.Next = func() error { + done = true + return nil + } + err = fn(c) + assert.Nil(err) + assert.True(done) +}