Skip to content

Commit

Permalink
Prep for upcoming backend.Key type change (#46675)
Browse files Browse the repository at this point in the history
Adds a (Key) IsZero method to determine if a key is populated. Some
Backend implementations today validate keys before operations via
a `len(key) == 0` check, however, that will no longer work once the
key is migrated.

Starts migrating the backend.LockConfiguration away from prepopulating
the lock name to passing in a list of components. There were a few
locks constructing a portion of the name manually, which will not
work when the key type is changed.
  • Loading branch information
rosstimothy committed Oct 15, 2024
1 parent ce0fe10 commit 2afbc09
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 69 deletions.
6 changes: 3 additions & 3 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e
if err := backend.RunWhileLocked(ctx,
backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: cfg.Backend,
LockName: domainName,
TTL: 30 * time.Second,
Backend: cfg.Backend,
LockNameComponents: []string{domainName},
TTL: 30 * time.Second,
},
RefreshLockInterval: 20 * time.Second,
}, func(ctx context.Context) error {
Expand Down
5 changes: 5 additions & 0 deletions lib/backend/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func (k Key) String() string {
return string(k)
}

// IsZero reports whether k represents the zero key.
func (k Key) IsZero() bool {
return len(k) == 0
}

// HasPrefix reports whether the key begins with prefix.
func (k Key) HasPrefix(prefix Key) bool {
return bytes.HasPrefix(k, prefix)
Expand Down
7 changes: 7 additions & 0 deletions lib/backend/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,10 @@ func TestKeyCompare(t *testing.T) {
})
}
}

func TestKeyIsZero(t *testing.T) {
assert.True(t, backend.Key{}.IsZero())
assert.True(t, backend.NewKey().IsZero())
assert.False(t, backend.NewKey("a", "b").IsZero())
assert.False(t, backend.ExactKey("a", "b").IsZero())
}
19 changes: 15 additions & 4 deletions lib/backend/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ func randomID() ([]byte, error) {
}

