diff --git a/backend/internal/middleware/rate_limit.go b/backend/internal/middleware/rate_limit.go index 2e112f7f..5f639a7f 100644 --- a/backend/internal/middleware/rate_limit.go +++ b/backend/internal/middleware/rate_limit.go @@ -3,11 +3,13 @@ package middleware import ( "KonferCA/SPUR/internal/v1/v1_common" "net/http" + "os" "sync" "time" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" + "KonferCA/SPUR/common" ) /* @@ -87,7 +89,26 @@ Example (auth): func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - ip := c.RealIP() + var ip string + + env := os.Getenv("APP_ENV") + if env == common.TEST_ENV || env == common.DEVELOPMENT_ENV { + ip = c.Request().Header.Get("CF-Connecting-IP") + if ip == "" { + ip = c.Request().Header.Get("X-Real-IP") + if ip == "" { + // Fallback to direct IP in test/dev + ip = c.RealIP() + } + } + } else { + ip = c.Request().Header.Get("CF-Connecting-IP") + } + + if ip == "" { + return echo.NewHTTPError(http.StatusForbidden, "missing client IP") + } + now := time.Now() rl.mu.Lock() diff --git a/backend/internal/tests/rate_limit_test.go b/backend/internal/tests/rate_limit_test.go index d257b948..0a7d96ec 100644 --- a/backend/internal/tests/rate_limit_test.go +++ b/backend/internal/tests/rate_limit_test.go @@ -47,7 +47,7 @@ func TestRateLimiter(t *testing.T) { t.Run("blocks requests over limit", func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("X-Real-IP", "192.168.1.100") + req.Header.Set("CF-Connecting-IP", "192.168.1.100") limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{ Requests: 1, @@ -93,7 +93,7 @@ func TestRateLimiter(t *testing.T) { ips := []string{"192.168.1.1", "192.168.1.2"} for _, ip := range ips { req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("X-Real-IP", ip) + req.Header.Set("CF-Connecting-IP", ip) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -113,13 +113,13 @@ func TestRateLimiter(t *testing.T) { t.Run("allows requests after window reset", func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("X-Real-IP", "192.168.1.200") + req.Header.Set("CF-Connecting-IP", "192.168.1.200") limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{ Requests: 1, Window: 100 * time.Millisecond, - BlockPeriod: 50 * time.Millisecond, - MaxBlocks: 2, + BlockPeriod: 50 * time.Millisecond, + MaxBlocks: 2, }) handler := func(c echo.Context) error { @@ -154,7 +154,7 @@ func TestRateLimiter(t *testing.T) { t.Run("applies progressive blocking", func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("X-Real-IP", "192.168.1.300") + req.Header.Set("CF-Connecting-IP", "192.168.1.300") limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{ Requests: 1,