diff --git a/api/client/dynamicwindows/dynamicwindows.go b/api/client/dynamicwindows/dynamicwindows.go index 6c158a39e4243..32ba1762f1aed 100644 --- a/api/client/dynamicwindows/dynamicwindows.go +++ b/api/client/dynamicwindows/dynamicwindows.go @@ -46,7 +46,7 @@ func (c *Client) GetDynamicWindowsDesktop(ctx context.Context, name string) (typ return desktop, trace.Wrap(err) } -func (c *Client) ListDynamicWindowsDesktop(ctx context.Context, pageSize int, pageToken string) ([]types.DynamicWindowsDesktop, string, error) { +func (c *Client) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, pageToken string) ([]types.DynamicWindowsDesktop, string, error) { resp, err := c.grpcClient.ListDynamicWindowsDesktops(ctx, &dynamicwindows.ListDynamicWindowsDesktopsRequest{ PageSize: int32(pageSize), PageToken: pageToken, diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 005a44cdda8a1..4d3bdabd846b7 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/crownjewel" "github.com/gravitational/teleport/api/client/databaseobject" + "github.com/gravitational/teleport/api/client/dynamicwindows" "github.com/gravitational/teleport/api/client/externalauditstorage" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/secreport" @@ -1600,6 +1601,8 @@ type ClientI interface { types.WebSessionsGetter types.WebTokensGetter + DynamicDesktopClient() *dynamicwindows.Client + // TrustClient returns a client to the Trust service. TrustClient() trustpb.TrustServiceClient diff --git a/lib/srv/desktop/discovery.go b/lib/srv/desktop/discovery.go index cc7dd71daa6d2..c44bcc0a9cabb 100644 --- a/lib/srv/desktop/discovery.go +++ b/lib/srv/desktop/discovery.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "net" "net/netip" "strings" @@ -32,6 +33,7 @@ import ( "github.com/go-ldap/ldap/v3" "github.com/gravitational/trace" + "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/windows" @@ -306,3 +308,91 @@ func (s *WindowsService) ldapEntryToWindowsDesktop(ctx context.Context, entry *l desktop.SetExpiry(s.cfg.Clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL * 3)) return desktop, nil } + +// startDynamicReconciler starts resource watcher and reconciler that registers/unregisters Windows desktops +// according to the up-to-date list of dynamic Windows desktops resources. +func (s *WindowsService) startDynamicReconciler(ctx context.Context) (*services.DynamicWindowsDesktopWatcher, error) { + if len(s.cfg.ResourceMatchers) == 0 { + s.cfg.Logger.DebugContext(ctx, "Not starting dynamic desktop resource watcher.") + return nil, nil + } + s.cfg.Logger.DebugContext(ctx, "Starting dynamic desktop resource watcher.") + dynamicDesktopClient := s.cfg.AuthClient.DynamicDesktopClient() + watcher, err := services.NewDynamicWindowsDesktopWatcher(ctx, services.DynamicWindowsDesktopWatcherConfig{ + DynamicWindowsDesktopGetter: dynamicDesktopClient, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentWindowsDesktop, + Client: s.cfg.AccessPoint, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + currentResources := make(map[string]types.WindowsDesktop) + var newResources map[string]types.WindowsDesktop + + reconciler, err := services.NewReconciler(services.ReconcilerConfig[types.WindowsDesktop]{ + Matcher: func(desktop types.WindowsDesktop) bool { + return services.MatchResourceLabels(s.cfg.ResourceMatchers, desktop.GetAllLabels()) + }, + GetCurrentResources: func() map[string]types.WindowsDesktop { + return currentResources + }, + GetNewResources: func() map[string]types.WindowsDesktop { + return newResources + }, + OnCreate: s.upsertDesktop, + OnUpdate: s.updateDesktop, + OnDelete: s.deleteDesktop, + }) + if err != nil { + return nil, trace.Wrap(err) + } + go func() { + defer s.cfg.Logger.DebugContext(ctx, "DynamicWindowsDesktop resource watcher done.") + defer watcher.Close() + for { + select { + case desktops := <-watcher.DynamicWindowsDesktopsC: + newResources = make(map[string]types.WindowsDesktop) + for _, dynamicDesktop := range desktops { + desktop, err := s.toWindowsDesktop(dynamicDesktop) + if err != nil { + s.cfg.Logger.WarnContext(ctx, "Can't create desktop resource", "error", err) + continue + } + newResources[dynamicDesktop.GetName()] = desktop + } + if err := reconciler.Reconcile(ctx); err != nil { + s.cfg.Logger.WarnContext(ctx, "Reconciliation failed, will retry", "error", err) + continue + } + currentResources = newResources + case <-watcher.Done(): + return + case <-ctx.Done(): + return + } + } + }() + return watcher, nil +} + +func (s *WindowsService) toWindowsDesktop(dynamicDesktop types.DynamicWindowsDesktop) (*types.WindowsDesktopV3, error) { + width, height := dynamicDesktop.GetScreenSize() + desktopLabels := dynamicDesktop.GetAllLabels() + labels := make(map[string]string, len(desktopLabels)+1) + maps.Copy(labels, desktopLabels) + labels[types.OriginLabel] = types.OriginDynamic + return types.NewWindowsDesktopV3(dynamicDesktop.GetName(), labels, types.WindowsDesktopSpecV3{ + Addr: dynamicDesktop.GetAddr(), + Domain: dynamicDesktop.GetDomain(), + HostID: s.cfg.Heartbeat.HostUUID, + NonAD: dynamicDesktop.NonAD(), + ScreenSize: &types.Resolution{ + Width: width, + Height: height, + }, + }) +} diff --git a/lib/srv/desktop/discovery_test.go b/lib/srv/desktop/discovery_test.go index 2a35918d264b4..fc188f75ce1d6 100644 --- a/lib/srv/desktop/discovery_test.go +++ b/lib/srv/desktop/discovery_test.go @@ -33,7 +33,9 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/windows" + "github.com/gravitational/teleport/lib/services" logutils "github.com/gravitational/teleport/lib/utils/log" ) @@ -169,3 +171,134 @@ func TestDNSErrors(t *testing.T) { require.Less(t, time.Since(start), dnsQueryTimeout-1*time.Second) require.Error(t, err) } + +func TestDynamicWindowsDiscovery(t *testing.T) { + t.Parallel() + authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + ClusterName: "test", + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, authServer.Close()) + }) + + tlsServer, err := authServer.NewTestTLSServer() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, tlsServer.Close()) + }) + + client, err := tlsServer.NewClient(auth.TestServerID(types.RoleWindowsDesktop, "test-host-id")) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, client.Close()) + }) + + dynamicWindowsClient := client.DynamicDesktopClient() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, testCase := range []struct { + name string + labels map[string]string + expected int + }{ + { + name: "no labels", + expected: 0, + }, + { + name: "no matching labels", + labels: map[string]string{"xyz": "abc"}, + expected: 0, + }, + { + name: "matching labels", + labels: map[string]string{"foo": "bar"}, + expected: 1, + }, + { + name: "matching wildcard labels", + labels: map[string]string{"abc": "abc"}, + expected: 1, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + s := &WindowsService{ + cfg: WindowsServiceConfig{ + Heartbeat: HeartbeatConfig{ + HostUUID: "1234", + }, + Logger: slog.New(logutils.NewSlogTextHandler(io.Discard, logutils.SlogTextHandlerConfig{})), + Clock: clockwork.NewFakeClock(), + AuthClient: client, + AccessPoint: client, + ResourceMatchers: []services.ResourceMatcher{{ + Labels: types.Labels{ + "foo": {"bar"}, + }, + }, { + Labels: types.Labels{ + "abc": {"*"}, + }, + }}, + }, + dnsResolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("this resolver always fails") + }, + }, + } + reconciler, err := s.startDynamicReconciler(ctx) + require.NoError(t, err) + t.Cleanup(func() { + reconciler.Close() + require.NoError(t, authServer.AuthServer.DeleteAllWindowsDesktops(ctx)) + require.NoError(t, authServer.AuthServer.DeleteAllDynamicWindowsDesktops(ctx)) + }) + + desktop, err := types.NewDynamicWindowsDesktopV1("test", testCase.labels, types.DynamicWindowsDesktopSpecV1{ + Addr: "addr", + }) + require.NoError(t, err) + + _, err = dynamicWindowsClient.CreateDynamicWindowsDesktop(ctx, desktop) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + desktops, err := client.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) + require.NoError(t, err) + require.Len(t, desktops, testCase.expected) + if testCase.expected > 0 { + require.Equal(t, desktop.GetName(), desktops[0].GetName()) + require.Equal(t, desktop.GetAddr(), desktops[0].GetAddr()) + } + + desktop.Spec.Addr = "addr2" + _, err = dynamicWindowsClient.UpdateDynamicWindowsDesktop(ctx, desktop) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + desktops, err = client.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) + require.NoError(t, err) + require.Len(t, desktops, testCase.expected) + if testCase.expected > 0 { + require.Equal(t, desktop.GetName(), desktops[0].GetName()) + require.Equal(t, desktop.GetAddr(), desktops[0].GetAddr()) + } + + require.NoError(t, dynamicWindowsClient.DeleteDynamicWindowsDesktop(ctx, "test")) + + time.Sleep(10 * time.Millisecond) + + desktops, err = client.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) + require.NoError(t, err) + require.Empty(t, desktops) + }) + + } +} diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 28aaee2a48483..791f477861666 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -411,6 +411,10 @@ func NewWindowsService(cfg WindowsServiceConfig) (*WindowsService, error) { return nil, trace.Wrap(err) } + if _, err := s.startDynamicReconciler(ctx); err != nil { + return nil, trace.Wrap(err) + } + if len(s.cfg.DiscoveryBaseDN) > 0 { if err := s.startDesktopDiscovery(); err != nil { return nil, trace.Wrap(err)