type LockConfiguration struct {
Backend Backend
// Backend to create the lock in.
Backend Backend
// LockName the precomputed lock name.
// TODO(tross) DELETE WHEN teleport.e is updated to use LockNameComponents.
LockName string
// LockNameComponents are subcomponents to be used when constructing
// the lock name.
LockNameComponents []string
// TTL defines when lock will be released automatically
TTL time.Duration
// RetryInterval defines interval which is used to retry locking after
Expand All @@ -63,9 +69,14 @@ func (l *LockConfiguration) CheckAndSetDefaults() error {
if l.Backend == nil {
return trace.BadParameter("missing Backend")
}
if l.LockName == "" {
return trace.BadParameter("missing LockName")
if l.LockName == "" && len(l.LockNameComponents) == 0 {
return trace.BadParameter("missing LockName/LockNameComponents")
}

if len(l.LockNameComponents) == 0 {
l.LockNameComponents = []string{l.LockName}
}

if l.TTL == 0 {
return trace.BadParameter("missing TTL")
}
Expand All @@ -81,7 +92,7 @@ func AcquireLock(ctx context.Context, cfg LockConfiguration) (Lock, error) {
if err != nil {
return Lock{}, trace.Wrap(err)
}
key := lockKey(cfg.LockName)
key := lockKey(cfg.LockNameComponents...)
id, err := randomID()
if err != nil {
return Lock{}, trace.Wrap(err)
Expand Down
51 changes: 26 additions & 25 deletions lib/backend/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,30 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) {
{
name: "minimum valid",
in: LockConfiguration{
Backend: mockBackend{},
LockName: "lock",
TTL: 30 * time.Second,
Backend: mockBackend{},
LockNameComponents: []string{"lock"},
TTL: 30 * time.Second,
},
want: LockConfiguration{
Backend: mockBackend{},
LockName: "lock",
TTL: 30 * time.Second,
RetryInterval: 250 * time.Millisecond,
Backend: mockBackend{},
LockNameComponents: []string{"lock"},
TTL: 30 * time.Second,
RetryInterval: 250 * time.Millisecond,
},
},
{
name: "set RetryAcquireLockTimeout",
in: LockConfiguration{
Backend: mockBackend{},
LockName: "lock",
TTL: 30 * time.Second,
RetryInterval: 10 * time.Second,
Backend: mockBackend{},
LockNameComponents: []string{"lock"},
TTL: 30 * time.Second,
RetryInterval: 10 * time.Second,
},
want: LockConfiguration{
Backend: mockBackend{},
LockName: "lock",
TTL: 30 * time.Second,
RetryInterval: 10 * time.Second,
Backend: mockBackend{},
LockNameComponents: []string{"lock"},
TTL: 30 * time.Second,
RetryInterval: 10 * time.Second,
},
},
{
Expand All @@ -95,9 +95,9 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) {
{
name: "missing TTL",
in: LockConfiguration{
Backend: mockBackend{},
LockName: "lock",
TTL: 0,
Backend: mockBackend{},
LockNameComponents: []string{"lock"},
TTL: 0,
},
wantErr: "missing TTL",
},
Expand All @@ -124,9 +124,9 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) {
ttl := 1 * time.Minute
minimumValidConfig := RunWhileLockedConfig{
LockConfiguration: LockConfiguration{
Backend: mockBackend{},
LockName: lockName,
TTL: ttl,
Backend: mockBackend{},
LockNameComponents: []string{lockName},
TTL: ttl,
},
}
tests := []struct {
Expand All @@ -142,10 +142,10 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) {
},
want: RunWhileLockedConfig{
LockConfiguration: LockConfiguration{
Backend: mockBackend{},
LockName: lockName,
TTL: ttl,
RetryInterval: 250 * time.Millisecond,
Backend: mockBackend{},
LockNameComponents: []string{lockName},
TTL: ttl,
RetryInterval: 250 * time.Millisecond,
},
ReleaseCtxTimeout: time.Second,
// defaults to halft of TTL.
Expand All @@ -157,6 +157,7 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) {
input: func() RunWhileLockedConfig {
cfg := minimumValidConfig
cfg.LockName = ""
cfg.LockNameComponents = nil
return cfg
},
wantErr: "missing LockName",
Expand Down
14 changes: 7 additions & 7 deletions lib/backend/test/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ func testLocking(t *testing.T, newBackend Constructor) {
defer requireNoAsyncErrors()

// Given a lock named `tok1` on the backend...
lock, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})
require.NoError(t, err)

// When I asynchronously release the lock...
Expand All @@ -848,7 +848,7 @@ func testLocking(t *testing.T, newBackend Constructor) {
}()

// ...and simultaneously attempt to create a new lock with the same name
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})

// expect that the asynchronous Release() has executed - we're using the
// change in the value of the marker value as a proxy for the Release().
Expand All @@ -860,7 +860,7 @@ func testLocking(t *testing.T, newBackend Constructor) {
require.NoError(t, lock.Release(ctx, uut))

// Given a lock with the same name as previously-existing, manually-released lock
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})
require.NoError(t, err)
atomic.StoreInt32(&marker, 7)

Expand All @@ -875,7 +875,7 @@ func testLocking(t *testing.T, newBackend Constructor) {
}()

// ...and simultaneously try to acquire another lock with the same name
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})

// expect that the asynchronous Release() has executed - we're using the
// change in the value of the marker value as a proxy for the call to
Expand All @@ -889,9 +889,9 @@ func testLocking(t *testing.T, newBackend Constructor) {

// Given a pair of locks named `tok1` and `tok2`
y := int32(0)
lock1, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock1, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})
require.NoError(t, err)
lock2, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok2, TTL: ttl})
lock2, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok2}, TTL: ttl})
require.NoError(t, err)

