diff --git a/lib/vnet/setup.go b/lib/vnet/setup.go index 56456097cff5f..55e9f93520472 100644 --- a/lib/vnet/setup.go +++ b/lib/vnet/setup.go @@ -18,6 +18,7 @@ package vnet import ( "context" + "fmt" "log/slog" "os" "time" @@ -27,69 +28,142 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -// Run is a blocking call to create and start Teleport VNet. -func Run(ctx context.Context, appProvider AppProvider) error { +// SetupAndRun creates a network stack for VNet and runs it in the background. To do this, it also +// needs to launch an admin subcommand in the background. It returns [ProcessManager] which controls +// the lifecycle of both background tasks. +// +// The caller is expected to call Close on the process manager to close the network stack, clean +// up any resources used by it and terminate the admin subcommand. +// +// ctx is used to wait for setup steps that happen before SetupAndRun hands out the control to the +// process manager. If ctx gets canceled during SetupAndRun, the process manager gets closed along +// with its background tasks. +func SetupAndRun(ctx context.Context, appProvider AppProvider) (*ProcessManager, error) { ipv6Prefix, err := NewIPv6Prefix() if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - dnsIPv6 := ipv6WithSuffix(ipv6Prefix, []byte{2}) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + pm, processCtx := newProcessManager() + success := false + defer func() { + if !success { + // Closes the socket and background tasks. + pm.Close() + } + }() + + // Create the socket that's used to receive the TUN device from the admin subcommand. + socket, socketPath, err := createUnixSocket() + if err != nil { + return nil, trace.Wrap(err) + } + slog.DebugContext(ctx, "Created unix socket for admin subcommand", "socket", socketPath) + pm.AddCriticalBackgroundTask("socket closer", func() error { + // Keep the socket open until the process context is canceled. + // Closing the socket signals the admin subcommand to terminate. + <-processCtx.Done() + return trace.NewAggregate(processCtx.Err(), socket.Close()) + }) - tunCh, adminCommandErrCh := CreateAndSetupTUNDevice(ctx, ipv6Prefix.String(), dnsIPv6.String()) + pm.AddCriticalBackgroundTask("admin subcommand", func() error { + return trace.Wrap(execAdminSubcommand(processCtx, socketPath, ipv6Prefix.String(), dnsIPv6.String())) + }) + + recvTUNErr := make(chan error, 1) + var tun tun.Device + go func() { + // Unblocks after receiving a TUN device or when the context gets canceled (and thus socket gets + // closed). + tunDevice, err := receiveTUNDevice(socket) + tun = tunDevice + recvTUNErr <- err + }() - var tun TUNDevice select { - case err := <-adminCommandErrCh: - return trace.Wrap(err) - case tun = <-tunCh: + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case <-processCtx.Done(): + return nil, trace.Wrap(context.Cause(processCtx)) + case err := <-recvTUNErr: + if err != nil { + if processCtx.Err() != nil { + // Both errors being present means that VNet failed to receive a TUN device because of a + // problem with the admin subcommand. + // Returning error from processCtx will be more informative to the user, e.g., the error + // will say "password prompt closed by user" instead of "read from closed socket". + slog.DebugContext(ctx, "Error from recvTUNErr ignored in favor of processCtx.Err", "error", err) + return nil, trace.Wrap(context.Cause(processCtx)) + } + return nil, trace.Wrap(err, "receiving TUN from admin subcommand") + } } appResolver, err := NewTCPAppResolver(appProvider) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - manager, err := NewManager(&Config{ + ns, err := newNetworkStack(&Config{ TUNDevice: tun, IPv6Prefix: ipv6Prefix, DNSIPv6: dnsIPv6, TCPHandlerResolver: appResolver, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - allErrors := make(chan error, 2) - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - // Make sure to cancel the context if manager.Run terminates for any reason. - defer cancel() - err := trace.Wrap(manager.Run(ctx), "running VNet manager") - allErrors <- err - return err + pm.AddCriticalBackgroundTask("network stack", func() error { + return trace.Wrap(ns.Run(processCtx)) }) - g.Go(func() error { - var adminCommandErr error - select { - case adminCommandErr = <-adminCommandErrCh: - // The admin command exited before the context was canceled, cancel everything and exit. - cancel() - case <-ctx.Done(): - // The context has been canceled, the admin command should now exit. - adminCommandErr = <-adminCommandErrCh + + success = true + return pm, nil +} + +func newProcessManager() (*ProcessManager, context.Context) { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + return &ProcessManager{ + g: g, + cancel: cancel, + }, ctx +} + +// ProcessManager handles background tasks needed to run VNet. +// Its semantics are similar to an error group with a context, but it cancels the context whenever +// any task returns prematurely, that is, a task exits while the context was not canceled. +type ProcessManager struct { + g *errgroup.Group + cancel context.CancelFunc +} + +// AddCriticalBackgroundTask adds a function to the error group. [task] is expected to block until +// the context returned by [newProcessManager] gets canceled. The context gets canceled either by +// calling Close on [ProcessManager] or if any task returns. +func (pm *ProcessManager) AddCriticalBackgroundTask(name string, task func() error) { + pm.g.Go(func() error { + err := task() + if err == nil { + // Make sure to always return an error so that the errgroup context is canceled. + err = fmt.Errorf("critical task %q exited prematurely", name) } - adminCommandErr = trace.Wrap(adminCommandErr, "running admin subcommand") - allErrors <- adminCommandErr - return adminCommandErr + return trace.Wrap(err) }) - // Deliberately ignoring the error from g.Wait() to return an aggregate of all errors. - _ = g.Wait() - close(allErrors) - return trace.NewAggregateFromChannel(allErrors, context.Background()) +} + +// Wait blocks and waits for the background tasks to finish, which typically happens when another +// goroutine calls Close on the process manager. +func (pm *ProcessManager) Wait() error { + return trace.Wrap(pm.g.Wait()) +} + +// Close stops any active background tasks by canceling the underlying context. +func (pm *ProcessManager) Close() { + pm.cancel() } // AdminSubcommand is the tsh subcommand that should run as root that will create and setup a TUN device and @@ -136,28 +210,6 @@ func AdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string } } -// CreateAndSetupTUNDevice creates a virtual network device and configures the host OS to use that device for -// VNet connections. -// -// If not already running as root, it will spawn a root process to handle the TUN creation and host -// configuration. -// -// After the TUN device is created, it will be sent on the result channel. Any error will be sent on the err -// channel. Always select on both the result channel and the err channel when waiting for a result. -// -// This will keep running until [ctx] is canceled or an unrecoverable error is encountered, in order to keep -// the host OS configuration up to date. -func CreateAndSetupTUNDevice(ctx context.Context, ipv6Prefix, dnsAddr string) (<-chan tun.Device, <-chan error) { - if os.Getuid() == 0 { - // We can get here if the user runs `tsh vnet` as root, but it is not in the expected path when - // started as a regular user. Typically we expect `tsh vnet` to be run as a non-root user, and for - // AdminSubcommand to directly call createAndSetupTUNDeviceAsRoot. - return createAndSetupTUNDeviceAsRoot(ctx, ipv6Prefix, dnsAddr) - } else { - return createAndSetupTUNDeviceWithoutRoot(ctx, ipv6Prefix, dnsAddr) - } -} - // createAndSetupTUNDeviceAsRoot creates a virtual network device and configures the host OS to use that device for // VNet connections. // diff --git a/lib/vnet/setup_darwin.go b/lib/vnet/setup_darwin.go index 5872889c2b644..3cb7a3ae6346c 100644 --- a/lib/vnet/setup_darwin.go +++ b/lib/vnet/setup_darwin.go @@ -23,7 +23,6 @@ import ( "context" "errors" "fmt" - "log/slog" "net" "os" "os/exec" @@ -33,7 +32,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" - "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/tun" @@ -42,56 +40,16 @@ import ( "github.com/gravitational/teleport/api/types" ) -// createAndSetupTUNDeviceWithoutRoot creates a virtual network device and configures the host OS to use that -// device for VNet connections. It will spawn a root process to handle the TUN creation and host -// configuration. -// -// After the TUN device is created, it will be sent on the result channel. Any error will be sent on the err -// channel. Always select on both the result channel and the err channel when waiting for a result. -// -// This will keep running until [ctx] is canceled or an unrecoverable error is encountered, in order to keep -// the host OS configuration up to date. -func createAndSetupTUNDeviceWithoutRoot(ctx context.Context, ipv6Prefix, dnsAddr string) (<-chan tun.Device, <-chan error) { - tunCh := make(chan tun.Device, 1) - errCh := make(chan error, 1) - - slog.InfoContext(ctx, "Spawning child process as root to create and setup TUN device") - socket, socketPath, err := createUnixSocket() +// receiveTUNDevice is a blocking call which waits for the admin subcommand to pass over the socket +// the name and fd of the TUN device. +func receiveTUNDevice(socket *net.UnixListener) (tun.Device, error) { + tunName, tunFd, err := recvTUNNameAndFd(socket) if err != nil { - errCh <- trace.Wrap(err, "creating unix socket") - return tunCh, errCh - } - - // Make sure all goroutines complete before sending an err on the error chan, to be sure they all have a - // chance to clean up before the process terminates. - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - // Requirements: - // - must close the socket concurrently with recvTUNNameAndFd if the context is canceled to unblock - // a stuck AcceptUnix (can't defer). - // - must close the socket exactly once before letting the process terminate. - <-ctx.Done() - return trace.Wrap(socket.Close()) - }) - g.Go(func() error { - // Admin command is expected to run until ctx is canceled. - return trace.Wrap(execAdminSubcommand(ctx, socketPath, ipv6Prefix, dnsAddr)) - }) - g.Go(func() error { - tunName, tunFd, err := recvTUNNameAndFd(ctx, socket) - if err != nil { - return trace.Wrap(err, "receiving TUN name and file descriptor") - } - tunDevice, err := tun.CreateTUNFromFile(os.NewFile(tunFd, tunName), 0) - if err != nil { - return trace.Wrap(err, "creating TUN device from file descriptor") - } - tunCh <- tunDevice - return nil - }) - go func() { errCh <- g.Wait() }() + return nil, trace.Wrap(err, "receiving TUN name and file descriptor") + } - return tunCh, errCh + tunDevice, err := tun.CreateTUNFromFile(os.NewFile(tunFd, tunName), 0) + return tunDevice, trace.Wrap(err, "creating TUN device from file descriptor") } func execAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { @@ -135,10 +93,36 @@ do shell script quoted form of executableName & `+ if strings.Contains(stderr, "-128") { return trace.Errorf("password prompt closed by user") } - return trace.Wrap(exitError, "admin subcommand exited, stderr: %s", stderr) + + if errors.Is(ctx.Err(), context.Canceled) { + // osascript exiting due to canceled context. + return ctx.Err() + } + + stderrDesc := "" + if stderr != "" { + stderrDesc = fmt.Sprintf(", stderr: %s", stderr) + } + return trace.Wrap(exitError, "osascript exited%s", stderrDesc) } + return trace.Wrap(err) } + + if ctx.Err() == nil { + // The admin subcommand is expected to run until VNet gets stopped (in other words, until ctx + // gets canceled). + // + // If it exits with no error _before_ ctx is canceled, then it most likely means that the socket + // was unexpectedly removed. When the socket gets removed, the admin subcommand assumes that the + // unprivileged process (executing this code here) has quit and thus it should quit as well. But + // we know that it's not the case, so in this scenario we return an error instead. + // + // If we don't return an error here, then other code won't be properly notified about the fact + // that the admin process has quit. + return trace.Errorf("admin subcommand exited prematurely with no error (likely because socket was removed)") + } + return nil } @@ -178,19 +162,12 @@ func sendTUNNameAndFd(socketPath, tunName string, fd uintptr) error { // recvTUNNameAndFd receives the name of a TUN device and its open file descriptor over a unix socket, meant // for passing the TUN from the root process which must create it to the user process. -func recvTUNNameAndFd(ctx context.Context, socket *net.UnixListener) (string, uintptr, error) { +func recvTUNNameAndFd(socket *net.UnixListener) (string, uintptr, error) { conn, err := socket.AcceptUnix() if err != nil { return "", 0, trace.Wrap(err, "accepting connection on unix socket") } - - // Close the connection early to unblock reads if the context is canceled. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - <-ctx.Done() - conn.Close() - }() + defer conn.Close() msg := make([]byte, 128) oob := make([]byte, unix.CmsgSpace(4)) // Fd is 4 bytes diff --git a/lib/vnet/setup_other.go b/lib/vnet/setup_other.go index 0c19b56368f6a..4d1976f93c088 100644 --- a/lib/vnet/setup_other.go +++ b/lib/vnet/setup_other.go @@ -21,6 +21,7 @@ package vnet import ( "context" + "net" "runtime" "github.com/gravitational/trace" @@ -32,16 +33,22 @@ var ( ErrVnetNotImplemented = &trace.NotImplementedError{Message: "VNet is not implemented on " + runtime.GOOS} ) -func createAndSetupTUNDeviceWithoutRoot(ctx context.Context, ipv6Prefix, dnsAddr string) (<-chan tun.Device, <-chan error) { - errCh := make(chan error, 1) - errCh <- trace.Wrap(ErrVnetNotImplemented) - return nil, errCh +func createUnixSocket() (*net.UnixListener, string, error) { + return nil, "", trace.Wrap(ErrVnetNotImplemented) } func sendTUNNameAndFd(socketPath, tunName string, fd uintptr) error { return trace.Wrap(ErrVnetNotImplemented) } +func receiveTUNDevice(socket *net.UnixListener) (tun.Device, error) { + return nil, trace.Wrap(ErrVnetNotImplemented) +} + func configureOS(ctx context.Context, cfg *osConfig) error { return trace.Wrap(ErrVnetNotImplemented) } + +func execAdminSubcommand(ctx context.Context, socketPath, ipv6Prefix, dnsAddr string) error { + return trace.Wrap(ErrVnetNotImplemented) +} diff --git a/lib/vnet/setup_test.go b/lib/vnet/setup_test.go new file mode 100644 index 0000000000000..5309150e35b5e --- /dev/null +++ b/lib/vnet/setup_test.go @@ -0,0 +1,77 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vnet + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestProcessManager_PrematureReturn(t *testing.T) { + pm, pmCtx := newProcessManager() + defer pm.Close() + + pm.AddCriticalBackgroundTask("premature return", func() error { + return nil + }) + pm.AddCriticalBackgroundTask("context-aware task", func() error { + <-pmCtx.Done() + return pmCtx.Err() + }) + + err := pm.Wait() + require.ErrorContains(t, err, "critical task \"premature return\" exited prematurely") + // Verify that the cancellation cause is propagated through the context. + require.ErrorIs(t, err, context.Cause(pmCtx)) +} + +func TestProcessManager_ReturnWithError(t *testing.T) { + pm, pmCtx := newProcessManager() + defer pm.Close() + + expectedErr := fmt.Errorf("lorem ipsum dolor sit amet") + pm.AddCriticalBackgroundTask("return with error", func() error { + return expectedErr + }) + pm.AddCriticalBackgroundTask("context-aware task", func() error { + <-pmCtx.Done() + return pmCtx.Err() + }) + + err := pm.Wait() + require.ErrorIs(t, err, expectedErr) + require.ErrorIs(t, err, context.Cause(pmCtx)) +} + +func TestProcessManager_Close(t *testing.T) { + pm, pmCtx := newProcessManager() + defer pm.Close() + + pm.AddCriticalBackgroundTask("context-aware task", func() error { + <-pmCtx.Done() + return pmCtx.Err() + }) + + pm.Close() + + err := pm.Wait() + require.ErrorIs(t, err, context.Canceled) + require.ErrorIs(t, err, context.Cause(pmCtx)) +} diff --git a/lib/vnet/vnet.go b/lib/vnet/vnet.go index 812fdc34fb33a..66f07a92ff6cf 100644 --- a/lib/vnet/vnet.go +++ b/lib/vnet/vnet.go @@ -21,6 +21,7 @@ import ( "errors" "log/slog" "net" + "os" "sync" "github.com/gravitational/trace" @@ -148,8 +149,8 @@ type TUNDevice interface { Close() error } -// Manager holds configuration and state for the VNet. -type Manager struct { +// NetworkStack holds configuration and state for the VNet. +type NetworkStack struct { // stack is the gVisor networking stack. stack *stack.Stack @@ -177,19 +178,19 @@ type Manager struct { // destroyed is a channel that will be closed when the VNet is in the process of being destroyed. // All goroutines should terminate quickly after either this is closed or the context passed to - // [Manager.Run] is canceled. + // [NetworkStack.Run] is canceled. destroyed chan struct{} - // wg is a [sync.WaitGroup] that keeps track of all running goroutines started by the [Manager]. + // wg is a [sync.WaitGroup] that keeps track of all running goroutines started by the [NetworkStack]. wg sync.WaitGroup - // state holds all mutable state for the Manager, it is currently protect by a single RWMutex, this could - // be optimized as necessary. + // state holds all mutable state for the NetworkStack. state state slog *slog.Logger } type state struct { + // mu is a single mutex that protects the whole state struct. This could be optimized as necessary. mu sync.RWMutex // Each app gets assigned both an IPv4 address and an IPv6 address, where the 4-bit suffix of the IPv6 @@ -213,12 +214,10 @@ func newState() state { } } -// NewManager creates a new VNet manager with the given configuration and root -// context. Call Run() on the returned manager to start the VNet. -// NewManager creates a new VNet manager with the given configuration and root context. It takes ownership of -// [cfg.TUNDevice] and will handle closing it before Run() returns. Call Run() on the returned manager to -// start the VNet. -func NewManager(cfg *Config) (*Manager, error) { +// newNetworkStack creates a new VNet network stack with the given configuration and root context. +// It takes ownership of [cfg.TUNDevice] and will handle closing it before Run() returns. Call Run() +// on the returned network stack to start the VNet. +func newNetworkStack(cfg *Config) (*NetworkStack, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -233,7 +232,7 @@ func NewManager(cfg *Config) (*Manager, error) { return nil, trace.Wrap(err) } - m := &Manager{ + ns := &NetworkStack{ tun: cfg.TUNDevice, stack: stack, linkEndpoint: linkEndpoint, @@ -244,11 +243,11 @@ func NewManager(cfg *Config) (*Manager, error) { slog: slog, } - tcpForwarder := tcp.NewForwarder(m.stack, tcpReceiveBufferSize, maxInFlightTCPConnectionAttempts, m.handleTCP) - m.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + tcpForwarder := tcp.NewForwarder(ns.stack, tcpReceiveBufferSize, maxInFlightTCPConnectionAttempts, ns.handleTCP) + ns.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - udpForwarder := udp.NewForwarder(m.stack, m.handleUDP) - m.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + udpForwarder := udp.NewForwarder(ns.stack, ns.handleUDP) + ns.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) if cfg.DNSIPv6 != (tcpip.Address{}) { upstreamNameserverSource := cfg.upstreamNameserverSource @@ -258,17 +257,17 @@ func NewManager(cfg *Config) (*Manager, error) { return nil, trace.Wrap(err) } } - dnsServer, err := dns.NewServer(m, upstreamNameserverSource) + dnsServer, err := dns.NewServer(ns, upstreamNameserverSource) if err != nil { return nil, trace.Wrap(err) } - if err := m.assignUDPHandler(cfg.DNSIPv6, dnsServer); err != nil { + if err := ns.assignUDPHandler(cfg.DNSIPv6, dnsServer); err != nil { return nil, trace.Wrap(err) } slog.DebugContext(context.Background(), "Serving DNS on IPv6.", "dns_addr", cfg.DNSIPv6) } - return m, nil + return ns, nil } func createStack() (*stack.Stack, *channel.Endpoint, error) { @@ -314,8 +313,8 @@ func installVnetRoutes(stack *stack.Stack) error { // Run starts the VNet. It blocks until [ctx] is canceled, at which point it closes the link endpoint, waits // for all goroutines to terminate, and destroys the networking stack. -func (m *Manager) Run(ctx context.Context) error { - m.slog.InfoContext(ctx, "Running Teleport VNet.", "ipv6_prefix", m.ipv6Prefix) +func (ns *NetworkStack) Run(ctx context.Context) error { + ns.slog.InfoContext(ctx, "Running Teleport VNet.", "ipv6_prefix", ns.ipv6Prefix) ctx, cancel := context.WithCancel(ctx) @@ -324,7 +323,7 @@ func (m *Manager) Run(ctx context.Context) error { g.Go(func() error { // Make sure to cancel the context in case this exits prematurely with a nil error. defer cancel() - err := forwardBetweenTunAndNetstack(ctx, m.tun, m.linkEndpoint) + err := forwardBetweenTunAndNetstack(ctx, ns.tun, ns.linkEndpoint) allErrors <- err return err }) @@ -333,13 +332,13 @@ func (m *Manager) Run(ctx context.Context) error { // have canceled it) destroy everything and quit. <-ctx.Done() - // In-flight connections should start terminating after closing [m.destroyed]. - close(m.destroyed) + // In-flight connections should start terminating after closing [ns.destroyed]. + close(ns.destroyed) // Close the link endpoint and the TUN, this should cause [forwardBetweenTunAndNetstack] to terminate // if it hasn't already. - m.linkEndpoint.Close() - err := trace.Wrap(m.tun.Close(), "closing TUN device") + ns.linkEndpoint.Close() + err := trace.Wrap(ns.tun.Close(), "closing TUN device") allErrors <- err return err @@ -349,19 +348,19 @@ func (m *Manager) Run(ctx context.Context) error { _ = g.Wait() // Wait for all connections and goroutines to clean themselves up. - m.wg.Wait() + ns.wg.Wait() // Now we can destroy the gVisor networking stack and wait for all its goroutines to terminate. - m.stack.Destroy() + ns.stack.Destroy() close(allErrors) return trace.NewAggregateFromChannel(allErrors, context.Background()) } -func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { +func (ns *NetworkStack) handleTCP(req *tcp.ForwarderRequest) { // Add 1 to the waitgroup because the networking stack runs this in its own goroutine. - m.wg.Add(1) - defer m.wg.Done() + ns.wg.Add(1) + defer ns.wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -376,11 +375,11 @@ func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { }() id := req.ID() - slog := m.slog.With("request", id) + slog := ns.slog.With("request", id) slog.DebugContext(ctx, "Handling TCP connection.") defer slog.DebugContext(ctx, "Finished handling TCP connection.") - handler, ok := m.getTCPHandler(id.LocalAddress) + handler, ok := ns.getTCPHandler(id.LocalAddress) if !ok { slog.DebugContext(ctx, "No handler for address.", "addr", id.LocalAddress) return @@ -404,17 +403,17 @@ func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { conn := gonet.NewTCPConn(&wq, endpoint) - m.wg.Add(1) + ns.wg.Add(1) go func() { defer func() { cancel() conn.Close() - m.wg.Done() + ns.wg.Done() }() select { case <-notifyCh: slog.DebugContext(ctx, "Got HUP or ERR, canceling request context and closing TCP conn.") - case <-m.destroyed: + case <-ns.destroyed: slog.DebugContext(ctx, "VNet is being destroyed, canceling request context and closing TCP conn.") case <-ctx.Done(): slog.DebugContext(ctx, "Request context canceled, closing TCP conn.") @@ -433,63 +432,63 @@ func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { } } -func (m *Manager) getTCPHandler(addr tcpip.Address) (TCPHandler, bool) { - m.state.mu.RLock() - defer m.state.mu.RUnlock() - handler, ok := m.state.tcpHandlers[ipv4Suffix(addr)] +func (ns *NetworkStack) getTCPHandler(addr tcpip.Address) (TCPHandler, bool) { + ns.state.mu.RLock() + defer ns.state.mu.RUnlock() + handler, ok := ns.state.tcpHandlers[ipv4Suffix(addr)] return handler, ok } // assignTCPHandler assigns an IPv4 address to [handlerSpec] from its preferred CIDR range, and returns that // new assigned address. -func (m *Manager) assignTCPHandler(handlerSpec *TCPHandlerSpec, fqdn string) (ipv4, error) { +func (ns *NetworkStack) assignTCPHandler(handlerSpec *TCPHandlerSpec, fqdn string) (ipv4, error) { _, ipNet, err := net.ParseCIDR(handlerSpec.IPv4CIDRRange) if err != nil { return 0, trace.Wrap(err, "parsing CIDR %q", handlerSpec.IPv4CIDRRange) } - m.state.mu.Lock() - defer m.state.mu.Unlock() + ns.state.mu.Lock() + defer ns.state.mu.Unlock() ip, err := randomFreeIPv4InNet(ipNet, func(ip ipv4) bool { - _, taken := m.state.tcpHandlers[ip] + _, taken := ns.state.tcpHandlers[ip] return !taken }) if err != nil { return 0, trace.Wrap(err, "assigning IP address") } - m.state.tcpHandlers[ip] = handlerSpec.TCPHandler - m.state.appIPs[fqdn] = ip + ns.state.tcpHandlers[ip] = handlerSpec.TCPHandler + ns.state.appIPs[fqdn] = ip - if err := m.addProtocolAddress(tcpip.AddrFrom4(ip.asArray())); err != nil { + if err := ns.addProtocolAddress(tcpip.AddrFrom4(ip.asArray())); err != nil { return 0, trace.Wrap(err) } - if err := m.addProtocolAddress(ipv6WithSuffix(m.ipv6Prefix, ip.asSlice())); err != nil { + if err := ns.addProtocolAddress(ipv6WithSuffix(ns.ipv6Prefix, ip.asSlice())); err != nil { return 0, trace.Wrap(err) } return ip, nil } -func (m *Manager) handleUDP(req *udp.ForwarderRequest) { - m.wg.Add(1) +func (ns *NetworkStack) handleUDP(req *udp.ForwarderRequest) { + ns.wg.Add(1) go func() { - defer m.wg.Done() - m.handleUDPConcurrent(req) + defer ns.wg.Done() + ns.handleUDPConcurrent(req) }() } -func (m *Manager) handleUDPConcurrent(req *udp.ForwarderRequest) { +func (ns *NetworkStack) handleUDPConcurrent(req *udp.ForwarderRequest) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() id := req.ID() - slog := m.slog.With("request", id) + slog := ns.slog.With("request", id) slog.DebugContext(ctx, "Handling UDP request.") defer slog.DebugContext(ctx, "Finished handling UDP request.") - handler, ok := m.getUDPHandler(id.LocalAddress) + handler, ok := ns.getUDPHandler(id.LocalAddress) if !ok { slog.DebugContext(ctx, "No handler for address.") return @@ -505,20 +504,20 @@ func (m *Manager) handleUDPConcurrent(req *udp.ForwarderRequest) { return } - conn := gonet.NewUDPConn(m.stack, &wq, endpoint) + conn := gonet.NewUDPConn(ns.stack, &wq, endpoint) defer conn.Close() - m.wg.Add(1) + ns.wg.Add(1) go func() { defer func() { cancel() conn.Close() - m.wg.Done() + ns.wg.Done() }() select { case <-notifyCh: slog.DebugContext(ctx, "Got HUP or ERR, canceling request context and closing UDP conn.") - case <-m.destroyed: + case <-ns.destroyed: slog.DebugContext(ctx, "VNet is being destroyed, canceling request context and closing UDP conn.") case <-ctx.Done(): slog.DebugContext(ctx, "Request context canceled, closing UDP conn.") @@ -530,42 +529,42 @@ func (m *Manager) handleUDPConcurrent(req *udp.ForwarderRequest) { } } -func (m *Manager) getUDPHandler(addr tcpip.Address) (UDPHandler, bool) { +func (ns *NetworkStack) getUDPHandler(addr tcpip.Address) (UDPHandler, bool) { ipv4 := ipv4Suffix(addr) - m.state.mu.RLock() - defer m.state.mu.RUnlock() - handler, ok := m.state.udpHandlers[ipv4] + ns.state.mu.RLock() + defer ns.state.mu.RUnlock() + handler, ok := ns.state.udpHandlers[ipv4] return handler, ok } -func (m *Manager) assignUDPHandler(addr tcpip.Address, handler UDPHandler) error { +func (ns *NetworkStack) assignUDPHandler(addr tcpip.Address, handler UDPHandler) error { ipv4 := ipv4Suffix(addr) - m.state.mu.Lock() - defer m.state.mu.Unlock() - if _, ok := m.state.udpHandlers[ipv4]; ok { + ns.state.mu.Lock() + defer ns.state.mu.Unlock() + if _, ok := ns.state.udpHandlers[ipv4]; ok { return trace.AlreadyExists("Handler for %s is already set", addr) } - if err := m.addProtocolAddress(addr); err != nil { + if err := ns.addProtocolAddress(addr); err != nil { return trace.Wrap(err) } - m.state.udpHandlers[ipv4] = handler + ns.state.udpHandlers[ipv4] = handler return nil } // ResolveA implements [dns.Resolver.ResolveA]. -func (m *Manager) ResolveA(ctx context.Context, fqdn string) (dns.Result, error) { +func (ns *NetworkStack) ResolveA(ctx context.Context, fqdn string) (dns.Result, error) { // Do the actual resolution within a [singleflight.Group] keyed by [fqdn] to avoid concurrent requests to // resolve an FQDN and then assign an address to it. - resultAny, err, _ := m.resolveHandlerGroup.Do(fqdn, func() (any, error) { + resultAny, err, _ := ns.resolveHandlerGroup.Do(fqdn, func() (any, error) { // If we've already assigned an IP address to this app, resolve to it. - if ip, ok := m.appIPv4(fqdn); ok { + if ip, ok := ns.appIPv4(fqdn); ok { return dns.Result{ A: ip.asArray(), }, nil } // If fqdn is a Teleport-managed app, create a new handler for it. - handlerSpec, err := m.tcpHandlerResolver.ResolveTCPHandler(ctx, fqdn) + handlerSpec, err := ns.tcpHandlerResolver.ResolveTCPHandler(ctx, fqdn) if err != nil { if errors.Is(err, ErrNoTCPHandler) { // Did not find any known app, forward the DNS request upstream. @@ -575,7 +574,7 @@ func (m *Manager) ResolveA(ctx context.Context, fqdn string) (dns.Result, error) } // Assign an unused IP address to this app's handler. - ip, err := m.assignTCPHandler(handlerSpec, fqdn) + ip, err := ns.assignTCPHandler(handlerSpec, fqdn) if err != nil { return dns.Result{}, trace.Wrap(err, "assigning address to handler for %q", fqdn) } @@ -592,22 +591,22 @@ func (m *Manager) ResolveA(ctx context.Context, fqdn string) (dns.Result, error) } // ResolveAAAA implements [dns.Resolver.ResolveAAAA]. -func (m *Manager) ResolveAAAA(ctx context.Context, fqdn string) (dns.Result, error) { - result, err := m.ResolveA(ctx, fqdn) +func (ns *NetworkStack) ResolveAAAA(ctx context.Context, fqdn string) (dns.Result, error) { + result, err := ns.ResolveA(ctx, fqdn) if err != nil { return dns.Result{}, trace.Wrap(err) } if result.A != ([4]byte{}) { - result.AAAA = ipv6WithSuffix(m.ipv6Prefix, result.A[:]).As16() + result.AAAA = ipv6WithSuffix(ns.ipv6Prefix, result.A[:]).As16() result.A = [4]byte{} } return result, nil } -func (m *Manager) appIPv4(fqdn string) (ipv4, bool) { - m.state.mu.RLock() - defer m.state.mu.RUnlock() - ipv4, ok := m.state.appIPs[fqdn] +func (ns *NetworkStack) appIPv4(fqdn string) (ipv4, bool) { + ns.state.mu.RLock() + defer ns.state.mu.RUnlock() + ipv4, ok := ns.state.appIPs[fqdn] return ipv4, ok } @@ -615,7 +614,7 @@ func forwardBetweenTunAndNetstack(ctx context.Context, tun TUNDevice, linkEndpoi slog.DebugContext(ctx, "Forwarding IP packets between OS and VNet.") g, ctx := errgroup.WithContext(ctx) g.Go(func() error { return forwardNetstackToTUN(ctx, linkEndpoint, tun) }) - g.Go(func() error { return forwardTUNtoNetstack(tun, linkEndpoint) }) + g.Go(func() error { return forwardTUNtoNetstack(ctx, tun, linkEndpoint) }) return trace.Wrap(g.Wait()) } @@ -640,7 +639,9 @@ func forwardNetstackToTUN(ctx context.Context, linkEndpoint *channel.Endpoint, t } } -func forwardTUNtoNetstack(tun TUNDevice, linkEndpoint *channel.Endpoint) error { +// forwardTUNtoNetstack does not abort on ctx being canceled, but it does check the ctx error before +// returning os.ErrClosed from tun.Read. +func forwardTUNtoNetstack(ctx context.Context, tun TUNDevice, linkEndpoint *channel.Endpoint) error { const readOffset = device.MessageTransportHeaderSize bufs := make([][]byte, tun.BatchSize()) for i := range bufs { @@ -650,6 +651,10 @@ func forwardTUNtoNetstack(tun TUNDevice, linkEndpoint *channel.Endpoint) error { for { n, err := tun.Read(bufs, sizes, readOffset) if err != nil { + // tun.Read might get interrupted due to the TUN device getting closed after ctx cancellation. + if errors.Is(err, os.ErrClosed) && ctx.Err() != nil { + return ctx.Err() + } return trace.Wrap(err, "reading packets from TUN") } for i := range sizes[:n] { @@ -666,12 +671,12 @@ func forwardTUNtoNetstack(tun TUNDevice, linkEndpoint *channel.Endpoint) error { } } -func (m *Manager) addProtocolAddress(localAddress tcpip.Address) error { +func (ns *NetworkStack) addProtocolAddress(localAddress tcpip.Address) error { protocolAddress, err := protocolAddress(localAddress) if err != nil { return trace.Wrap(err) } - if err := m.stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + if err := ns.stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { return trace.Errorf("%s", err) } return nil diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index ac33f5f07a513..84bf8434faf63 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -64,7 +64,7 @@ func TestMain(m *testing.M) { type testPack struct { vnetIPv6Prefix tcpip.Address dnsIPv6 tcpip.Address - manager *Manager + ns *NetworkStack testStack *stack.Stack testLinkEndpoint *channel.Endpoint @@ -77,9 +77,9 @@ func newTestPack(t *testing.T, ctx context.Context, clock clockwork.FakeClock, a // Create an isolated userspace networking stack that can be manipulated from test code. It will be // connected to the VNet over the TUN interface. This emulates the host networking stack. - // This is a completely separate gvisor network stack than the one that will be created in NewManager - - // the two will be connected over a fake TUN interface. This exists so that the test can setup IP routes - // without affecting the host running the Test. + // This is a completely separate gvisor network stack than the one that will be created in + // NewNetworkStack - the two will be connected over a fake TUN interface. This exists so that the + // test can setup IP routes without affecting the host running the Test. testStack, testLinkEndpoint, err := createStack() require.NoError(t, err) @@ -120,12 +120,11 @@ func newTestPack(t *testing.T, ctx context.Context, clock clockwork.FakeClock, a }}) dnsIPv6 := ipv6WithSuffix(vnetIPv6Prefix, []byte{2}) - tcpHandlerResolver, err := NewTCPAppResolver(appProvider, withClock(clock)) require.NoError(t, err) // Create the VNet and connect it to the other side of the TUN. - manager, err := NewManager(&Config{ + ns, err := newNetworkStack(&Config{ TUNDevice: tun2, IPv6Prefix: vnetIPv6Prefix, DNSIPv6: dnsIPv6, @@ -137,7 +136,7 @@ func newTestPack(t *testing.T, ctx context.Context, clock clockwork.FakeClock, a utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ Name: "VNet", Task: func(ctx context.Context) error { - if err := manager.Run(ctx); !errIsOK(err) { + if err := ns.Run(ctx); !errIsOK(err) { return trace.Wrap(err) } return nil @@ -147,7 +146,7 @@ func newTestPack(t *testing.T, ctx context.Context, clock clockwork.FakeClock, a return &testPack{ vnetIPv6Prefix: vnetIPv6Prefix, dnsIPv6: dnsIPv6, - manager: manager, + ns: ns, testStack: testStack, testLinkEndpoint: testLinkEndpoint, localAddress: localAddress, @@ -168,7 +167,7 @@ func (p *testPack) dialIPPort(ctx context.Context, addr tcpip.Address, port uint NIC: nicID, Addr: addr, Port: port, - LinkAddr: p.manager.linkEndpoint.LinkAddress(), + LinkAddr: p.ns.linkEndpoint.LinkAddress(), }, ipv6.ProtocolNumber, ) @@ -187,7 +186,7 @@ func (p *testPack) dialUDP(ctx context.Context, addr tcpip.Address, port uint16) NIC: nicID, Addr: addr, Port: port, - LinkAddr: p.manager.linkEndpoint.LinkAddress(), + LinkAddr: p.ns.linkEndpoint.LinkAddress(), }, ipv6.ProtocolNumber, ) diff --git a/tool/tsh/common/vnet_common.go b/tool/tsh/common/vnet_common.go index aed9cad9c8bff..f753d3124b580 100644 --- a/tool/tsh/common/vnet_common.go +++ b/tool/tsh/common/vnet_common.go @@ -105,6 +105,7 @@ func (p *vnetAppProvider) GetDialOptions(ctx context.Context, profileName string dialOpts := &vnet.DialOptions{ WebProxyAddr: profile.WebProxyAddr, ALPNConnUpgradeRequired: profile.TLSRoutingConnUpgradeRequired, + InsecureSkipVerify: p.cf.InsecureSkipVerify, } if dialOpts.ALPNConnUpgradeRequired { dialOpts.RootClusterCACertPool, err = p.getRootClusterCACertPool(ctx, profileName) diff --git a/tool/tsh/common/vnet_darwin.go b/tool/tsh/common/vnet_darwin.go index 1313ffcacbed8..c31fccb363ab7 100644 --- a/tool/tsh/common/vnet_darwin.go +++ b/tool/tsh/common/vnet_darwin.go @@ -44,7 +44,18 @@ func (c *vnetCommand) run(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - return trace.Wrap(vnet.Run(cf.Context, appProvider)) + + processManager, err := vnet.SetupAndRun(cf.Context, appProvider) + if err != nil { + return trace.Wrap(err) + } + + go func() { + <-cf.Context.Done() + processManager.Close() + }() + + return trace.Wrap(processManager.Wait()) } type vnetAdminSetupCommand struct {