Skip to content

Commit

Permalink
Merge pull request #52 from PSNAppz/port_filter
Browse files Browse the repository at this point in the history
Port Filtering Plugin
  • Loading branch information
KingAkeem authored Nov 30, 2024
2 parents c8bae58 + e66010a commit 9a89697
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Variables
DB_USER="gorm"
DB_PASSWORD="gorm"
DB_NAME="gorm"
DB_NAME="shadowguard"

# Check if PostgreSQL is installed; if not, install it
if ! command -v psql &> /dev/null; then
Expand Down
6 changes: 4 additions & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"port": "5432",
"user": "gorm",
"password": "gorm",
"dbname": "gorm"
"dbname": "shadowguard"
},
"host": "http://localhost",
"port": ":8081",
Expand Down Expand Up @@ -49,7 +49,9 @@
"US",
"IN"
],
"region-blacklist": []
"region-blacklist": [],
"port-blacklist": [],
"port-whitelist": []
}
}
],
Expand Down
1 change: 1 addition & 0 deletions plugins/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ package plugins
import (
_ "shadowguard/plugins/ipfilter"
_ "shadowguard/plugins/monitor"
_ "shadowguard/plugins/portfilter"
_ "shadowguard/plugins/ratelimiter"
)
116 changes: 116 additions & 0 deletions plugins/portfilter/portfilter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package portfilter

import (
"errors"
"log"
"net/http"
"shadowguard/pkg/database"
"shadowguard/pkg/plugin"
"shadowguard/pkg/publisher"
"strconv"
"strings"
)

var Type string = "portfilter"

func init() {
plugin.RegisterPlugin(Type, NewPortFilterPlugin)
}

type PortFilterPlugin struct {
db database.DB
Settings map[string]interface{}
activeMode bool
portBlacklist []interface{}
portWhitelist []interface{}
publishers []publisher.Publisher
}

// NewPortFilterPlugin initializes the PortFilterPlugin.
func NewPortFilterPlugin(settings map[string]interface{}, db database.DB) plugin.Plugin {
publishers, err := publisher.CreatePublishers(settings)
if err != nil {
panic(err)
}

portBlacklist, _ := settings["port-blacklist"].([]interface{})
portWhitelist, _ := settings["port-whitelist"].([]interface{})

return &PortFilterPlugin{
db: db,
Settings: settings,
activeMode: settings["active_mode"].(bool),
portBlacklist: portBlacklist,
portWhitelist: portWhitelist,
publishers: publishers,
}
}

func (p *PortFilterPlugin) Type() string {
return Type
}

func (p *PortFilterPlugin) IsActiveMode() bool {
return p.activeMode
}

func (p *PortFilterPlugin) Notify(message string) {
for _, pub := range p.publishers {
err := pub.Publish(message)
if err != nil {
log.Printf("unable to notify publisher. message %s - error: %v", message, err)
}
}
}

func (p *PortFilterPlugin) Handle(r *http.Request) error {
// Extract port from the request's host field
hostPort := r.Host
port := 80 // Default HTTP port

if parts := strings.Split(hostPort, ":"); len(parts) == 2 {
var err error
port, err = strconv.Atoi(parts[1])
if err != nil {
return errors.New("invalid port in request")
}
}

// Check port against blacklist
for _, blacklistedPort := range p.portBlacklist {
if int(blacklistedPort.(int)) == port {
req, err := database.NewRequest(r, "portblacklist")
if err != nil {
print("ERROR")
println(err)
return err
}
p.db.Insert(req)
p.Notify("Port is blacklisted: " + strconv.Itoa(port))
return errors.New("port is blacklisted")
}
}

// Check port against whitelist if defined
if len(p.portWhitelist) > 0 {
isWhitelisted := false
for _, whitelistedPort := range p.portWhitelist {
if int(whitelistedPort.(int)) == port {
req, err := database.NewRequest(r, "portwhitelist")

if err != nil {
return err
}
p.db.Insert(req)
isWhitelisted = true
break
}
}
if !isWhitelisted {
p.Notify("Port is not whitelisted: " + strconv.Itoa(port))
return errors.New("port is not whitelisted")
}
}

return nil
}
74 changes: 74 additions & 0 deletions plugins/portfilter/portfilter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package portfilter

import (
"net/http"
"net/http/httptest"
"shadowguard/pkg/database"
"testing"
)

func TestPortFilterPlugin(t *testing.T) {
// Test 1: Port Blacklist
settings := map[string]interface{}{
"port-blacklist": []interface{}{22, 3306},
"port-whitelist": []interface{}{},
"active_mode": true,
}
plugin := NewPortFilterPlugin(settings, database.NewMock()).(*PortFilterPlugin)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "localhost:22"

err := plugin.Handle(req)
if err == nil || err.Error() != "port is blacklisted" {
t.Errorf("PortFilterPlugin did not block blacklisted port. Error: %v", err)
}

// Test 2: Port Whitelist
settings = map[string]interface{}{
"port-blacklist": []interface{}{},
"port-whitelist": []interface{}{80, 443},
"active_mode": true,
}
plugin = NewPortFilterPlugin(settings, database.NewMock()).(*PortFilterPlugin)

req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "localhost:80"

err = plugin.Handle(req)
if err != nil {
t.Errorf("PortFilterPlugin blocked whitelisted port. Error: %v", err)
}

// Test 3: Default Allow Behavior (No Whitelist or Blacklist)
settings = map[string]interface{}{
"port-blacklist": []interface{}{},
"port-whitelist": []interface{}{},
"active_mode": true,
}
plugin = NewPortFilterPlugin(settings, database.NewMock()).(*PortFilterPlugin)

req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "localhost:8080"

err = plugin.Handle(req)
if err != nil {
t.Errorf("PortFilterPlugin blocked port when no restrictions were configured. Error: %v", err)
}

// Test 4: Whitelist Restriction (Non-whitelisted port)
settings = map[string]interface{}{
"port-blacklist": []interface{}{},
"port-whitelist": []interface{}{80, 443},
"active_mode": true,
}
plugin = NewPortFilterPlugin(settings, database.NewMock()).(*PortFilterPlugin)

req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "localhost:8080"

err = plugin.Handle(req)
if err == nil || err.Error() != "port is not whitelisted" {
t.Errorf("PortFilterPlugin did not block non-whitelisted port. Error: %v", err)
}
}

0 comments on commit 9a89697

Please sign in to comment.