diff --git a/common/nexus/endpoint_registry.go b/common/nexus/endpoint_registry.go index 7e3dea4bc55..a05e5d09b7c 100644 --- a/common/nexus/endpoint_registry.go +++ b/common/nexus/endpoint_registry.go @@ -73,15 +73,14 @@ type ( EndpointRegistryImpl struct { config *EndpointRegistryConfig - dataReady chan struct{} + dataReady atomic.Pointer[dataReady] dataLock sync.RWMutex // Protects tableVersion and endpoints. tableVersion int64 endpointsByID map[string]*persistencespb.NexusEndpointEntry // Mapping of endpoint ID -> endpoint. endpointsByName map[string]*persistencespb.NexusEndpointEntry // Mapping of endpoint name -> endpoint. - refreshPoller atomic.Pointer[goro.Handle] - cancelDcSub func() + cancelDcSub func() matchingClient matchingservice.MatchingServiceClient persistence p.NexusEndpointManager @@ -89,8 +88,15 @@ type ( readThroughCacheByID cache.Cache } + + dataReady struct { + refresh *goro.Handle // handle to refresh goroutine + ready chan struct{} // channel that clients can wait on for state changes + } ) +var ErrNexusDisabled = serviceerror.NewFailedPrecondition("nexus is disabled") + func NewEndpointRegistryConfig(dc *dynamicconfig.Collection) *EndpointRegistryConfig { config := &EndpointRegistryConfig{ refreshEnabled: dynamicconfig.EnableNexus.Subscribe(dc), @@ -113,7 +119,6 @@ func NewEndpointRegistry( ) *EndpointRegistryImpl { return &EndpointRegistryImpl{ config: config, - dataReady: make(chan struct{}), endpointsByID: make(map[string]*persistencespb.NexusEndpointEntry), endpointsByName: make(map[string]*persistencespb.NexusEndpointEntry), matchingClient: matchingClient, @@ -141,22 +146,30 @@ func (r *EndpointRegistryImpl) StopLifecycle() { } func (r *EndpointRegistryImpl) setEnabled(enabled bool) { - oldPoller := r.refreshPoller.Load() - if oldPoller == nil && enabled { + oldReady := r.dataReady.Load() + if oldReady == nil && enabled { backgroundCtx := headers.SetCallerInfo( context.Background(), headers.SystemBackgroundCallerInfo, ) - newPoller := goro.NewHandle(backgroundCtx) - oldPoller = r.refreshPoller.Swap(newPoller) - if oldPoller == nil { - newPoller.Go(r.refreshEndpointsLoop) + newReady := &dataReady{ + refresh: goro.NewHandle(backgroundCtx), + ready: make(chan struct{}), } - } else if oldPoller != nil && !enabled { - oldPoller = r.refreshPoller.Swap(nil) - if oldPoller != nil { - oldPoller.Cancel() - <-oldPoller.Done() + if r.dataReady.CompareAndSwap(oldReady, newReady) { + newReady.refresh.Go(func(ctx context.Context) error { + return r.refreshEndpointsLoop(ctx, newReady) + }) + } + } else if oldReady != nil && !enabled { + if r.dataReady.CompareAndSwap(oldReady, nil) { + oldReady.refresh.Cancel() + <-oldReady.refresh.Done() + // If oldReady.ready was not already closed here, callers blocked in waitUntilInitialized + // will block indefinitely (until context timeout). If we wanted to wake them up, we + // could close ready here, but we would need to use a sync.Once to avoid closing it + // twice. Then waitUntilInitialized would need to reload r.dataReady to check that the + // wakeup was due to data being ready rather than this close. } } } @@ -206,15 +219,19 @@ func (r *EndpointRegistryImpl) GetByID(ctx context.Context, id string) (*persist } func (r *EndpointRegistryImpl) waitUntilInitialized(ctx context.Context) error { + dataReady := r.dataReady.Load() + if dataReady == nil { + return ErrNexusDisabled + } select { - case <-r.dataReady: + case <-dataReady.ready: return nil case <-ctx.Done(): return ctx.Err() } } -func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context) error { +func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context, dataReady *dataReady) error { hasLoadedEndpointData := false for ctx.Err() == nil { @@ -226,7 +243,9 @@ func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context) error { err := backoff.ThrottleRetryContext(ctx, r.loadEndpoints, r.config.refreshRetryPolicy, nil) if err == nil { hasLoadedEndpointData = true - close(r.dataReady) + // Note: do not reload r.dataReady here, use value from argument to ensure that + // each channel is closed no more than once. + close(dataReady.ready) } } else { // Endpoints have previously been loaded, so just keep them up to date with long poll requests to diff --git a/common/nexus/endpoint_registry_test.go b/common/nexus/endpoint_registry_test.go index 77cd648ee53..f9827aed80b 100644 --- a/common/nexus/endpoint_registry_test.go +++ b/common/nexus/endpoint_registry_test.go @@ -26,6 +26,7 @@ import ( "context" "errors" "fmt" + "sync" "testing" "time" @@ -177,6 +178,85 @@ func TestInitializationFallback(t *testing.T) { assert.Equal(t, int64(1), reg.tableVersion) } +func TestEnableDisableEnable(t *testing.T) { + t.Parallel() + + testEntry := newEndpointEntry(t.Name()) + mocks := newTestMocks(t) + + mocks.config.refreshMinWait = dynamicconfig.GetDurationPropertyFn(time.Millisecond) + var callback func(bool) // capture callback to call later + mocks.config.refreshEnabled = func(cb func(bool)) (bool, func()) { + callback = cb + return false, func() {} + } + + // start disabled + reg := NewEndpointRegistry(mocks.config, mocks.matchingClient, mocks.persistence, log.NewNoopLogger(), metrics.NoopMetricsHandler) + reg.StartLifecycle() + defer reg.StopLifecycle() + + // check waitUntilInitialized + quickCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + require.ErrorIs(t, reg.waitUntilInitialized(quickCtx), ErrNexusDisabled) + + // mocks for initial load + inLongPoll := make(chan struct{}) + closeOnce := sync.OnceFunc(func() { close(inLongPoll) }) + mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), gomock.Any()).Return(&matchingservice.ListNexusEndpointsResponse{ + Entries: []*persistencepb.NexusEndpointEntry{testEntry}, + TableVersion: 1, + NextPageToken: nil, + }, nil) + mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), &matchingservice.ListNexusEndpointsRequest{ + PageSize: int32(100), + LastKnownTableVersion: int64(1), + Wait: true, + }).DoAndReturn(func(context.Context, *matchingservice.ListNexusEndpointsRequest, ...interface{}) (*matchingservice.ListNexusEndpointsResponse, error) { + closeOnce() + time.Sleep(100 * time.Millisecond) + return &matchingservice.ListNexusEndpointsResponse{TableVersion: int64(1)}, nil + }) + + // enable + callback(true) + <-inLongPoll + + // check waitUntilInitialized + quickCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + require.NoError(t, reg.waitUntilInitialized(quickCtx)) + + // now disable + callback(false) + + quickCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + require.ErrorIs(t, reg.waitUntilInitialized(quickCtx), ErrNexusDisabled) + + // enable again, should not crash + + inLongPoll = make(chan struct{}) + closeOnce = sync.OnceFunc(func() { close(inLongPoll) }) + mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), gomock.Any()).Return(&matchingservice.ListNexusEndpointsResponse{ + Entries: []*persistencepb.NexusEndpointEntry{testEntry}, + TableVersion: 1, + NextPageToken: nil, + }, nil) + mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), &matchingservice.ListNexusEndpointsRequest{ + PageSize: int32(100), + LastKnownTableVersion: int64(1), + Wait: true, + }).DoAndReturn(func(context.Context, *matchingservice.ListNexusEndpointsRequest, ...interface{}) (*matchingservice.ListNexusEndpointsResponse, error) { + closeOnce() + time.Sleep(100 * time.Millisecond) + return &matchingservice.ListNexusEndpointsResponse{TableVersion: int64(1)}, nil + }) + callback(true) + <-inLongPoll +} + func TestTableVersionErrorResetsMatchingPagination(t *testing.T) { t.Parallel()