diff --git a/socket_linux.go b/socket_linux.go index b881fe49..e812ee42 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -54,10 +54,8 @@ func (r *socketRequest) Serialize() []byte { copy(b.Next(16), r.ID.Source) copy(b.Next(16), r.ID.Destination) } else { - copy(b.Next(4), r.ID.Source.To4()) - b.Next(12) - copy(b.Next(4), r.ID.Destination.To4()) - b.Next(12) + copy(b.Next(16), r.ID.Source.To4()) + copy(b.Next(16), r.ID.Destination.To4()) } native.PutUint32(b.Next(4), r.ID.Interface) native.PutUint32(b.Next(4), r.ID.Cookie[0]) @@ -117,20 +115,44 @@ func (s *Socket) deserialize(b []byte) error { // SocketGet returns the Socket identified by its local and remote addresses. func SocketGet(local, remote net.Addr) (*Socket, error) { - localTCP, ok := local.(*net.TCPAddr) - if !ok { + var protocol uint8 + var localIP, remoteIP net.IP + var localPort, remotePort uint16 + switch l := local.(type) { + case *net.TCPAddr: + r, ok := remote.(*net.TCPAddr) + if !ok { + return nil, ErrNotImplemented + } + localIP = l.IP + localPort = uint16(l.Port) + remoteIP = r.IP + remotePort = uint16(r.Port) + protocol = unix.IPPROTO_TCP + case *net.UDPAddr: + r, ok := remote.(*net.UDPAddr) + if !ok { + return nil, ErrNotImplemented + } + localIP = l.IP + localPort = uint16(l.Port) + remoteIP = r.IP + remotePort = uint16(r.Port) + protocol = unix.IPPROTO_UDP + default: return nil, ErrNotImplemented } - remoteTCP, ok := remote.(*net.TCPAddr) - if !ok { - return nil, ErrNotImplemented + + var family uint8 + if localIP.To4() != nil && remoteIP.To4() != nil { + family = unix.AF_INET } - localIP := localTCP.IP.To4() - if localIP == nil { - return nil, ErrNotImplemented + + if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil { + family = unix.AF_INET6 } - remoteIP := remoteTCP.IP.To4() - if remoteIP == nil { + + if family == 0 { return nil, ErrNotImplemented } @@ -139,19 +161,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { return nil, err } defer s.Close() - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0) + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ - Family: unix.AF_INET, - Protocol: unix.IPPROTO_TCP, + Family: family, + Protocol: protocol, + States: 0xffffffff, ID: SocketID{ - SourcePort: uint16(localTCP.Port), - DestinationPort: uint16(remoteTCP.Port), + SourcePort: localPort, + DestinationPort: remotePort, Source: localIP, Destination: remoteIP, Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, }, }) - s.Send(req) + + if err := s.Send(req); err != nil { + return nil, err + } + msgs, from, err := s.Receive() if err != nil { return nil, err diff --git a/socket_test.go b/socket_test.go index f21520f1..eaebcaaa 100644 --- a/socket_test.go +++ b/socket_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -14,47 +15,90 @@ import ( func TestSocketGet(t *testing.T) { defer setUpNetlinkTestWithLoopback(t)() - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - log.Fatal(err) + type Addr struct { + IP net.IP + Port int } - l, err := net.ListenTCP("tcp", addr) - if err != nil { - log.Fatal(err) - } - defer l.Close() - conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) - if err != nil { - t.Fatal(err) + getAddr := func(a net.Addr) Addr { + var addr Addr + switch v := a.(type) { + case *net.UDPAddr: + addr.IP = v.IP + addr.Port = v.Port + case *net.TCPAddr: + addr.IP = v.IP + addr.Port = v.Port + } + return addr } - defer conn.Close() - localAddr := conn.LocalAddr().(*net.TCPAddr) - remoteAddr := conn.RemoteAddr().(*net.TCPAddr) - socket, err := SocketGet(localAddr, remoteAddr) - if err != nil { - t.Fatal(err) - } + checkSocket := func(t *testing.T, local, remote net.Addr) { + socket, err := SocketGet(local, remote) + if err != nil { + t.Fatal(err) + } - if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { - t.Fatalf("local ip = %v, want %v", got, want) - } - if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) { - t.Fatalf("remote ip = %v, want %v", got, want) - } - if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want { - t.Fatalf("local port = %d, want %d", got, want) - } - if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want { - t.Fatalf("remote port = %d, want %d", got, want) + localAddr, remoteAddr := getAddr(local), getAddr(remote) + + if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { + t.Fatalf("local ip = %v, want %v", got, want) + } + if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) { + t.Fatalf("remote ip = %v, want %v", got, want) + } + if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want { + t.Fatalf("local port = %d, want %d", got, want) + } + if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want { + t.Fatalf("remote port = %d, want %d", got, want) + } + u, err := user.Current() + if err != nil { + t.Fatal(err) + } + if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { + t.Fatalf("UID = %s, want %s", got, want) + } } - u, err := user.Current() - if err != nil { - t.Fatal(err) + + for _, v := range [...]string{"tcp4", "tcp6"} { + addr, err := net.ResolveTCPAddr(v, "localhost:0") + if err != nil { + log.Fatal(err) + } + l, err := net.ListenTCP(v, addr) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + checkSocket(t, conn.LocalAddr(), conn.RemoteAddr()) } - if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { - t.Fatalf("UID = %s, want %s", got, want) + + for _, v := range [...]string{"udp4", "udp6"} { + addr, err := net.ResolveUDPAddr(v, "localhost:0") + if err != nil { + log.Fatal(err) + } + l, err := net.ListenUDP(v, addr) + if err != nil { + log.Fatal(err) + } + defer l.Close() + conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + checkSocket(t, conn.LocalAddr(), conn.RemoteAddr()) } }