Skip to content

Commit

Permalink
Add Cloudflare IP support to rate limit middleware (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
SherRao authored Dec 27, 2024
2 parents 19f6523 + c9799ca commit 9428216
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
23 changes: 22 additions & 1 deletion backend/internal/middleware/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package middleware
import (
"KonferCA/SPUR/internal/v1/v1_common"
"net/http"
"os"
"sync"
"time"

"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"KonferCA/SPUR/common"
)

/*
Expand Down Expand Up @@ -87,7 +89,26 @@ Example (auth):
func (rl *RateLimiter) RateLimit() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ip := c.RealIP()
var ip string

env := os.Getenv("APP_ENV")
if env == common.TEST_ENV || env == common.DEVELOPMENT_ENV {
ip = c.Request().Header.Get("CF-Connecting-IP")
if ip == "" {
ip = c.Request().Header.Get("X-Real-IP")
if ip == "" {
// Fallback to direct IP in test/dev
ip = c.RealIP()
}
}
} else {
ip = c.Request().Header.Get("CF-Connecting-IP")
}

if ip == "" {
return echo.NewHTTPError(http.StatusForbidden, "missing client IP")
}

now := time.Now()

rl.mu.Lock()
Expand Down
12 changes: 6 additions & 6 deletions backend/internal/tests/rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestRateLimiter(t *testing.T) {
t.Run("blocks requests over limit", func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
req.Header.Set("CF-Connecting-IP", "192.168.1.100")

limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{
Requests: 1,
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestRateLimiter(t *testing.T) {
ips := []string{"192.168.1.1", "192.168.1.2"}
for _, ip := range ips {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Real-IP", ip)
req.Header.Set("CF-Connecting-IP", ip)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

Expand All @@ -113,13 +113,13 @@ func TestRateLimiter(t *testing.T) {
t.Run("allows requests after window reset", func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
req.Header.Set("CF-Connecting-IP", "192.168.1.200")

limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{
Requests: 1,
Window: 100 * time.Millisecond,
BlockPeriod: 50 * time.Millisecond,
MaxBlocks: 2,
BlockPeriod: 50 * time.Millisecond,
MaxBlocks: 2,
})

handler := func(c echo.Context) error {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestRateLimiter(t *testing.T) {
t.Run("applies progressive blocking", func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.300")
req.Header.Set("CF-Connecting-IP", "192.168.1.300")

limiter := middleware.NewRateLimiter(&middleware.RateLimiterConfig{
Requests: 1,
Expand Down

0 comments on commit 9428216

Please sign in to comment.