Skip to content

Commit

Permalink
Merge pull request #51 from KonferCA/45-be-implement-rate-limiting-fo…
Browse files Browse the repository at this point in the history
…r-api-endpoints

Feature/9/add-rate-limits
  • Loading branch information
AmirAgassi authored Nov 12, 2024
2 parents 60165e3 + 63c8259 commit 5da1e0a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 4 deletions.
107 changes: 107 additions & 0 deletions internal/middleware/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package middleware

import (
"net/http"
"sync"
"time"

"github.com/labstack/echo/v4"
)

type visitor struct {
lastSeen time.Time
count int
blockTime time.Time
}

type RateLimiter struct {
visitors map[string]*visitor
mu sync.RWMutex
limit int
window time.Duration
blockPeriod time.Duration
}

func NewRateLimiter(limit int, window, blockPeriod time.Duration) *RateLimiter {
return &RateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: window,
blockPeriod: blockPeriod,
}
}

func (rl *RateLimiter) isBlocked(ip string) bool {
rl.mu.RLock()
v, exists := rl.visitors[ip]
rl.mu.RUnlock()

if !exists {
return false
}

return time.Now().Before(v.blockTime)
}

func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc {
// cleanup old entries every minute
go func() {
for {
time.Sleep(time.Minute)
rl.mu.Lock()
for ip, v := range rl.visitors {
if time.Since(v.lastSeen) > rl.window && time.Now().After(v.blockTime) {
delete(rl.visitors, ip)
}
}
rl.mu.Unlock()
}
}()

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ip := c.RealIP()

// check if ip is blocked
if rl.isBlocked(ip) {
return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later")
}

rl.mu.Lock()
v, exists := rl.visitors[ip]
if !exists {
rl.visitors[ip] = &visitor{
lastSeen: time.Now(),
count: 1,
}
} else {
// reset count if window has passed
if time.Since(v.lastSeen) > rl.window {
v.count = 1
v.lastSeen = time.Now()
} else {
v.count++
// block if limit exceeded
if v.count > rl.limit {
v.blockTime = time.Now().Add(rl.blockPeriod)
rl.mu.Unlock()
return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later")
}
}
}
rl.mu.Unlock()

return next(c)
}
}
}

// NewTestRateLimiter makes a cooler rate limiter with shorter durations for testing
func NewTestRateLimiter(limit int) *RateLimiter {
return &RateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: 100 * time.Millisecond, // 100ms window for testing
blockPeriod: 200 * time.Millisecond, // 200ms block for testing
}
}
1 change: 1 addition & 0 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

func (s *Server) setupAuthRoutes() {
auth := s.apiV1.Group("/auth")
auth.Use(s.authLimiter.RateLimit()) // special rate limit for auth routes
auth.POST("/signup", s.handleSignup)
auth.POST("/signin", s.handleSignin)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestAuth(t *testing.T) {
os.Setenv("DB_SSLMODE", "disable")

// create server
s, err := New()
s, err := New(true)
if err != nil {
t.Fatalf("failed to create server: %v", err)
}
Expand Down
29 changes: 27 additions & 2 deletions internal/server/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"fmt"
"os"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/labstack/echo/v4"
Expand All @@ -17,11 +18,13 @@ type Server struct {
queries *db.Queries
echoInstance *echo.Echo
apiV1 *echo.Group
authLimiter *middleware.RateLimiter
apiLimiter *middleware.RateLimiter
}

// Create a new Server instance and registers all routes and middlewares.
// Initialize database pool connection.
func New() (*Server, error) {
func New(testing bool) (*Server, error) {
connStr := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
os.Getenv("DB_HOST"),
Expand All @@ -41,10 +44,30 @@ func New() (*Server, error) {

e := echo.New()

e.HTTPErrorHandler = globalErrorHandler
// create rate limiters
var authLimiter, apiLimiter *middleware.RateLimiter

if testing {
authLimiter = middleware.NewTestRateLimiter(20)
apiLimiter = middleware.NewTestRateLimiter(100)
} else {
authLimiter = middleware.NewRateLimiter(
20, // 20 requests
5*time.Minute, // per 5 minutes
15*time.Minute, // block for 15 minutes if exceeded
)
apiLimiter = middleware.NewRateLimiter(
100, // 100 requests
time.Minute, // per minute
5*time.Minute, // block for 5 minutes if exceeded
)
}

// setup error handler and middlewares
e.HTTPErrorHandler = globalErrorHandler
e.Use(middleware.Logger())
e.Use(echoMiddleware.Recover())
e.Use(apiLimiter.RateLimit()) // global rate limit

customValidator := NewCustomValidator()
fmt.Printf("Initializing validator: %+v\n", customValidator)
Expand All @@ -54,6 +77,8 @@ func New() (*Server, error) {
DBPool: pool,
queries: queries,
echoInstance: e,
authLimiter: authLimiter,
apiLimiter: apiLimiter,
}

// setup api routes
Expand Down
15 changes: 15 additions & 0 deletions internal/server/test_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package server

import (
"github.com/labstack/echo/v4"
)

// TestIPMiddleware adds a fake IP (for testing rate limits)
func TestIPMiddleware(ip string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Request().Header.Set("X-Real-IP", ip)
return next(c)
}
}
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
}

s, err := server.New()
s, err := server.New(false)
if err != nil {
log.Fatal().Err(err).Msg("failed to initialized server")
}
Expand Down

0 comments on commit 5da1e0a

Please sign in to comment.