Skip to content

Commit

Permalink
Move trusted cluster storage out of presence
Browse files Browse the repository at this point in the history
This groups the trusted cluster and cert authority storage into
the same service. While on its own this doesn't change much, it
unlocks the ability to fix racy behavior described in
#36400.
  • Loading branch information
rosstimothy committed Jul 3, 2024
1 parent 268109d commit abafb51
Show file tree
Hide file tree
Showing 15 changed files with 904 additions and 874 deletions.
21 changes: 7 additions & 14 deletions lib/auth/migration/0001_db_ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package migration

import (
"context"
"log/slog"

"github.com/gravitational/trace"

Expand All @@ -33,9 +34,8 @@ import (
// Database CA for all existing clusters that do not already
// have one. Introduced in v10.
type createDBAuthority struct {
trustServiceFn func(b backend.Backend) services.Trust
configServiceFn func(b backend.Backend) (services.ClusterConfiguration, error)
presenceServiceFn func(b backend.Backend) services.Presence
trustServiceFn func(b backend.Backend) services.Trust
configServiceFn func(b backend.Backend) (services.ClusterConfiguration, error)
}

func (d createDBAuthority) Version() int64 {
Expand Down Expand Up @@ -66,25 +66,18 @@ func (d createDBAuthority) Up(ctx context.Context, b backend.Backend) error {
}
}

if d.presenceServiceFn == nil {
d.presenceServiceFn = func(b backend.Backend) services.Presence {
return local.NewPresenceService(b)
}
}

trustSvc := d.trustServiceFn(b)
configSvc, err := d.configServiceFn(b)
if err != nil {
return trace.Wrap(err)
}
presenceSvc := d.presenceServiceFn(b)

localClusterName, err := configSvc.GetClusterName()
if err != nil {
return trace.Wrap(err)
}

trustedClusters, err := presenceSvc.GetTrustedClusters(ctx)
trustedClusters, err := trustSvc.GetTrustedClusters(ctx)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -133,7 +126,7 @@ func migrateDBAuthority(ctx context.Context, trustSvc services.Trust, cluster st
// The migration for this cluster can be skipped since
// the new CA already exists.
if err == nil {
log.Debugf("Migrations: cert authority %q already exists.", toType)
slog.DebugContext(ctx, "Migrations: cert authority already exists.", "authority", toType)
return nil
}
if !trace.IsNotFound(err) {
Expand All @@ -157,7 +150,7 @@ func migrateDBAuthority(ctx context.Context, trustSvc services.Trust, cluster st
return trace.Wrap(err)
}

log.Infof("Migrating %s CA for cluster: %s", toType, cluster)
slog.InfoContext(ctx, "Migrating CA", "authority", toType, "cluster", cluster)

existingCAV2, ok := existingCA.(*types.CertAuthorityV2)
if !ok {
Expand All @@ -175,7 +168,7 @@ func migrateDBAuthority(ctx context.Context, trustSvc services.Trust, cluster st

err = trustSvc.CreateCertAuthority(ctx, newCA)
if trace.IsAlreadyExists(err) {
log.Warnf("%s CA has already been created by a different Auth instance", toType)
slog.WarnContext(ctx, "CA has already been created by a different Auth instance", "authority", toType)
return nil
}
return trace.Wrap(err)
Expand Down
27 changes: 7 additions & 20 deletions lib/auth/migration/0001_db_ca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ func TestDBAuthorityUp(t *testing.T) {
cases := []struct {
name string
fakeTrust *fakeTrust
fakePresence fakePresence
assertion require.ErrorAssertionFunc
validateFunc func(t *testing.T, created []types.CertAuthority)
}{
Expand All @@ -194,8 +193,6 @@ func TestDBAuthorityUp(t *testing.T) {
leaf1DB: fakeCA,
leaf2Host: fakeCA,
},
},
fakePresence: fakePresence{
clusters: []types.TrustedCluster{
&types.TrustedClusterV2{
Kind: types.KindTrustedCluster,
Expand Down Expand Up @@ -227,8 +224,6 @@ func TestDBAuthorityUp(t *testing.T) {
leaf2DB: fakeCA,
leaf1DB: fakeCA,
},
},
fakePresence: fakePresence{
clusters: []types.TrustedCluster{
&types.TrustedClusterV2{
Kind: types.KindTrustedCluster,
Expand All @@ -255,7 +250,6 @@ func TestDBAuthorityUp(t *testing.T) {
b, err := memory.New(memory.Config{EventsOff: true})
require.NoError(t, err)

test.fakePresence.Presence = local.NewPresenceService(b)
test.fakeTrust.Trust = local.NewCAService(b)

migration := createDBAuthority{
Expand All @@ -272,9 +266,6 @@ func TestDBAuthorityUp(t *testing.T) {
clusterName: clusterName("root"),
}, nil
},
presenceServiceFn: func(b backend.Backend) services.Presence {
return test.fakePresence
},
}

test.assertion(t, migration.Up(context.Background(), b))
Expand All @@ -292,22 +283,14 @@ func (f fakeConfig) GetClusterName(opts ...services.MarshalOption) (types.Cluste
return f.clusterName, nil
}

type fakePresence struct {
services.Presence
clusters []types.TrustedCluster
}

func (f fakePresence) GetTrustedClusters(ctx context.Context) ([]types.TrustedCluster, error) {
return f.clusters, nil
}

type fakeTrust struct {
services.Trust

authorities map[types.CertAuthID]types.CertAuthority

mu sync.Mutex
created []types.CertAuthority
clusters []types.TrustedCluster
mu sync.Mutex
created []types.CertAuthority
}

func (f *fakeTrust) casCreated() []types.CertAuthority {
Expand Down Expand Up @@ -336,3 +319,7 @@ func (f *fakeTrust) GetCertAuthority(ctx context.Context, id types.CertAuthID, l

return ca, nil
}

func (f *fakeTrust) GetTrustedClusters(ctx context.Context) ([]types.TrustedCluster, error) {
return f.clusters, nil
}
6 changes: 3 additions & 3 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1403,8 +1403,8 @@ func TestTunnelConnectionsCRUD(t *testing.T) {
require.NoError(t, err)

suite := &suite.ServicesTestSuite{
PresenceS: clt,
Clock: clockwork.NewFakeClock(),
TrustS: clt,
Clock: clockwork.NewFakeClock(),
}
suite.TunnelConnectionsCRUD(t)
}
Expand Down Expand Up @@ -4102,7 +4102,7 @@ func TestEvents(t *testing.T) {
LocalConfigS: testSrv.Auth(),
EventsS: clt,
PresenceS: testSrv.Auth(),
CAS: testSrv.Auth(),
TrustS: testSrv.Auth(),
ProvisioningS: clt,
Access: clt,
UsersS: clt,
Expand Down
2 changes: 1 addition & 1 deletion lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2110,7 +2110,7 @@ func (c *Cache) GetRemoteCluster(ctx context.Context, clusterName string) (types
rg.Release()
// fallback is sane because this method is never used
// in construction of derivative caches.
if rc, err := c.Config.Presence.GetRemoteCluster(ctx, clusterName); err == nil {
if rc, err := c.Config.Trust.GetRemoteCluster(ctx, clusterName); err == nil {
return rc, nil
}
}
Expand Down
16 changes: 8 additions & 8 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1626,16 +1626,16 @@ func TestTunnelConnections(t *testing.T) {
LastHeartbeat: time.Now().UTC(),
})
},
create: modifyNoContext(p.presenceS.UpsertTunnelConnection),
create: modifyNoContext(p.trustS.UpsertTunnelConnection),
list: func(ctx context.Context) ([]types.TunnelConnection, error) {
return p.presenceS.GetTunnelConnections(clusterName)
return p.trustS.GetTunnelConnections(clusterName)
},
cacheList: func(ctx context.Context) ([]types.TunnelConnection, error) {
return p.cache.GetTunnelConnections(clusterName)
},
update: modifyNoContext(p.presenceS.UpsertTunnelConnection),
update: modifyNoContext(p.trustS.UpsertTunnelConnection),
deleteAll: func(ctx context.Context) error {
return p.presenceS.DeleteAllTunnelConnections()
return p.trustS.DeleteAllTunnelConnections()
},
})
}
Expand Down Expand Up @@ -1730,11 +1730,11 @@ func TestRemoteClusters(t *testing.T) {
return types.NewRemoteCluster(name)
},
create: func(ctx context.Context, rc types.RemoteCluster) error {
_, err := p.presenceS.CreateRemoteCluster(ctx, rc)
_, err := p.trustS.CreateRemoteCluster(ctx, rc)
return err
},
list: func(ctx context.Context) ([]types.RemoteCluster, error) {
return p.presenceS.GetRemoteClusters(ctx)
return p.trustS.GetRemoteClusters(ctx)
},
cacheGet: func(ctx context.Context, name string) (types.RemoteCluster, error) {
return p.cache.GetRemoteCluster(ctx, name)
Expand All @@ -1743,11 +1743,11 @@ func TestRemoteClusters(t *testing.T) {
return p.cache.GetRemoteClusters(ctx)
},
update: func(ctx context.Context, rc types.RemoteCluster) error {
_, err := p.presenceS.UpdateRemoteCluster(ctx, rc)
_, err := p.trustS.UpdateRemoteCluster(ctx, rc)
return err
},
deleteAll: func(ctx context.Context) error {
return p.presenceS.DeleteAllRemoteClusters(ctx)
return p.trustS.DeleteAllRemoteClusters(ctx)
},
})
}
Expand Down
26 changes: 13 additions & 13 deletions lib/cache/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,28 +809,28 @@ var _ executor[types.AccessRequest, noReader] = accessRequestExecutor{}
type tunnelConnectionExecutor struct{}

func (tunnelConnectionExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.TunnelConnection, error) {
return cache.Presence.GetAllTunnelConnections()
return cache.Trust.GetAllTunnelConnections()
}

func (tunnelConnectionExecutor) upsert(ctx context.Context, cache *Cache, resource types.TunnelConnection) error {
return cache.presenceCache.UpsertTunnelConnection(resource)
return cache.trustCache.UpsertTunnelConnection(resource)
}

func (tunnelConnectionExecutor) deleteAll(ctx context.Context, cache *Cache) error {
return cache.presenceCache.DeleteAllTunnelConnections()
return cache.trustCache.DeleteAllTunnelConnections()
}

func (tunnelConnectionExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
return cache.presenceCache.DeleteTunnelConnection(resource.GetSubKind(), resource.GetName())
return cache.trustCache.DeleteTunnelConnection(resource.GetSubKind(), resource.GetName())
}

func (tunnelConnectionExecutor) isSingleton() bool { return false }

func (tunnelConnectionExecutor) getReader(cache *Cache, cacheOK bool) tunnelConnectionGetter {
if cacheOK {
return cache.presenceCache
return cache.trustCache
}
return cache.Config.Presence
return cache.Config.Trust
}

type tunnelConnectionGetter interface {
Expand All @@ -843,36 +843,36 @@ var _ executor[types.TunnelConnection, tunnelConnectionGetter] = tunnelConnectio
type remoteClusterExecutor struct{}

func (remoteClusterExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.RemoteCluster, error) {
return cache.Presence.GetRemoteClusters(ctx)
return cache.Trust.GetRemoteClusters(ctx)
}

func (remoteClusterExecutor) upsert(ctx context.Context, cache *Cache, resource types.RemoteCluster) error {
err := cache.presenceCache.DeleteRemoteCluster(ctx, resource.GetName())
err := cache.trustCache.DeleteRemoteCluster(ctx, resource.GetName())
if err != nil {
if !trace.IsNotFound(err) {
cache.Logger.WithError(err).Warnf("Failed to delete remote cluster %v.", resource.GetName())
return trace.Wrap(err)
}
}
_, err = cache.presenceCache.CreateRemoteCluster(ctx, resource)
_, err = cache.trustCache.CreateRemoteCluster(ctx, resource)
return trace.Wrap(err)
}

func (remoteClusterExecutor) deleteAll(ctx context.Context, cache *Cache) error {
return cache.presenceCache.DeleteAllRemoteClusters(ctx)
return cache.trustCache.DeleteAllRemoteClusters(ctx)
}

func (remoteClusterExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
return cache.presenceCache.DeleteRemoteCluster(ctx, resource.GetName())
return cache.trustCache.DeleteRemoteCluster(ctx, resource.GetName())
}

func (remoteClusterExecutor) isSingleton() bool { return false }

func (remoteClusterExecutor) getReader(cache *Cache, cacheOK bool) remoteClusterGetter {
if cacheOK {
return cache.presenceCache
return cache.trustCache
}
return cache.Config.Presence
return cache.Config.Trust
}

type remoteClusterGetter interface {
Expand Down
Loading

0 comments on commit abafb51

Please sign in to comment.