Skip to content

Commit

Permalink
[v15] Prevent AMRs being watched if access monitoring is not enabled (#…
Browse files Browse the repository at this point in the history
…44504)

* Prevent AMRs being watched if access monitoring is not enabled

* Log and ignore AMRs if cache initialisation fails

* Update integrations/access/accessrequest/app.go

Co-authored-by: Marco Dinis <[email protected]>

* Use allowPartialSuccess when watching over changing watch kinds

* Add ability to get confirmed watch kinds from watch job

* Add test coverage for newJobWithConfirmedWatchKinds

* Swap watcher accepted kinds logic to use callback instead of channel

* Update test for watcherjob

* Add check for unset init callback to watcherjob

* Remove unneeded retry utility

* Remove unused import

* Remove retry logic from watcherjob tests

* Fix linting

* Add clarifying comments and better error message

* Remove unused import

* Allow nil err on test process cancel

---------

Co-authored-by: Marco Dinis <[email protected]>
  • Loading branch information
EdwardDowling and marcoandredinis authored Jul 25, 2024
1 parent 88d22c6 commit 303c257
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 18 deletions.
33 changes: 24 additions & 9 deletions integrations/access/accessrequest/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import (
const (
// handlerTimeout is used to bound the execution time of watcher event handler.
handlerTimeout = time.Second * 5
// defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request
// defaultAccessMonitoringRulePageSize is the default number of rules to retrieve per request.
defaultAccessMonitoringRulePageSize = 10
)

Expand Down Expand Up @@ -123,16 +123,24 @@ func (a *App) Err() error {
func (a *App) run(ctx context.Context) error {
process := lib.MustGetProcess(ctx)

job, err := watcherjob.NewJob(
watchKinds := []types.WatchKind{
{Kind: types.KindAccessRequest},
{Kind: types.KindAccessMonitoringRule},
}

acceptedWatchKinds := make([]string, 0, len(watchKinds))
job, err := watcherjob.NewJobWithConfirmedWatchKinds(
a.apiClient,
watcherjob.Config{
Watch: types.Watch{Kinds: []types.WatchKind{
{Kind: types.KindAccessRequest},
{Kind: types.KindAccessMonitoringRule},
}},
Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true},
EventFuncTimeout: handlerTimeout,
},
a.onWatcherEvent,
func(ws types.WatchStatus) {
for _, watchKind := range ws.GetKinds() {
acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind)
}
},
)
if err != nil {
return trace.Wrap(err)
Expand All @@ -144,9 +152,16 @@ func (a *App) run(ctx context.Context) error {
if err != nil {
return trace.Wrap(err)
}

if err := a.initAccessMonitoringRulesCache(ctx); err != nil {
return trace.Wrap(err)
if len(acceptedWatchKinds) == 0 {
return trace.BadParameter("failed to initialize watcher for all the required resources: %+v",
watchKinds)
}
// Check if KindAccessMonitoringRule resources are being watched,
// the role the plugin is running as may not have access.
if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) {
if err := a.initAccessMonitoringRulesCache(ctx); err != nil {
return trace.Wrap(err, "initializing Access Monitoring Rule cache")
}
}

a.job.SetReady(ok)
Expand Down
35 changes: 35 additions & 0 deletions integrations/lib/watcherjob/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ func NewMockEventsProcess(ctx context.Context, t *testing.T, config Config, fn E
return &process
}

// WaitReady waits for the job to be ready.
func (process *MockEventsProcess) WaitReady(ctx context.Context) (bool, error) {
return process.eventsJob.WaitReady(ctx)
}

// Shutdown sends a termination signal and waits for process completion.
func (process *MockEventsProcess) Shutdown(ctx context.Context) error {
process.Terminate()
Expand Down Expand Up @@ -181,3 +186,33 @@ func (countdown *Countdown) Wait(ctx context.Context) error {
return trace.Wrap(ctx.Err())
}
}

// NewMockEventsProcessWithConfirmedWatchJobs creates a new mock process that passes confirmed watch kinds back.
func NewMockEventsProcessWithConfirmedWatchJobs(ctx context.Context, t *testing.T, config Config, fn EventFunc, watchInitFunc WatchInitFunc) *MockEventsProcess {
t.Helper()
process := MockEventsProcess{
Process: lib.NewProcess(ctx),
}
t.Cleanup(func() {
process.Terminate()
if err := process.Shutdown(ctx); err != nil {
assert.ErrorContains(t, err, context.Canceled.Error(), "if a non-nil error is returned, it should be canceled context")
}
process.Close()
})
var err error

process.eventsJob, err = NewJobWithConfirmedWatchKinds(&process.Events, config, fn, watchInitFunc)
require.NoError(t, err)
process.SpawnCriticalJob(process.eventsJob)
require.NoError(t, process.Events.WaitSomeWatchers(ctx))
process.Events.Fire(types.Event{
Type: types.OpInit,
Resource: &types.WatchStatusV1{
Spec: types.WatchStatusSpecV1{
Kinds: config.Watch.Kinds,
},
}})

return &process
}
36 changes: 27 additions & 9 deletions integrations/lib/watcherjob/watcherjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
)

type EventFunc func(context.Context, types.Event) error
type WatchInitFunc func(types.WatchStatus)

type Config struct {
Watch types.Watch
Expand All @@ -53,10 +54,11 @@ type Config struct {

type job struct {
lib.ServiceJob
config Config
eventFunc EventFunc
events types.Events
eventCh chan *types.Event
config Config
eventFunc EventFunc
events types.Events
eventCh chan *types.Event
onWatchInitFunc WatchInitFunc
}

type eventKey struct {
Expand All @@ -68,7 +70,17 @@ func NewJob(client teleport.Client, config Config, fn EventFunc) (lib.ServiceJob
return NewJobWithEvents(client, config, fn)
}

// NewJobWithConfirmedWatchKinds returns a new watcherJob and passes confirmed watch kinds
// from the initialisation down confirmedWatchKindsCh.
func NewJobWithConfirmedWatchKinds(events types.Events, config Config, fn EventFunc, watchInitFunc WatchInitFunc) (lib.ServiceJob, error) {
return newJobWithEvents(events, config, fn, watchInitFunc)
}

func NewJobWithEvents(events types.Events, config Config, fn EventFunc) (lib.ServiceJob, error) {
return newJobWithEvents(events, config, fn, nil)
}

func newJobWithEvents(events types.Events, config Config, fn EventFunc, watchInitFunc WatchInitFunc) (job, error) {
if config.MaxConcurrency == 0 {
config.MaxConcurrency = DefaultMaxConcurrency
}
Expand All @@ -78,15 +90,16 @@ func NewJobWithEvents(events types.Events, config Config, fn EventFunc) (lib.Ser
if flagVar := os.Getenv(failFastEnvVarName); !config.FailFast && flagVar != "" {
flag, err := strconv.ParseBool(flagVar)
if err != nil {
return nil, trace.WrapWithMessage(err, "failed to parse content '%s' of the %s environment variable", flagVar, failFastEnvVarName)
return job{}, trace.WrapWithMessage(err, "failed to parse content '%s' of the %s environment variable", flagVar, failFastEnvVarName)
}
config.FailFast = flag
}
job := job{
events: events,
config: config,
eventFunc: fn,
eventCh: make(chan *types.Event, config.MaxConcurrency),
events: events,
config: config,
eventFunc: fn,
eventCh: make(chan *types.Event, config.MaxConcurrency),
onWatchInitFunc: watchInitFunc,
}
job.ServiceJob = lib.NewServiceJob(func(ctx context.Context) error {
process := lib.MustGetProcess(ctx)
Expand Down Expand Up @@ -184,6 +197,11 @@ func (job job) waitInit(ctx context.Context, watcher types.Watcher, timeout time
if event.Type != types.OpInit {
return trace.ConnectionProblem(nil, "unexpected event type %q", event.Type)
}
if watchStatus, ok := event.Resource.(types.WatchStatus); ok {
if job.onWatchInitFunc != nil {
job.onWatchInitFunc(watchStatus)
}
}
return nil
case <-time.After(timeout):
return trace.ConnectionProblem(nil, "watcher initialization timed out")
Expand Down
52 changes: 52 additions & 0 deletions integrations/lib/watcherjob/watcherjob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package watcherjob
import (
"context"
"fmt"
"slices"
"testing"
"time"

Expand Down Expand Up @@ -111,3 +112,54 @@ func TestConcurrencyLimit(t *testing.T) {
timeAfter := time.Now()
assert.InDelta(t, 4*time.Second, timeAfter.Sub(timeBefore), float64(750*time.Millisecond))
}

// TestNewJobWithConfirmedWatchKinds checks that the watch kinds are passed back after init.
func TestNewJobWithConfirmedWatchKinds(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)

watchKinds := []types.WatchKind{
{Kind: types.KindAccessRequest},
}
config := Config{
MaxConcurrency: 4,
Watch: types.Watch{
Kinds: watchKinds,
},
}
countdown := NewCountdown(config.MaxConcurrency)
var acceptedWatchKinds []string
onWatchInit := func(ws types.WatchStatus) {
for _, watchKind := range ws.GetKinds() {
acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind)
}
}

process := NewMockEventsProcessWithConfirmedWatchJobs(ctx, t, config,
func(ctx context.Context, event types.Event) error {
defer countdown.Decrement()
time.Sleep(time.Second)
return trace.Wrap(ctx.Err())
}, onWatchInit)

_, err := process.WaitReady(ctx)
require.NoError(t, err)

if !slices.ContainsFunc(acceptedWatchKinds, func(kind string) bool {
return kind == types.KindAccessRequest
}) {
t.Error("access request watch kind not returned after init: %V", acceptedWatchKinds)
}

timeBefore := time.Now()
for i := 0; i < config.MaxConcurrency; i++ {
resource, err := types.NewAccessRequest("REQ-SAME", "foo", "admin")
require.NoError(t, err)
process.Events.Fire(types.Event{Type: types.OpPut, Resource: resource})
}
require.NoError(t, countdown.Wait(ctx))

timeAfter := time.Now()
assert.InDelta(t, 4*time.Second, timeAfter.Sub(timeBefore), float64(1000*time.Millisecond))
}

0 comments on commit 303c257

Please sign in to comment.