diff --git a/component/arp/arp.go b/component/arp/arp.go new file mode 100644 index 0000000000..b70d8ac40e --- /dev/null +++ b/component/arp/arp.go @@ -0,0 +1,91 @@ +package arp + +import ( + "net" + "net/netip" + "sync" + "time" + + "github.com/metacubex/mihomo/component/iface" + "github.com/metacubex/mihomo/log" +) + +var ( + table map[string]string + failedIPs map[string]int + mu sync.RWMutex + failedIPsMutex sync.RWMutex + lastFetch time.Time +) + +type ARPEntry struct { + IP net.IP + MAC net.HardwareAddr +} + +const refreshInterval = 5 * time.Minute + +func init() { + table = make(map[string]string) + failedIPs = make(map[string]int) +} + +func IsReserved(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + return ip4[3] == 0 || ip4[3] == 255 + } + return false +} + +func refreshARPTable() { + newTable, err := GetARPTable() + if err != nil { + log.Warnln("Failed to refresh ARP table") + return + } + + mu.Lock() + defer mu.Unlock() + + table = newTable + lastFetch = time.Now() +} + +func IPToMac(ip netip.Addr) string { + if ok, _ := iface.IsLocalIp(ip); ok { + return "" + } + + mu.RLock() + if time.Since(lastFetch) > refreshInterval { + mu.RUnlock() + refreshARPTable() + mu.RLock() + } + defer mu.RUnlock() + + if mac, ok := table[ip.String()]; ok { + return mac + } + + if ip.IsPrivate() { + failedIPsMutex.RLock() + failCount, exists := failedIPs[ip.String()] + failedIPsMutex.RUnlock() + if exists && failCount >= 10 { + return "" + } + + mu.RUnlock() + refreshARPTable() + mu.RLock() + + if mac, ok := table[ip.String()]; ok { + return mac + } else { + failedIPs[ip.String()]++ + } + } + + return "" +} diff --git a/component/arp/arp_linux.go b/component/arp/arp_linux.go new file mode 100644 index 0000000000..55bd74d30c --- /dev/null +++ b/component/arp/arp_linux.go @@ -0,0 +1,47 @@ +package arp + +import ( + "fmt" + "net" + + "github.com/sagernet/netlink" +) + +func neighMAC(n netlink.Neigh) string { + return n.HardwareAddr.String() +} + +func neighIP(n netlink.Neigh) net.IP { + return n.IP +} + +func GetARPTable() (map[string]string, error) { + entries := make(map[string]string) + + links, err := netlink.LinkList() + if err != nil { + return nil, err + } + + for _, link := range links { + attr := link.Attrs() + neighs, err := netlink.NeighList(attr.Index, 0) + if err != nil { + fmt.Println(err) + continue + } + for _, neigh := range neighs { + ip := neighIP(neigh) + mac := neighMAC(neigh) + + if IsReserved(ip) { + continue + } + + if ip.IsGlobalUnicast() { + entries[ip.String()] = mac + } + } + } + return entries, nil +} diff --git a/component/arp/arp_other.go b/component/arp/arp_other.go new file mode 100644 index 0000000000..35ecc353ef --- /dev/null +++ b/component/arp/arp_other.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package arp + +func GetARPTable() (map[string]string, error) { + return nil, nil +} diff --git a/component/arp/arp_windows.go b/component/arp/arp_windows.go new file mode 100644 index 0000000000..4095ee1c78 --- /dev/null +++ b/component/arp/arp_windows.go @@ -0,0 +1,22 @@ +package arp + +func GetARPTable() (map[string]string, error) { + table, err := GetIpNetTable2() + if err != nil { + return nil, err + } + + entries := make(map[string]string) + for _, row := range table { + entry := row.ToARPEntry() + + if IsReserved(entry.IP) { + continue + } + + if entry.IP.IsGlobalUnicast() { + entries[entry.IP.String()] = entry.MAC.String() + } + } + return entries, nil +} diff --git a/component/arp/get_ip_net_table2.go b/component/arp/get_ip_net_table2.go new file mode 100644 index 0000000000..c9f9ec7f74 --- /dev/null +++ b/component/arp/get_ip_net_table2.go @@ -0,0 +1,52 @@ +//go:build windows +// +build windows + +package arp + +import ( + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var iphlpapi *windows.DLL + +func init() { + iphlpapi = windows.MustLoadDLL("Iphlpapi.dll") +} + +func GetIpNetTable2() (MIBIpNetTable2, error) { + proc, err := iphlpapi.FindProc("GetIpNetTable2") + if err != nil { + return nil, err + } + + free, err := iphlpapi.FindProc("FreeMibTable") + if err != nil { + return nil, err + } + + var data *rawMIBIpNetTable2 + errno, _, _ := proc.Call(0, uintptr(unsafe.Pointer(&data))) + defer free.Call(uintptr(unsafe.Pointer(data))) + + switch syscall.Errno(errno) { + case windows.ERROR_SUCCESS: + err = nil + case windows.ERROR_NOT_ENOUGH_MEMORY: + err = fmt.Errorf("insufficient memory resources are available to complete the operation") + case windows.ERROR_INVALID_PARAMETER: + err = fmt.Errorf("an invalid parameter was passed to the function") + case windows.ERROR_NOT_FOUND: + err = fmt.Errorf("no neighbor IP address entries as specified in the Family parameter were found") + case windows.ERROR_NOT_SUPPORTED: + err = fmt.Errorf("the IPv4 or IPv6 transports are not configured on the local computer") + default: + err = windows.GetLastError() + } + + table := data.parse() + return table, err +} diff --git a/component/arp/mib_ipnet_row2.go b/component/arp/mib_ipnet_row2.go new file mode 100644 index 0000000000..657395a1db --- /dev/null +++ b/component/arp/mib_ipnet_row2.go @@ -0,0 +1,143 @@ +//go:build windows +// +build windows + +package arp + +import ( + "encoding/binary" + "net" + "time" +) + +const MIBIpNetRow2Size = 88 +const SockAddrSize = 28 + +type SockAddrIn struct { + sinFamily uint16 + sinPort uint16 + sinAddr net.IP + sinZero []byte +} + +func NewSockAddrIn(buffer []byte) SockAddrIn { + addr := SockAddrIn{ + sinFamily: binary.LittleEndian.Uint16(buffer[:2]), + sinPort: binary.LittleEndian.Uint16(buffer[2:4]), + sinAddr: net.IP(make([]byte, 4)).To4(), + sinZero: make([]byte, 8), + } + copy(addr.sinAddr, buffer[4:8]) + copy(addr.sinZero, buffer[8:16]) + return addr +} + +func (s SockAddrIn) Family() uint16 { + return s.sinFamily +} + +func (s SockAddrIn) Addr() net.IP { + return s.sinAddr.To4() +} + +type SockAddrIn6 struct { + sin6Family uint16 + sin6Port uint16 + sin6FlowInfo uint32 + sin6Addr net.IP + sin6ScopeId uint32 +} + +func NewSockAddrIn6(buffer []byte) SockAddrIn6 { + addr := SockAddrIn6{ + sin6Family: binary.LittleEndian.Uint16(buffer[:2]), + sin6Port: binary.LittleEndian.Uint16(buffer[2:4]), + sin6FlowInfo: binary.LittleEndian.Uint32(buffer[4:8]), + sin6Addr: net.IP(make([]byte, 16)).To16(), + sin6ScopeId: binary.LittleEndian.Uint32(buffer[24:28]), + } + copy(addr.sin6Addr, buffer[8:24]) + return addr +} + +func (s SockAddrIn6) Family() uint16 { + return s.sin6Family +} + +func (s SockAddrIn6) Addr() net.IP { + return s.sin6Addr.To16() +} + +type SockAddr interface { + Family() uint16 + Addr() net.IP +} + +func parseSockAddr(buffer []byte) SockAddr { + sockType := binary.LittleEndian.Uint16(buffer[:2]) + switch sockType { + case 2: // IPv4 + return NewSockAddrIn(buffer[:SockAddrSize]) + case 23: // IPv6 + return NewSockAddrIn6(buffer[:SockAddrSize]) + default: + return nil + } +} + +func parsePhysicalAddress(buffer []byte, physicalAddressLength uint32) net.HardwareAddr { + pa := make(net.HardwareAddr, physicalAddressLength) + copy(pa, buffer[:physicalAddressLength]) + return pa +} + +type MIBIpNetRow2 struct { + address SockAddr + interfaceIndex uint32 + interfaceLuid uint64 + physicalAddress net.HardwareAddr + physicalAddressLength uint32 + flags uint32 + reachabilityTime time.Duration +} + +func (r MIBIpNetRow2) MAC() net.HardwareAddr { + mac := make(net.HardwareAddr, r.physicalAddressLength) + copy(mac, r.physicalAddress) + return mac +} + +func (r MIBIpNetRow2) IP() net.IP { + length := len(r.address.Addr()) + ip := make(net.IP, length) + copy(ip, r.address.Addr()) + return ip +} + +func (r MIBIpNetRow2) ToARPEntry() ARPEntry { + return ARPEntry{ + MAC: r.MAC(), + IP: r.IP(), + } +} + +type rawMIBIpNetRow2 struct { + address [28]byte + interfaceIndex uint32 + interfaceLuid uint64 + physicalAddress [32]byte + physicalAddressLength uint32 + flags uint32 + reachabilityTime uint32 +} + +func (r rawMIBIpNetRow2) Parse() MIBIpNetRow2 { + return MIBIpNetRow2{ + address: parseSockAddr(r.address[:]), + interfaceIndex: r.interfaceIndex, + interfaceLuid: r.interfaceLuid, + physicalAddress: parsePhysicalAddress(r.physicalAddress[:], r.physicalAddressLength), + physicalAddressLength: r.physicalAddressLength, + flags: r.flags, + reachabilityTime: time.Duration(r.reachabilityTime * uint32(time.Millisecond)), + } +} diff --git a/component/arp/mib_ipnet_table2.go b/component/arp/mib_ipnet_table2.go new file mode 100644 index 0000000000..6a3823a5cb --- /dev/null +++ b/component/arp/mib_ipnet_table2.go @@ -0,0 +1,22 @@ +//go:build windows +// +build windows + +package arp + +const anySize = 1 << 16 + +type MIBIpNetTable2 []MIBIpNetRow2 + +type rawMIBIpNetTable2 struct { + numEntries uint32 + padding uint32 + table [anySize]rawMIBIpNetRow2 +} + +func (r *rawMIBIpNetTable2) parse() MIBIpNetTable2 { + t := make([]MIBIpNetRow2, r.numEntries) + for i := 0; i < int(r.numEntries); i++ { + t[i] = r.table[i].Parse() + } + return t +} diff --git a/constant/rule.go b/constant/rule.go index a91ee6cb07..e221bd0060 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -19,6 +19,7 @@ const ( DstPort InPort DSCP + Mac InUser InName InType @@ -98,6 +99,8 @@ func (rt RuleType) String() string { return "Uid" case SubRules: return "SubRules" + case Mac: + return "Mac" case AND: return "AND" case OR: diff --git a/rules/common/mac.go b/rules/common/mac.go new file mode 100644 index 0000000000..9699f4d299 --- /dev/null +++ b/rules/common/mac.go @@ -0,0 +1,52 @@ +package common + +import ( + "fmt" + "regexp" + "runtime" + "strings" + + "github.com/metacubex/mihomo/component/arp" + C "github.com/metacubex/mihomo/constant" +) + +type Mac struct { + *Base + mac string + adapter string +} + +func (m *Mac) RuleType() C.RuleType { + return C.Mac +} + +func (m *Mac) Match(metadata *C.Metadata) (bool, string) { + if runtime.GOOS == "windows" || runtime.GOOS == "linux" { + if arp.IPToMac(metadata.SrcIP) == m.mac { + return true, m.adapter + } + } + return false, m.adapter +} + +func (m *Mac) Adapter() string { + return m.adapter +} + +func (m *Mac) Payload() string { + return m.mac +} + +func NewMAC(mac string, adapter string) (*Mac, error) { + mac = regexp.MustCompile(`[^a-fA-F0-9]`).ReplaceAllString(mac, "") + if len(mac) != 12 { + return nil, fmt.Errorf("invalid MAC address length") + } + formattedMAC := fmt.Sprintf("%s:%s:%s:%s:%s:%s", + mac[0:2], mac[2:4], mac[4:6], mac[6:8], mac[8:10], mac[10:12]) + return &Mac{ + Base: &Base{}, + mac: strings.ToLower(formattedMAC), + adapter: adapter, + }, nil +} diff --git a/rules/parser.go b/rules/parser.go index 9b1f552007..178d1d10ab 100644 --- a/rules/parser.go +++ b/rules/parser.go @@ -49,6 +49,8 @@ func ParseRule(tp, payload, target string, params []string, subRules map[string] parsed, parseErr = RC.NewPort(payload, target, C.InPort) case "DSCP": parsed, parseErr = RC.NewDSCP(payload, target) + case "MAC", "SRC-MAC": + parsed, parseErr = RC.NewMAC(payload, target) case "PROCESS-NAME": parsed, parseErr = RC.NewProcess(payload, target, true, false) case "PROCESS-PATH": diff --git a/tunnel/dns_dialer.go b/tunnel/dns_dialer.go index 1839869b4a..ba293f3887 100644 --- a/tunnel/dns_dialer.go +++ b/tunnel/dns_dialer.go @@ -37,6 +37,7 @@ func (d *DNSDialer) DialContext(ctx context.Context, network, addr string) (net. metadata := &C.Metadata{ NetWork: C.TCP, Type: C.INNER, + Process: C.Name, } err := metadata.SetRemoteAddress(addr) // tcp can resolve host by remote if err != nil {