From cde91ba11f77c9e17d1f9c1bad5fbc5d8fef0418 Mon Sep 17 00:00:00 2001 From: AmirAgassi <33383085+AmirAgassi@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:22:56 -0500 Subject: [PATCH] Add rate limits --- internal/middleware/ratelimit.go | 107 +++++++++++++++++++++++++++++++ internal/server/auth.go | 1 + internal/server/auth_test.go | 2 +- internal/server/index.go | 28 +++++++- internal/server/test_helpers.go | 15 +++++ main.go | 2 +- 6 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/ratelimit.go create mode 100644 internal/server/test_helpers.go diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go new file mode 100644 index 0000000..6b86df1 --- /dev/null +++ b/internal/middleware/ratelimit.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" +) + +type visitor struct { + lastSeen time.Time + count int + blockTime time.Time +} + +type RateLimiter struct { + visitors map[string]*visitor + mu sync.RWMutex + limit int + window time.Duration + blockPeriod time.Duration +} + +func NewRateLimiter(limit int, window, blockPeriod time.Duration) *RateLimiter { + return &RateLimiter{ + visitors: make(map[string]*visitor), + limit: limit, + window: window, + blockPeriod: blockPeriod, + } +} + +func (rl *RateLimiter) isBlocked(ip string) bool { + rl.mu.RLock() + v, exists := rl.visitors[ip] + rl.mu.RUnlock() + + if !exists { + return false + } + + return time.Now().Before(v.blockTime) +} + +func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc { + // cleanup old entries every minute + go func() { + for { + time.Sleep(time.Minute) + rl.mu.Lock() + for ip, v := range rl.visitors { + if time.Since(v.lastSeen) > rl.window && time.Now().After(v.blockTime) { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } + }() + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ip := c.RealIP() + + // check if ip is blocked + if rl.isBlocked(ip) { + return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later") + } + + rl.mu.Lock() + v, exists := rl.visitors[ip] + if !exists { + rl.visitors[ip] = &visitor{ + lastSeen: time.Now(), + count: 1, + } + } else { + // reset count if window has passed + if time.Since(v.lastSeen) > rl.window { + v.count = 1 + v.lastSeen = time.Now() + } else { + v.count++ + // block if limit exceeded + if v.count > rl.limit { + v.blockTime = time.Now().Add(rl.blockPeriod) + rl.mu.Unlock() + return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later") + } + } + } + rl.mu.Unlock() + + return next(c) + } + } +} + +// NewTestRateLimiter makes a cooler rate limiter with shorter durations for testing +func NewTestRateLimiter(limit int) *RateLimiter { + return &RateLimiter{ + visitors: make(map[string]*visitor), + limit: limit, + window: 100 * time.Millisecond, // 100ms window for testing + blockPeriod: 200 * time.Millisecond, // 200ms block for testing + } +} diff --git a/internal/server/auth.go b/internal/server/auth.go index b6dd3d6..e9e842d 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -13,6 +13,7 @@ import ( func (s *Server) setupAuthRoutes() { auth := s.apiV1.Group("/auth") + auth.Use(s.authLimiter.RateLimit()) // special rate limit for auth routes auth.POST("/signup", s.handleSignup) auth.POST("/signin", s.handleSignin) } diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index edf3b5b..7d8254f 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -24,7 +24,7 @@ func TestAuth(t *testing.T) { os.Setenv("DB_SSLMODE", "disable") // create server - s, err := New() + s, err := New(true) if err != nil { t.Fatalf("failed to create server: %v", err) } diff --git a/internal/server/index.go b/internal/server/index.go index f4f84a6..888304f 100644 --- a/internal/server/index.go +++ b/internal/server/index.go @@ -3,6 +3,7 @@ package server import ( "fmt" "os" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/labstack/echo/v4" @@ -17,11 +18,13 @@ type Server struct { queries *db.Queries echoInstance *echo.Echo apiV1 *echo.Group + authLimiter *middleware.RateLimiter + apiLimiter *middleware.RateLimiter } // Create a new Server instance and registers all routes and middlewares. // Initialize database pool connection. -func New() (*Server, error) { +func New(testing bool) (*Server, error) { connStr := fmt.Sprintf( "host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", os.Getenv("DB_HOST"), @@ -41,8 +44,29 @@ func New() (*Server, error) { e := echo.New() + // create rate limiters + var authLimiter, apiLimiter *middleware.RateLimiter + + if testing { + authLimiter = middleware.NewTestRateLimiter(20) + apiLimiter = middleware.NewTestRateLimiter(100) + } else { + authLimiter = middleware.NewRateLimiter( + 20, // 20 requests + 5*time.Minute, // per 5 minutes + 15*time.Minute, // block for 15 minutes if exceeded + ) + apiLimiter = middleware.NewRateLimiter( + 100, // 100 requests + time.Minute, // per minute + 5*time.Minute, // block for 5 minutes if exceeded + ) + } + + // setup middlewares e.Use(middleware.Logger()) e.Use(echoMiddleware.Recover()) + e.Use(apiLimiter.RateLimit()) // global rate limit customValidator := NewCustomValidator() fmt.Printf("Initializing validator: %+v\n", customValidator) @@ -52,6 +76,8 @@ func New() (*Server, error) { DBPool: pool, queries: queries, echoInstance: e, + authLimiter: authLimiter, + apiLimiter: apiLimiter, } // setup api routes diff --git a/internal/server/test_helpers.go b/internal/server/test_helpers.go new file mode 100644 index 0000000..d9069a9 --- /dev/null +++ b/internal/server/test_helpers.go @@ -0,0 +1,15 @@ +package server + +import ( + "github.com/labstack/echo/v4" +) + +// TestIPMiddleware adds a fake IP (for testing rate limits) +func TestIPMiddleware(ip string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Request().Header.Set("X-Real-IP", ip) + return next(c) + } + } +} diff --git a/main.go b/main.go index b81c499..d58be26 100644 --- a/main.go +++ b/main.go @@ -24,7 +24,7 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}) } - s, err := server.New() + s, err := server.New(false) if err != nil { log.Fatal().Err(err).Msg("failed to initialized server") }