From 303c25782181b9c0253815ad9885431a42e5d44c Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Thu, 25 Jul 2024 14:26:42 +0100 Subject: [PATCH] [v15] Prevent AMRs being watched if access monitoring is not enabled (#44504) * Prevent AMRs being watched if access monitoring is not enabled * Log and ignore AMRs if cache initialisation fails * Update integrations/access/accessrequest/app.go Co-authored-by: Marco Dinis * Use allowPartialSuccess when watching over changing watch kinds * Add ability to get confirmed watch kinds from watch job * Add test coverage for newJobWithConfirmedWatchKinds * Swap watcher accepted kinds logic to use callback instead of channel * Update test for watcherjob * Add check for unset init callback to watcherjob * Remove unneeded retry utility * Remove unused import * Remove retry logic from watcherjob tests * Fix linting * Add clarifying comments and better error message * Remove unused import * Allow nil err on test process cancel --------- Co-authored-by: Marco Dinis --- integrations/access/accessrequest/app.go | 33 ++++++++---- integrations/lib/watcherjob/helpers_test.go | 35 +++++++++++++ integrations/lib/watcherjob/watcherjob.go | 36 +++++++++---- .../lib/watcherjob/watcherjob_test.go | 52 +++++++++++++++++++ 4 files changed, 138 insertions(+), 18 deletions(-) diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go index 2b6dd0fcf9cb4..359db8ab3a08c 100644 --- a/integrations/access/accessrequest/app.go +++ b/integrations/access/accessrequest/app.go @@ -42,7 +42,7 @@ import ( const ( // handlerTimeout is used to bound the execution time of watcher event handler. handlerTimeout = time.Second * 5 - // defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request + // defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request. defaultAccessMonitoringRulePageSize = 10 ) @@ -123,16 +123,24 @@ func (a *App) Err() error { func (a *App) run(ctx context.Context) error { process := lib.MustGetProcess(ctx) - job, err := watcherjob.NewJob( + watchKinds := []types.WatchKind{ + {Kind: types.KindAccessRequest}, + {Kind: types.KindAccessMonitoringRule}, + } + + acceptedWatchKinds := make([]string, 0, len(watchKinds)) + job, err := watcherjob.NewJobWithConfirmedWatchKinds( a.apiClient, watcherjob.Config{ - Watch: types.Watch{Kinds: []types.WatchKind{ - {Kind: types.KindAccessRequest}, - {Kind: types.KindAccessMonitoringRule}, - }}, + Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true}, EventFuncTimeout: handlerTimeout, }, a.onWatcherEvent, + func(ws types.WatchStatus) { + for _, watchKind := range ws.GetKinds() { + acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind) + } + }, ) if err != nil { return trace.Wrap(err) @@ -144,9 +152,16 @@ func (a *App) run(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - - if err := a.initAccessMonitoringRulesCache(ctx); err != nil { - return trace.Wrap(err) + if len(acceptedWatchKinds) == 0 { + return trace.BadParameter("failed to initialize watcher for all the required resources: %+v", + watchKinds) + } + // Check if KindAccessMonitoringRule resources are being watched, + // the role the plugin is running as may not have access. + if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) { + if err := a.initAccessMonitoringRulesCache(ctx); err != nil { + return trace.Wrap(err, "initializing Access Monitoring Rule cache") + } } a.job.SetReady(ok) diff --git a/integrations/lib/watcherjob/helpers_test.go b/integrations/lib/watcherjob/helpers_test.go index dac7ab08c0aae..f097bbcf52920 100644 --- a/integrations/lib/watcherjob/helpers_test.go +++ b/integrations/lib/watcherjob/helpers_test.go @@ -133,6 +133,11 @@ func NewMockEventsProcess(ctx context.Context, t *testing.T, config Config, fn E return &process } +// WaitReady waits for the job to be ready. +func (process *MockEventsProcess) WaitReady(ctx context.Context) (bool, error) { + return process.eventsJob.WaitReady(ctx) +} + // Shutdown sends a termination signal and waits for process completion. func (process *MockEventsProcess) Shutdown(ctx context.Context) error { process.Terminate() @@ -181,3 +186,33 @@ func (countdown *Countdown) Wait(ctx context.Context) error { return trace.Wrap(ctx.Err()) } } + +// NewMockEventsProcessWithConfirmedWatchJobs creates a new mock process that passes confirmed watch kinds back. +func NewMockEventsProcessWithConfirmedWatchJobs(ctx context.Context, t *testing.T, config Config, fn EventFunc, watchInitFunc WatchInitFunc) *MockEventsProcess { + t.Helper() + process := MockEventsProcess{ + Process: lib.NewProcess(ctx), + } + t.Cleanup(func() { + process.Terminate() + if err := process.Shutdown(ctx); err != nil { + assert.ErrorContains(t, err, context.Canceled.Error(), "if a non-nil error is returned, it should be canceled context") + } + process.Close() + }) + var err error + + process.eventsJob, err = NewJobWithConfirmedWatchKinds(&process.Events, config, fn, watchInitFunc) + require.NoError(t, err) + process.SpawnCriticalJob(process.eventsJob) + require.NoError(t, process.Events.WaitSomeWatchers(ctx)) + process.Events.Fire(types.Event{ + Type: types.OpInit, + Resource: &types.WatchStatusV1{ + Spec: types.WatchStatusSpecV1{ + Kinds: config.Watch.Kinds, + }, + }}) + + return &process +} diff --git a/integrations/lib/watcherjob/watcherjob.go b/integrations/lib/watcherjob/watcherjob.go index ce9a7070c5cda..687086f8d1ec0 100644 --- a/integrations/lib/watcherjob/watcherjob.go +++ b/integrations/lib/watcherjob/watcherjob.go @@ -43,6 +43,7 @@ const ( ) type EventFunc func(context.Context, types.Event) error +type WatchInitFunc func(types.WatchStatus) type Config struct { Watch types.Watch @@ -53,10 +54,11 @@ type Config struct { type job struct { lib.ServiceJob - config Config - eventFunc EventFunc - events types.Events - eventCh chan *types.Event + config Config + eventFunc EventFunc + events types.Events + eventCh chan *types.Event + onWatchInitFunc WatchInitFunc } type eventKey struct { @@ -68,7 +70,17 @@ func NewJob(client teleport.Client, config Config, fn EventFunc) (lib.ServiceJob return NewJobWithEvents(client, config, fn) } +// NewJobWithConfirmedWatchKinds returns a new watcherJob and passes confirmed watch kinds +// from the initialisation down confirmedWatchKindsCh. +func NewJobWithConfirmedWatchKinds(events types.Events, config Config, fn EventFunc, watchInitFunc WatchInitFunc) (lib.ServiceJob, error) { + return newJobWithEvents(events, config, fn, watchInitFunc) +} + func NewJobWithEvents(events types.Events, config Config, fn EventFunc) (lib.ServiceJob, error) { + return newJobWithEvents(events, config, fn, nil) +} + +func newJobWithEvents(events types.Events, config Config, fn EventFunc, watchInitFunc WatchInitFunc) (job, error) { if config.MaxConcurrency == 0 { config.MaxConcurrency = DefaultMaxConcurrency } @@ -78,15 +90,16 @@ func NewJobWithEvents(events types.Events, config Config, fn EventFunc) (lib.Ser if flagVar := os.Getenv(failFastEnvVarName); !config.FailFast && flagVar != "" { flag, err := strconv.ParseBool(flagVar) if err != nil { - return nil, trace.WrapWithMessage(err, "failed to parse content '%s' of the %s environment variable", flagVar, failFastEnvVarName) + return job{}, trace.WrapWithMessage(err, "failed to parse content '%s' of the %s environment variable", flagVar, failFastEnvVarName) } config.FailFast = flag } job := job{ - events: events, - config: config, - eventFunc: fn, - eventCh: make(chan *types.Event, config.MaxConcurrency), + events: events, + config: config, + eventFunc: fn, + eventCh: make(chan *types.Event, config.MaxConcurrency), + onWatchInitFunc: watchInitFunc, } job.ServiceJob = lib.NewServiceJob(func(ctx context.Context) error { process := lib.MustGetProcess(ctx) @@ -184,6 +197,11 @@ func (job job) waitInit(ctx context.Context, watcher types.Watcher, timeout time if event.Type != types.OpInit { return trace.ConnectionProblem(nil, "unexpected event type %q", event.Type) } + if watchStatus, ok := event.Resource.(types.WatchStatus); ok { + if job.onWatchInitFunc != nil { + job.onWatchInitFunc(watchStatus) + } + } return nil case <-time.After(timeout): return trace.ConnectionProblem(nil, "watcher initialization timed out") diff --git a/integrations/lib/watcherjob/watcherjob_test.go b/integrations/lib/watcherjob/watcherjob_test.go index c021093d941ef..9596487c051c5 100644 --- a/integrations/lib/watcherjob/watcherjob_test.go +++ b/integrations/lib/watcherjob/watcherjob_test.go @@ -21,6 +21,7 @@ package watcherjob import ( "context" "fmt" + "slices" "testing" "time" @@ -111,3 +112,54 @@ func TestConcurrencyLimit(t *testing.T) { timeAfter := time.Now() assert.InDelta(t, 4*time.Second, timeAfter.Sub(timeBefore), float64(750*time.Millisecond)) } + +// TestNewJobWithConfirmedWatchKinds checks that the watch kinds are passed back after init. +func TestNewJobWithConfirmedWatchKinds(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + watchKinds := []types.WatchKind{ + {Kind: types.KindAccessRequest}, + } + config := Config{ + MaxConcurrency: 4, + Watch: types.Watch{ + Kinds: watchKinds, + }, + } + countdown := NewCountdown(config.MaxConcurrency) + var acceptedWatchKinds []string + onWatchInit := func(ws types.WatchStatus) { + for _, watchKind := range ws.GetKinds() { + acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind) + } + } + + process := NewMockEventsProcessWithConfirmedWatchJobs(ctx, t, config, + func(ctx context.Context, event types.Event) error { + defer countdown.Decrement() + time.Sleep(time.Second) + return trace.Wrap(ctx.Err()) + }, onWatchInit) + + _, err := process.WaitReady(ctx) + require.NoError(t, err) + + if !slices.ContainsFunc(acceptedWatchKinds, func(kind string) bool { + return kind == types.KindAccessRequest + }) { + t.Error("access request watch kind not returned after init: %V", acceptedWatchKinds) + } + + timeBefore := time.Now() + for i := 0; i < config.MaxConcurrency; i++ { + resource, err := types.NewAccessRequest("REQ-SAME", "foo", "admin") + require.NoError(t, err) + process.Events.Fire(types.Event{Type: types.OpPut, Resource: resource}) + } + require.NoError(t, countdown.Wait(ctx)) + + timeAfter := time.Now() + assert.InDelta(t, 4*time.Second, timeAfter.Sub(timeBefore), float64(1000*time.Millisecond)) +}