-
Notifications
You must be signed in to change notification settings - Fork 1
/
limiter.go
108 lines (100 loc) · 2.39 KB
/
limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package ksmux
import (
"net"
"net/http"
"time"
"github.com/kamalshkeir/kmap"
"github.com/kamalshkeir/lg"
"golang.org/x/time/rate"
)
type limiterClient struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
limited = kmap.New[string, *limiterClient]()
limiterQuit chan struct{}
limiterUsed = false
defCheckEvery = 5 * time.Minute
defBlockDuration = 10 * time.Minute
defRateEvery = 10 * time.Minute
defBurstsN = 100
defMessage = "TOO MANY REQUESTS"
)
type ConfigLimiter struct {
Message string // default "TOO MANY REQUESTS"
RateEvery time.Duration // default 10 min
BurstsN int // default 100
CheckEvery time.Duration // default 5 min
BlockDuration time.Duration // default 10 min
}
func Limiter(conf *ConfigLimiter) func(http.Handler) http.Handler {
if conf == nil {
conf = &ConfigLimiter{
CheckEvery: defCheckEvery,
BlockDuration: defBlockDuration,
RateEvery: defRateEvery,
BurstsN: defBurstsN,
Message: defMessage,
}
} else {
if conf.CheckEvery == 0 {
conf.CheckEvery = defCheckEvery
}
if conf.BlockDuration == 0 {
conf.BlockDuration = defBlockDuration
}
if conf.RateEvery == 0 {
conf.RateEvery = defRateEvery
}
if conf.BurstsN == 0 {
conf.BurstsN = defBurstsN
}
if conf.Message == "" {
conf.Message = defMessage
}
}
ticker := time.NewTicker(conf.CheckEvery)
limiterQuit = make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
limited.Range(func(key string, value *limiterClient) bool {
if time.Since(value.lastSeen) > conf.BlockDuration {
go limited.Delete(key)
}
return true
})
case <-limiterQuit:
ticker.Stop()
return
}
}
}()
limiterUsed = true
return func(handler http.Handler) http.Handler {
return Handler(func(c *Context) {
ip, _, err := net.SplitHostPort(c.Request.RemoteAddr)
if lg.CheckError(err) {
c.SetStatus(http.StatusInternalServerError)
return
}
var ll *rate.Limiter
if lim, found := limited.Get(ip); !found {
ll = rate.NewLimiter(rate.Every(conf.RateEvery), conf.BurstsN)
} else {
ll = lim.limiter
}
limited.Set(ip, &limiterClient{
limiter: ll,
lastSeen: time.Now(),
})
if !ll.Allow() {
c.Status(http.StatusTooManyRequests).Text(conf.Message)
return
}
handler.ServeHTTP(c.ResponseWriter, c.Request)
})
}
}