diff --git a/lib/services/access_request_cache.go b/lib/services/access_request_cache.go index 100458ef49061..ff23fe2d278ad 100644 --- a/lib/services/access_request_cache.go +++ b/lib/services/access_request_cache.go @@ -53,6 +53,8 @@ type AccessRequestCacheConfig struct { Events types.Events // Getter is an access request getter client. Getter AccessRequestGetter + // MaxRetryPeriod is the maximum retry period on failed watches. + MaxRetryPeriod time.Duration } // CheckAndSetDefaults valides the config and provides reasonable defaults for optional fields. @@ -84,8 +86,12 @@ type AccessRequestCache struct { primaryCache *sortcache.SortCache[*types.AccessRequestV3] ttlCache *utils.FnCache initC chan struct{} + initOnce sync.Once closeContext context.Context cancel context.CancelFunc + // onInit is a callback used in tests to detect + // individual initializations. + onInit func() } // NewAccessRequestCache sets up a new [AccessRequestCache] instance based on the supplied @@ -117,8 +123,9 @@ func NewAccessRequestCache(cfg AccessRequestCacheConfig) (*AccessRequestCache, e } if _, err := newResourceWatcher(ctx, c, ResourceWatcherConfig{ - Component: "access-request-cache", - Client: cfg.Events, + Component: "access-request-cache", + Client: cfg.Events, + MaxRetryPeriod: cfg.MaxRetryPeriod, }); err != nil { cancel() return nil, trace.Wrap(err) @@ -344,11 +351,23 @@ func (c *AccessRequestCache) getResourcesAndUpdateCurrent(ctx context.Context) e c.rw.Lock() defer c.rw.Unlock() c.primaryCache = cache - close(c.initC) + c.initOnce.Do(func() { + close(c.initC) + }) + if c.onInit != nil { + c.onInit() + } return nil } -// processEventAndUpdateCurrent is part of the resourceCollector interface and is used to update the +// SetInitCallback is used in tests that care about cache inits. +func (c *AccessRequestCache) SetInitCallback(cb func()) { + c.rw.Lock() + defer c.rw.Unlock() + c.onInit = cb +} + +// processEventsAndUpdateCurrent is part of the resourceCollector interface and is used to update the // primary cache state when modification events occur. func (c *AccessRequestCache) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { c.rw.RLock() @@ -385,6 +404,7 @@ func (c *AccessRequestCache) notifyStale() { } c.primaryCache = nil c.initC = make(chan struct{}) + c.initOnce = sync.Once{} } // initializationChan is part of the resourceCollector interface and gets the channel diff --git a/lib/services/access_request_cache_test.go b/lib/services/access_request_cache_test.go index d34aa948fe064..decd85684fe71 100644 --- a/lib/services/access_request_cache_test.go +++ b/lib/services/access_request_cache_test.go @@ -22,7 +22,10 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" @@ -34,6 +37,8 @@ import ( type accessRequestServices struct { types.Events services.DynamicAccessExt + + bk *memory.Memory } func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.AccessRequestCache) { @@ -43,11 +48,13 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access svcs := accessRequestServices{ Events: local.NewEventsService(bk), DynamicAccessExt: local.NewDynamicAccessService(bk), + bk: bk, } cache, err := services.NewAccessRequestCache(services.AccessRequestCacheConfig{ - Events: svcs, - Getter: svcs, + Events: svcs, + Getter: svcs, + MaxRetryPeriod: time.Millisecond * 100, }) require.NoError(t, err) @@ -60,6 +67,108 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access return svcs, cache } +func TestAccessRequestCacheResets(t *testing.T) { + const ( + requestCount = 100 + workers = 20 + resets = 3 + ) + + t.Parallel() + + svcs, cache := newAccessRequestPack(t) + defer cache.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < requestCount; i++ { + r, err := types.NewAccessRequest(uuid.New().String(), "alice@example.com", "some-role") + require.NoError(t, err) + + _, err = svcs.CreateAccessRequestV2(ctx, r) + require.NoError(t, err) + } + + timeout := time.After(time.Second * 30) + + for { + rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{ + Limit: requestCount, + }) + require.NoError(t, err) + if len(rsp.AccessRequests) == requestCount { + break + } + + select { + case <-timeout: + require.FailNow(t, "timeout waiting for access request cache to populate") + case <-time.After(time.Millisecond * 200): + } + } + + doneC := make(chan struct{}) + reads := make(chan struct{}, workers) + var eg errgroup.Group + + for i := 0; i < workers; i++ { + eg.Go(func() error { + for { + select { + case <-doneC: + return nil + case <-time.After(time.Millisecond * 20): + } + + rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{ + Limit: int32(requestCount), + }) + if err != nil { + return trace.Errorf("unexpected read failure: %v", err) + } + + select { + case reads <- struct{}{}: + default: + } + + if len(rsp.AccessRequests) != requestCount { + return trace.Errorf("unexpected number of access requests: %d (expected %d)", len(rsp.AccessRequests), requestCount) + } + } + }) + } + + inits := make(chan struct{}, resets+1) + cache.SetInitCallback(func() { + inits <- struct{}{} + }) + + timeout = time.After(time.Second * 30) + for i := 0; i < resets; i++ { + svcs.bk.CloseWatchers() + select { + case <-inits: + case <-timeout: + require.FailNowf(t, "timeout waiting for access request cache to reset", "reset=%d", i) + } + + for j := 0; j < workers; j++ { + // ensure that we're not racing ahead of worker reads too + // much if inits are happening quickly. + select { + case <-reads: + case <-timeout: + require.FailNowf(t, "timeout waiting for worker reads to catch up", "reset=%d", i) + } + } + } + + close(doneC) + require.NoError(t, eg.Wait()) +} + // TestAccessRequestCacheBasics verifies the basic expected behaviors of the access request cache, // including correct sorting and handling of put/delete events. func TestAccessRequestCacheBasics(t *testing.T) {