Skip to content

Commit

Permalink
Fix enable/disable in nexus endpoint registry (#6570)
Browse files Browse the repository at this point in the history
## What changed?
Correctly handle a sequence of enable, disable, enable in nexus endpoint
registry without double-close on the ready channel.

## Why?
avoid crash

## How did you test it?
new unit test
  • Loading branch information
dnr authored Sep 27, 2024
1 parent 76a6276 commit e125ed5
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 18 deletions.
55 changes: 37 additions & 18 deletions common/nexus/endpoint_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,30 @@ type (
EndpointRegistryImpl struct {
config *EndpointRegistryConfig

dataReady chan struct{}
dataReady atomic.Pointer[dataReady]

dataLock sync.RWMutex // Protects tableVersion and endpoints.
tableVersion int64
endpointsByID map[string]*persistencespb.NexusEndpointEntry // Mapping of endpoint ID -> endpoint.
endpointsByName map[string]*persistencespb.NexusEndpointEntry // Mapping of endpoint name -> endpoint.

refreshPoller atomic.Pointer[goro.Handle]
cancelDcSub func()
cancelDcSub func()

matchingClient matchingservice.MatchingServiceClient
persistence p.NexusEndpointManager
logger log.Logger

readThroughCacheByID cache.Cache
}

dataReady struct {
refresh *goro.Handle // handle to refresh goroutine
ready chan struct{} // channel that clients can wait on for state changes
}
)

var ErrNexusDisabled = serviceerror.NewFailedPrecondition("nexus is disabled")

func NewEndpointRegistryConfig(dc *dynamicconfig.Collection) *EndpointRegistryConfig {
config := &EndpointRegistryConfig{
refreshEnabled: dynamicconfig.EnableNexus.Subscribe(dc),
Expand All @@ -113,7 +119,6 @@ func NewEndpointRegistry(
) *EndpointRegistryImpl {
return &EndpointRegistryImpl{
config: config,
dataReady: make(chan struct{}),
endpointsByID: make(map[string]*persistencespb.NexusEndpointEntry),
endpointsByName: make(map[string]*persistencespb.NexusEndpointEntry),
matchingClient: matchingClient,
Expand Down Expand Up @@ -141,22 +146,30 @@ func (r *EndpointRegistryImpl) StopLifecycle() {
}

func (r *EndpointRegistryImpl) setEnabled(enabled bool) {
oldPoller := r.refreshPoller.Load()
if oldPoller == nil && enabled {
oldReady := r.dataReady.Load()
if oldReady == nil && enabled {
backgroundCtx := headers.SetCallerInfo(
context.Background(),
headers.SystemBackgroundCallerInfo,
)
newPoller := goro.NewHandle(backgroundCtx)
oldPoller = r.refreshPoller.Swap(newPoller)
if oldPoller == nil {
newPoller.Go(r.refreshEndpointsLoop)
newReady := &dataReady{
refresh: goro.NewHandle(backgroundCtx),
ready: make(chan struct{}),
}
} else if oldPoller != nil && !enabled {
oldPoller = r.refreshPoller.Swap(nil)
if oldPoller != nil {
oldPoller.Cancel()
<-oldPoller.Done()
if r.dataReady.CompareAndSwap(oldReady, newReady) {
newReady.refresh.Go(func(ctx context.Context) error {
return r.refreshEndpointsLoop(ctx, newReady)
})
}
} else if oldReady != nil && !enabled {
if r.dataReady.CompareAndSwap(oldReady, nil) {
oldReady.refresh.Cancel()
<-oldReady.refresh.Done()
// If oldReady.ready was not already closed here, callers blocked in waitUntilInitialized
// will block indefinitely (until context timeout). If we wanted to wake them up, we
// could close ready here, but we would need to use a sync.Once to avoid closing it
// twice. Then waitUntilInitialized would need to reload r.dataReady to check that the
// wakeup was due to data being ready rather than this close.
}
}
}
Expand Down Expand Up @@ -206,15 +219,19 @@ func (r *EndpointRegistryImpl) GetByID(ctx context.Context, id string) (*persist
}

func (r *EndpointRegistryImpl) waitUntilInitialized(ctx context.Context) error {
dataReady := r.dataReady.Load()
if dataReady == nil {
return ErrNexusDisabled
}
select {
case <-r.dataReady:
case <-dataReady.ready:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context) error {
func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context, dataReady *dataReady) error {
hasLoadedEndpointData := false

for ctx.Err() == nil {
Expand All @@ -226,7 +243,9 @@ func (r *EndpointRegistryImpl) refreshEndpointsLoop(ctx context.Context) error {
err := backoff.ThrottleRetryContext(ctx, r.loadEndpoints, r.config.refreshRetryPolicy, nil)
if err == nil {
hasLoadedEndpointData = true
close(r.dataReady)
// Note: do not reload r.dataReady here, use value from argument to ensure that
// each channel is closed no more than once.
close(dataReady.ready)
}
} else {
// Endpoints have previously been loaded, so just keep them up to date with long poll requests to
Expand Down
80 changes: 80 additions & 0 deletions common/nexus/endpoint_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -177,6 +178,85 @@ func TestInitializationFallback(t *testing.T) {
assert.Equal(t, int64(1), reg.tableVersion)
}

func TestEnableDisableEnable(t *testing.T) {
t.Parallel()

testEntry := newEndpointEntry(t.Name())
mocks := newTestMocks(t)

mocks.config.refreshMinWait = dynamicconfig.GetDurationPropertyFn(time.Millisecond)
var callback func(bool) // capture callback to call later
mocks.config.refreshEnabled = func(cb func(bool)) (bool, func()) {
callback = cb
return false, func() {}
}

// start disabled
reg := NewEndpointRegistry(mocks.config, mocks.matchingClient, mocks.persistence, log.NewNoopLogger(), metrics.NoopMetricsHandler)
reg.StartLifecycle()
defer reg.StopLifecycle()

// check waitUntilInitialized
quickCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
require.ErrorIs(t, reg.waitUntilInitialized(quickCtx), ErrNexusDisabled)

// mocks for initial load
inLongPoll := make(chan struct{})
closeOnce := sync.OnceFunc(func() { close(inLongPoll) })
mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), gomock.Any()).Return(&matchingservice.ListNexusEndpointsResponse{
Entries: []*persistencepb.NexusEndpointEntry{testEntry},
TableVersion: 1,
NextPageToken: nil,
}, nil)
mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), &matchingservice.ListNexusEndpointsRequest{
PageSize: int32(100),
LastKnownTableVersion: int64(1),
Wait: true,
}).DoAndReturn(func(context.Context, *matchingservice.ListNexusEndpointsRequest, ...interface{}) (*matchingservice.ListNexusEndpointsResponse, error) {
closeOnce()
time.Sleep(100 * time.Millisecond)
return &matchingservice.ListNexusEndpointsResponse{TableVersion: int64(1)}, nil
})

// enable
callback(true)
<-inLongPoll

// check waitUntilInitialized
quickCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
require.NoError(t, reg.waitUntilInitialized(quickCtx))

// now disable
callback(false)

quickCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
require.ErrorIs(t, reg.waitUntilInitialized(quickCtx), ErrNexusDisabled)

// enable again, should not crash

inLongPoll = make(chan struct{})
closeOnce = sync.OnceFunc(func() { close(inLongPoll) })
mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), gomock.Any()).Return(&matchingservice.ListNexusEndpointsResponse{
Entries: []*persistencepb.NexusEndpointEntry{testEntry},
TableVersion: 1,
NextPageToken: nil,
}, nil)
mocks.matchingClient.EXPECT().ListNexusEndpoints(gomock.Any(), &matchingservice.ListNexusEndpointsRequest{
PageSize: int32(100),
LastKnownTableVersion: int64(1),
Wait: true,
}).DoAndReturn(func(context.Context, *matchingservice.ListNexusEndpointsRequest, ...interface{}) (*matchingservice.ListNexusEndpointsResponse, error) {
closeOnce()
time.Sleep(100 * time.Millisecond)
return &matchingservice.ListNexusEndpointsResponse{TableVersion: int64(1)}, nil
})
callback(true)
<-inLongPoll
}

func TestTableVersionErrorResetsMatchingPagination(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit e125ed5

Please sign in to comment.