diff --git a/lib/srv/discovery/access_graph_azure.go b/lib/srv/discovery/access_graph_azure.go index 8b0ec38cacc0c..5916e6f79dc54 100644 --- a/lib/srv/discovery/access_graph_azure.go +++ b/lib/srv/discovery/access_graph_azure.go @@ -29,6 +29,7 @@ func (s *Server) reconcileAccessGraphAzure( err error } + // Get all the fetchers allFetchers := s.getAllTAGSyncAzureFetchers() if len(allFetchers) == 0 { // If there are no fetchers, we don't need to continue. @@ -41,16 +42,8 @@ func (s *Server) reconcileAccessGraphAzure( return trace.Wrap(errNoAccessGraphFetchers) } - // TODO (mbrock): Update discovery config status - /* - s.awsSyncStatus.iterationStarted(allFetchers, s.clock.Now()) - for _, discoveryConfigName := range s.awsSyncStatus.discoveryConfigs() { - s.updateDiscoveryConfigStatus(discoveryConfigName) - } - */ - + // Fetch results concurrently resultsC := make(chan fetcherResult, len(allFetchers)) - // Use a channel to limit the number of concurrent fetchers. tokens := make(chan struct{}, 3) accountIds := map[string]struct{}{} for _, fetcher := range allFetchers { @@ -66,11 +59,11 @@ func (s *Server) reconcileAccessGraphAzure( }() } + // Collect the results from all fetchers. results := make([]*azure_sync.Resources, 0, len(allFetchers)) errs := make([]error, 0, len(allFetchers)) - // Collect the results from all fetchers. - // Each fetcher can return an error and a result. for i := 0; i < len(allFetchers); i++ { + // Each fetcher can return an error and a result. fetcherResult := <-resultsC if fetcherResult.err != nil { errs = append(errs, fetcherResult.err) @@ -79,28 +72,23 @@ func (s *Server) reconcileAccessGraphAzure( results = append(results, fetcherResult.result) } } + // Aggregate all errors into a single error. err := trace.NewAggregate(errs...) if err != nil { s.Log.ErrorContext(ctx, "Error polling TAGs", "error", err) } result := azure_sync.MergeResources(results...) + // Merge all results into a single result upsert, toDel := azure_sync.ReconcileResults(currentTAGResources, result) pushErr := azurePush(stream, upsert, toDel) - // TODO (mbrock): Update discovery config status - /* - s.awsSyncStatus.iterationFinished(allFetchers, pushErr, s.clock.Now()) - for _, discoveryConfigName := range s.awsSyncStatus.discoveryConfigs() { - s.updateDiscoveryConfigStatus(discoveryConfigName) - } - */ - if pushErr != nil { s.Log.ErrorContext(ctx, "Error pushing TAGs", "error", pushErr) return nil } + // Update the currentTAGResources with the result of the reconciliation. *currentTAGResources = *result return nil @@ -321,38 +309,39 @@ func (s *Server) initTAGAzureWatchers(ctx context.Context, cfg *Config) error { s.Log.ErrorContext(ctx, "Error initializing access graph fetchers", "error", err) } s.staticTAGAzureFetchers = staticFetchers - if cfg.AccessGraphConfig.Enabled { - go func() { - reloadCh := s.newDiscoveryConfigChangedSub() - for { - fetchers := s.getAllTAGSyncAzureFetchers() - // Wait for the config to change and re-evaluate the fetchers before starting the sync. - if len(fetchers) == 0 { - s.Log.Debug("No Azure sync fetchers configured. Access graph sync will not be enabled.") - select { - case <-ctx.Done(): - return - case <-reloadCh: - // if the config changes, we need to get the updated list of fetchers - } - continue - } - // Reset the Azure resources to force a full sync - if err := s.initializeAndWatchAzureAccessGraph(ctx, reloadCh); errors.Is(err, errTAGFeatureNotEnabled) { - s.Log.Warn("Access Graph specified in config, but the license does not include Teleport Policy. Access graph sync will not be enabled.") - break - } else if err != nil { - s.Log.Warn("Error initializing and watching access graph", "error", err) - } - + if !cfg.AccessGraphConfig.Enabled { + return nil + } + go func() { + reloadCh := s.newDiscoveryConfigChangedSub() + for { + fetchers := s.getAllTAGSyncAzureFetchers() + // Wait for the config to change and re-evaluate the fetchers before starting the sync. + if len(fetchers) == 0 { + s.Log.Debug("No Azure sync fetchers configured. Access graph sync will not be enabled.") select { case <-ctx.Done(): return - case <-time.After(time.Minute): + case <-reloadCh: + // if the config changes, we need to get the updated list of fetchers } + continue } - }() - } + // Reset the Azure resources to force a full sync + if err := s.initializeAndWatchAzureAccessGraph(ctx, reloadCh); errors.Is(err, errTAGFeatureNotEnabled) { + s.Log.Warn("Access Graph specified in config, but the license does not include Teleport Policy. Access graph sync will not be enabled.") + break + } else if err != nil { + s.Log.Warn("Error initializing and watching access graph", "error", err) + } + + select { + case <-ctx.Done(): + return + case <-time.After(time.Minute): + } + } + }() return nil } diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 0c634d5545f39..1cf95e108d623 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -383,6 +383,7 @@ func New(ctx context.Context, cfg *Config) (*Server, error) { dynamicServerAzureFetchers: make(map[string][]server.Fetcher), dynamicServerGCPFetchers: make(map[string][]server.Fetcher), dynamicTAGAWSFetchers: make(map[string][]aws_sync.AWSSync), + dynamicTAGAzureFetchers: make(map[string][]*azure_sync.Fetcher), dynamicDiscoveryConfig: make(map[string]*discoveryconfig.DiscoveryConfig), awsSyncStatus: awsSyncStatus{}, } @@ -424,6 +425,10 @@ func New(ctx context.Context, cfg *Config) (*Server, error) { return nil, trace.Wrap(err) } + if err := s.initTAGAzureWatchers(s.ctx, cfg); err != nil { + return nil, trace.Wrap(err) + } + return s, nil } diff --git a/lib/srv/discovery/fetchers/azure-sync/azure-sync.go b/lib/srv/discovery/fetchers/azure-sync/azure-sync.go index b6bfc60fa18ac..bc62e439ff248 100644 --- a/lib/srv/discovery/fetchers/azure-sync/azure-sync.go +++ b/lib/srv/discovery/fetchers/azure-sync/azure-sync.go @@ -20,7 +20,6 @@ package azure_sync import ( "context" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" @@ -83,26 +82,21 @@ type Fetcher struct { func NewFetcher(cfg Config, ctx context.Context) (*Fetcher, error) { // Establish the credential from the managed identity - cred, err := azidentity.NewManagedIdentityCredential(nil) + cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, trace.Wrap(err) } - token, err := cred.GetToken(ctx, policy.TokenRequestOptions{}) - if err != nil { - return nil, trace.Wrap(err) - } - staticCred := azure.NewStaticCredential(token) // Create the clients - vmClient, err := azure.NewVirtualMachinesClient(cfg.SubscriptionID, staticCred, nil) + vmClient, err := azure.NewVirtualMachinesClient(cfg.SubscriptionID, cred, nil) if err != nil { return nil, trace.Wrap(err) } - roleDefClient, err := azure.NewRoleDefinitionsClient(cfg.SubscriptionID, staticCred, nil) + roleDefClient, err := azure.NewRoleDefinitionsClient(cfg.SubscriptionID, cred, nil) if err != nil { return nil, trace.Wrap(err) } - roleAssignClient, err := azure.NewRoleAssignmentsClient(cfg.SubscriptionID, staticCred, nil) + roleAssignClient, err := azure.NewRoleAssignmentsClient(cfg.SubscriptionID, cred, nil) if err != nil { return nil, trace.Wrap(err) }