Skip to content

Commit

Permalink
add rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
totegamma committed Oct 15, 2024
1 parent 058bb25 commit 120ddcc
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 11 deletions.
4 changes: 3 additions & 1 deletion cmd/gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ func main() {

proxy.Transport = otelhttp.NewTransport(http.DefaultTransport)

middlewares := []echo.MiddlewareFunc{}
middlewares := []echo.MiddlewareFunc{
authService.RateLimiter(service.RateLimitConf),
}
if service.InjectCors {
middlewares = append(middlewares, cors)
}
Expand Down
15 changes: 9 additions & 6 deletions cmd/gateway/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"github.com/go-yaml/yaml"
"log"
"os"

"github.com/totegamma/concurrent/core"
)

type GatewayConfig struct {
Expand All @@ -15,12 +17,13 @@ type ServiceInfo struct {
}

type Service struct {
Name string `yaml:"name"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Path string `yaml:"path"`
PreservePath bool `yaml:"preservePath"`
InjectCors bool `yaml:"injectCors"`
Name string `yaml:"name"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Path string `yaml:"path"`
PreservePath bool `yaml:"preservePath"`
InjectCors bool `yaml:"injectCors"`
RateLimitConf core.RateLimitConfigMap `yaml:"rateLimit"`
}

// Load loads concurrent config from given path
Expand Down
1 change: 1 addition & 0 deletions core/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type AssociationService interface {
type AuthService interface {
IssuePassport(ctx context.Context, requester string, key []Key) (string, error)
IdentifyIdentity(next echo.HandlerFunc) echo.HandlerFunc
RateLimiter(configMap RateLimitConfigMap) echo.MiddlewareFunc
}

type DomainService interface {
Expand Down
7 changes: 7 additions & 0 deletions core/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ type SyncStatus struct {
LatestOnDB time.Time `json:"latestOnDB"`
Progress string `json:"progress"`
}

type RateLimitConfig struct {
BucketSize int `yaml:"bucketSize"`
RefillSpan float64 `yaml:"refillSpan"`
}

type RateLimitConfigMap map[string]RateLimitConfig
2 changes: 1 addition & 1 deletion wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

157 changes: 157 additions & 0 deletions x/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"
"github.com/totegamma/concurrent/core"
"github.com/totegamma/concurrent/x/jwt"
"github.com/totegamma/concurrent/x/key"
Expand Down Expand Up @@ -449,3 +451,158 @@ func Recaptcha(validator *recaptcha.ReCAPTCHA) echo.MiddlewareFunc {
}
}
}

func (s *service) RateLimiter(configMap core.RateLimitConfigMap) echo.MiddlewareFunc {

routerEcho := echo.New()
dummyFunc := func(c echo.Context) error {
return nil
}

if configMap == nil {
configMap = core.RateLimitConfigMap{}
}

if _, ok := configMap["DEFAULT"]; !ok {
configMap["DEFAULT"] = core.RateLimitConfig{
BucketSize: 1000,
RefillSpan: 1,
}
}

core.JsonPrint("RateLimitConfigMap", configMap)

for path := range configMap {
if path == "DEFAULT" {
continue
}

splitter := strings.Index(path, ":")
if splitter == -1 {
routerEcho.Any(path, dummyFunc)
continue
}

method := path[:splitter]
path = path[splitter+1:]

switch method {
case "GET":
routerEcho.GET(path, dummyFunc)
case "POST":
routerEcho.POST(path, dummyFunc)
case "PUT":
routerEcho.PUT(path, dummyFunc)
case "DELETE":
routerEcho.DELETE(path, dummyFunc)
case "PATCH":
routerEcho.PATCH(path, dummyFunc)
case "OPTIONS":
routerEcho.OPTIONS(path, dummyFunc)
case "HEAD":
routerEcho.HEAD(path, dummyFunc)
case "CONNECT":
routerEcho.CONNECT(path, dummyFunc)
case "TRACE":
routerEcho.TRACE(path, dummyFunc)
default:
fmt.Println("Invalid method")
}
}

resolvePath := func(c echo.Context) string {
req := http.Request{}
ctx := routerEcho.NewContext(&req, nil)
routerEcho.Router().Find(c.Request().Method, c.Request().URL.Path, ctx)
return fmt.Sprintf("%s:%s", c.Request().Method, ctx.Path())
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {

ctx := c.Request().Context()
path := resolvePath(c)

config, ok := configMap[path]
if !ok {
config = configMap["DEFAULT"]
path = "DEFAULT"
}

requester, ok := ctx.Value(core.RequesterIdCtxKey).(string)
if !ok {
requester = c.RealIP()
}

key := "rate_limit:" + requester + ":" + path

// Get the current value of the bucket
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
// If the key does not exist, create it
value := fmt.Sprintf("%d;%d", time.Now().UnixMilli(), config.BucketSize)
if err == redis.Nil {
s.rdb.Set(
ctx,
key,
value,
time.Second*time.Duration(config.BucketSize)*time.Duration(config.RefillSpan),
)

c.Response().Header().Set("X-RateLimit-Limit", strconv.Itoa(config.BucketSize))
c.Response().Header().Set("X-RateLimit-Remaining", strconv.Itoa(config.BucketSize))

// Continue processing the request
return next(c)
} else {
return err
}
}

// Parse the value of the bucket
split := strings.Split(val, ";")
timestamp, _ := strconv.ParseInt(split[0], 10, 64)
lastRefill := time.UnixMilli(timestamp)

bucketCount, _ := strconv.Atoi(split[1])

// Calculate the number of requests that should have been refilled
now := time.Now()
elapsed := now.Sub(lastRefill)
refillCount := int(elapsed.Seconds() / config.RefillSpan)
consumedTime := float64(refillCount) * config.RefillSpan
nextRefill := lastRefill.Add(time.Second * time.Duration(consumedTime))

// Refill the bucket
bucketCount += refillCount

// If the bucket is full, clip it
if bucketCount > config.BucketSize {
bucketCount = config.BucketSize
}

// subtract one from the bucket
bucketCount -= 1

// If the bucket is empty, return a 429
if bucketCount < 0 {
return c.String(http.StatusTooManyRequests, "Rate limit exceeded\n")
}

// Update the bucket
value := fmt.Sprintf("%d;%d", nextRefill.UnixMilli(), bucketCount)
s.rdb.Set(
ctx,
key,
value,
time.Second*time.Duration(config.BucketSize)*time.Duration(config.RefillSpan),
)

c.Response().Header().Set("X-RateLimit-Limit", strconv.Itoa(config.BucketSize))
c.Response().Header().Set("X-RateLimit-Remaining", strconv.Itoa(bucketCount))

// Continue processing the request
return next(c)
}
}
}
4 changes: 2 additions & 2 deletions x/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestLocalRootSuccess(t *testing.T) {
FQDN: "local.example.com",
}

