diff --git a/.env.sample b/.env.sample index 415f8bb..c1c2355 100644 --- a/.env.sample +++ b/.env.sample @@ -185,8 +185,8 @@ ACTIVATE_FIREWALL=yes # Block one or several IPs [LISTTYPE=blacklist | IP=x.x.x.x] LISTTYPE=whitelist # LISTTYPE=blacklist -# IP - comma-separated list -# IP=192.168.0.1,10.0.0.1 +# IP - comma-separated list, IPv4, IPv6, CIDR +# IP=192.168.0.1,10.0.0.1,172.16.0.0/12,2400:cb00::/32 IP=* # diff --git a/example/.env.sample b/example/.env.sample index 415f8bb..c1c2355 100644 --- a/example/.env.sample +++ b/example/.env.sample @@ -185,8 +185,8 @@ ACTIVATE_FIREWALL=yes # Block one or several IPs [LISTTYPE=blacklist | IP=x.x.x.x] LISTTYPE=whitelist # LISTTYPE=blacklist -# IP - comma-separated list -# IP=192.168.0.1,10.0.0.1 +# IP - comma-separated list, IPv4, IPv6, CIDR +# IP=192.168.0.1,10.0.0.1,172.16.0.0/12,2400:cb00::/32 IP=* # diff --git a/lib/middleware/firewall.go b/lib/middleware/firewall.go index c99804d..8ad8d2d 100644 --- a/lib/middleware/firewall.go +++ b/lib/middleware/firewall.go @@ -5,28 +5,77 @@ package middleware // Copyright (c) 2022 pilinux import ( + "fmt" + "net" "net/http" "strings" + "sync" "github.com/gin-gonic/gin" ) +// firewall package-level variables +var ( + parsedOnce sync.Once + ipNets []*net.IPNet + ipListMap map[string]bool + ipCIDR bool +) + // Firewall - whitelist/blacklist IPs func Firewall(listType string, ipList string) gin.HandlerFunc { return func(c *gin.Context) { - // Get the real client IP - clientIP := c.ClientIP() + // parse the IP list only once + parsedOnce.Do(func() { + parseIPList(listType, ipList) + }) + + // get the real client IP + clientNetIP := net.ParseIP(c.ClientIP()) + if clientNetIP == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, "IP invalid") + return + } + clientIP := clientNetIP.String() if !strings.Contains(ipList, "*") { if listType == "whitelist" { - if !strings.Contains(ipList, clientIP) { + var allowIP bool + if len(ipListMap) > 0 { + if _, ok := ipListMap[clientIP]; ok { + allowIP = true + } + } + if !allowIP && ipCIDR { + for _, ipNet := range ipNets { + if ipNet.Contains(clientNetIP) { + allowIP = true + break + } + } + } + if !allowIP { c.AbortWithStatusJSON(http.StatusUnauthorized, "IP blocked") return } } if listType == "blacklist" { - if strings.Contains(ipList, clientIP) { + var blockIP bool + if len(ipListMap) > 0 { + if _, ok := ipListMap[clientIP]; ok { + blockIP = true + } + } + if !blockIP && ipCIDR { + for _, ipNet := range ipNets { + if ipNet.Contains(clientNetIP) { + blockIP = true + break + } + } + } + if blockIP { c.AbortWithStatusJSON(http.StatusUnauthorized, "IP blocked") return } @@ -43,3 +92,77 @@ func Firewall(listType string, ipList string) gin.HandlerFunc { c.Next() } } + +// helper function to parse the IP list and CIDR notations +func parseIPList(listType, ipList string) { + ipListMap = make(map[string]bool) + + // split the list by comma and trim spaces + ipListSlice := strings.Split(ipList, ",") + for _, ip := range ipListSlice { + ip = strings.TrimSpace(ip) + if ip == "" { + continue + } + if strings.Contains(ip, "/") { + // parse CIDR notations + _, ipNet, err := net.ParseCIDR(ip) + if err == nil { + ipNets = append(ipNets, ipNet) + } + } else { + ipListMap[ip] = true + } + } + + // if any CIDR notations were found, set ipCIDR to true + if len(ipNets) > 0 { + ipCIDR = true + } + + var validIPs string + var validCIDRs string + for ip := range ipListMap { + validIPs += ip + ", " + } + for _, ipNet := range ipNets { + validCIDRs += ipNet.String() + ", " + } + // remove the trailing comma and space + validIPs = strings.TrimSuffix(validIPs, ", ") + validCIDRs = strings.TrimSuffix(validCIDRs, ", ") + + fmt.Println("application firewall initialized") + if listType == "whitelist" { + if strings.Contains(validIPs, "*") { + fmt.Println("whitelisted IPs: *") + } else { + if len(validIPs) > 0 { + fmt.Println("whitelisted IPs:", validIPs) + } + if len(validCIDRs) > 0 { + fmt.Println("whitelisted CIDRs:", validCIDRs) + } + } + } + if listType == "blacklist" { + if strings.Contains(validIPs, "*") { + fmt.Println("blacklisted IPs: *") + } else { + if len(validIPs) > 0 { + fmt.Println("blacklisted IPs:", validIPs) + } + if len(validCIDRs) > 0 { + fmt.Println("blacklisted CIDRs:", validCIDRs) + } + } + } +} + +// ResetFirewallState - helper function to reset firewall package-level variables +func ResetFirewallState() { + parsedOnce = sync.Once{} + ipNets = nil + ipListMap = nil + ipCIDR = false +} diff --git a/lib/middleware/firewall_test.go b/lib/middleware/firewall_test.go index be20585..66a2e19 100644 --- a/lib/middleware/firewall_test.go +++ b/lib/middleware/firewall_test.go @@ -10,6 +10,7 @@ import ( ) type testCase struct { + testNo string listType string ipList string remoteIP string @@ -18,42 +19,279 @@ type testCase struct { func TestFirewall(t *testing.T) { testCases := []testCase{ - {"whitelist", "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", "192.168.0.1", http.StatusOK}, - {"whitelist", "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", "192.168.0.5", http.StatusUnauthorized}, - {"blacklist", "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", "192.168.0.1", http.StatusUnauthorized}, - {"blacklist", "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", "192.168.0.5", http.StatusOK}, + // list of IPs + { + "1.1", + "whitelist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "192.168.0.1", + http.StatusOK, + }, + { + "1.2", + "whitelist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "192.168.0.5", + http.StatusUnauthorized, + }, + { + "1.3", + "blacklist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "192.168.0.1", + http.StatusUnauthorized, + }, + { + "1.4", + "blacklist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "192.168.0.5", + http.StatusOK, + }, - {"whitelist", "*", "192.168.1.1", http.StatusOK}, - {"blacklist", "*", "192.168.1.1", http.StatusUnauthorized}, + // missing client IP + { + "2.1", + "whitelist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "", + http.StatusUnauthorized, + }, + { + "2.2", + "blacklist", + "192.168.0.1, 192.168.0.2, 192.168.0.3, 192.168.0.4", + "", + http.StatusUnauthorized, + }, + + // wildcard + { + "3.1", + "whitelist", + "*", + "192.168.1.1", + http.StatusOK, + }, + { + "3.2", + "blacklist", + "*", + "192.168.1.1", + http.StatusUnauthorized, + }, + + // CIDR + { + "4.1", + "whitelist", + "192.168.0.0/16", + "192.168.1.1", + http.StatusOK, + }, + { + "4.2", + "whitelist", + "192.168.0.0/16", + "192.169.1.1", + http.StatusUnauthorized, + }, + { + "4.3", + "whitelist", + "192.168.10.0/24", + "192.168.10.255", + http.StatusOK, + }, + { + "4.4", + "whitelist", + "192.168.10.0/24", + "192.168.11.1", + http.StatusUnauthorized, + }, + { + "4.5", + "whitelist", + "192.168.10.10/32", + "192.168.10.10", + http.StatusOK, + }, + { + "4.6", + "whitelist", + "192.168.10.10/32", + "192.168.10.11", + http.StatusUnauthorized, + }, + { + "4.7", + "whitelist", + "172.16.0.0/12", + "172.22.0.1", + http.StatusOK, + }, + { + "4.8", + "whitelist", + "172.16.0.0/12", + "172.32.0.1", + http.StatusUnauthorized, + }, + { + "4.9", + "blacklist", + "192.168.0.0/16", + "192.168.1.1", + http.StatusUnauthorized, + }, + { + "4.10", + "blacklist", + "192.168.0.0/16", + "192.169.1.1", + http.StatusOK, + }, + + // CIDR and IPs + { + "5.1", + "whitelist", + "192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.168.0.1", + http.StatusOK, + }, + { + "5.2", + "whitelist", + "192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.169.0.1", + http.StatusOK, + }, + { + "5.3", + "whitelist", + "192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.170.10.240", + http.StatusOK, + }, + { + "5.4", + "blacklist", + "192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.170.0.1", + http.StatusOK, + }, + { + "5.5", + "blacklist", + "192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.168.10.1", + http.StatusUnauthorized, + }, + + // *, CIDR and IPs + { + "6.1", + "whitelist", + "*, 192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.171.0.1", + http.StatusOK, + }, + { + "6.2", + "blacklist", + "*, 192.168.0.0/16, 192.169.0.1, 192.169.0.2, 192.169.0.3, 192.169.0.4, 192.170.10.0/24", + "192.171.0.1", + http.StatusUnauthorized, + }, + + // IPv6 + { + "7.1", + "whitelist", + "2001:db8::", + "2001:db8::", + http.StatusOK, + }, + { + "7.2", + "whitelist", + "2001:db8::/32", + "2001:db8::1", + http.StatusOK, + }, + { + "7.3", + "whitelist", + "2001:db8::/128", + "2001:db8::1", + http.StatusUnauthorized, + }, + { + "7.4", + "blacklist", + "2001:db8::/32", + "2001:db8::1", + http.StatusUnauthorized, + }, + { + "7.5", + "blacklist", + "2001:db8::/128", + "2001:db8::ffff", + http.StatusOK, + }, + + // IPv4 and IPv6 + { + "8.1", + "whitelist", + "2001:db8::/32,, 192.168.10.10/32,", + "2001:db8::1", + http.StatusOK, + }, + { + "8.2", + "blacklist", + "2001:db8::/32,, 192.168.10.10/32,", + "192.168.10.10", + http.StatusUnauthorized, + }, } for _, tc := range testCases { - // set up a gin router and handler - gin.SetMode(gin.TestMode) - router := gin.New() - err := router.SetTrustedProxies(nil) - if err != nil { - t.Errorf("failed to set trusted proxies to nil") - } - router.TrustedPlatform = "X-Real-Ip" - router.Use(middleware.Firewall(tc.listType, tc.ipList)) - router.GET("/", func(c *gin.Context) { - c.Status(http.StatusOK) - }) + t.Run("TestCase"+tc.testNo, func(t *testing.T) { + // reset firewall state between test cases + middleware.ResetFirewallState() + + // set up a gin router and handler + gin.SetMode(gin.TestMode) + router := gin.New() + err := router.SetTrustedProxies(nil) + if err != nil { + t.Errorf("failed to set trusted proxies to nil") + } + router.TrustedPlatform = "X-Real-Ip" + router.Use(middleware.Firewall(tc.listType, tc.ipList)) + router.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) + }) - // create a request and response recorder - w := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Errorf("failed to create an HTTP request") - return - } - req.Header.Set("X-Real-Ip", tc.remoteIP) - - // pass the request to the router and check the response - router.ServeHTTP(w, req) - if w.Code != tc.statusExp { - t.Errorf("expected status code %d, got %d", tc.statusExp, w.Code) - } + // create a request and response recorder + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("failed to create an HTTP request") + return + } + req.Header.Set("X-Real-Ip", tc.remoteIP) + + // pass the request to the router and check the response + router.ServeHTTP(w, req) + if w.Code != tc.statusExp { + t.Errorf("testCase no %s, expected status code %d, got %d", tc.testNo, tc.statusExp, w.Code) + } + }) } }