From e22cd926ecf54ce0864a4482b1d8eae55b7cbcd0 Mon Sep 17 00:00:00 2001 From: Asutorufa <16442314+Asutorufa@users.noreply.github.com> Date: Tue, 19 Sep 2023 21:05:05 +0800 Subject: [PATCH] SocketGet support udp and ipv6 Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com> --- socket_linux.go | 67 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/socket_linux.go b/socket_linux.go index b881fe49..e812ee42 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -54,10 +54,8 @@ func (r *socketRequest) Serialize() []byte { copy(b.Next(16), r.ID.Source) copy(b.Next(16), r.ID.Destination) } else { - copy(b.Next(4), r.ID.Source.To4()) - b.Next(12) - copy(b.Next(4), r.ID.Destination.To4()) - b.Next(12) + copy(b.Next(16), r.ID.Source.To4()) + copy(b.Next(16), r.ID.Destination.To4()) } native.PutUint32(b.Next(4), r.ID.Interface) native.PutUint32(b.Next(4), r.ID.Cookie[0]) @@ -117,20 +115,44 @@ func (s *Socket) deserialize(b []byte) error { // SocketGet returns the Socket identified by its local and remote addresses. func SocketGet(local, remote net.Addr) (*Socket, error) { - localTCP, ok := local.(*net.TCPAddr) - if !ok { + var protocol uint8 + var localIP, remoteIP net.IP + var localPort, remotePort uint16 + switch l := local.(type) { + case *net.TCPAddr: + r, ok := remote.(*net.TCPAddr) + if !ok { + return nil, ErrNotImplemented + } + localIP = l.IP + localPort = uint16(l.Port) + remoteIP = r.IP + remotePort = uint16(r.Port) + protocol = unix.IPPROTO_TCP + case *net.UDPAddr: + r, ok := remote.(*net.UDPAddr) + if !ok { + return nil, ErrNotImplemented + } + localIP = l.IP + localPort = uint16(l.Port) + remoteIP = r.IP + remotePort = uint16(r.Port) + protocol = unix.IPPROTO_UDP + default: return nil, ErrNotImplemented } - remoteTCP, ok := remote.(*net.TCPAddr) - if !ok { - return nil, ErrNotImplemented + + var family uint8 + if localIP.To4() != nil && remoteIP.To4() != nil { + family = unix.AF_INET } - localIP := localTCP.IP.To4() - if localIP == nil { - return nil, ErrNotImplemented + + if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil { + family = unix.AF_INET6 } - remoteIP := remoteTCP.IP.To4() - if remoteIP == nil { + + if family == 0 { return nil, ErrNotImplemented } @@ -139,19 +161,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { return nil, err } defer s.Close() - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0) + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ - Family: unix.AF_INET, - Protocol: unix.IPPROTO_TCP, + Family: family, + Protocol: protocol, + States: 0xffffffff, ID: SocketID{ - SourcePort: uint16(localTCP.Port), - DestinationPort: uint16(remoteTCP.Port), + SourcePort: localPort, + DestinationPort: remotePort, Source: localIP, Destination: remoteIP, Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, }, }) - s.Send(req) + + if err := s.Send(req); err != nil { + return nil, err + } + msgs, from, err := s.Receive() if err != nil { return nil, err