service := NewService(config, mockEntity, mockDomain, mockKey, mockPolicy)
service := NewService(nil, config, mockEntity, mockDomain, mockKey, mockPolicy)

c, req, rec, traceID := testutil.CreateHttpRequest()

Expand Down Expand Up @@ -134,7 +134,7 @@ func TestRemoteRootSuccess(t *testing.T) {
FQDN: "local.example.com",
}

service := NewService(config, mockEntity, mockDomain, mockKey, mockPolicy)
service := NewService(nil, config, mockEntity, mockDomain, mockKey, mockPolicy)
c, req, rec, traceID := testutil.CreateHttpRequest()

fmt.Print("traceID: ", traceID, "\n")
Expand Down
6 changes: 5 additions & 1 deletion x/auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"fmt"
"time"

"github.com/redis/go-redis/v9"

"github.com/totegamma/concurrent/core"
)

type service struct {
rdb *redis.Client
config core.Config
entity core.EntityService
domain core.DomainService
Expand All @@ -21,13 +24,14 @@ type service struct {

// NewService creates a new auth service
func NewService(
rdb *redis.Client,
config core.Config,
entity core.EntityService,
domain core.DomainService,
key core.KeyService,
policy core.PolicyService,
) core.AuthService {
return &service{config, entity, domain, key, policy}
return &service{rdb, config, entity, domain, key, policy}
}

// GetPassport takes client signed JWT and returns server signed JWT
Expand Down

0 comments on commit 120ddcc

Please sign in to comment.