Skip to content

Commit

Permalink
feat: add MAC rule
Browse files Browse the repository at this point in the history
  • Loading branch information
xishang0128 committed Aug 11, 2024
1 parent ae98c23 commit 17c6585
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 0 deletions.
91 changes: 91 additions & 0 deletions component/arp/arp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package arp

import (
"net"
"net/netip"
"sync"
"time"

"github.com/metacubex/mihomo/component/iface"
"github.com/metacubex/mihomo/log"
)

var (
table map[string]string
failedIPs map[string]int
mu sync.RWMutex
failedIPsMutex sync.RWMutex
lastFetch time.Time
)

type ARPEntry struct {
IP net.IP
MAC net.HardwareAddr
}

const refreshInterval = 5 * time.Minute

func init() {
table = make(map[string]string)
failedIPs = make(map[string]int)
}

func IsReserved(ip net.IP) bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[3] == 0 || ip4[3] == 255
}
return false
}

func refreshARPTable() {
newTable, err := GetARPTable()
if err != nil {
log.Warnln("Failed to refresh ARP table")
return
}

mu.Lock()
defer mu.Unlock()

table = newTable
lastFetch = time.Now()
}

func IPToMac(ip netip.Addr) string {
if ok, _ := iface.IsLocalIp(ip); ok {
return ""
}

mu.RLock()
if time.Since(lastFetch) > refreshInterval {
mu.RUnlock()
refreshARPTable()
mu.RLock()
}
defer mu.RUnlock()

if mac, ok := table[ip.String()]; ok {
return mac
}

if ip.IsPrivate() {
failedIPsMutex.RLock()
failCount, exists := failedIPs[ip.String()]
failedIPsMutex.RUnlock()
if exists && failCount >= 10 {
return ""
}

mu.RUnlock()
refreshARPTable()
mu.RLock()

if mac, ok := table[ip.String()]; ok {
return mac
} else {
failedIPs[ip.String()]++
}
}

return ""
}
47 changes: 47 additions & 0 deletions component/arp/arp_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package arp

import (
"fmt"
"net"

"github.com/sagernet/netlink"
)

func neighMAC(n netlink.Neigh) string {
return n.HardwareAddr.String()
}

func neighIP(n netlink.Neigh) net.IP {
return n.IP
}

func GetARPTable() (map[string]string, error) {
entries := make(map[string]string)

links, err := netlink.LinkList()
if err != nil {
return nil, err
}

for _, link := range links {
attr := link.Attrs()
neighs, err := netlink.NeighList(attr.Index, 0)
if err != nil {
fmt.Println(err)
continue
}
for _, neigh := range neighs {
ip := neighIP(neigh)
mac := neighMAC(neigh)

if IsReserved(ip) {
continue
}

if ip.IsGlobalUnicast() {
entries[ip.String()] = mac
}
}
}
return entries, nil
}
7 changes: 7 additions & 0 deletions component/arp/arp_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !linux && !windows

package arp

func GetARPTable() (map[string]string, error) {
return nil, nil
}
22 changes: 22 additions & 0 deletions component/arp/arp_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package arp

func GetARPTable() (map[string]string, error) {
table, err := GetIpNetTable2()
if err != nil {
return nil, err
}

entries := make(map[string]string)
for _, row := range table {
entry := row.ToARPEntry()

if IsReserved(entry.IP) {
continue
}

if entry.IP.IsGlobalUnicast() {
entries[entry.IP.String()] = entry.MAC.String()
}
}
return entries, nil
}
52 changes: 52 additions & 0 deletions component/arp/get_ip_net_table2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build windows
// +build windows

package arp

import (
"fmt"
"syscall"
"unsafe"

"golang.org/x/sys/windows"
)

var iphlpapi *windows.DLL

func init() {
iphlpapi = windows.MustLoadDLL("Iphlpapi.dll")
}

func GetIpNetTable2() (MIBIpNetTable2, error) {
proc, err := iphlpapi.FindProc("GetIpNetTable2")
if err != nil {
return nil, err
}

free, err := iphlpapi.FindProc("FreeMibTable")
if err != nil {
return nil, err
}

var data *rawMIBIpNetTable2
errno, _, _ := proc.Call(0, uintptr(unsafe.Pointer(&data)))
defer free.Call(uintptr(unsafe.Pointer(data)))

switch syscall.Errno(errno) {
case windows.ERROR_SUCCESS:
err = nil
case windows.ERROR_NOT_ENOUGH_MEMORY:
err = fmt.Errorf("insufficient memory resources are available to complete the operation")
case windows.ERROR_INVALID_PARAMETER:
err = fmt.Errorf("an invalid parameter was passed to the function")
case windows.ERROR_NOT_FOUND:
err = fmt.Errorf("no neighbor IP address entries as specified in the Family parameter were found")
case windows.ERROR_NOT_SUPPORTED:
err = fmt.Errorf("the IPv4 or IPv6 transports are not configured on the local computer")
default:
err = windows.GetLastError()
}

table := data.parse()
return table, err
}
143 changes: 143 additions & 0 deletions component/arp/mib_ipnet_row2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//go:build windows
// +build windows

