From f32569cb05e844424910b85fa6241d7053d5584e Mon Sep 17 00:00:00 2001 From: Saurlax Date: Wed, 4 Dec 2024 13:40:07 +0800 Subject: [PATCH] chore: optmize packet capturing --- .gitignore | 1 + main.go | 10 +-- tic/netvigil.go | 7 +- tic/threatbook.go | 8 +- tic/tic.go | 28 ++++-- util/action.go | 13 +-- util/config.go | 4 +- util/db.go | 5 +- util/netstat.go | 217 ++++++++++++++++++++-------------------------- util/web.go | 3 +- 10 files changed, 138 insertions(+), 158 deletions(-) diff --git a/.gitignore b/.gitignore index 7b56a77..84f3505 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ go.work *.db +*.db-journal *.mmdb dist /config.toml diff --git a/main.go b/main.go index b02b30f..e08b434 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "log" "os" "os/signal" "syscall" @@ -15,18 +16,11 @@ func main() { sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) <-sig + log.Println("Shutting down...") util.DB.Close() } func init() { - if viper.GetDuration("capture_interval") > 0 { - go func() { - for { - util.Capture() - time.Sleep(viper.GetDuration("capture_interval")) - } - }() - } if viper.GetDuration("check_interval") > 0 { go func() { for { diff --git a/tic/netvigil.go b/tic/netvigil.go index 7c19038..4d50375 100644 --- a/tic/netvigil.go +++ b/tic/netvigil.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "log" "net/http" "github.com/saurlax/netvigil/util" @@ -28,13 +29,13 @@ func (t *Netvigil) Check(ips []string) []*util.Threat { IPs: ips, }) if err != nil { - fmt.Println("[Netvigil] Failed to marshal request:", err) + log.Println("[Netvigil] Failed to marshal request:", err) return threats } resp, err := http.Post(fmt.Sprintf("%s/api/check", t.Server), "application/json", bytes.NewBuffer(requestBody)) if err != nil { - fmt.Println("[Netvigil] Failed to request:", err) + log.Println("[Netvigil] Failed to request:", err) return threats } defer resp.Body.Close() @@ -42,7 +43,7 @@ func (t *Netvigil) Check(ips []string) []*util.Threat { var res NetvigilResponse err = json.NewDecoder(resp.Body).Decode(&res) if err != nil { - fmt.Println("[Netvigil] Failed to decode response:", err) + log.Println("[Netvigil] Failed to decode response:", err) return threats } diff --git a/tic/threatbook.go b/tic/threatbook.go index 3ee96a5..f27c1a1 100644 --- a/tic/threatbook.go +++ b/tic/threatbook.go @@ -2,7 +2,7 @@ package tic import ( "encoding/json" - "fmt" + "log" "net" "net/http" "net/url" @@ -41,7 +41,7 @@ func (t *Threatbook) Check(ips []string) []*util.Threat { "resource": resource, }) if err != nil { - fmt.Println("[Threatbook] Failed to request:", err) + log.Println("[Threatbook] Failed to request:", err) return threats } defer resp.Body.Close() @@ -49,11 +49,11 @@ func (t *Threatbook) Check(ips []string) []*util.Threat { var res ThreatbookResponse err = json.NewDecoder(resp.Body).Decode(&res) if err != nil { - fmt.Println("[Threatbook] Failed to decode response:", err) + log.Println("[Threatbook] Failed to decode response:", err) return threats } if res.ResponseCode != 0 { - fmt.Printf("[Threatbook] Abnormal response (%v): %v\n", res.ResponseCode, res.VerBoseMsg) + log.Printf("[Threatbook] Abnormal response (%v): %v\n", res.ResponseCode, res.VerBoseMsg) } for ip, data := range res.Data { diff --git a/tic/tic.go b/tic/tic.go index 566c752..d979068 100644 --- a/tic/tic.go +++ b/tic/tic.go @@ -1,9 +1,10 @@ package tic import ( - "fmt" + "log" "net" + "github.com/google/gopacket/layers" "github.com/saurlax/netvigil/util" "github.com/spf13/viper" ) @@ -15,8 +16,8 @@ type TIC interface { var tics = make([]TIC, 0) -// Create a TIC instance with config -func Create(m map[string]any) TIC { +// create a TIC instance with config +func create(m map[string]any) TIC { switch m["type"] { case "local": blacklist := make([]net.IP, 0) @@ -41,7 +42,7 @@ func Create(m map[string]any) TIC { } // Check all IPs with all TICs created -func CheckIPs(ips []string) []*util.Threat { +func CheckAll(ips []string) []*util.Threat { threats, _ := util.GetThreatsByIPs(ips) ips2check := make([]string, 0) Loop: @@ -71,13 +72,22 @@ func Check() { Loop: for { select { - case ip := <-util.IPs: - ips = append(ips, ip) + case packet := <-util.Packets: + ipv4Layer := packet.Layer(layers.LayerTypeIPv4) + if ipv4Layer != nil { + ip := ipv4Layer.(*layers.IPv4) + ips = append(ips, ip.DstIP.String()) + } + ipv6Layer := packet.Layer(layers.LayerTypeIPv6) + if ipv6Layer != nil { + ip := ipv6Layer.(*layers.IPv6) + ips = append(ips, ip.DstIP.String()) + } default: break Loop } } - CheckIPs(ips) + CheckAll(ips) } func init() { @@ -87,9 +97,9 @@ func init() { if !ok { break } - tic := Create(m) + tic := create(m) if tic != nil { - fmt.Printf("[TIC] %s created\n", m["type"]) + log.Printf("[TIC] %s created\n", m["type"]) tics = append(tics, tic) } } diff --git a/util/action.go b/util/action.go index 6510ca2..9c18d78 100644 --- a/util/action.go +++ b/util/action.go @@ -2,6 +2,7 @@ package util import ( "fmt" + "log" "os" "os/exec" @@ -35,15 +36,15 @@ func DelFireWall(ip string) { } func suspiciousAction(n Netstat) { - AddFireWall(n.RemoteIP) - fmt.Printf("\x1B[33mSuspicious threat detected: %s → %s\x1B[0m\n", n.Executable, n.RemoteIP) - beeep.Notify("Suspicious threat detected!", fmt.Sprintf("%s → %s", n.Executable, n.RemoteIP), "") + AddFireWall(n.DstIP) + log.Printf("\x1B[33mSuspicious threat detected: %s → %s\x1B[0m\n", n.Executable, n.DstIP) + beeep.Notify("Suspicious threat detected!", fmt.Sprintf("%s → %s", n.Executable, n.DstIP), "") } func maliciousAction(n Netstat) { - AddFireWall(n.RemoteIP) - fmt.Printf("\x1B[31mMalicious threat detected: %s → %s\x1B[0m\n", n.Executable, n.RemoteIP) - beeep.Notify("Malicious threat detected!", fmt.Sprintf("%s → %s", n.Executable, n.RemoteIP), "") + AddFireWall(n.DstIP) + log.Printf("\x1B[31mMalicious threat detected: %s → %s\x1B[0m\n", n.Executable, n.DstIP) + beeep.Notify("Malicious threat detected!", fmt.Sprintf("%s → %s", n.Executable, n.DstIP), "") } func (t Threat) Action(n Netstat) { diff --git a/util/config.go b/util/config.go index 2e966b9..33918ab 100644 --- a/util/config.go +++ b/util/config.go @@ -1,6 +1,8 @@ package util import ( + "log" + "github.com/spf13/viper" ) @@ -12,6 +14,6 @@ func init() { viper.SetConfigFile("config.toml") err := viper.ReadInConfig() if err != nil { - panic(err) + log.Panicln("Failed to read config:", err) } } diff --git a/util/db.go b/util/db.go index 9e05a5f..73c4984 100644 --- a/util/db.go +++ b/util/db.go @@ -2,6 +2,7 @@ package util import ( "database/sql" + "log" "github.com/IncSW/geoip2" _ "github.com/mattn/go-sqlite3" @@ -16,10 +17,10 @@ var ( func init() { DB, err = sql.Open("sqlite3", "file:netvigil.db") if err != nil { - panic(err) + log.Panicln("Failed to open database:", err) } GeoLiteCity, err = geoip2.NewCityReaderFromFile("GeoLite2-City.mmdb") if err != nil { - panic(err) + log.Panicln("Failed to open GeoLite2-City.mmdb:", err) } } diff --git a/util/netstat.go b/util/netstat.go index f261045..ac1c07c 100644 --- a/util/netstat.go +++ b/util/netstat.go @@ -1,11 +1,8 @@ package util import ( - "fmt" "log" "net" - "sync" - "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -16,152 +13,124 @@ import ( type Netstat struct { ID int64 `json:"id"` Time int64 `json:"time"` - LocalIP string `json:"localIP"` - LocalPort uint16 `json:"localPort"` - RemoteIP string `json:"remoteIP"` - RemotePort uint16 `json:"remotePort"` + SrcIP string `json:"srcIP"` + SrcPort uint16 `json:"srcPort"` + DstIP string `json:"dstIP"` + DstPort uint16 `json:"dstPort"` Executable string `json:"executable"` Location string `json:"location"` } -var IPs chan string - -// netstat obtained from the system at different times will be duplicated, using a cache to deduplicate -// -// key: LocalIP:LcoalPort-RemoteIP:RemotePort, value: time -var cache = make(map[string]time.Time) +var ( + Packets chan gopacket.Packet +) // Capture network traffic information -func Capture() { - - for e, t := range cache { - if time.Since(t) > 60*time.Second { - delete(cache, e) +func capture(ps *gopacket.PacketSource) { + for packet := range ps.Packets() { + n := Netstat{ + Time: packet.Metadata().Timestamp.Unix(), } - } - - devices, err := pcap.FindAllDevs() - if err != nil { - log.Fatal(err) - } - - var wg sync.WaitGroup - - // Open pcap device - for _, device := range devices { - wg.Add(1) // Increment the WaitGroup counter for each device - // Launch a goroutine to capture packets for each device - go func(device pcap.Interface) { - defer wg.Done() // Decrement the WaitGroup counter when the goroutine finishes - - handle, err := pcap.OpenLive(device.Name, 1600, true, -1) - if err != nil { - log.Printf("Error opening device %s: %v\n", device.Name, err) - return + ipv4Layer := packet.Layer(layers.LayerTypeIPv4) + if ipv4Layer != nil { + if ipv4Layer.(*layers.IPv4).DstIP.IsLoopback() { + continue } - defer handle.Close() - - // Set a timeout for each device capture (2 seconds) - timer := time.NewTimer(2 * time.Second) - defer timer.Stop() - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - log.Printf("Starting capture on device: %s\n", device.Name) - - // Capture packets until the timeout is reached - for { - select { - case packet := <-packetSource.Packets(): - ipLayer := packet.Layer(layers.LayerTypeIPv4) - if ipLayer != nil { - ip := ipLayer.(*layers.IPv4) - - // Skip loopback addresses (127.0.0.1) - if ip.SrcIP.IsLoopback() || ip.DstIP.IsLoopback() { - continue - } - - key := fmt.Sprintf("%s-%s", ip.SrcIP.String(), ip.DstIP.String()) - - // Continue if the entry is already in the cache - if _, ok := cache[key]; ok { - continue - } - - log.Printf("Processing connection: %s\n", key) - cache[key] = time.Now() - - n := Netstat{ - Time: time.Now().UnixMilli(), - LocalIP: ip.SrcIP.String(), - LocalPort: 0, - RemoteIP: ip.DstIP.String(), - RemotePort: 0, - } - - // Get the location of the IP address - ipAddr := net.ParseIP(n.RemoteIP) - record, err := GeoLiteCity.Lookup(ipAddr) - if err == nil { - countryName := record.Country.Names["zh-CN"] - if countryName == "" { - countryName = record.Country.Names["en"] - } - cityName := record.City.Names["zh-CN"] - if cityName == "" { - cityName = record.City.Names["en"] - } - n.Location = countryName - if cityName != "" { - if n.Location != "" { - n.Location += " " - } - n.Location += cityName - } - } + n.SrcIP = ipv4Layer.(*layers.IPv4).SrcIP.String() + n.DstIP = ipv4Layer.(*layers.IPv4).DstIP.String() + } else { + ipv6Layer := packet.Layer(layers.LayerTypeIPv6) + if ipv6Layer != nil { + if ipv6Layer.(*layers.IPv6).DstIP.IsLoopback() { + continue + } + n.SrcIP = ipv6Layer.(*layers.IPv6).SrcIP.String() + n.DstIP = ipv6Layer.(*layers.IPv6).DstIP.String() + } else { + continue + } + } - n.Save() + tcpLayer := packet.Layer(layers.LayerTypeTCP) + if tcpLayer != nil { + n.SrcPort = uint16(tcpLayer.(*layers.TCP).SrcPort) + n.DstPort = uint16(tcpLayer.(*layers.TCP).DstPort) + } else { + udpLayer := packet.Layer(layers.LayerTypeUDP) + if udpLayer != nil { + n.SrcPort = uint16(udpLayer.(*layers.UDP).SrcPort) + n.DstPort = uint16(udpLayer.(*layers.UDP).DstPort) + } else { + continue + } + } - // Send the remote IP address to the channel - select { - case IPs <- n.RemoteIP: - default: - // If the channel is full, exit the goroutine - return - } - } - // Timeout, stop capturing and move to the next device - case <-timer.C: - log.Printf("Capture timeout reached for device %s\n", device.Name) - return + // Get the location of the IP address + ipAddr := net.ParseIP(n.DstIP) + record, _ := GeoLiteCity.Lookup(ipAddr) + if record != nil { + countryName := record.Country.Names["zh-CN"] + if countryName == "" { + countryName = record.Country.Names["en"] + } + cityName := record.City.Names["zh-CN"] + if cityName == "" { + cityName = record.City.Names["en"] + } + n.Location = countryName + if cityName != "" { + if n.Location != "" { + n.Location += " " } + n.Location += cityName } - }(device) // Pass the device to the goroutine - } + } - // Wait for all goroutines to finish before exiting the function - wg.Wait() + n.Save() + Packets <- packet + } } func init() { - IPs = make(chan string, viper.GetInt("buffer_size")) - DB.Exec("CREATE TABLE IF NOT EXISTS netstats (time INTEGER, local_ip TEXT, local_port INTEGER, remote_ip TEXT, remote_port INTEGER, executable TEXT, location TEXT)") + Packets = make(chan gopacket.Packet, viper.GetInt("buffer_size")) + DB.Exec("CREATE TABLE IF NOT EXISTS netstats (time INTEGER, src_ip TEXT, src_port INTEGER, dst_ip TEXT, dst_port INTEGER, executable TEXT, location TEXT)") DB.Exec("CREATE INDEX IF NOT EXISTS idx_time ON netstats (time)") - DB.Exec("CREATE INDEX IF NOT EXISTS idx_local_ip ON netstats (local_ip)") - DB.Exec("CREATE INDEX IF NOT EXISTS idx_local_port ON netstats (local_port)") - DB.Exec("CREATE INDEX IF NOT EXISTS idx_remote_ip ON netstats (remote_ip)") - DB.Exec("CREATE INDEX IF NOT EXISTS idx_remote_port ON netstats (remote_port)") + DB.Exec("CREATE INDEX IF NOT EXISTS idx_src_ip ON netstats (src_ip)") + DB.Exec("CREATE INDEX IF NOT EXISTS idx_src_port ON netstats (src_port)") + DB.Exec("CREATE INDEX IF NOT EXISTS idx_dst_ip ON netstats (dst_ip)") + DB.Exec("CREATE INDEX IF NOT EXISTS idx_dst_port ON netstats (dst_port)") + + devices, err := pcap.FindAllDevs() + if err != nil { + log.Fatal(err) + } + + for _, dev := range devices { + for _, addr := range dev.Addresses { + if !addr.IP.IsLoopback() && !addr.IP.IsUnspecified() { + handle, err := pcap.OpenLive(dev.Name, 1600, true, -1) + if err != nil { + log.Fatalf("Error opening device %s: %v\n", dev.Name, err) + } else { + log.Printf("Capturing on device: %s\n", dev.Name) + } + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + go capture(packetSource) + break + } + } + } } func (n *Netstat) Save() error { - _, err := DB.Exec("INSERT INTO netstats (time, local_ip, local_port, remote_ip, remote_port, executable, location) VALUES (?, ?, ?, ?, ?, ?, ?)", n.Time, n.LocalIP, n.LocalPort, n.RemoteIP, n.RemotePort, n.Executable, n.Location) + _, err := DB.Exec("INSERT INTO netstats (time, src_ip, src_port, dst_ip, dst_port, executable, location) VALUES (?, ?, ?, ?, ?, ?, ?)", n.Time, n.SrcIP, n.SrcPort, n.DstIP, n.DstPort, n.Executable, n.Location) return err } func GetNetstats(limit int, page int) ([]*Netstat, error) { offset := limit * (page - 1) - rows, err := DB.Query("SELECT ROWID, time, local_ip, local_port, remote_ip, remote_port, executable, location FROM netstats ORDER BY time DESC LIMIT ? OFFSET ?", limit, offset) + rows, err := DB.Query("SELECT ROWID, time, src_ip, src_port, dst_ip, dst_port, executable, location FROM netstats ORDER BY time DESC LIMIT ? OFFSET ?", limit, offset) if err != nil { return nil, err } @@ -170,7 +139,7 @@ func GetNetstats(limit int, page int) ([]*Netstat, error) { var netstats []*Netstat for rows.Next() { var n Netstat - err := rows.Scan(&n.ID, &n.Time, &n.LocalIP, &n.LocalPort, &n.RemoteIP, &n.RemotePort, &n.Executable, &n.Location) + err := rows.Scan(&n.ID, &n.Time, &n.SrcIP, &n.SrcPort, &n.DstIP, &n.DstPort, &n.Executable, &n.Location) if err != nil { return nil, err } diff --git a/util/web.go b/util/web.go index 12cbd93..d0d6e86 100644 --- a/util/web.go +++ b/util/web.go @@ -3,6 +3,7 @@ package util import ( "crypto/rand" "fmt" + "log" "os" "strconv" "strings" @@ -183,6 +184,6 @@ func init() { r.POST("/api/check", checkHandler) r.NoRoute(staticHandler) - fmt.Printf("Web server started on http://%s/\n", addr) + log.Printf("Web server started on http://%s/\n", addr) go r.Run(addr) }