From 995a28baea5b9bfa001f2ba26ebc321692eefdbc Mon Sep 17 00:00:00 2001 From: Matt Brock Date: Tue, 17 Dec 2024 17:50:46 -0600 Subject: [PATCH] Invoking the Azure fetcher in the Discovery service --- .../{access_graph.go => access_graph_aws.go} | 12 +- lib/srv/discovery/access_graph_azure.go | 410 ++++++++++++++++++ lib/srv/discovery/discovery.go | 46 +- 3 files changed, 451 insertions(+), 17 deletions(-) rename lib/srv/discovery/{access_graph.go => access_graph_aws.go} (98%) create mode 100644 lib/srv/discovery/access_graph_azure.go diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph_aws.go similarity index 98% rename from lib/srv/discovery/access_graph.go rename to lib/srv/discovery/access_graph_aws.go index 4bc207b21df01..b05e85d72bc32 100644 --- a/lib/srv/discovery/access_graph.go +++ b/lib/srv/discovery/access_graph_aws.go @@ -145,15 +145,15 @@ func (s *Server) reconcileAccessGraph(ctx context.Context, currentTAGResources * // getAllAWSSyncFetchers returns all AWS sync fetchers. func (s *Server) getAllAWSSyncFetchers() []aws_sync.AWSSync { - allFetchers := make([]aws_sync.AWSSync, 0, len(s.dynamicTAGSyncFetchers)) + allFetchers := make([]aws_sync.AWSSync, 0, len(s.dynamicTAGAWSFetchers)) - s.muDynamicTAGSyncFetchers.RLock() - for _, fetcherSet := range s.dynamicTAGSyncFetchers { + s.muDynamicTAGAWSFetchers.RLock() + for _, fetcherSet := range s.dynamicTAGAWSFetchers { allFetchers = append(allFetchers, fetcherSet...) } - s.muDynamicTAGSyncFetchers.RUnlock() + s.muDynamicTAGAWSFetchers.RUnlock() - allFetchers = append(allFetchers, s.staticTAGSyncFetchers...) + allFetchers = append(allFetchers, s.staticTAGAWSFetchers...) // TODO(tigrato): submit fetchers event return allFetchers } @@ -443,7 +443,7 @@ func (s *Server) initAccessGraphWatchers(ctx context.Context, cfg *Config) error if err != nil { s.Log.ErrorContext(ctx, "Error initializing access graph fetchers", "error", err) } - s.staticTAGSyncFetchers = fetchers + s.staticTAGAWSFetchers = fetchers if cfg.AccessGraphConfig.Enabled { go func() { diff --git a/lib/srv/discovery/access_graph_azure.go b/lib/srv/discovery/access_graph_azure.go new file mode 100644 index 0000000000000..a2898538cc699 --- /dev/null +++ b/lib/srv/discovery/access_graph_azure.go @@ -0,0 +1,410 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package discovery + +import ( + "context" + "errors" + "io" + "sync" + "time" + + "github.com/gravitational/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/entitlements" + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/services" + azure_sync "github.com/gravitational/teleport/lib/srv/discovery/fetchers/azure-sync" +) + +// reconcileAccessGraphAzure fetches Azure resources, creates a set of resources to delete and upsert based on +// the previous fetch, and then sends the delete and upsert results to the Access Graph stream +func (s *Server) reconcileAccessGraphAzure( + ctx context.Context, + currentTAGResources *azure_sync.Resources, + stream accessgraphv1alpha.AccessGraphService_AzureEventsStreamClient, + features azure_sync.Features, +) error { + type fetcherResult struct { + result *azure_sync.Resources + err error + } + + // Get all the fetchers + allFetchers := s.getAllTAGSyncAzureFetchers() + if len(allFetchers) == 0 { + // If there are no fetchers, we don't need to continue. + // We will send a delete request for all resources and return. + upsert, toDel := azure_sync.ReconcileResults(currentTAGResources, &azure_sync.Resources{}) + + if err := azurePush(stream, upsert, toDel); err != nil { + s.Log.ErrorContext(ctx, "Error pushing empty resources to TAGs", "error", err) + } + return trace.Wrap(errNoAccessGraphFetchers) + } + + // Fetch results concurrently + resultsC := make(chan fetcherResult, len(allFetchers)) + tokens := make(chan struct{}, 3) + accountIds := map[string]struct{}{} + for _, fetcher := range allFetchers { + fetcher := fetcher + accountIds[fetcher.GetSubscriptionID()] = struct{}{} + tokens <- struct{}{} + go func() { + defer func() { + <-tokens + }() + result, err := fetcher.Poll(ctx, features) + resultsC <- fetcherResult{result, trace.Wrap(err)} + }() + } + + // Collect the results from all fetchers. + results := make([]*azure_sync.Resources, 0, len(allFetchers)) + errs := make([]error, 0, len(allFetchers)) + for i := 0; i < len(allFetchers); i++ { + // Each fetcher can return an error and a result. + fetcherResult := <-resultsC + if fetcherResult.err != nil { + errs = append(errs, fetcherResult.err) + } + if fetcherResult.result != nil { + results = append(results, fetcherResult.result) + } + } + + // Aggregate all errors into a single error. + err := trace.NewAggregate(errs...) + if err != nil { + s.Log.ErrorContext(ctx, "Error polling TAGs", "error", err) + } + result := azure_sync.MergeResources(results...) + + // Merge all results into a single result + upsert, toDel := azure_sync.ReconcileResults(currentTAGResources, result) + pushErr := azurePush(stream, upsert, toDel) + + if pushErr != nil { + s.Log.ErrorContext(ctx, "Error pushing TAGs", "error", pushErr) + return nil + } + + // Update the currentTAGResources with the result of the reconciliation. + *currentTAGResources = *result + return nil +} + +// azurePushUpsertInBatches upserts resources to the Access Graph in batches +func azurePushUpsertInBatches( + client accessgraphv1alpha.AccessGraphService_AzureEventsStreamClient, + upsert *accessgraphv1alpha.AzureResourceList, +) error { + for i := 0; i < len(upsert.Resources); i += batchSize { + end := i + batchSize + if end > len(upsert.Resources) { + end = len(upsert.Resources) + } + err := client.Send( + &accessgraphv1alpha.AzureEventsStreamRequest{ + Operation: &accessgraphv1alpha.AzureEventsStreamRequest_Upsert{ + Upsert: &accessgraphv1alpha.AzureResourceList{ + Resources: upsert.Resources[i:end], + }, + }, + }, + ) + if err != nil { + return trace.Wrap(err) + } + } + return nil +} + +// azurePushDeleteInBatches deletes resources from the Access Graph in batches +func azurePushDeleteInBatches( + client accessgraphv1alpha.AccessGraphService_AzureEventsStreamClient, + toDel *accessgraphv1alpha.AzureResourceList, +) error { + for i := 0; i < len(toDel.Resources); i += batchSize { + end := i + batchSize + if end > len(toDel.Resources) { + end = len(toDel.Resources) + } + err := client.Send( + &accessgraphv1alpha.AzureEventsStreamRequest{ + Operation: &accessgraphv1alpha.AzureEventsStreamRequest_Delete{ + Delete: &accessgraphv1alpha.AzureResourceList{ + Resources: toDel.Resources[i:end], + }, + }, + }, + ) + if err != nil { + return trace.Wrap(err) + } + } + return nil +} + +// azurePush upserts and deletes Azure resources to/from the Access Graph +func azurePush( + client accessgraphv1alpha.AccessGraphService_AzureEventsStreamClient, + upsert *accessgraphv1alpha.AzureResourceList, + toDel *accessgraphv1alpha.AzureResourceList, +) error { + err := azurePushUpsertInBatches(client, upsert) + if err != nil { + return trace.Wrap(err) + } + err = azurePushDeleteInBatches(client, toDel) + if err != nil { + return trace.Wrap(err) + } + err = client.Send( + &accessgraphv1alpha.AzureEventsStreamRequest{ + Operation: &accessgraphv1alpha.AzureEventsStreamRequest_Sync{}, + }, + ) + return trace.Wrap(err) +} + +// getAllTAGSyncAzureFetchers returns both static and dynamic TAG Azure fetchers +func (s *Server) getAllTAGSyncAzureFetchers() []*azure_sync.Fetcher { + allFetchers := make([]*azure_sync.Fetcher, 0, len(s.dynamicTAGAzureFetchers)) + + s.muDynamicTAGAzureFetchers.RLock() + for _, fetcherSet := range s.dynamicTAGAzureFetchers { + allFetchers = append(allFetchers, fetcherSet...) + } + s.muDynamicTAGAzureFetchers.RUnlock() + + allFetchers = append(allFetchers, s.staticTAGAzureFetchers...) + return allFetchers +} + +// initializeAndWatchAzureAccessGraph initializes and watches the TAG Azure stream +func (s *Server) initializeAndWatchAzureAccessGraph(ctx context.Context, reloadCh chan struct{}) error { + // Check if the access graph is enabled + clusterFeatures := s.Config.ClusterFeatures() + policy := modules.GetProtoEntitlement(&clusterFeatures, entitlements.Policy) + if !clusterFeatures.AccessGraph && !policy.Enabled { + return trace.Wrap(errTAGFeatureNotEnabled) + } + + // Configure the access graph semaphore for constraining multiple discovery servers + const ( + semaphoreExpiration = time.Minute + semaphoreName = "access_graph_azure_sync" + serviceConfig = `{ + "loadBalancingPolicy": "round_robin", + "healthCheckConfig": { + "serviceName": "" + } + }` + ) + lease, err := services.AcquireSemaphoreLockWithRetry( + ctx, + services.SemaphoreLockConfigWithRetry{ + SemaphoreLockConfig: services.SemaphoreLockConfig{ + Service: s.AccessPoint, + Params: types.AcquireSemaphoreRequest{ + SemaphoreKind: types.KindAccessGraph, + SemaphoreName: semaphoreName, + MaxLeases: 1, + Holder: s.Config.ServerID, + }, + Expiry: semaphoreExpiration, + Clock: s.clock, + }, + Retry: retryutils.LinearConfig{ + Clock: s.clock, + First: time.Second, + Step: semaphoreExpiration / 2, + Max: semaphoreExpiration, + Jitter: retryutils.DefaultJitter, + }, + }, + ) + if err != nil { + return trace.Wrap(err) + } + ctx, cancel := context.WithCancel(lease) + defer cancel() + defer func() { + lease.Stop() + if err := lease.Wait(); err != nil { + s.Log.WarnContext(ctx, "error cleaning up semaphore", "error", err) + } + }() + + // Create the access graph client + accessGraphConn, err := newAccessGraphClient( + ctx, + s.GetClientCert, + s.Config.AccessGraphConfig, + grpc.WithDefaultServiceConfig(serviceConfig), + ) + if err != nil { + return trace.Wrap(err) + } + // Close the connection when the function returns. + defer accessGraphConn.Close() + client := accessgraphv1alpha.NewAccessGraphServiceClient(accessGraphConn) + + // Create the event stream + stream, err := client.AzureEventsStream(ctx) + if err != nil { + s.Log.ErrorContext(ctx, "Failed to get TAG Azure service stream", "error", err) + return trace.Wrap(err) + } + header, err := stream.Header() + if err != nil { + s.Log.ErrorContext(ctx, "Failed to get TAG Azure service stream header", "error", err) + return trace.Wrap(err) + } + const ( + supportedResourcesKey = "supported-kinds" + ) + supportedKinds := header.Get(supportedResourcesKey) + if len(supportedKinds) == 0 { + return trace.BadParameter("TAG Azure service did not return supported kinds") + } + features := azure_sync.BuildFeatures(supportedKinds...) + + // Cancels the context to stop the event watcher if the access graph connection fails + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + if !accessGraphConn.WaitForStateChange(ctx, connectivity.Ready) { + s.Log.InfoContext(ctx, "access graph service connection was closed") + } + }() + + // Configure the poll interval + tickerInterval := defaultPollInterval + if s.Config.Matchers.AccessGraph != nil { + if s.Config.Matchers.AccessGraph.PollInterval > defaultPollInterval { + tickerInterval = s.Config.Matchers.AccessGraph.PollInterval + } else { + s.Log.WarnContext(ctx, + "Access graph Azure service poll interval cannot be less than the default", + "default_poll_interval", + defaultPollInterval) + } + } + s.Log.InfoContext(ctx, "Access graph Azure service poll interval", "poll_interval", tickerInterval) + + // Reconciles the resources as they're imported from Azure + azureResources := &azure_sync.Resources{} + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + for { + err := s.reconcileAccessGraphAzure(ctx, azureResources, stream, features) + if errors.Is(err, errNoAccessGraphFetchers) { + err := stream.CloseSend() + if errors.Is(err, io.EOF) { + err = nil + } + return trace.Wrap(err) + } + select { + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + case <-ticker.C: + case <-reloadCh: + } + } +} + +// initTAGAzureWatchers initializes the TAG Azure watchers +func (s *Server) initTAGAzureWatchers(ctx context.Context, cfg *Config) error { + staticFetchers, err := s.accessGraphAzureFetchersFromMatchers(cfg.Matchers, "" /* discoveryConfigName */) + if err != nil { + s.Log.ErrorContext(ctx, "Error initializing access graph fetchers", "error", err) + } + s.staticTAGAzureFetchers = staticFetchers + if !cfg.AccessGraphConfig.Enabled { + return nil + } + go func() { + reloadCh := s.newDiscoveryConfigChangedSub() + for { + fetchers := s.getAllTAGSyncAzureFetchers() + // Wait for the config to change and re-evaluate the fetchers before starting the sync. + if len(fetchers) == 0 { + s.Log.DebugContext(ctx, "No Azure sync fetchers configured. Access graph sync will not be enabled.") + select { + case <-ctx.Done(): + return + case <-reloadCh: + // if the config changes, we need to get the updated list of fetchers + } + continue + } + // Reset the Azure resources to force a full sync + if err := s.initializeAndWatchAzureAccessGraph(ctx, reloadCh); errors.Is(err, errTAGFeatureNotEnabled) { + s.Log.WarnContext(ctx, "Access Graph specified in config, but the license does not include Teleport Policy. Access graph sync will not be enabled.") + break + } else if err != nil { + s.Log.WarnContext(ctx, "Error initializing and watching access graph", "error", err) + } + + select { + case <-ctx.Done(): + return + case <-time.After(time.Minute): + } + } + }() + return nil +} + +// accessGraphAzureFetchersFromMatchers converts matcher configuration to fetchers for Azure resource synchronization +func (s *Server) accessGraphAzureFetchersFromMatchers( + matchers Matchers, discoveryConfigName string) ([]*azure_sync.Fetcher, error) { + var fetchers []*azure_sync.Fetcher + var errs []error + if matchers.AccessGraph == nil { + return fetchers, nil + } + for _, matcher := range matchers.AccessGraph.Azure { + fetcherCfg := azure_sync.Config{ + CloudClients: s.CloudClients, + SubscriptionID: matcher.SubscriptionID, + Integration: matcher.Integration, + DiscoveryConfigName: discoveryConfigName, + } + fetcher, err := azure_sync.NewFetcher(fetcherCfg, s.ctx) + if err != nil { + errs = append(errs, err) + continue + } + fetchers = append(fetchers, fetcher) + } + return fetchers, trace.NewAggregate(errs...) +} diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index ffb4a76353f59..dd2595950f14e 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -61,6 +61,7 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" "github.com/gravitational/teleport/lib/srv/discovery/fetchers" aws_sync "github.com/gravitational/teleport/lib/srv/discovery/fetchers/aws-sync" + azure_sync "github.com/gravitational/teleport/lib/srv/discovery/fetchers/azure-sync" "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/srv/server" logutils "github.com/gravitational/teleport/lib/utils/log" @@ -372,11 +373,17 @@ type Server struct { muDynamicServerGCPFetchers sync.RWMutex staticServerGCPFetchers []server.Fetcher - // dynamicTAGSyncFetchers holds the current TAG Fetchers for the Dynamic Matchers (those coming from DiscoveryConfig resource). + // dynamicTAGAWSFetchers holds the current TAG Fetchers for the Dynamic Matchers (those coming from DiscoveryConfig resource). // The key is the DiscoveryConfig name. - dynamicTAGSyncFetchers map[string][]aws_sync.AWSSync - muDynamicTAGSyncFetchers sync.RWMutex - staticTAGSyncFetchers []aws_sync.AWSSync + dynamicTAGAWSFetchers map[string][]aws_sync.AWSSync + muDynamicTAGAWSFetchers sync.RWMutex + staticTAGAWSFetchers []aws_sync.AWSSync + + // dynamicTAGAzureFetchers holds the current TAG Fetchers for the Dynamic Matchers (those coming from DiscoveryConfig resource). + // The key is the DiscoveryConfig name. + dynamicTAGAzureFetchers map[string][]*azure_sync.Fetcher + muDynamicTAGAzureFetchers sync.RWMutex + staticTAGAzureFetchers []*azure_sync.Fetcher // dynamicKubeFetchers holds the current kube fetchers that use integration as a source of credentials, // for the Dynamic Matchers (those coming from DiscoveryConfig resource). @@ -422,7 +429,8 @@ func New(ctx context.Context, cfg *Config) (*Server, error) { dynamicServerAWSFetchers: make(map[string][]server.Fetcher), dynamicServerAzureFetchers: make(map[string][]server.Fetcher), dynamicServerGCPFetchers: make(map[string][]server.Fetcher), - dynamicTAGSyncFetchers: make(map[string][]aws_sync.AWSSync), + dynamicTAGAWSFetchers: make(map[string][]aws_sync.AWSSync), + dynamicTAGAzureFetchers: make(map[string][]*azure_sync.Fetcher), dynamicDiscoveryConfig: make(map[string]*discoveryconfig.DiscoveryConfig), awsSyncStatus: awsSyncStatus{}, awsEC2ResourcesStatus: newAWSResourceStatusCollector(types.AWSMatcherEC2), @@ -467,6 +475,10 @@ func New(ctx context.Context, cfg *Config) (*Server, error) { return nil, trace.Wrap(err) } + if err := s.initTAGAzureWatchers(s.ctx, cfg); err != nil { + return nil, trace.Wrap(err) + } + return s, nil } @@ -1693,9 +1705,13 @@ func (s *Server) deleteDynamicFetchers(name string) { delete(s.dynamicServerGCPFetchers, name) s.muDynamicServerGCPFetchers.Unlock() - s.muDynamicTAGSyncFetchers.Lock() - delete(s.dynamicTAGSyncFetchers, name) - s.muDynamicTAGSyncFetchers.Unlock() + s.muDynamicTAGAWSFetchers.Lock() + delete(s.dynamicTAGAWSFetchers, name) + s.muDynamicTAGAWSFetchers.Unlock() + + s.muDynamicTAGAzureFetchers.Lock() + delete(s.dynamicTAGAzureFetchers, name) + s.muDynamicTAGAzureFetchers.Unlock() s.muDynamicKubeFetchers.Lock() delete(s.dynamicKubeFetchers, name) @@ -1749,9 +1765,17 @@ func (s *Server) upsertDynamicMatchers(ctx context.Context, dc *discoveryconfig. if err != nil { return trace.Wrap(err) } - s.muDynamicTAGSyncFetchers.Lock() - s.dynamicTAGSyncFetchers[dc.GetName()] = awsSyncMatchers - s.muDynamicTAGSyncFetchers.Unlock() + s.muDynamicTAGAWSFetchers.Lock() + s.dynamicTAGAWSFetchers[dc.GetName()] = awsSyncMatchers + s.muDynamicTAGAWSFetchers.Unlock() + + azureSyncMatchers, err := s.accessGraphAzureFetchersFromMatchers(matchers, dc.GetName()) + if err != nil { + return trace.Wrap(err) + } + s.muDynamicTAGAzureFetchers.Lock() + s.dynamicTAGAzureFetchers[dc.GetName()] = azureSyncMatchers + s.muDynamicTAGAzureFetchers.Unlock() kubeFetchers, err := s.kubeFetchersFromMatchers(matchers, dc.GetName()) if err != nil {