// When I asynchronously release the locks...
Expand All @@ -908,7 +908,7 @@ func testLocking(t *testing.T, newBackend Constructor) {
}
}()

lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl})
lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl})
require.NoError(t, err)
require.Equal(t, int32(15), atomic.LoadInt32(&y))
require.NoError(t, lock.Release(ctx, uut))
Expand Down
4 changes: 2 additions & 2 deletions lib/events/athena/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ func (c *consumer) runContinuouslyOnSingleAuth(ctx context.Context, eventsProces
default:
err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: c.backend,
LockName: "athena_lock",
Backend: c.backend,
LockNameComponents: []string{"athena_lock"},
// TTL is higher then batchMaxInterval because we want to optimize
// for low backend writes.
TTL: 5 * c.batchMaxInterval,
Expand Down
6 changes: 3 additions & 3 deletions lib/services/local/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ func (s *AccessService) DeleteAllLocks(ctx context.Context) error {
func (s *AccessService) ReplaceRemoteLocks(ctx context.Context, clusterName string, newRemoteLocks []types.Lock) error {
return backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.Backend,
LockName: "ReplaceRemoteLocks/" + clusterName,
TTL: time.Minute,
Backend: s.Backend,
LockNameComponents: []string{"ReplaceRemoteLocks", clusterName},
TTL: time.Minute,
},
}, func(ctx context.Context) error {
remoteLocksKey := backend.ExactKey(locksPrefix, clusterName)
Expand Down
9 changes: 4 additions & 5 deletions lib/services/local/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package local

import (
"context"
"strings"
"time"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -196,7 +195,7 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces

var err error
if feature := modules.GetModules().Features(); !feature.IGSEnabled() {
err = a.service.RunWhileLocked(ctx, "createAccessListLimitLock", accessListLockTTL, func(ctx context.Context, _ backend.Backend) error {
err = a.service.RunWhileLocked(ctx, []string{"createAccessListLimitLock"}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error {
if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -453,7 +452,7 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc

var err error
if feature := modules.GetModules().Features(); !feature.IGSEnabled() {
err = a.service.RunWhileLocked(ctx, "createAccessListWithMembersLimitLock", accessListLockTTL, func(ctx context.Context, _ backend.Backend) error {
err = a.service.RunWhileLocked(ctx, []string{"createAccessListWithMembersLimitLock"}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error {
if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -638,8 +637,8 @@ func (a *AccessListService) DeleteAllAccessListReviews(ctx context.Context) erro
return trace.Wrap(a.reviewService.DeleteAllResources(ctx))
}

func lockName(accessListName string) string {
return strings.Join([]string{"access_list", accessListName}, string(backend.Separator))
func lockName(accessListName string) []string {
return []string{"access_list", accessListName}
}

// VerifyAccessListCreateLimit ensures creating access list is limited to no more than 1 (updating is allowed).
Expand Down
18 changes: 9 additions & 9 deletions lib/services/local/externalauditstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ func (s *ExternalAuditStorageService) CreateDraftExternalAuditStorage(ctx contex
var lease *backend.Lease
err = backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.backend,
LockName: externalAuditStorageLockName,
TTL: externalAuditStorageLockTTL,
Backend: s.backend,
LockNameComponents: []string{externalAuditStorageLockName},
TTL: externalAuditStorageLockTTL,
},
}, func(ctx context.Context) error {
// Check that the referenced AWS OIDC integration actually exists.
Expand Down Expand Up @@ -122,9 +122,9 @@ func (s *ExternalAuditStorageService) UpsertDraftExternalAuditStorage(ctx contex
var lease *backend.Lease
err = backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.backend,
LockName: externalAuditStorageLockName,
TTL: externalAuditStorageLockTTL,
Backend: s.backend,
LockNameComponents: []string{externalAuditStorageLockName},
TTL: externalAuditStorageLockTTL,
},
}, func(ctx context.Context) error {
// Check that the referenced AWS OIDC integration actually exists.
Expand Down Expand Up @@ -185,9 +185,9 @@ func (s *ExternalAuditStorageService) GetClusterExternalAuditStorage(ctx context
func (s *ExternalAuditStorageService) PromoteToClusterExternalAuditStorage(ctx context.Context) error {
err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.backend,
LockName: externalAuditStorageLockName,
TTL: externalAuditStorageLockTTL,
Backend: s.backend,
LockNameComponents: []string{externalAuditStorageLockName},
TTL: externalAuditStorageLockTTL,
},
}, func(ctx context.Context) error {
draft, err := s.GetDraftExternalAuditStorage(ctx)
Expand Down
10 changes: 5 additions & 5 deletions lib/services/local/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,14 @@ func (s *Service[T]) MakeKey(name string) backend.Key {
}

// RunWhileLocked will run the given function in a backend lock. This is a wrapper around the backend.RunWhileLocked function.
func (s *Service[T]) RunWhileLocked(ctx context.Context, lockName string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error {
func (s *Service[T]) RunWhileLocked(ctx context.Context, lockNameComponents []string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error {
return trace.Wrap(backend.RunWhileLocked(ctx,
backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.backend,
LockName: lockName,
TTL: ttl,
RetryInterval: s.runWhileLockedRetryInterval,
Backend: s.backend,
LockNameComponents: lockNameComponents,
TTL: ttl,
RetryInterval: s.runWhileLockedRetryInterval,
},
}, func(ctx context.Context) error {
return fn(ctx, s.backend)
Expand Down
2 changes: 1 addition & 1 deletion lib/services/local/generic/generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func TestGenericCRUD(t *testing.T) {
require.ErrorIs(t, err, trace.NotFound(`generic resource "doesnotexist" doesn't exist`))

// Test running while locked.
err = service.RunWhileLocked(ctx, "test-lock", time.Second*5, func(ctx context.Context, backend backend.Backend) error {
err = service.RunWhileLocked(ctx, []string{"test-lock"}, time.Second*5, func(ctx context.Context, backend backend.Backend) error {
item, err := backend.Get(ctx, service.MakeKey(r1.GetName()))
require.NoError(t, err)

Expand Down
6 changes: 3 additions & 3 deletions lib/services/local/integrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ func (s *IntegrationsService) DeleteIntegration(ctx context.Context, name string
// so that no new EAS integrations can be concurrently created.
err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: s.backend,
LockName: externalAuditStorageLockName,
TTL: externalAuditStorageLockTTL,
Backend: s.backend,
LockNameComponents: []string{externalAuditStorageLockName},
TTL: externalAuditStorageLockTTL,
},
}, func(ctx context.Context) error {
if err := notReferencedByEAS(ctx, s.backend, name); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

return trace.Wrap(s.svc.RunWhileLocked(ctx, samlIDPServiceProviderModifyLock, samlIDPServiceProviderModifyLockTTL,
return trace.Wrap(s.svc.RunWhileLocked(ctx, []string{samlIDPServiceProviderModifyLock}, samlIDPServiceProviderModifyLockTTL,
func(ctx context.Context, backend backend.Backend) error {
if err := s.ensureEntityIDIsUnique(ctx, sp); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -181,7 +181,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

return trace.Wrap(s.svc.RunWhileLocked(ctx, samlIDPServiceProviderModifyLock, samlIDPServiceProviderModifyLockTTL,
return trace.Wrap(s.svc.RunWhileLocked(ctx, []string{samlIDPServiceProviderModifyLock}, samlIDPServiceProviderModifyLockTTL,
func(ctx context.Context, backend backend.Backend) error {
if err := s.ensureEntityIDIsUnique(ctx, sp); err != nil {
return trace.Wrap(err)
Expand Down

0 comments on commit 2afbc09

Please sign in to comment.