Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/9/add-rate-limits #51

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions internal/middleware/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package middleware

import (
"net/http"
"sync"
"time"

"github.com/labstack/echo/v4"
)

type visitor struct {
lastSeen time.Time
count int
blockTime time.Time
}

type RateLimiter struct {
visitors map[string]*visitor
mu sync.RWMutex
limit int
window time.Duration
blockPeriod time.Duration
}

func NewRateLimiter(limit int, window, blockPeriod time.Duration) *RateLimiter {
return &RateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: window,
blockPeriod: blockPeriod,
}
}

func (rl *RateLimiter) isBlocked(ip string) bool {
rl.mu.RLock()
v, exists := rl.visitors[ip]
rl.mu.RUnlock()

if !exists {
return false
}

return time.Now().Before(v.blockTime)
}

func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc {
// cleanup old entries every minute
go func() {
for {
time.Sleep(time.Minute)
rl.mu.Lock()
for ip, v := range rl.visitors {
if time.Since(v.lastSeen) > rl.window && time.Now().After(v.blockTime) {
delete(rl.visitors, ip)
}
}
rl.mu.Unlock()
}
}()

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ip := c.RealIP()

// check if ip is blocked
if rl.isBlocked(ip) {
return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later")
}

rl.mu.Lock()
v, exists := rl.visitors[ip]
if !exists {
rl.visitors[ip] = &visitor{
lastSeen: time.Now(),
count: 1,
}
} else {
// reset count if window has passed
if time.Since(v.lastSeen) > rl.window {
v.count = 1
v.lastSeen = time.Now()
} else {
v.count++
// block if limit exceeded
if v.count > rl.limit {
v.blockTime = time.Now().Add(rl.blockPeriod)
rl.mu.Unlock()
return echo.NewHTTPError(http.StatusTooManyRequests, "too many requests, please try again later")
}
}
}
rl.mu.Unlock()

return next(c)
}
}
}

// NewTestRateLimiter makes a cooler rate limiter with shorter durations for testing
func NewTestRateLimiter(limit int) *RateLimiter {
return &RateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: 100 * time.Millisecond, // 100ms window for testing
blockPeriod: 200 * time.Millisecond, // 200ms block for testing
}
}
1 change: 1 addition & 0 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

func (s *Server) setupAuthRoutes() {
auth := s.apiV1.Group("/auth")
auth.Use(s.authLimiter.RateLimit()) // special rate limit for auth routes
auth.POST("/signup", s.handleSignup)
auth.POST("/signin", s.handleSignin)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestAuth(t *testing.T) {
os.Setenv("DB_SSLMODE", "disable")

// create server
s, err := New()
s, err := New(true)
if err != nil {
t.Fatalf("failed to create server: %v", err)
}
Expand Down
29 changes: 27 additions & 2 deletions internal/server/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"fmt"
"os"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/labstack/echo/v4"
Expand All @@ -17,11 +18,13 @@ type Server struct {
queries *db.Queries
echoInstance *echo.Echo
apiV1 *echo.Group
authLimiter *middleware.RateLimiter
apiLimiter *middleware.RateLimiter
}

// Create a new Server instance and registers all routes and middlewares.
// Initialize database pool connection.
func New() (*Server, error) {
func New(testing bool) (*Server, error) {
connStr := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
os.Getenv("DB_HOST"),
Expand All @@ -41,10 +44,30 @@ func New() (*Server, error) {

e := echo.New()

e.HTTPErrorHandler = globalErrorHandler
// create rate limiters
var authLimiter, apiLimiter *middleware.RateLimiter

if testing {
authLimiter = middleware.NewTestRateLimiter(20)
apiLimiter = middleware.NewTestRateLimiter(100)
} else {
authLimiter = middleware.NewRateLimiter(
20, // 20 requests
5*time.Minute, // per 5 minutes
15*time.Minute, // block for 15 minutes if exceeded
)
apiLimiter = middleware.NewRateLimiter(
100, // 100 requests
time.Minute, // per minute
5*time.Minute, // block for 5 minutes if exceeded
)
}

// setup error handler and middlewares
e.HTTPErrorHandler = globalErrorHandler
e.Use(middleware.Logger())
e.Use(echoMiddleware.Recover())
e.Use(apiLimiter.RateLimit()) // global rate limit

customValidator := NewCustomValidator()
fmt.Printf("Initializing validator: %+v\n", customValidator)
Expand All @@ -54,6 +77,8 @@ func New() (*Server, error) {
DBPool: pool,
queries: queries,
echoInstance: e,
authLimiter: authLimiter,
apiLimiter: apiLimiter,
}

// setup api routes
Expand Down
15 changes: 15 additions & 0 deletions internal/server/test_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package server

import (
"github.com/labstack/echo/v4"
)

// TestIPMiddleware adds a fake IP (for testing rate limits)
func TestIPMiddleware(ip string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Request().Header.Set("X-Real-IP", ip)
return next(c)
}
}
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
}

s, err := server.New()
s, err := server.New(false)
if err != nil {
log.Fatal().Err(err).Msg("failed to initialized server")
}
Expand Down
Loading