Skip to content

Commit

Permalink
fix access request cache panic (#45225) (#45494)
Browse files Browse the repository at this point in the history
  • Loading branch information
fspmarshall authored Aug 14, 2024
1 parent 54dd4bb commit 984999b
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 6 deletions.
28 changes: 24 additions & 4 deletions lib/services/access_request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type AccessRequestCacheConfig struct {
Events types.Events
// Getter is an access request getter client.
Getter AccessRequestGetter
// MaxRetryPeriod is the maximum retry period on failed watches.
MaxRetryPeriod time.Duration
}

// CheckAndSetDefaults valides the config and provides reasonable defaults for optional fields.
Expand Down Expand Up @@ -84,8 +86,12 @@ type AccessRequestCache struct {
primaryCache *sortcache.SortCache[*types.AccessRequestV3]
ttlCache *utils.FnCache
initC chan struct{}
initOnce sync.Once
closeContext context.Context
cancel context.CancelFunc
// onInit is a callback used in tests to detect
// individual initializations.
onInit func()
}

// NewAccessRequestCache sets up a new [AccessRequestCache] instance based on the supplied
Expand Down Expand Up @@ -117,8 +123,9 @@ func NewAccessRequestCache(cfg AccessRequestCacheConfig) (*AccessRequestCache, e
}

if _, err := newResourceWatcher(ctx, c, ResourceWatcherConfig{
Component: "access-request-cache",
Client: cfg.Events,
Component: "access-request-cache",
Client: cfg.Events,
MaxRetryPeriod: cfg.MaxRetryPeriod,
}); err != nil {
cancel()
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -344,11 +351,23 @@ func (c *AccessRequestCache) getResourcesAndUpdateCurrent(ctx context.Context) e
c.rw.Lock()
defer c.rw.Unlock()
c.primaryCache = cache
close(c.initC)
c.initOnce.Do(func() {
close(c.initC)
})
if c.onInit != nil {
c.onInit()
}
return nil
}

// processEventAndUpdateCurrent is part of the resourceCollector interface and is used to update the
// SetInitCallback is used in tests that care about cache inits.
func (c *AccessRequestCache) SetInitCallback(cb func()) {
c.rw.Lock()
defer c.rw.Unlock()
c.onInit = cb
}

// processEventsAndUpdateCurrent is part of the resourceCollector interface and is used to update the
// primary cache state when modification events occur.
func (c *AccessRequestCache) processEventAndUpdateCurrent(ctx context.Context, event types.Event) {
c.rw.RLock()
Expand Down Expand Up @@ -385,6 +404,7 @@ func (c *AccessRequestCache) notifyStale() {
}
c.primaryCache = nil
c.initC = make(chan struct{})
c.initOnce = sync.Once{}
}

// initializationChan is part of the resourceCollector interface and gets the channel
Expand Down
113 changes: 111 additions & 2 deletions lib/services/access_request_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
Expand All @@ -34,6 +37,8 @@ import (
type accessRequestServices struct {
types.Events
services.DynamicAccessExt

bk *memory.Memory
}

func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.AccessRequestCache) {
Expand All @@ -43,11 +48,13 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access
svcs := accessRequestServices{
Events: local.NewEventsService(bk),
DynamicAccessExt: local.NewDynamicAccessService(bk),
bk: bk,
}

cache, err := services.NewAccessRequestCache(services.AccessRequestCacheConfig{
Events: svcs,
Getter: svcs,
Events: svcs,
Getter: svcs,
MaxRetryPeriod: time.Millisecond * 100,
})
require.NoError(t, err)

Expand All @@ -60,6 +67,108 @@ func newAccessRequestPack(t *testing.T) (accessRequestServices, *services.Access
return svcs, cache
}

func TestAccessRequestCacheResets(t *testing.T) {
const (
requestCount = 100
workers = 20
resets = 3
)

t.Parallel()

svcs, cache := newAccessRequestPack(t)
defer cache.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for i := 0; i < requestCount; i++ {
r, err := types.NewAccessRequest(uuid.New().String(), "[email protected]", "some-role")
require.NoError(t, err)

_, err = svcs.CreateAccessRequestV2(ctx, r)
require.NoError(t, err)
}

timeout := time.After(time.Second * 30)

for {
rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: requestCount,
})
require.NoError(t, err)
if len(rsp.AccessRequests) == requestCount {
break
}

select {
case <-timeout:
require.FailNow(t, "timeout waiting for access request cache to populate")
case <-time.After(time.Millisecond * 200):
}
}

doneC := make(chan struct{})
reads := make(chan struct{}, workers)
var eg errgroup.Group

for i := 0; i < workers; i++ {
eg.Go(func() error {
for {
select {
case <-doneC:
return nil
case <-time.After(time.Millisecond * 20):
}

rsp, err := cache.ListAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Limit: int32(requestCount),
})
if err != nil {
return trace.Errorf("unexpected read failure: %v", err)
}

select {
case reads <- struct{}{}:
default:
}

if len(rsp.AccessRequests) != requestCount {
return trace.Errorf("unexpected number of access requests: %d (expected %d)", len(rsp.AccessRequests), requestCount)
}
}
})
}

inits := make(chan struct{}, resets+1)
cache.SetInitCallback(func() {
inits <- struct{}{}
})

timeout = time.After(time.Second * 30)
for i := 0; i < resets; i++ {
svcs.bk.CloseWatchers()
select {
case <-inits:
case <-timeout:
require.FailNowf(t, "timeout waiting for access request cache to reset", "reset=%d", i)
}

for j := 0; j < workers; j++ {
// ensure that we're not racing ahead of worker reads too
// much if inits are happening quickly.
select {
case <-reads:
case <-timeout:
require.FailNowf(t, "timeout waiting for worker reads to catch up", "reset=%d", i)
}
}
}

close(doneC)
require.NoError(t, eg.Wait())
}

// TestAccessRequestCacheBasics verifies the basic expected behaviors of the access request cache,
// including correct sorting and handling of put/delete events.
func TestAccessRequestCacheBasics(t *testing.T) {
Expand Down

0 comments on commit 984999b

Please sign in to comment.