diff --git a/lib/vnet/app_resolver.go b/lib/vnet/app_resolver.go index ab1d44545005b..fbd91fe030bfa 100644 --- a/lib/vnet/app_resolver.go +++ b/lib/vnet/app_resolver.go @@ -26,6 +26,7 @@ import ( "log/slog" "net" "strings" + "sync" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -265,7 +266,7 @@ func (r *TCPAppResolver) resolveTCPHandlerForCluster( return nil, ErrNoTCPHandler } app := resp.Resources[0].GetApp() - appHandler, err := r.newTCPAppHandler(ctx, profileName, leafClusterName, app) + appHandler, err := r.newTCPAppHandler(ctx, log, profileName, leafClusterName, app) if err != nil { return nil, trace.Wrap(err) } @@ -282,49 +283,88 @@ func (r *TCPAppResolver) resolveTCPHandlerForCluster( } type tcpAppHandler struct { - profileName string - leafClusterName string - lp *alpnproxy.LocalProxy + log *slog.Logger + appProvider AppProvider + clock clockwork.Clock + profileName string + leafClusterName string + app types.Application + portToLocalProxy map[uint16]*alpnproxy.LocalProxy + // mu guards access to portToLocalProxy. + mu sync.Mutex } func (r *TCPAppResolver) newTCPAppHandler( ctx context.Context, + log *slog.Logger, profileName string, leafClusterName string, app types.Application, ) (*tcpAppHandler, error) { - dialOpts, err := r.appProvider.GetDialOptions(ctx, profileName) + return &tcpAppHandler{ + appProvider: r.appProvider, + clock: r.clock, + profileName: profileName, + leafClusterName: leafClusterName, + app: app, + portToLocalProxy: make(map[uint16]*alpnproxy.LocalProxy), + log: log.With(teleport.ComponentKey, "VNet.AppHandler"), + }, nil +} + +// getOrInitializeLocalProxy returns a separate local proxy for each port for multi-port apps. For +// single-port apps, it returns the same local proxy no matter the port. +func (h *tcpAppHandler) getOrInitializeLocalProxy(ctx context.Context, localPort uint16) (*alpnproxy.LocalProxy, error) { + h.mu.Lock() + defer h.mu.Unlock() + + // Connections to single-port apps need to go through a local proxy that has a cert with TargetPort + // set to 0. This ensures that the old behavior is kept for such apps, where the client can dial + // the public address of an app on any port and be routed to the port from the URI. + // + // https://github.com/gravitational/teleport/blob/master/rfd/0182-multi-port-tcp-app-access.md#vnet-with-single-port-apps + if len(h.app.GetTCPPorts()) == 0 { + localPort = 0 + } + + lp, ok := h.portToLocalProxy[localPort] + if ok { + return lp, nil + } + + dialOpts, err := h.appProvider.GetDialOptions(ctx, h.profileName) if err != nil { - return nil, trace.Wrap(err, "getting dial options for profile %q", profileName) + return nil, trace.Wrap(err, "getting dial options for profile %q", h.profileName) } - clusterClient, err := r.appProvider.GetCachedClient(ctx, profileName, leafClusterName) + clusterClient, err := h.appProvider.GetCachedClient(ctx, h.profileName, h.leafClusterName) if err != nil { return nil, trace.Wrap(err) } routeToApp := proto.RouteToApp{ - Name: app.GetName(), - PublicAddr: app.GetPublicAddr(), + Name: h.app.GetName(), + PublicAddr: h.app.GetPublicAddr(), // ClusterName must not be set to "" when targeting an app from a root cluster. Otherwise the // connection routed through a local proxy will just get lost somewhere in the cluster (with no // clear error being reported) and hang forever. ClusterName: clusterClient.ClusterName(), - URI: app.GetURI(), + URI: h.app.GetURI(), + TargetPort: uint32(localPort), } appCertIssuer := &appCertIssuer{ - appProvider: r.appProvider, - profileName: profileName, - leafClusterName: leafClusterName, + appProvider: h.appProvider, + profileName: h.profileName, + leafClusterName: h.leafClusterName, routeToApp: routeToApp, } - certChecker := client.NewCertChecker(appCertIssuer, r.clock) + certChecker := client.NewCertChecker(appCertIssuer, h.clock) middleware := &localProxyMiddleware{ certChecker: certChecker, - appProvider: r.appProvider, + appProvider: h.appProvider, routeToApp: routeToApp, - profileName: profileName, - leafClusterName: leafClusterName, + profileName: h.profileName, + leafClusterName: h.leafClusterName, } localProxyConfig := alpnproxy.LocalProxyConfig{ @@ -336,25 +376,28 @@ func (r *TCPAppResolver) newTCPAppHandler( ALPNConnUpgradeRequired: dialOpts.ALPNConnUpgradeRequired, Middleware: middleware, InsecureSkipVerify: dialOpts.InsecureSkipVerify, - Clock: r.clock, + Clock: h.clock, } - lp, err := alpnproxy.NewLocalProxy(localProxyConfig) + h.log.DebugContext(ctx, "Creating local proxy", "target_port", localPort) + newLP, err := alpnproxy.NewLocalProxy(localProxyConfig) if err != nil { return nil, trace.Wrap(err, "creating local proxy") } - return &tcpAppHandler{ - profileName: profileName, - leafClusterName: leafClusterName, - lp: lp, - }, nil + h.portToLocalProxy[localPort] = newLP + + return newLP, nil } // HandleTCPConnector handles an incoming TCP connection from VNet by passing it to the local alpn proxy, // which is set up with middleware to automatically handler certificate renewal and re-logins. -func (h *tcpAppHandler) HandleTCPConnector(ctx context.Context, connector func() (net.Conn, error)) error { - return trace.Wrap(h.lp.HandleTCPConnector(ctx, connector), "handling TCP connector") +func (h *tcpAppHandler) HandleTCPConnector(ctx context.Context, localPort uint16, connector func() (net.Conn, error)) error { + lp, err := h.getOrInitializeLocalProxy(ctx, localPort) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(lp.HandleTCPConnector(ctx, connector), "handling TCP connector") } // appCertIssuer implements [client.CertIssuer]. diff --git a/lib/vnet/vnet.go b/lib/vnet/vnet.go index 66f07a92ff6cf..fb5b6710ac220 100644 --- a/lib/vnet/vnet.go +++ b/lib/vnet/vnet.go @@ -117,7 +117,7 @@ type TCPHandlerSpec struct { // [connector] to complete the TCP handshake and get the TCP conn. This is so that clients will see that the // TCP connection was refused, instead of seeing a successful TCP dial that is immediately closed. type TCPHandler interface { - HandleTCPConnector(ctx context.Context, connector func() (net.Conn, error)) error + HandleTCPConnector(ctx context.Context, localPort uint16, connector func() (net.Conn, error)) error } // UDPHandler defines the behavior for handling UDP connections from VNet. @@ -423,7 +423,7 @@ func (ns *NetworkStack) handleTCP(req *tcp.ForwarderRequest) { return conn, nil } - if err := handler.HandleTCPConnector(ctx, connector); err != nil { + if err := handler.HandleTCPConnector(ctx, id.LocalPort, connector); err != nil { if errors.Is(err, context.Canceled) { slog.DebugContext(ctx, "TCP connection handler returned early due to canceled context.") } else { diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index 40eaaf10cfe62..59b75450bfb6e 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -52,6 +52,7 @@ import ( headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" "github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/utils" ) @@ -210,7 +211,12 @@ func (p *testPack) lookupHost(ctx context.Context, host string) ([]string, error return resolver.LookupHost(ctx, host) } -func (p *testPack) dialHost(ctx context.Context, host string) (net.Conn, error) { +// dialHost dials port 123 if port is zero. port matters only for multi-port tests. +func (p *testPack) dialHost(ctx context.Context, host string, port int) (net.Conn, error) { + const defaultPort = 123 + if port == 0 { + port = defaultPort + } addrs, err := p.lookupHost(ctx, host) if err != nil { return nil, trace.Wrap(err) @@ -219,7 +225,7 @@ func (p *testPack) dialHost(ctx context.Context, host string) (net.Conn, error) for _, addr := range addrs { netIP := net.ParseIP(addr) ip := tcpip.AddrFromSlice(netIP) - conn, err := p.dialIPPort(ctx, ip, 123) + conn, err := p.dialIPPort(ctx, ip, uint16(port)) if err != nil { allErrs = append(allErrs, trace.Wrap(err, "dialing %s", addr)) continue @@ -238,6 +244,7 @@ func (n noUpstreamNameservers) UpstreamNameservers(ctx context.Context) ([]strin type appSpec struct { // publicAddr is used bothas the name of the app and its public address in the final spec. publicAddr string + tcpPorts []*types.PortRange } type testClusterSpec struct { @@ -252,14 +259,18 @@ type echoAppProvider struct { dialOpts DialOptions reissueAppCert func() tls.Certificate onNewConnectionCallCount atomic.Uint32 + // requestedRouteToApps indexed by public address. + requestedRouteToApps map[string][]proto.RouteToApp + requestedRouteToAppsMu sync.RWMutex } // newEchoAppProvider returns an app provider with the list of named apps in each profile and leaf cluster. func newEchoAppProvider(clusterSpecs map[string]testClusterSpec, dialOpts DialOptions, reissueAppCert func() tls.Certificate) *echoAppProvider { return &echoAppProvider{ - clusters: clusterSpecs, - dialOpts: dialOpts, - reissueAppCert: reissueAppCert, + clusters: clusterSpecs, + dialOpts: dialOpts, + reissueAppCert: reissueAppCert, + requestedRouteToApps: make(map[string][]proto.RouteToApp), } } @@ -297,9 +308,25 @@ func (p *echoAppProvider) GetCachedClient(ctx context.Context, profileName, leaf } func (p *echoAppProvider) ReissueAppCert(ctx context.Context, profileName, leafClusterName string, routeToApp proto.RouteToApp) (tls.Certificate, error) { + p.requestedRouteToAppsMu.Lock() + defer p.requestedRouteToAppsMu.Unlock() + + p.requestedRouteToApps[routeToApp.PublicAddr] = append(p.requestedRouteToApps[routeToApp.PublicAddr], routeToApp) + return p.reissueAppCert(), nil } +func (p *echoAppProvider) AreAllRequestedRouteToAppsForPort(publicAddr string, port int) bool { + p.requestedRouteToAppsMu.RLock() + defer p.requestedRouteToAppsMu.RUnlock() + + routes := p.requestedRouteToApps[publicAddr] + + return apiutils.All(routes, func(route proto.RouteToApp) bool { + return route.TargetPort == uint32(port) + }) +} + func (p *echoAppProvider) GetDialOptions(ctx context.Context, profileName string) (*DialOptions, error) { return &p.dialOpts, nil } @@ -387,6 +414,19 @@ func (c *fakeAuthClient) GetResources(ctx context.Context, req *proto.ListResour if !strings.Contains(req.PredicateExpression, app.publicAddr) { continue } + spec := &types.AppV3{ + Metadata: types.Metadata{ + Name: app.publicAddr, + }, + Spec: types.AppSpecV3{ + PublicAddr: app.publicAddr, + }, + } + + if len(app.tcpPorts) != 0 { + spec.SetTCPPorts(app.tcpPorts) + } + resp.Resources = append(resp.Resources, &proto.PaginatedResource{ Resource: &proto.PaginatedResource_AppServer{ AppServer: &types.AppServerV3{ @@ -395,14 +435,7 @@ func (c *fakeAuthClient) GetResources(ctx context.Context, req *proto.ListResour Name: app.publicAddr, }, Spec: types.AppServerSpecV3{ - App: &types.AppV3{ - Metadata: types.Metadata{ - Name: app.publicAddr, - }, - Spec: types.AppSpecV3{ - PublicAddr: app.publicAddr, - }, - }, + App: spec, }, }, }, @@ -477,6 +510,17 @@ func TestDialFakeApp(t *testing.T) { appSpec{ publicAddr: "not.in.a.custom.zone", }, + appSpec{ + publicAddr: "multi-port.root1.example.com", + tcpPorts: []*types.PortRange{ + &types.PortRange{ + Port: 1337, + }, + &types.PortRange{ + Port: 4242, + }, + }, + }, }, customDNSZones: []string{ "myzone.example.com", @@ -488,6 +532,17 @@ func TestDialFakeApp(t *testing.T) { appSpec{ publicAddr: "echo1.leaf1.example.com", }, + appSpec{ + publicAddr: "multi-port.leaf1.example.com", + tcpPorts: []*types.PortRange{ + &types.PortRange{ + Port: 1337, + }, + &types.PortRange{ + Port: 4242, + }, + }, + }, }, }, "leaf2.example.com": { @@ -527,6 +582,7 @@ func TestDialFakeApp(t *testing.T) { validTestCases := []struct { app string + port int expectCIDR string }{ { @@ -565,6 +621,16 @@ func TestDialFakeApp(t *testing.T) { app: "echo1.leaf3.example.com", expectCIDR: defaultIPv4CIDRRange, }, + { + app: "multi-port.root1.example.com", + port: 1337, + expectCIDR: "192.168.2.0/24", + }, + { + app: "multi-port.leaf1.example.com", + port: 1337, + expectCIDR: defaultIPv4CIDRRange, + }, } t.Run("valid", func(t *testing.T) { @@ -584,7 +650,7 @@ func TestDialFakeApp(t *testing.T) { _, expectNet, err := net.ParseCIDR(tc.expectCIDR) require.NoError(t, err) - conn, err := p.dialHost(ctx, tc.app) + conn, err := p.dialHost(ctx, tc.app, tc.port) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, conn.Close()) }) @@ -599,6 +665,14 @@ func TestDialFakeApp(t *testing.T) { require.True(t, expectNet.Contains(remoteIPSuffix), "expected CIDR range %s does not include remote IP %s", expectNet, remoteIPSuffix) testEchoConnection(t, conn) + + // For multi-port apps, certs should have RouteToApp.TargetPort set to the specified + // cert. + // + // Single-port apps are going to be dialed on defaultPort in tests, but certs for them + // need to have RouteToApp.TargetPort set to 0. + require.True(t, appProvider.AreAllRequestedRouteToAppsForPort(tc.app, tc.port), + "not all requested certs had RouteToApp.TargetPort set to %d", tc.port) }) } }) @@ -688,7 +762,7 @@ func TestOnNewConnection(t *testing.T) { require.Equal(t, uint32(0), appProvider.onNewConnectionCallCount.Load()) // Establish a connection to a valid app and verify that OnNewConnection was called. - conn, err := p.dialHost(ctx, validAppName) + conn, err := p.dialHost(ctx, validAppName, 0) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, conn.Close()) }) require.Equal(t, uint32(1), appProvider.onNewConnectionCallCount.Load())