Skip to content

Commit

Permalink
Get rid of uds.Conn (#43596)
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini authored Jun 27, 2024
1 parent 9fd2a7e commit 3de08d2
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 52 deletions.
4 changes: 2 additions & 2 deletions lib/srv/reexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ func handleRemotePortForward(ctx context.Context, addr string, file *os.File) er
} else {
payload = []byte(err.Error())
}
_, _, err2 := controlConn.WriteWithFDs(payload, files)
_, _, err2 := uds.WriteWithFDs(controlConn, payload, files)
return trace.NewAggregate(err, err2)
}

Expand Down Expand Up @@ -764,7 +764,7 @@ func runForward(handler forwardHandler) (errw io.Writer, code int, err error) {
for {
buf := make([]byte, 1024)
fbuf := make([]*os.File, 1)
n, fn, err := conn.ReadWithFDs(buf, fbuf)
n, fn, err := uds.ReadWithFDs(conn, buf, fbuf)
if err != nil {
if utils.IsOKNetworkError(err) {
return errorWriter, teleport.RemoteCommandSuccess, nil
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/reexec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestStartNewParker(t *testing.T) {
}
}

func newSocketPair(t *testing.T) (localConn *uds.Conn, remoteFD *os.File) {
func newSocketPair(t *testing.T) (localConn *net.UnixConn, remoteFD *os.File) {
localConn, remoteConn, err := uds.NewSocketpair(uds.SocketTypeDatagram)
require.NoError(t, err)
t.Cleanup(func() {
Expand Down Expand Up @@ -233,7 +233,7 @@ func TestLocalPortForwardCommand(t *testing.T) {
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
dialConn, dialFD := newSocketPair(t)
if _, _, err := controlConn.WriteWithFDs([]byte(addr), []*os.File{dialFD}); err != nil {
if _, _, err := uds.WriteWithFDs(controlConn, []byte(addr), []*os.File{dialFD}); err != nil {
return nil, trace.Wrap(err)
}
return dialConn, nil
Expand Down Expand Up @@ -272,10 +272,10 @@ func testRemotePortForwardCommand(t *testing.T, login string) {

// Request a listener from the forwarder.
replyConn, replyFD := newSocketPair(t)
_, _, err = controlConn.WriteWithFDs([]byte("127.0.0.1:0"), []*os.File{replyFD})
_, _, err = uds.WriteWithFDs(controlConn, []byte("127.0.0.1:0"), []*os.File{replyFD})
require.NoError(t, err)
var fbuf [1]*os.File
_, fn, err := replyConn.ReadWithFDs(nil, fbuf[:])
_, fn, err := uds.ReadWithFDs(replyConn, nil, fbuf[:])
require.NoError(t, err)
require.Equal(t, 1, fn)
listener, err := net.FileListener(fbuf[0])
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ func (s *Server) getDirectTCPIPForwardDialer(scx *srv.ServerContext) (sshutils.T
}
defer remoteFD.Close()

_, _, err = proc.Conn.WriteWithFDs([]byte(addr), []*os.File{remoteFD})
_, _, err = uds.WriteWithFDs(proc.Conn, []byte(addr), []*os.File{remoteFD})
if err != nil {
local.Close()
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1152,7 +1152,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe
return nil, trace.Wrap(err)
}
defer remoteFD.Close()
_, _, err = proc.Conn.WriteWithFDs([]byte(addr), []*os.File{remoteFD})
_, _, err = uds.WriteWithFDs(proc.Conn, []byte(addr), []*os.File{remoteFD})
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1166,7 +1166,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe
defer close(fileCh)

fbuf := make([]*os.File, 1)
if _, fn, _ := localConn.ReadWithFDs(nil, fbuf); fn == 0 {
if _, fn, _ := uds.ReadWithFDs(localConn, nil, fbuf); fn == 0 {
fileCh <- nil
}
select {
Expand All @@ -1187,7 +1187,7 @@ func (s *Server) listenTCPIP(scx *srv.ServerContext, addr string) (*net.TCPListe
if listenerFD == nil {
return nil, trace.BadParameter("forwarding process did not return a listener")
}
if err := validateListenerSocket(scx, localConn.UnixConn, listenerFD); err != nil {
if err := validateListenerSocket(scx, localConn, listenerFD); err != nil {
return nil, trace.Wrap(err)
}

Expand Down
10 changes: 5 additions & 5 deletions lib/srv/regular/sshserver_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
func TestValidateListenerSocket(t *testing.T) {
t.Parallel()

newSocketFiles := func(t *testing.T) (*uds.Conn, *os.File) {
newSocketFiles := func(t *testing.T) (*net.UnixConn, *os.File) {
left, right, err := uds.NewSocketpair(uds.SocketTypeStream)
require.NoError(t, err)

Expand Down Expand Up @@ -64,7 +64,7 @@ func TestValidateListenerSocket(t *testing.T) {

tests := []struct {
name string
mutateFiles func(*testing.T, *uds.Conn, *os.File) (*uds.Conn, *os.File)
mutateFiles func(*testing.T, *net.UnixConn, *os.File) (*net.UnixConn, *os.File)
mutateConn func(*testing.T, *os.File)
assert require.ErrorAssertionFunc
}{
Expand All @@ -74,7 +74,7 @@ func TestValidateListenerSocket(t *testing.T) {
},
{
name: "socket type not STREAM",
mutateFiles: func(t *testing.T, conn *uds.Conn, file *os.File) (*uds.Conn, *os.File) {
mutateFiles: func(t *testing.T, conn *net.UnixConn, file *os.File) (*net.UnixConn, *os.File) {
left, right, err := uds.NewSocketpair(uds.SocketTypeDatagram)
require.NoError(t, err)
listenerFD, err := right.File()
Expand All @@ -99,7 +99,7 @@ func TestValidateListenerSocket(t *testing.T) {
},
{
name: "listener socket is not listening",
mutateFiles: func(t *testing.T, conn *uds.Conn, file *os.File) (*uds.Conn, *os.File) {
mutateFiles: func(t *testing.T, conn *net.UnixConn, file *os.File) (*net.UnixConn, *os.File) {
left, right, err := uds.NewSocketpair(uds.SocketTypeStream)
require.NoError(t, err)
listenerFD, err := right.File()
Expand All @@ -123,7 +123,7 @@ func TestValidateListenerSocket(t *testing.T) {
if tc.mutateConn != nil {
tc.mutateConn(t, listenerFD)
}
err := validateListenerSocket(&srv.ServerContext{}, conn.UnixConn, listenerFD)
err := validateListenerSocket(&srv.ServerContext{}, conn, listenerFD)
tc.assert(t, err)
})
}
Expand Down
3 changes: 1 addition & 2 deletions lib/sshutils/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"golang.org/x/crypto/ssh/agent"

"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils/uds"
)

// TCPIPForwardDialer represents a dialer used to handle TCPIP forward requests.
Expand All @@ -40,7 +39,7 @@ type TCPIPForwardDialer func(string) (net.Conn, error)
// TCPIPForwardProcess represents an instance of a port forwarding process.
type TCPIPForwardProcess struct {
// Conn is the socket used to request a dialer or listener in the process.
Conn *uds.Conn
Conn *net.UnixConn
// Done signals when the process completes.
Done <-chan struct{}
// Closer contains and extra io.Closer to run when the process as a whole
Expand Down
3 changes: 1 addition & 2 deletions lib/tbot/service_ssh_multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,7 @@ func (s *SSHMultiplexerService) handleConn(
b := make([]byte, 1)
fds := make([]*os.File, 1)

// TODO(espadolini): get rid of [uds.Conn]
n, fdn, err := (&uds.Conn{UnixConn: un}).ReadWithFDs(b, fds)
n, fdn, err := uds.ReadWithFDs(un, b, fds)
if err != nil {
return trace.Wrap(err, "reading request")
}
Expand Down
6 changes: 4 additions & 2 deletions lib/utils/uds/socketpair_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ package uds

import (
"errors"
"net"

"github.com/gravitational/trace"
)

var nonUnixErr = errors.New("socket pair not available on non-unix platform")

// NewSocketpair creates a unix socket pair, returning the halves as files.
func NewSocketpair(t SocketType) (left, right *Conn, err error) {
// NewSocketpair creates a unix socket pair, returning the halves as
// [*net.UnixConn]s.
func NewSocketpair(t SocketType) (left, right *net.UnixConn, err error) {
return nil, nil, trace.Wrap(nonUnixErr)
}
6 changes: 4 additions & 2 deletions lib/utils/uds/socketpair_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
package uds

import (
"net"
"os"

"github.com/gravitational/trace"
)

// NewSocketpair creates a unix socket pair, returning the halves as files.
func NewSocketpair(t SocketType) (left, right *Conn, err error) {
// NewSocketpair creates a unix socket pair, returning the halves as
// [*net.UnixConn]s.
func NewSocketpair(t SocketType) (left, right *net.UnixConn, err error) {
lfd, rfd, err := cloexecSocketpair(t)
if err != nil {
return nil, nil, trace.Wrap(err)
Expand Down
7 changes: 4 additions & 3 deletions lib/utils/uds/socketpair_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package uds
import (
"fmt"
"io"
"net"
"os"
"testing"

Expand Down Expand Up @@ -51,7 +52,7 @@ func TestSocketparFDPassing(t *testing.T) {
buf := make([]byte, 1024)
fbuf := make([]*os.File, maxFiles*2)
for {
n, fn, err := server.ReadWithFDs(buf, fbuf)
n, fn, err := ReadWithFDs(server, buf, fbuf)
if err != nil {
select {
case <-done:
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestSocketparFDPassing(t *testing.T) {
msg := fmt.Sprintf("send-%d", f)
// conns are the local halves of socket pairs that we
// will use to read our message back from the server.
conns := make([]*Conn, 0, f)
conns := make([]*net.UnixConn, 0, f)

// fds are the remote halves of socket pairs to be sent
// to the server along with the associated message.
Expand All @@ -102,7 +103,7 @@ func TestSocketparFDPassing(t *testing.T) {
// write message and files together so that server reads them
// together and therefore will know what message to send back
// over the fds.
_, _, err := client.WriteWithFDs([]byte(msg), fds)
_, _, err := WriteWithFDs(client, []byte(msg), fds)
if err != nil {
return trace.Errorf("failed to write fds: %v", err)
}
Expand Down
20 changes: 7 additions & 13 deletions lib/utils/uds/uds_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,25 @@ import (
"github.com/gravitational/trace"
)

var errNonUnix = errors.New("uds.Conn only supported on unix")
// TODO(espadolini): reevaluate UDS support on Windows (passing file descriptors
// is probably not supported, still)
var errNonUnix = errors.New("unix domain sockets are only supported on unix")

// assert that *Conn implements net.Conn.
var _ net.Conn = (*Conn)(nil)

// Conn extends net.UnixConn with additional useful methods.
type Conn struct {
*net.UnixConn
}

// FromFile attempts to create a [Conn] from a file. The returned conn
// FromFile attempts to create a [net.UnixConn] from a file. The returned conn
// is independent from the file and closing one does not close the other.
func FromFile(fd *os.File) (*Conn, error) {
func FromFile(fd *os.File) (*net.UnixConn, error) {
return nil, trace.Wrap(errNonUnix)
}

// WriteWithFDs performs a write that may also send associated FDs. Note that unless the
// underlying socket is a datagram socket it is not guaranteed that the exact bytes written
// will match the bytes received with the fds on the reader side.
func (c *Conn) WriteWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) {
func WriteWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) {
return 0, 0, trace.Wrap(errNonUnix)
}

// ReadWithFDs reads data and associated fds. Note that the underlying socket must be a datagram socket
// if you need the bytes read to exactly match the bytes sent with the FDs.
func (c *Conn) ReadWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) {
func ReadWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) {
return 0, 0, trace.Wrap(errNonUnix)
}
18 changes: 5 additions & 13 deletions lib/utils/uds/uds_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,9 @@ import (
"github.com/gravitational/trace"
)

// assert that *Conn implements net.Conn.
var _ net.Conn = (*Conn)(nil)

// Conn extends net.UnixConn with additional useful methods.
type Conn struct {
*net.UnixConn
}

// FromFile attempts to create a [Conn] from a file. The returned conn
// FromFile attempts to create a [net.UnixConn] from a file. The returned conn
// is independent from the file and closing one does not close the other.
func FromFile(fd *os.File) (*Conn, error) {
func FromFile(fd *os.File) (*net.UnixConn, error) {
fconn, err := net.FileConn(fd)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -50,13 +42,13 @@ func FromFile(fd *os.File) (*Conn, error) {
return nil, trace.Errorf("unexpected conn type %T (expected %T)", fconn, uconn)
}

return &Conn{uconn}, nil
return uconn, nil
}

// WriteWithFDs performs a write that may also send associated FDs. Note that unless the
// underlying socket is a datagram socket it is not guaranteed that the exact bytes written
// will match the bytes received with the fds on the reader side.
func (c *Conn) WriteWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) {
func WriteWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) {
fbuf := make([]int, 0, len(fds))

for _, fd := range fds {
Expand Down Expand Up @@ -84,7 +76,7 @@ const (

// ReadWithFDs reads data and associated fds. Note that the underlying socket must be a datagram socket
// if you need the bytes read to exactly match the bytes sent with the FDs.
func (c *Conn) ReadWithFDs(b []byte, fds []*os.File) (n, fdn int, err error) {
func ReadWithFDs(c *net.UnixConn, b []byte, fds []*os.File) (n, fdn int, err error) {
// set up a buffer capable of supporting the maximum possible out of band data
obuf := make([]byte, syscall.CmsgSpace(sizeOfInt*len(fds)))

Expand Down

0 comments on commit 3de08d2

Please sign in to comment.