From 330dc3e51c3c11400672920673ceb011dca25748 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 27 Jun 2024 17:54:36 +0200 Subject: [PATCH] Get rid of uds.Conn (#43597) --- lib/srv/reexec.go | 4 ++-- lib/srv/reexec_test.go | 8 ++++---- lib/srv/regular/sshserver.go | 8 ++++---- lib/srv/regular/sshserver_unix_test.go | 10 +++++----- lib/sshutils/ctx.go | 3 +-- lib/tbot/service_ssh_multiplexer.go | 3 +-- lib/utils/uds/socketpair_other.go | 6 ++++-- lib/utils/uds/socketpair_unix.go | 6 ++++-- lib/utils/uds/socketpair_unix_test.go | 7 ++++--- lib/utils/uds/uds_other.go | 20 +++++++------------- lib/utils/uds/uds_unix.go | 18 +++++------------- 11 files changed, 41 insertions(+), 52 deletions(-) diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index 792597dd9f01e..c123ca739a568 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -687,7 +687,7 @@ func handleRemotePortForward(ctx context.Context, addr string, file *os.File) er } else { payload = []byte(err.Error()) } - _, _, err2 := controlConn.WriteWithFDs(payload, files) + _, _, err2 := uds.WriteWithFDs(controlConn, payload, files) return trace.NewAggregate(err, err2) } @@ -764,7 +764,7 @@ func runForward(handler forwardHandler) (errw io.Writer, code int, err error) { for { buf := make([]byte, 1024) fbuf := make([]*os.File, 1) - n, fn, err := conn.ReadWithFDs(buf, fbuf) + n, fn, err := uds.ReadWithFDs(conn, buf, fbuf) if err != nil { if utils.IsOKNetworkError(err) { return errorWriter, teleport.RemoteCommandSuccess, nil diff --git a/lib/srv/reexec_test.go b/lib/srv/reexec_test.go index 9716fbc396a29..071f7efe7156e 100644 --- a/lib/srv/reexec_test.go +++ b/lib/srv/reexec_test.go @@ -183,7 +183,7 @@ func TestStartNewParker(t *testing.T) { } } -func newSocketPair(t *testing.T) (localConn *uds.Conn, remoteFD *os.File) { +func newSocketPair(t *testing.T) (localConn *net.UnixConn, remoteFD *os.File) { localConn, remoteConn, err := uds.NewSocketpair(uds.SocketTypeDatagram) require.NoError(t, err) t.Cleanup(func() { @@ -233,7 +233,7 @@ func TestLocalPortForwardCommand(t *testing.T) { Transport: &http.Transport{ Dial: func(network, addr string) (net.Conn, error) { dialConn, dialFD := newSocketPair(t) - if _, _, err := controlConn.WriteWithFDs([]byte(addr), []*os.File{dialFD}); err != nil { + if _, _, err := uds.WriteWithFDs(controlConn, []byte(addr), []*os.File{dialFD}); err != nil { return nil, trace.Wrap(err) } return dialConn, nil @@ -272,10 +272,10 @@ func testRemotePortForwardCommand(t *testing.T, login string) { // Request a listener from the forwarder. replyConn, replyFD := newSocketPair(t) - _, _, err = controlConn.WriteWithFDs([]byte("127.0.0.1:0"), []*os.File{replyFD}) + _, _, err = uds.WriteWithFDs(controlConn, []byte("127.0.0.1:0"), []*os.File{replyFD}) require.NoError(t, err) var fbuf [1]*os.File - _, fn, err := replyConn.ReadWithFDs(nil, fbuf[:]) + _, fn, err := uds.ReadWithFDs(replyConn, nil, fbuf[:]) require.NoError(t, err) require.Equal(t, 1, fn) listener, err := net.FileListener(fbuf[0]) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index e866de565f265..f20ddcdd04040 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1139,7 +1139,7 @@ func (s *Server) getDirectTCPIPForwardDialer(scx *srv.ServerContext) (sshutils.T } defer remoteFD.Close() - _, _, err = proc.Conn.WriteWithFDs([]byte(addr), []*os.File{remoteFD}) + _, _, err = uds.WriteWithFDs(proc.Conn, []byte(addr), []*os.File{remoteFD}) if err != nil { local.Close() return nil, trace.Wrap(err) @@ -1183,7 +1183,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe return nil, trace.Wrap(err) } defer remoteFD.Close() - _, _, err = proc.Conn.WriteWithFDs([]byte(addr), []*os.File{remoteFD}) + _, _, err = uds.WriteWithFDs(proc.Conn, []byte(addr), []*os.File{remoteFD}) if err != nil { return nil, trace.Wrap(err) } @@ -1197,7 +1197,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe defer close(fileCh) fbuf := make([]*os.File, 1) - if _, fn, _ := localConn.ReadWithFDs(nil, fbuf); fn == 0 { + if _, fn, _ := uds.ReadWithFDs(localConn, nil, fbuf); fn == 0 { fileCh <- nil } select { @@ -1218,7 +1218,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe if listenerFD == nil { return nil, trace.BadParameter("forwarding process did not return a listener") } - if err := validateListenerSocket(scx, localConn.UnixConn, listenerFD); err != nil { + if err := validateListenerSocket(scx, localConn, listenerFD); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/regular/sshserver_unix_test.go b/lib/srv/regular/sshserver_unix_test.go index 6bf92aeaec8bc..9a1a7dae48009 100644 --- a/lib/srv/regular/sshserver_unix_test.go +++ b/lib/srv/regular/sshserver_unix_test.go @@ -36,7 +36,7 @@ import ( func TestValidateListenerSocket(t *testing.T) { t.Parallel() - newSocketFiles := func(t *testing.T) (*uds.Conn, *os.File) { + newSocketFiles := func(t *testing.T) (*net.UnixConn, *os.File) { left, right, err := uds.NewSocketpair(uds.SocketTypeStream) require.NoError(t, err) @@ -64,7 +64,7 @@ func TestValidateListenerSocket(t *testing.T) { tests := []struct { name string - mutateFiles func(*testing.T, *uds.Conn, *os.File) (*uds.Conn, *os.File) + mutateFiles func(*testing.T, *net.UnixConn, *os.File) (*net.UnixConn, *os.File) mutateConn func(*testing.T, *os.File) assert require.ErrorAssertionFunc }{ @@ -74,7 +74,7 @@ func TestValidateListenerSocket(t *testing.T) { }, { name: "socket type not STREAM", - mutateFiles: func(t *testing.T, conn *uds.Conn, file *os.File) (*uds.Conn, *os.File) { + mutateFiles: func(t *testing.T, conn *net.UnixConn, file *os.File) (*net.UnixConn, *os.File) { left, right, err := uds.NewSocketpair(uds.SocketTypeDatagram) require.NoError(t, err) listenerFD, err := right.File() @@ -99,7 +99,7 @@ func TestValidateListenerSocket(t *testing.T) { }, { name: "listener socket is not listening", - mutateFiles: func(t *testing.T, conn *uds.Conn, file *os.File) (*uds.Conn, *os.File) { + mutateFiles: func(t *testing.T, conn *net.UnixConn, file *os.File) (*net.UnixConn, *os.File) { left, right, err := uds.NewSocketpair(uds.SocketTypeStream) require.NoError(t, err) listenerFD, err := right.File() @@ -123,7 +123,7 @@ func TestValidateListenerSocket(t *testing.T) { if tc.mutateConn != nil { tc.mutateConn(t, listenerFD) } - err := validateListenerSocket(&srv.ServerContext{}, conn.UnixConn, listenerFD) + err := validateListenerSocket(&srv.ServerContext{}, conn, listenerFD) tc.assert(t, err) }) } diff --git a/lib/sshutils/ctx.go b/lib/sshutils/ctx.go index 0484d0c2ae757..554044c6ed177 100644 --- a/lib/sshutils/ctx.go +++ b/lib/sshutils/ctx.go @@ -31,7 +31,6 @@ import ( "golang.org/x/crypto/ssh/agent" "github.com/gravitational/teleport/lib/teleagent" - "github.com/gravitational/teleport/lib/utils/uds" ) // TCPIPForwardDialer represents a dialer used to handle TCPIP forward requests. @@ -40,7 +39,7 @@ type TCPIPForwardDialer func(string) (net.Conn, error) // TCPIPForwardProcess represents an instance of a port forwarding process. type TCPIPForwardProcess struct { // Conn is the socket used to request a dialer or listener in the process. - Conn *uds.Conn + Conn *net.UnixConn // Done signals when the process completes. Done <-chan struct{} // Closer contains and extra io.Closer to run when the process as a whole diff --git a/lib/tbot/service_ssh_multiplexer.go b/lib/tbot/service_ssh_multiplexer.go index f8e1e877f6a78..5b8a5a93a1a19 100644 --- a/lib/tbot/service_ssh_multiplexer.go +++ b/lib/tbot/service_ssh_multiplexer.go @@ -558,8 +558,7 @@ func (s *SSHMultiplexerService) handleConn( b := make([]byte, 1) fds := make([]*os.File, 1) - // TODO(espadolini): get rid of [uds.Conn] - n, fdn, err := (&uds.Conn{UnixConn: un}).ReadWithFDs(b, fds) + n, fdn, err := uds.ReadWithFDs(un, b, fds) if err != nil { return trace.Wrap(err, "reading request") } diff --git a/lib/utils/uds/socketpair_other.go b/lib/utils/uds/socketpair_other.go index 2360d0a7d0a53..60edd6671d2f7 100644 --- a/lib/utils/uds/socketpair_other.go +++ b/lib/utils/uds/socketpair_other.go @@ -22,13 +22,15 @@ package uds import ( "errors" + "net" "github.com/gravitational/trace" ) var nonUnixErr = errors.New("socket pair not available on non-unix platform") -// NewSocketpair creates a unix socket pair, returning the halves as files. -func NewSocketpair(t SocketType) (left, right *Conn, err error) { +// NewSocketpair creates a unix socket pair, returning the halves as +// [*net.UnixConn]s. +func NewSocketpair(t SocketType) (left, right *net.UnixConn, err error) { return nil, nil, trace.Wrap(nonUnixErr) } diff --git a/lib/utils/uds/socketpair_unix.go b/lib/utils/uds/socketpair_unix.go index 9c3ba63a66707..dca499a15fe18 100644 --- a/lib/utils/uds/socketpair_unix.go +++ b/lib/utils/uds/socketpair_unix.go @@ -21,13 +21,15 @@ package uds import ( + "net" "os" "github.com/gravitational/trace" ) -// NewSocketpair creates a unix socket pair, returning the halves as files. -func NewSocketpair(t SocketType) (left, right *Conn, err error) { +// NewSocketpair creates a unix socket pair, returning the halves as +// [*net.UnixConn]s. +func NewSocketpair(t SocketType) (left, right *net.UnixConn, err error) { lfd, rfd, err := cloexecSocketpair(t) if err != nil { return nil, nil, trace.Wrap(err) diff --git a/lib/utils/uds/socketpair_unix_test.go b/lib/utils/uds/socketpair_unix_test.go index d17413c883780..48f07eb7368d6 100644 --- a/lib/utils/uds/socketpair_unix_test.go +++ b/lib/utils/uds/socketpair_unix_test.go @@ -23,6 +23,7 @@ package uds import ( "fmt" "io" + "net" "os" "testing" @@ -51,7 +52,7 @@ func TestSocketparFDPassing(t *testing.T) { buf := make([]byte, 1024) fbuf := make([]*os.File, maxFiles*2) for { - n, fn, err := server.ReadWithFDs(buf, fbuf) + n, fn, err := ReadWithFDs(server, buf, fbuf) if err != nil { select { case <-done: @@ -79,7 +80,7 @@ func TestSocketparFDPassing(t *testing.T) { msg := fmt.Sprintf("send-%d", f) // conns are the local halves of socket pairs that we // will use to read our message back from the server. - conns := make([]*Conn, 0, f) + conns := make([]*net.UnixConn, 0, f) // fds are the remote halves of socket pairs to be sent // to the server along with the associated message. @@ -102,7 +103,7 @@ func TestSocketparFDPassing(t *testing.T) { // write message and files together so that server reads them // together and therefore will know what message to send back // over the fds. - _, _, err := client.WriteWithFDs([]byte(msg), fds) + _, _, err := WriteWithFDs(client, []byte(msg), fds) if err != nil { return trace.Errorf("failed to write fds: %v", err) } diff --git a/lib/utils/uds/uds_other.go b/lib/utils/uds/uds_other.go index 328da1e26d2a3..02bb149924c6f 100644 --- a/lib/utils/uds/uds_other.go +++ b/lib/utils/uds/uds_other.go @@ -28,31 +28,25 @@ import ( "github.com/gravitational/trace" ) -var errNonUnix = errors.New("uds.Conn only supported on unix") +// TODO(espadolini): reevaluate UDS support on Windows (passing file descriptors +// is probably not supported, still) +var errNonUnix = errors.New("unix domain sockets are only supported on unix") -// assert that *Conn implements net.Conn. -var _ net.Conn = (*Conn)(nil) - -// Conn extends net.UnixConn with additional useful methods. -type Conn struct { - *net.UnixConn -} - -// FromFile attempts to create a [Conn] from a file. The returned conn +// FromFile attempts to create a [net.UnixConn] from a file. The returned conn // is independent from the file and closing one does not close the other. -func FromFile(fd *os.File) (*Conn, error) { +func FromFile(fd *os.File) (*net.UnixConn, error) { return nil, trace.Wrap(errNonUnix) } // WriteWithFDs performs a write that may also send associated FDs. Note that unless the // underlying socket is a datagram socket it is not guaranteed that the exact bytes written // will match the bytes received with the fds on the reader side. -func (c *Conn) WriteWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) { +func WriteWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) { return 0, 0, trace.Wrap(errNonUnix) } // ReadWithFDs reads data and associated fds. Note that the underlying socket must be a datagram socket // if you need the bytes read to exactly match the bytes sent with the FDs. -func (c *Conn) ReadWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) { +func ReadWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) { return 0, 0, trace.Wrap(errNonUnix) } diff --git a/lib/utils/uds/uds_unix.go b/lib/utils/uds/uds_unix.go index e9645dfd4193a..42dad1a6e463b 100644 --- a/lib/utils/uds/uds_unix.go +++ b/lib/utils/uds/uds_unix.go @@ -29,17 +29,9 @@ import ( "github.com/gravitational/trace" ) -// assert that *Conn implements net.Conn. -var _ net.Conn = (*Conn)(nil) - -// Conn extends net.UnixConn with additional useful methods. -type Conn struct { - *net.UnixConn -} - -// FromFile attempts to create a [Conn] from a file. The returned conn +// FromFile attempts to create a [net.UnixConn] from a file. The returned conn // is independent from the file and closing one does not close the other. -func FromFile(fd *os.File) (*Conn, error) { +func FromFile(fd *os.File) (*net.UnixConn, error) { fconn, err := net.FileConn(fd) if err != nil { return nil, trace.Wrap(err) @@ -50,13 +42,13 @@ func FromFile(fd *os.File) (*Conn, error) { return nil, trace.Errorf("unexpected conn type %T (expected %T)", fconn, uconn) } - return &Conn{uconn}, nil + return uconn, nil } // WriteWithFDs performs a write that may also send associated FDs. Note that unless the // underlying socket is a datagram socket it is not guaranteed that the exact bytes written // will match the bytes received with the fds on the reader side. -func (c *Conn) WriteWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) { +func WriteWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) { fbuf := make([]int, 0, len(fds)) for _, fd := range fds { @@ -84,7 +76,7 @@ const ( // ReadWithFDs reads data and associated fds. Note that the underlying socket must be a datagram socket // if you need the bytes read to exactly match the bytes sent with the FDs. -func (c *Conn) ReadWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) { +func ReadWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) { // set up a buffer capable of supporting the maximum possible out of band data obuf := make([]byte, syscall.CmsgSpace(sizeOfInt*len(fds)))