package arp

import (
"encoding/binary"
"net"
"time"
)

const MIBIpNetRow2Size = 88
const SockAddrSize = 28

type SockAddrIn struct {
sinFamily uint16
sinPort uint16
sinAddr net.IP
sinZero []byte
}

func NewSockAddrIn(buffer []byte) SockAddrIn {
addr := SockAddrIn{
sinFamily: binary.LittleEndian.Uint16(buffer[:2]),
sinPort: binary.LittleEndian.Uint16(buffer[2:4]),
sinAddr: net.IP(make([]byte, 4)).To4(),
sinZero: make([]byte, 8),
}
copy(addr.sinAddr, buffer[4:8])
copy(addr.sinZero, buffer[8:16])
return addr
}

func (s SockAddrIn) Family() uint16 {
return s.sinFamily
}

func (s SockAddrIn) Addr() net.IP {
return s.sinAddr.To4()
}

type SockAddrIn6 struct {
sin6Family uint16
sin6Port uint16
sin6FlowInfo uint32
sin6Addr net.IP
sin6ScopeId uint32
}

func NewSockAddrIn6(buffer []byte) SockAddrIn6 {
addr := SockAddrIn6{
sin6Family: binary.LittleEndian.Uint16(buffer[:2]),
sin6Port: binary.LittleEndian.Uint16(buffer[2:4]),
sin6FlowInfo: binary.LittleEndian.Uint32(buffer[4:8]),
sin6Addr: net.IP(make([]byte, 16)).To16(),
sin6ScopeId: binary.LittleEndian.Uint32(buffer[24:28]),
}
copy(addr.sin6Addr, buffer[8:24])
return addr
}

func (s SockAddrIn6) Family() uint16 {
return s.sin6Family
}

func (s SockAddrIn6) Addr() net.IP {
return s.sin6Addr.To16()
}

type SockAddr interface {
Family() uint16
Addr() net.IP
}

func parseSockAddr(buffer []byte) SockAddr {
sockType := binary.LittleEndian.Uint16(buffer[:2])
switch sockType {
case 2: // IPv4
return NewSockAddrIn(buffer[:SockAddrSize])
case 23: // IPv6
return NewSockAddrIn6(buffer[:SockAddrSize])
default:
return nil
}
}

func parsePhysicalAddress(buffer []byte, physicalAddressLength uint32) net.HardwareAddr {
pa := make(net.HardwareAddr, physicalAddressLength)
copy(pa, buffer[:physicalAddressLength])
return pa
}

type MIBIpNetRow2 struct {
address SockAddr
interfaceIndex uint32
interfaceLuid uint64
physicalAddress net.HardwareAddr
physicalAddressLength uint32
flags uint32
reachabilityTime time.Duration
}

func (r MIBIpNetRow2) MAC() net.HardwareAddr {
mac := make(net.HardwareAddr, r.physicalAddressLength)
copy(mac, r.physicalAddress)
return mac
}

func (r MIBIpNetRow2) IP() net.IP {
length := len(r.address.Addr())
ip := make(net.IP, length)
copy(ip, r.address.Addr())
return ip
}

func (r MIBIpNetRow2) ToARPEntry() ARPEntry {
return ARPEntry{
MAC: r.MAC(),
IP: r.IP(),
}
}

type rawMIBIpNetRow2 struct {
address [28]byte
interfaceIndex uint32
interfaceLuid uint64
physicalAddress [32]byte
physicalAddressLength uint32
flags uint32
reachabilityTime uint32
}

func (r rawMIBIpNetRow2) Parse() MIBIpNetRow2 {
return MIBIpNetRow2{
address: parseSockAddr(r.address[:]),
interfaceIndex: r.interfaceIndex,
interfaceLuid: r.interfaceLuid,
physicalAddress: parsePhysicalAddress(r.physicalAddress[:], r.physicalAddressLength),
physicalAddressLength: r.physicalAddressLength,
flags: r.flags,
reachabilityTime: time.Duration(r.reachabilityTime * uint32(time.Millisecond)),
}
}
22 changes: 22 additions & 0 deletions component/arp/mib_ipnet_table2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build windows
// +build windows

package arp

const anySize = 1 << 16

type MIBIpNetTable2 []MIBIpNetRow2

type rawMIBIpNetTable2 struct {
numEntries uint32
padding uint32
table [anySize]rawMIBIpNetRow2
}

func (r *rawMIBIpNetTable2) parse() MIBIpNetTable2 {
t := make([]MIBIpNetRow2, r.numEntries)
for i := 0; i < int(r.numEntries); i++ {
t[i] = r.table[i].Parse()
}
return t
}
Loading

0 comments on commit 17c6585

Please sign in to comment.