From 63bcaaf900c35da568e71dfd153684dbc5067cb0 Mon Sep 17 00:00:00 2001 From: Forrest <30576607+fspmarshall@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:22:20 -0700 Subject: [PATCH] switch trusted/remote cluster management to atomic write (#48094) --- api/types/trustedcluster.go | 8 + lib/auth/trustedcluster.go | 294 +++++++---------- lib/auth/trustedcluster_test.go | 17 +- lib/services/local/trust.go | 529 ++++++++++++++++++++++++++----- lib/services/local/trust_test.go | 330 ++++++++++++++++++- lib/services/trust.go | 20 ++ 6 files changed, 914 insertions(+), 284 deletions(-) diff --git a/api/types/trustedcluster.go b/api/types/trustedcluster.go index 7e233c864c826..27d8129f70cfe 100644 --- a/api/types/trustedcluster.go +++ b/api/types/trustedcluster.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils" ) // TrustedCluster holds information needed for a cluster that can not be directly @@ -60,6 +62,8 @@ type TrustedCluster interface { SetReverseTunnelAddress(string) // CanChangeStateTo checks the TrustedCluster can transform into another. CanChangeStateTo(TrustedCluster) error + // Clone returns a deep copy of the TrustedCluster. + Clone() TrustedCluster } // NewTrustedCluster is a convenience way to create a TrustedCluster resource. @@ -259,6 +263,10 @@ func (c *TrustedClusterV2) CanChangeStateTo(t TrustedCluster) error { return nil } +func (c *TrustedClusterV2) Clone() TrustedCluster { + return utils.CloneProtoMsg(c) +} + // String represents a human readable version of trusted cluster settings. func (c *TrustedClusterV2) String() string { return fmt.Sprintf("TrustedCluster(Enabled=%v,Roles=%v,Token=%v,ProxyAddress=%v,ReverseTunnelAddress=%v)", diff --git a/lib/auth/trustedcluster.go b/lib/auth/trustedcluster.go index bd68f9d10832e..acbc46dc4f281 100644 --- a/lib/auth/trustedcluster.go +++ b/lib/auth/trustedcluster.go @@ -45,129 +45,115 @@ import ( ) // UpsertTrustedCluster creates or toggles a Trusted Cluster relationship. -func (a *Server) UpsertTrustedCluster(ctx context.Context, trustedCluster types.TrustedCluster) (newTrustedCluster types.TrustedCluster, returnErr error) { +func (a *Server) UpsertTrustedCluster(ctx context.Context, tc types.TrustedCluster) (newTrustedCluster types.TrustedCluster, returnErr error) { + // verify that trusted cluster role map does not reference non-existent roles + if err := a.checkLocalRoles(ctx, tc.GetRoleMap()); err != nil { + return nil, trace.Wrap(err) + } + // It is recommended to omit trusted cluster name because the trusted cluster name // is updated to the roots cluster name during the handshake with the root cluster. var existingCluster types.TrustedCluster - if trustedCluster.GetName() != "" { + var cas []types.CertAuthority + if tc.GetName() != "" { var err error - existingCluster, err = a.GetTrustedCluster(ctx, trustedCluster.GetName()) + existingCluster, err = a.GetTrustedCluster(ctx, tc.GetName()) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) } } - enable := trustedCluster.GetEnabled() - - // If the trusted cluster already exists in the backend, make sure it's a - // valid state change client is trying to make. - if existingCluster != nil { - if err := existingCluster.CanChangeStateTo(trustedCluster); err != nil { - return nil, trace.Wrap(err) - } + // if there is no existing cluster, switch to the create case + if existingCluster == nil { + return a.createTrustedCluster(ctx, tc) } - logger := log.WithField("trusted_cluster", trustedCluster.GetName()) + if err := existingCluster.CanChangeStateTo(tc); err != nil { + return nil, trace.Wrap(err) + } - // change state - if err := a.checkLocalRoles(ctx, trustedCluster.GetRoleMap()); err != nil { + // always load all current CAs. even if we aren't changing them as part of + // this function, Services.UpdateTrustedCluster will only correctly activate/deactivate + // CAs that are explicitly passed to it. note that we pass in the existing cluster state + // since where CAs are stored depends on the current state of the trusted cluster. + cas, err := a.getCAsForTrustedCluster(ctx, existingCluster) + if err != nil { return nil, trace.Wrap(err) } - // Update role map - if existingCluster != nil && !existingCluster.GetRoleMap().IsEqual(trustedCluster.GetRoleMap()) { - if err := a.UpdateUserCARoleMap(ctx, existingCluster.GetName(), trustedCluster.GetRoleMap(), - existingCluster.GetEnabled()); err != nil { - return nil, trace.Wrap(err) - } + // propagate any role map changes to cas + configureCAsForTrustedCluster(tc, cas) - // Reset previous UserCA role map if this func fails later on - defer func() { - if returnErr != nil { - if err := a.UpdateUserCARoleMap(ctx, trustedCluster.GetName(), existingCluster.GetRoleMap(), - trustedCluster.GetEnabled()); err != nil { - returnErr = trace.NewAggregate(err, returnErr) - } - } - }() - } - // Create or update state - switch { - case existingCluster != nil && enable == true: - if existingCluster.GetEnabled() { - break - } - log.Debugf("Enabling existing Trusted Cluster relationship.") + // state transition is valid, set the expected revision + tc.SetRevision(existingCluster.GetRevision()) - if err := a.activateCertAuthority(ctx, trustedCluster); err != nil { - if trace.IsNotFound(err) { - return nil, trace.BadParameter("enable only supported for Trusted Clusters created with Teleport 2.3 and above") - } - return nil, trace.Wrap(err) - } + revision, err := a.Services.UpdateTrustedCluster(ctx, tc, cas) + if err != nil { + return nil, trace.Wrap(err) + } - if err := a.createReverseTunnel(ctx, trustedCluster); err != nil { - return nil, trace.Wrap(err) - } - case existingCluster != nil && enable == false: - if !existingCluster.GetEnabled() { - break - } - log.Debugf("Disabling existing Trusted Cluster relationship.") + tc.SetRevision(revision) - if err := a.deactivateCertAuthority(ctx, trustedCluster); err != nil { - if trace.IsNotFound(err) { - return nil, trace.BadParameter("enable only supported for Trusted Clusters created with Teleport 2.3 and above") - } - return nil, trace.Wrap(err) - } + if err := a.onTrustedClusterWrite(ctx, tc); err != nil { + return nil, trace.Wrap(err) + } - if err := a.DeleteReverseTunnel(ctx, trustedCluster.GetName()); err != nil { - return nil, trace.Wrap(err) - } - case existingCluster == nil && enable == true: - logger.Info("Creating enabled Trusted Cluster relationship.") + return tc, nil +} - remoteCAs, err := a.establishTrust(ctx, trustedCluster) - if err != nil { - return nil, trace.Wrap(err) - } +func (a *Server) createTrustedCluster(ctx context.Context, tc types.TrustedCluster) (types.TrustedCluster, error) { + remoteCAs, err := a.establishTrust(ctx, tc) + if err != nil { + return nil, trace.Wrap(err) + } - // Force name of the trusted cluster resource - // to be equal to the name of the remote cluster it is connecting to. - trustedCluster.SetName(remoteCAs[0].GetClusterName()) + // Force name to the name of the trusted cluster. + tc.SetName(remoteCAs[0].GetClusterName()) - if err := a.addCertAuthorities(ctx, trustedCluster, remoteCAs); err != nil { - return nil, trace.Wrap(err) - } + // perform some configuration on the remote CAs + configureCAsForTrustedCluster(tc, remoteCAs) - if err := a.createReverseTunnel(ctx, trustedCluster); err != nil { - return nil, trace.Wrap(err) - } + // atomically create trusted cluster and cert authorities + revision, err := a.Services.CreateTrustedCluster(ctx, tc, remoteCAs) + if err != nil { + return nil, trace.Wrap(err) + } - case existingCluster == nil && enable == false: - logger.Info("Creating disabled Trusted Cluster relationship.") + tc.SetRevision(revision) - remoteCAs, err := a.establishTrust(ctx, trustedCluster) - if err != nil { - return nil, trace.Wrap(err) - } + if err := a.onTrustedClusterWrite(ctx, tc); err != nil { + return nil, trace.Wrap(err) + } - // Force name to the name of the trusted cluster. - trustedCluster.SetName(remoteCAs[0].GetClusterName()) + return tc, nil +} - if err := a.addCertAuthorities(ctx, trustedCluster, remoteCAs); err != nil { - return nil, trace.Wrap(err) - } +// configureCAsForTrustedCluster modifies remote CAs for use as trusted cluster CAs. +func configureCAsForTrustedCluster(tc types.TrustedCluster, cas []types.CertAuthority) { + // modify the remote CAs for use as tc cas. + for _, ca := range cas { + // change the name of the remote ca to the name of the trusted cluster. + ca.SetName(tc.GetName()) - if err := a.deactivateCertAuthority(ctx, trustedCluster); err != nil { - return nil, trace.Wrap(err) + // wipe out roles sent from the remote cluster and set roles from the trusted cluster + ca.SetRoles(nil) + if ca.GetType() == types.UserCA { + for _, r := range tc.GetRoles() { + ca.AddRole(r) + } + ca.SetRoleMap(tc.GetRoleMap()) } } +} - tc, err := a.Services.UpsertTrustedCluster(ctx, trustedCluster) - if err != nil { - return nil, trace.Wrap(err) +func (a *Server) onTrustedClusterWrite(ctx context.Context, tc types.TrustedCluster) error { + var cerr error + if tc.GetEnabled() { + cerr = a.createReverseTunnel(ctx, tc) + } else { + if err := a.DeleteReverseTunnel(ctx, tc.GetName()); err != nil && !trace.IsNotFound(err) { + cerr = err + } } if err := a.emitter.EmitAuditEvent(ctx, &apievents.TrustedClusterCreate{ @@ -177,14 +163,14 @@ func (a *Server) UpsertTrustedCluster(ctx context.Context, trustedCluster types. }, UserMetadata: authz.ClientUserMetadata(ctx), ResourceMetadata: apievents.ResourceMetadata{ - Name: trustedCluster.GetName(), + Name: tc.GetName(), }, ConnectionMetadata: authz.ConnectionMetadata(ctx), }); err != nil { - logger.WithError(err).Warn("Failed to emit trusted cluster create event.") + a.logger.WarnContext(ctx, "failed to emit trusted cluster create event", "error", err) } - return tc, nil + return trace.Wrap(cerr) } func (a *Server) checkLocalRoles(ctx context.Context, roleMap types.RoleMap) error { @@ -207,6 +193,29 @@ func (a *Server) checkLocalRoles(ctx context.Context, roleMap types.RoleMap) err return nil } +func (a *Server) getCAsForTrustedCluster(ctx context.Context, tc types.TrustedCluster) ([]types.CertAuthority, error) { + var cas []types.CertAuthority + // not all CA types are present for trusted clusters, but there isn't a meaningful downside to + // just grabbing everything. + for _, caType := range types.CertAuthTypes { + var ca types.CertAuthority + var err error + if tc.GetEnabled() { + ca, err = a.GetCertAuthority(ctx, types.CertAuthID{Type: caType, DomainName: tc.GetName()}, false) + } else { + ca, err = a.GetInactiveCertAuthority(ctx, types.CertAuthID{Type: caType, DomainName: tc.GetName()}, false) + } + if err != nil { + if trace.IsNotFound(err) { + continue + } + return nil, trace.Wrap(err) + } + cas = append(cas, ca) + } + return cas, nil +} + // DeleteTrustedCluster removes types.CertAuthority, services.ReverseTunnel, // and services.TrustedCluster resources. func (a *Server) DeleteTrustedCluster(ctx context.Context, name string) error { @@ -229,7 +238,7 @@ func (a *Server) DeleteTrustedCluster(ctx context.Context, name string) error { }) } - if err := a.DeleteCertAuthorities(ctx, ids...); err != nil { + if err := a.Services.DeleteTrustedClusterInternal(ctx, name, ids); err != nil { return trace.Wrap(err) } @@ -239,10 +248,6 @@ func (a *Server) DeleteTrustedCluster(ctx context.Context, name string) error { } } - if err := a.Services.DeleteTrustedCluster(ctx, name); err != nil { - return trace.Wrap(err) - } - if err := a.emitter.EmitAuditEvent(ctx, &apievents.TrustedClusterDelete{ Metadata: apievents.Metadata{ Type: events.TrustedClusterDeleteEvent, @@ -324,54 +329,30 @@ func (a *Server) establishTrust(ctx context.Context, trustedCluster types.Truste return validateResponse.CAs, nil } -func (a *Server) addCertAuthorities(ctx context.Context, trustedCluster types.TrustedCluster, remoteCAs []types.CertAuthority) error { - // the remote auth server has verified our token. add the - // remote certificate authority to our backend - for _, remoteCertAuthority := range remoteCAs { - // change the name of the remote ca to the name of the trusted cluster - remoteCertAuthority.SetName(trustedCluster.GetName()) - - // wipe out roles sent from the remote cluster and set roles from the trusted cluster - remoteCertAuthority.SetRoles(nil) - if remoteCertAuthority.GetType() == types.UserCA { - for _, r := range trustedCluster.GetRoles() { - remoteCertAuthority.AddRole(r) - } - remoteCertAuthority.SetRoleMap(trustedCluster.GetRoleMap()) - } - } - - // we use create here instead of upsert to prevent people from wiping out - // their own ca if it has the same name as the remote ca - _, err := a.CreateCertAuthorities(ctx, remoteCAs...) - return trace.Wrap(err) -} - // DeleteRemoteCluster deletes remote cluster resource, all certificate authorities // associated with it -func (a *Server) DeleteRemoteCluster(ctx context.Context, clusterName string) error { - // To make sure remote cluster exists - to protect against random - // clusterName requests (e.g. when clusterName is set to local cluster name) - if _, err := a.GetRemoteCluster(ctx, clusterName); err != nil { +func (a *Server) DeleteRemoteCluster(ctx context.Context, name string) error { + cn, err := a.GetClusterName() + if err != nil { return trace.Wrap(err) } + // This check ensures users are not deleting their root/own cluster. + if cn.GetClusterName() == name { + return trace.BadParameter("remote cluster %q is the name of this root cluster and cannot be removed.", name) + } + // we only expect host CAs to be present for remote clusters, but it doesn't hurt // to err on the side of paranoia and delete all CA types. var ids []types.CertAuthID for _, caType := range types.CertAuthTypes { ids = append(ids, types.CertAuthID{ Type: caType, - DomainName: clusterName, + DomainName: name, }) } - // delete cert authorities associated with the cluster - if err := a.DeleteCertAuthorities(ctx, ids...); err != nil { - return trace.Wrap(err) - } - - return trace.Wrap(a.Services.DeleteRemoteCluster(ctx, clusterName)) + return trace.Wrap(a.Services.DeleteRemoteClusterInternal(ctx, name, ids)) } // GetRemoteCluster returns remote cluster by name @@ -497,12 +478,6 @@ func (a *Server) validateTrustedCluster(ctx context.Context, validateRequest *au if remoteClusterName == domainName { return nil, trace.AccessDenied("remote cluster has same name as this cluster: %v", domainName) } - _, err = a.GetTrustedCluster(ctx, remoteClusterName) - if err == nil { - return nil, trace.AccessDenied("remote cluster has same name as trusted cluster: %v", remoteClusterName) - } else if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } remoteCluster, err := types.NewRemoteCluster(remoteClusterName) if err != nil { @@ -522,15 +497,8 @@ func (a *Server) validateTrustedCluster(ctx context.Context, validateRequest *au } remoteCluster.SetConnectionStatus(teleport.RemoteClusterStatusOffline) - _, err = a.CreateRemoteCluster(ctx, remoteCluster) - if err != nil { - if !trace.IsAlreadyExists(err) { - return nil, trace.Wrap(err) - } - } - - err = a.UpsertCertAuthority(ctx, remoteCA) - if err != nil { + _, err = a.CreateRemoteClusterInternal(ctx, remoteCluster, []types.CertAuthority{remoteCA}) + if err != nil && !trace.IsAlreadyExists(err) { return nil, trace.Wrap(err) } @@ -641,36 +609,6 @@ func (a *Server) sendValidateRequestToProxy(host string, validateRequest *authcl return validateResponse, nil } -// activateCertAuthority will activate both the user and host certificate -// authority given in the services.TrustedCluster resource. -func (a *Server) activateCertAuthority(ctx context.Context, t types.TrustedCluster) error { - return trace.Wrap(a.ActivateCertAuthorities(ctx, []types.CertAuthID{ - { - Type: types.UserCA, - DomainName: t.GetName(), - }, - { - Type: types.HostCA, - DomainName: t.GetName(), - }, - }...)) -} - -// deactivateCertAuthority will deactivate both the user and host certificate -// authority given in the services.TrustedCluster resource. -func (a *Server) deactivateCertAuthority(ctx context.Context, t types.TrustedCluster) error { - return trace.Wrap(a.DeactivateCertAuthorities(ctx, []types.CertAuthID{ - { - Type: types.UserCA, - DomainName: t.GetName(), - }, - { - Type: types.HostCA, - DomainName: t.GetName(), - }, - }...)) -} - // createReverseTunnel will create a services.ReverseTunnel givenin the // services.TrustedCluster resource. func (a *Server) createReverseTunnel(ctx context.Context, t types.TrustedCluster) error { diff --git a/lib/auth/trustedcluster_test.go b/lib/auth/trustedcluster_test.go index ba7ffac769b62..f1581dbc64fee 100644 --- a/lib/auth/trustedcluster_test.go +++ b/lib/auth/trustedcluster_test.go @@ -469,22 +469,11 @@ func TestUpsertTrustedCluster(t *testing.T) { }) require.NoError(t, err) - leafClusterCA := types.CertAuthority(suite.NewTestCA(types.HostCA, "trustedcluster")) - _, err = a.validateTrustedCluster(ctx, &authclient.ValidateTrustedClusterRequest{ - Token: validToken, - CAs: []types.CertAuthority{leafClusterCA}, - TeleportVersion: teleport.Version, - }) - require.NoError(t, err) - - _, err = a.Services.UpsertTrustedCluster(ctx, trustedCluster) - require.NoError(t, err) - ca := suite.NewTestCA(types.UserCA, "trustedcluster") - err = a.addCertAuthorities(ctx, trustedCluster, []types.CertAuthority{ca}) - require.NoError(t, err) - err = a.UpsertCertAuthority(ctx, ca) + configureCAsForTrustedCluster(trustedCluster, []types.CertAuthority{ca}) + + _, err = a.Services.CreateTrustedCluster(ctx, trustedCluster, []types.CertAuthority{ca}) require.NoError(t, err) err = a.createReverseTunnel(ctx, trustedCluster) diff --git a/lib/services/local/trust.go b/lib/services/local/trust.go index 72d2979dba675..2a2e454cdcb19 100644 --- a/lib/services/local/trust.go +++ b/lib/services/local/trust.go @@ -20,7 +20,6 @@ package local import ( "context" - "encoding/json" "errors" "log/slog" "slices" @@ -67,44 +66,164 @@ func (s *CA) CreateCertAuthority(ctx context.Context, ca types.CertAuthority) er // CreateCertAuthorities creates multiple cert authorities atomically. func (s *CA) CreateCertAuthorities(ctx context.Context, cas ...types.CertAuthority) (revision string, err error) { - var condacts []backend.ConditionalAction - var clusterNames []string - for _, ca := range cas { - if !slices.Contains(clusterNames, ca.GetName()) { - clusterNames = append(clusterNames, ca.GetName()) + condacts, err := createCertAuthoritiesCondActs(cas, true /* active */) + if err != nil { + return "", trace.Wrap(err) + } + + rev, err := s.AtomicWrite(ctx, condacts) + if err != nil { + if errors.Is(err, backend.ErrConditionFailed) { + var clusterNames []string + for _, ca := range cas { + if slices.Contains(clusterNames, ca.GetClusterName()) { + continue + } + clusterNames = append(clusterNames, ca.GetClusterName()) + } + return "", trace.AlreadyExists("one or more CAs from cluster(s) %q already exist", strings.Join(clusterNames, ",")) } + return "", trace.Wrap(err) + } + + return rev, nil +} + +// createCertAuthoritiesCondActs sets up conditional actions for creating a set of CAs. +func createCertAuthoritiesCondActs(cas []types.CertAuthority, active bool) ([]backend.ConditionalAction, error) { + condacts := make([]backend.ConditionalAction, 0, len(cas)*2) + for _, ca := range cas { if err := services.ValidateCertAuthority(ca); err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } item, err := caToItem(backend.Key{}, ca) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } - condacts = append(condacts, []backend.ConditionalAction{ - { - Key: activeCAKey(ca.GetID()), - Condition: backend.NotExists(), - Action: backend.Put(item), - }, - { - Key: inactiveCAKey(ca.GetID()), - Condition: backend.Whatever(), - Action: backend.Delete(), - }, - }...) + if active { + // for an enabled tc, we perform a conditional create for the active CA key + // and an unconditional delete for the inactive CA key since the active range + // is given priority over the inactive range. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: activeCAKey(ca.GetID()), + Condition: backend.NotExists(), + Action: backend.Put(item), + }, + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.Whatever(), + Action: backend.Delete(), + }, + }...) + } else { + // for a disabled tc, we perform a conditional create for the inactive CA key + // and assert the non-existence of the active CA key. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.NotExists(), + Action: backend.Put(item), + }, + { + Key: activeCAKey(ca.GetID()), + Condition: backend.NotExists(), + Action: backend.Nop(), + }, + }...) + } } - rev, err := s.AtomicWrite(ctx, condacts) - if err != nil { - if errors.Is(err, backend.ErrConditionFailed) { - return "", trace.AlreadyExists("one or more CAs from cluster(s) %q already exist", strings.Join(clusterNames, ",")) + return condacts, nil +} + +func updateCertAuthoritiesCondActs(cas []types.CertAuthority, active bool, currentlyActive bool) ([]backend.ConditionalAction, error) { + condacts := make([]backend.ConditionalAction, 0, len(cas)*2) + for _, ca := range cas { + if err := services.ValidateCertAuthority(ca); err != nil { + return nil, trace.Wrap(err) + } + + item, err := caToItem(backend.Key{}, ca) + if err != nil { + return nil, trace.Wrap(err) + } + + if active { + if currentlyActive { + // we are updating an active CA without changing its active status. we want to perform + // a conditional update on the acitve CA key and an unconditonal delete on the inactive + // CA key in order to correctly model active range priority. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: activeCAKey(ca.GetID()), + Condition: backend.Revision(item.Revision), + Action: backend.Put(item), + }, + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.Whatever(), + Action: backend.Delete(), + }, + }...) + } else { + // we are updating a currently inactive CA to the active state. we want to perform + // a create on the active CA key and a revision-conditional delete on the inactive CA key + // to affect a "move-and-update" that respects the active range priority. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: activeCAKey(ca.GetID()), + Condition: backend.NotExists(), + Action: backend.Put(item), + }, + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.Revision(item.Revision), + Action: backend.Delete(), + }, + }...) + } + } else { + if currentlyActive { + // we are updating an active CA to the inactive state. we want to perform a conditional + // delete on the active CA key and an unconditional put on the inactive CA key to + // affect a "move-and-update" that respects the active range priority. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: activeCAKey(ca.GetID()), + Condition: backend.Revision(item.Revision), + Action: backend.Delete(), + }, + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.Whatever(), + Action: backend.Put(item), + }, + }...) + + } else { + // we are updating an inactive CA without changing its active status. we want to perform + // a conditional update on the inactive CA key and assert the non-existence of the active + // CA key. + condacts = append(condacts, []backend.ConditionalAction{ + { + Key: inactiveCAKey(ca.GetID()), + Condition: backend.Revision(item.Revision), + Action: backend.Put(item), + }, + { + Key: activeCAKey(ca.GetID()), + Condition: backend.NotExists(), + Action: backend.Nop(), + }, + }...) + } } - return "", trace.Wrap(err) } - return rev, nil + return condacts, nil } // UpsertCertAuthority updates or inserts a new certificate authority @@ -198,10 +317,15 @@ func (s *CA) DeleteCertAuthority(ctx context.Context, id types.CertAuthID) error // DeleteCertAuthorities deletes multiple cert authorities atomically. func (s *CA) DeleteCertAuthorities(ctx context.Context, ids ...types.CertAuthID) error { + _, err := s.AtomicWrite(ctx, s.deleteCertAuthoritiesCondActs(ids)) + return trace.Wrap(err) +} + +func (s *CA) deleteCertAuthoritiesCondActs(ids []types.CertAuthID) []backend.ConditionalAction { var condacts []backend.ConditionalAction for _, id := range ids { if err := id.Check(); err != nil { - return trace.Wrap(err) + continue } for _, key := range []backend.Key{activeCAKey(id), inactiveCAKey(id)} { condacts = append(condacts, backend.ConditionalAction{ @@ -211,9 +335,7 @@ func (s *CA) DeleteCertAuthorities(ctx context.Context, ids ...types.CertAuthID) }) } } - - _, err := s.AtomicWrite(ctx, condacts) - return trace.Wrap(err) + return condacts } // ActivateCertAuthority moves a CertAuthority from the deactivated list to @@ -325,10 +447,26 @@ func (s *CA) DeactivateCertAuthorities(ctx context.Context, ids ...types.CertAut // GetCertAuthority returns certificate authority by given id. Parameter loadSigningKeys // controls if signing keys are loaded func (s *CA) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool) (types.CertAuthority, error) { + return s.getCertAuthority(ctx, id, loadSigningKeys, true /* active */) +} + +// GetInactiveCertAuthority returns inactive certificate authority by given id. Parameter loadSigningKeys +// controls if signing keys are loaded. +func (s *CA) GetInactiveCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool) (types.CertAuthority, error) { + return s.getCertAuthority(ctx, id, loadSigningKeys, false /* inactive */) +} + +func (s *CA) getCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool, active bool) (types.CertAuthority, error) { if err := id.Check(); err != nil { return nil, trace.Wrap(err) } - item, err := s.Get(ctx, activeCAKey(id)) + + key := activeCAKey(id) + if !active { + key = inactiveCAKey(id) + } + + item, err := s.Get(ctx, key) if err != nil { return nil, trace.Wrap(err) } @@ -425,25 +563,135 @@ func (s *CA) UpdateUserCARoleMap(ctx context.Context, name string, roleMap types return nil } +// CreateTrustedCluster atomically creates a new trusted cluster along with associated resources. +func (s *CA) CreateTrustedCluster(ctx context.Context, tc types.TrustedCluster, cas []types.CertAuthority) (revision string, err error) { + if err := services.ValidateTrustedCluster(tc); err != nil { + return "", trace.Wrap(err) + } + + item, err := trustedClusterToItem(tc) + if err != nil { + return "", trace.Wrap(err) + } + + condacts := []backend.ConditionalAction{ + { + Key: item.Key, + Condition: backend.NotExists(), + Action: backend.Put(item), + }, + // also assert that no remote cluster exists by this name, as + // we currently do not allow for a trusted cluster and remote + // cluster to share a name (CAs end up stored at the same location). + { + Key: remoteClusterKey(tc.GetName()), + Condition: backend.NotExists(), + Action: backend.Nop(), + }, + } + + // perform some initial trusted-cluster related validation. common ca validation is handled later + // on by the createCertAuthoritiesCondActs helper. + for _, ca := range cas { + if tc.GetName() != ca.GetClusterName() { + return "", trace.BadParameter("trusted cluster name %q does not match CA cluster name %q", tc.GetName(), ca.GetClusterName()) + } + } + + ccas, err := createCertAuthoritiesCondActs(cas, tc.GetEnabled()) + if err != nil { + return "", trace.Wrap(err) + } + + condacts = append(condacts, ccas...) + + rev, err := s.AtomicWrite(ctx, condacts) + if err != nil { + if errors.Is(err, backend.ErrConditionFailed) { + if _, err := s.GetRemoteCluster(ctx, tc.GetName()); err == nil { + return "", trace.BadParameter("cannot create trusted cluster with same name as remote cluster %q, bidirectional trust is not supported", tc.GetName()) + } + + return "", trace.AlreadyExists("trusted cluster %q and/or one or more of its cert authorities already exists", tc.GetName()) + } + return "", trace.Wrap(err) + } + + return rev, nil +} + +// UpdateTrustedCluster atomically updates a trusted cluster along with associated resources. +func (s *CA) UpdateTrustedCluster(ctx context.Context, tc types.TrustedCluster, cas []types.CertAuthority) (revision string, err error) { + if err := services.ValidateTrustedCluster(tc); err != nil { + return "", trace.Wrap(err) + } + + // fetch the current state. we'll need this later on to correctly construct our CA condacts, and + // it doesn't hurt to reject mismatched revisions early. + extant, err := s.GetTrustedCluster(ctx, tc.GetName()) + if err != nil { + return "", trace.Wrap(err) + } + + if tc.GetRevision() != extant.GetRevision() { + return "", trace.CompareFailed("trusted cluster %q has been modified, please retry", tc.GetName()) + } + + item, err := trustedClusterToItem(tc) + if err != nil { + return "", trace.Wrap(err) + } + + condacts := []backend.ConditionalAction{ + { + Key: item.Key, + Condition: backend.Revision(item.Revision), + Action: backend.Put(item), + }, + } + + // perform some initial trusted-cluster related validation. common ca validation is handled later + // on by the createCertAuthoritiesCondActs helper. + for _, ca := range cas { + if tc.GetName() != ca.GetClusterName() { + return "", trace.BadParameter("trusted cluster name %q does not match CA cluster name %q", tc.GetName(), ca.GetClusterName()) + } + } + + ccas, err := updateCertAuthoritiesCondActs(cas, tc.GetEnabled(), extant.GetEnabled()) + if err != nil { + return "", trace.Wrap(err) + } + + condacts = append(condacts, ccas...) + + rev, err := s.AtomicWrite(ctx, condacts) + if err != nil { + if errors.Is(err, backend.ErrConditionFailed) { + return "", trace.CompareFailed("trusted cluster %q and/or one or more of its cert authorities have been modified, please retry", tc.GetName()) + } + return "", trace.Wrap(err) + } + + return rev, nil +} + // UpsertTrustedCluster creates or updates a TrustedCluster in the backend. func (s *CA) UpsertTrustedCluster(ctx context.Context, trustedCluster types.TrustedCluster) (types.TrustedCluster, error) { if err := services.ValidateTrustedCluster(trustedCluster); err != nil { return nil, trace.Wrap(err) } - rev := trustedCluster.GetRevision() - value, err := services.MarshalTrustedCluster(trustedCluster) + + item, err := trustedClusterToItem(trustedCluster) if err != nil { return nil, trace.Wrap(err) } - _, err = s.Put(ctx, backend.Item{ - Key: backend.NewKey(trustedClustersPrefix, trustedCluster.GetName()), - Value: value, - Expires: trustedCluster.Expiry(), - Revision: rev, - }) + + _, err = s.Put(ctx, item) if err != nil { return nil, trace.Wrap(err) } + return trustedCluster, nil } @@ -482,16 +730,44 @@ func (s *CA) GetTrustedClusters(ctx context.Context) ([]types.TrustedCluster, er // DeleteTrustedCluster removes a TrustedCluster from the backend by name. func (s *CA) DeleteTrustedCluster(ctx context.Context, name string) error { + return s.DeleteTrustedClusterInternal(ctx, name, nil /* no cert authorities */) +} + +// DeleteTrustedClusterInternal removes a trusted cluster and associated resources atomically. +func (s *CA) DeleteTrustedClusterInternal(ctx context.Context, name string, caIDs []types.CertAuthID) error { if name == "" { return trace.BadParameter("missing trusted cluster name") } - err := s.Delete(ctx, backend.NewKey(trustedClustersPrefix, name)) - if err != nil { - if trace.IsNotFound(err) { + + for _, id := range caIDs { + if err := id.Check(); err != nil { + return trace.Wrap(err) + } + + if id.DomainName != name { + return trace.BadParameter("ca %q does not belong to trusted cluster %q", id.DomainName, name) + } + } + + condacts := []backend.ConditionalAction{ + { + Key: backend.NewKey(trustedClustersPrefix, name), + Condition: backend.Exists(), + Action: backend.Delete(), + }, + } + + condacts = append(condacts, s.deleteCertAuthoritiesCondActs(caIDs)...) + + if _, err := s.AtomicWrite(ctx, condacts); err != nil { + if errors.Is(err, backend.ErrConditionFailed) { return trace.NotFound("trusted cluster %q is not found", name) } + + return trace.Wrap(err) } - return trace.Wrap(err) + + return nil } // UpsertTunnelConnection updates or creates tunnel connection @@ -608,25 +884,71 @@ func (s *CA) DeleteAllTunnelConnections() error { return trace.Wrap(err) } -// CreateRemoteCluster creates remote cluster -func (s *CA) CreateRemoteCluster( - ctx context.Context, rc types.RemoteCluster, -) (types.RemoteCluster, error) { - value, err := json.Marshal(rc) +// CreateRemoteCluster creates a remote cluster +func (s *CA) CreateRemoteCluster(ctx context.Context, rc types.RemoteCluster) (types.RemoteCluster, error) { + rev, err := s.CreateRemoteClusterInternal(ctx, rc, nil) if err != nil { return nil, trace.Wrap(err) } - item := backend.Item{ - Key: backend.NewKey(remoteClustersPrefix, rc.GetName()), - Value: value, - Expires: rc.Expiry(), + + rc.SetRevision(rev) + return rc, nil +} + +// CreateRemoteCluster atomically creates a new remote cluster along with associated resources. +func (s *CA) CreateRemoteClusterInternal(ctx context.Context, rc types.RemoteCluster, cas []types.CertAuthority) (revision string, err error) { + if err := services.CheckAndSetDefaults(rc); err != nil { + return "", trace.Wrap(err) } - lease, err := s.Create(ctx, item) + + item, err := remoteClusterToItem(rc) if err != nil { - return nil, trace.Wrap(err) + return "", trace.Wrap(err) } - rc.SetRevision(lease.Revision) - return rc, nil + + condacts := []backend.ConditionalAction{ + { + Key: item.Key, + Condition: backend.NotExists(), + Action: backend.Put(item), + }, + // also assert that no trusted cluster exists by this name, as + // we currently do not allow for a trusted cluster and remote + // cluster to share a name (CAs end up stored at the same location). + { + Key: trustedClusterKey(rc.GetName()), + Condition: backend.NotExists(), + Action: backend.Nop(), + }, + } + + // perform some initial remote-cluster related validation. common ca validation is handled later + // on by the createCertAuthoritiesCondActs helper. + for _, ca := range cas { + if rc.GetName() != ca.GetClusterName() { + return "", trace.BadParameter("remote cluster name %q does not match CA cluster name %q", rc.GetName(), ca.GetClusterName()) + } + } + + ccas, err := createCertAuthoritiesCondActs(cas, true /* remote cluster cas always considered active */) + if err != nil { + return "", trace.Wrap(err) + } + + condacts = append(condacts, ccas...) + + rev, err := s.AtomicWrite(ctx, condacts) + if err != nil { + if errors.Is(err, backend.ErrConditionFailed) { + if _, err := s.GetTrustedCluster(ctx, rc.GetName()); err == nil { + return "", trace.BadParameter("cannot create remote cluster with same name as trusted cluster %q, bidirectional trust is not supported", rc.GetName()) + } + return "", trace.AlreadyExists("remote cluster %q and/or one or more of its cert authorities already exists", rc.GetName()) + } + return "", trace.Wrap(err) + } + + return rev, nil } // UpdateRemoteCluster updates selected remote cluster fields: expiry and labels @@ -652,17 +974,12 @@ func (s *CA) UpdateRemoteCluster(ctx context.Context, rc types.RemoteCluster) (t existing.SetConnectionStatus(rc.GetConnectionStatus()) existing.SetMetadata(rc.GetMetadata()) - updateValue, err := services.MarshalRemoteCluster(existing) + item, err := remoteClusterToItem(existing) if err != nil { return nil, trace.Wrap(err) } - lease, err := s.ConditionalUpdate(ctx, backend.Item{ - Key: backend.NewKey(remoteClustersPrefix, existing.GetName()), - Value: updateValue, - Expires: existing.Expiry(), - Revision: existing.GetRevision(), - }) + lease, err := s.ConditionalUpdate(ctx, item) if err != nil { if trace.IsCompareFailed(err) { // Retry! @@ -707,17 +1024,12 @@ func (s *CA) PatchRemoteCluster( return nil, trace.BadParameter("metadata.revision: cannot be patched") } - updatedValue, err := services.MarshalRemoteCluster(updated) + item, err := remoteClusterToItem(updated) if err != nil { return nil, trace.Wrap(err) } - lease, err := s.ConditionalUpdate(ctx, backend.Item{ - Key: backend.NewKey(remoteClustersPrefix, name), - Value: updatedValue, - Expires: updated.Expiry(), - Revision: updated.GetRevision(), - }) + lease, err := s.ConditionalUpdate(ctx, item) if err != nil { if trace.IsCompareFailed(err) { // Retry! @@ -822,13 +1134,44 @@ func (s *CA) GetRemoteCluster( } // DeleteRemoteCluster deletes remote cluster by name -func (s *CA) DeleteRemoteCluster( - ctx context.Context, clusterName string, -) error { - if clusterName == "" { +func (s *CA) DeleteRemoteCluster(ctx context.Context, clusterName string) error { + return s.DeleteRemoteClusterInternal(ctx, clusterName, nil /* no cert authorities */) +} + +// DeleteRemoteClusterInternal atomically deletes a remote cluster along with associated resources. +func (s *CA) DeleteRemoteClusterInternal(ctx context.Context, name string, ids []types.CertAuthID) error { + if name == "" { return trace.BadParameter("missing parameter cluster name") } - return s.Delete(ctx, backend.NewKey(remoteClustersPrefix, clusterName)) + + for _, id := range ids { + if err := id.Check(); err != nil { + return trace.Wrap(err) + } + + if id.DomainName != name { + return trace.BadParameter("ca %q does not belong to remote cluster %q", id.DomainName, name) + } + } + + condacts := []backend.ConditionalAction{ + { + Key: remoteClusterKey(name), + Condition: backend.Exists(), + Action: backend.Delete(), + }, + } + + condacts = append(condacts, s.deleteCertAuthoritiesCondActs(ids)...) + + if _, err := s.AtomicWrite(ctx, condacts); err != nil { + if errors.Is(err, backend.ErrConditionFailed) { + return trace.NotFound("remote cluster %q is not found", name) + } + return trace.Wrap(err) + } + + return nil } // DeleteAllRemoteClusters deletes all remote clusters @@ -853,6 +1196,42 @@ func caToItem(key backend.Key, ca types.CertAuthority) (backend.Item, error) { }, nil } +func trustedClusterToItem(tc types.TrustedCluster) (backend.Item, error) { + value, err := services.MarshalTrustedCluster(tc) + if err != nil { + return backend.Item{}, trace.Wrap(err) + } + + return backend.Item{ + Key: trustedClusterKey(tc.GetName()), + Value: value, + Expires: tc.Expiry(), + Revision: tc.GetRevision(), + }, nil +} + +func trustedClusterKey(name string) backend.Key { + return backend.NewKey(trustedClustersPrefix, name) +} + +func remoteClusterToItem(rc types.RemoteCluster) (backend.Item, error) { + value, err := services.MarshalRemoteCluster(rc) + if err != nil { + return backend.Item{}, trace.Wrap(err) + } + + return backend.Item{ + Key: remoteClusterKey(rc.GetName()), + Value: value, + Expires: rc.Expiry(), + Revision: rc.GetRevision(), + }, nil +} + +func remoteClusterKey(name string) backend.Key { + return backend.NewKey(remoteClustersPrefix, name) +} + // activeCAKey builds the active key variant for the supplied ca id. func activeCAKey(id types.CertAuthID) backend.Key { return backend.NewKey(authoritiesPrefix, string(id.Type), id.DomainName) diff --git a/lib/services/local/trust_test.go b/lib/services/local/trust_test.go index 34a85171d4887..3188c546e6c16 100644 --- a/lib/services/local/trust_test.go +++ b/lib/services/local/trust_test.go @@ -20,6 +20,7 @@ package local import ( "context" + "crypto/x509/pkix" "fmt" "testing" "time" @@ -32,11 +33,205 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/tlsca" ) +func TestUpdateCertAuthorityCondActs(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // setup closure creates our initial state and returns its components + setup := func(active bool) (types.TrustedCluster, types.CertAuthority, *CA) { + bk, err := memory.New(memory.Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, bk.Close()) }) + service := NewCAService(bk) + + tc, err := types.NewTrustedCluster("tc", types.TrustedClusterSpecV2{ + Enabled: active, + Roles: []string{"rrr"}, + Token: "xxx", + ProxyAddress: "xxx", + ReverseTunnelAddress: "xxx", + }) + require.NoError(t, err) + + ca := newCertAuthority(t, types.HostCA, "tc") + revision, err := service.CreateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + return tc, ca, service + } + + // putCA is a helper for injecting a CA into the backend, bypassing atomic condition protections + putCA := func(ctx context.Context, service *CA, ca types.CertAuthority, active bool) { + key := activeCAKey(ca.GetID()) + if !active { + key = inactiveCAKey(ca.GetID()) + } + item, err := caToItem(key, ca) + require.NoError(t, err) + _, err = service.Put(ctx, item) + require.NoError(t, err) + } + + // delCA is a helper for deleting a CA from the backend, bypassing atomic condition protections + delCA := func(ctx context.Context, service *CA, ca types.CertAuthority, active bool) { + key := activeCAKey(ca.GetID()) + if !active { + key = inactiveCAKey(ca.GetID()) + } + require.NoError(t, service.Delete(ctx, key)) + } + + // -- update active in place --- + tc, ca, service := setup(true /* active */) + + // verify basic update works + tc.SetRoles([]string{"rrr", "zzz"}) + revision, err := service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err := service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + // verify that an inactive CA doesn't prevent update + putCA(ctx, service, ca, false /* inactive */) + tc.SetRoles([]string{"rrr", "zzz", "aaa"}) + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + // verify that concurrent update of the active CA causes update to fail + putCA(ctx, service, ca, true /* active */) + tc.SetRoles([]string{"rrr", "zzz", "aaa", "bbb"}) + _, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.True(t, trace.IsCompareFailed(err), "err=%v", err) + + // --- update inactive in place --- + tc, ca, service = setup(false /* inactive */) + + // verify basic update works + tc.SetRoles([]string{"rrr", "zzz"}) + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + + // verify that an active CA prevents update + putCA(ctx, service, ca, true /* active */) + tc.SetRoles([]string{"rrr", "zzz", "aaa"}) + _, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.True(t, trace.IsCompareFailed(err), "err=%v", err) + delCA(ctx, service, ca, true /* active */) + + // verify that concurrent update of the inactive CA causes update to fail + putCA(ctx, service, ca, false /* inactive */) + tc.SetRoles([]string{"rrr", "zzz", "aaa", "bbb"}) + _, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.True(t, trace.IsCompareFailed(err), "err=%v", err) + + // --- activate/deactivate --- + tc, ca, service = setup(false /* inactive */) + + // verify that activating works + tc.SetEnabled(true) + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + // verify that deactivating works + tc.SetEnabled(false) + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + + // verify that an active CA conflicts with activation + putCA(ctx, service, ca, true /* active */) + tc.SetEnabled(true) + _, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.True(t, trace.IsCompareFailed(err), "err=%v", err) + delCA(ctx, service, ca, true /* active */) + + // activation should work after deleting conlicting CA + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + // verify that deactivation works even if there is an inaactive CA present + putCA(ctx, service, ca, false /* inactive */) + tc.SetEnabled(false) + revision, err = service.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + tc.SetRevision(revision) + ca.SetRevision(revision) + + gotTC, err = service.GetTrustedCluster(ctx, tc.GetName()) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc, gotTC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + _, err = service.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + _, err = service.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) +} + func TestRemoteClusterCRUD(t *testing.T) { t.Parallel() ctx := context.Background() @@ -67,22 +262,38 @@ func TestRemoteClusterCRUD(t *testing.T) { src.SetConnectionStatus(teleport.RemoteClusterStatusOnline) src.SetLastHeartbeat(clock.Now().Add(-time.Hour)) - // create remote clusters - gotRC, err := trustService.CreateRemoteCluster(ctx, rc) + // set up fake CAs for the remote clusters + ca := newCertAuthority(t, types.HostCA, "foo") + sca := newCertAuthority(t, types.HostCA, "bar") + + // create remote cluster + revision, err := trustService.CreateRemoteClusterInternal(ctx, rc, []types.CertAuthority{ca}) require.NoError(t, err) - require.Empty(t, cmp.Diff(rc, gotRC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) - gotSRC, err := trustService.CreateRemoteCluster(ctx, src) + rc.SetRevision(revision) + ca.SetRevision(revision) + + _, err = trustService.CreateRemoteClusterInternal(ctx, rc, []types.CertAuthority{ca}) + require.True(t, trace.IsAlreadyExists(err), "err=%v", err) + + revision, err = trustService.CreateRemoteClusterInternal(ctx, src, []types.CertAuthority{sca}) require.NoError(t, err) - require.Empty(t, cmp.Diff(src, gotSRC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + src.SetRevision(revision) + sca.SetRevision(revision) // get remote cluster make sure it's correct - gotRC, err = trustService.GetRemoteCluster(ctx, "foo") + gotRC, err := trustService.GetRemoteCluster(ctx, "foo") require.NoError(t, err) require.Equal(t, "foo", gotRC.GetName()) require.Equal(t, teleport.RemoteClusterStatusOffline, gotRC.GetConnectionStatus()) require.Equal(t, clock.Now().Nanosecond(), gotRC.GetLastHeartbeat().Nanosecond()) require.Equal(t, originalLabels, gotRC.GetMetadata().Labels) + // get remote cluster CA make sure it's correct + gotCA, err := trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + + require.Empty(t, cmp.Diff(ca, gotCA, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + rc = gotRC updatedLabels := map[string]string{ "e": "f", @@ -99,10 +310,9 @@ func TestRemoteClusterCRUD(t *testing.T) { require.NoError(t, err) require.Empty(t, cmp.Diff(rc, gotRC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) - src = gotSRC src.SetConnectionStatus(teleport.RemoteClusterStatusOffline) src.SetLastHeartbeat(clock.Now()) - gotSRC, err = trustService.UpdateRemoteCluster(ctx, src) + gotSRC, err := trustService.UpdateRemoteCluster(ctx, src) require.NoError(t, err) require.Empty(t, cmp.Diff(src, gotSRC, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) @@ -126,13 +336,26 @@ func TestRemoteClusterCRUD(t *testing.T) { require.Len(t, allRC, 2) // delete cluster - err = trustService.DeleteRemoteCluster(ctx, "foo") + err = trustService.DeleteRemoteClusterInternal(ctx, "foo", []types.CertAuthID{ca.GetID()}) require.NoError(t, err) // make sure it's really gone - err = trustService.DeleteRemoteCluster(ctx, "foo") - require.Error(t, err) - require.ErrorIs(t, err, trace.NotFound(`key "/remoteClusters/foo" is not found`)) + _, err = trustService.GetRemoteCluster(ctx, "foo") + require.True(t, trace.IsNotFound(err)) + _, err = trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err)) + + // make sure we can't create trusted clusters with the same name as an extant remote cluster + tc, err := types.NewTrustedCluster("bar", types.TrustedClusterSpecV2{ + Enabled: true, + Roles: []string{"bar", "baz"}, + Token: "qux", + ProxyAddress: "quux", + ReverseTunnelAddress: "quuz", + }) + require.NoError(t, err) + _, err = trustService.CreateTrustedCluster(ctx, tc, nil) + require.True(t, trace.IsBadParameter(err), "err=%v", err) } func TestPresenceService_PatchRemoteCluster(t *testing.T) { @@ -290,10 +513,13 @@ func TestTrustedClusterCRUD(t *testing.T) { }) require.NoError(t, err) + ca := newCertAuthority(t, types.HostCA, "foo") + sca := newCertAuthority(t, types.HostCA, "bar") + // create trusted clusters - _, err = trustService.UpsertTrustedCluster(ctx, tc) + _, err = trustService.CreateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) require.NoError(t, err) - _, err = trustService.UpsertTrustedCluster(ctx, stc) + _, err = trustService.CreateTrustedCluster(ctx, stc, []types.CertAuthority{sca}) require.NoError(t, err) // get trusted cluster make sure it's correct @@ -306,17 +532,87 @@ func TestTrustedClusterCRUD(t *testing.T) { require.Equal(t, "quux", gotTC.GetProxyAddress()) require.Equal(t, "quuz", gotTC.GetReverseTunnelAddress()) + // get trusted cluster CA make sure it's correct + gotCA, err := trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + require.Empty(t, cmp.Diff(ca, gotCA, cmpopts.IgnoreFields(types.Metadata{}, "Revision"))) + // get all clusters allTC, err := trustService.GetTrustedClusters(ctx) require.NoError(t, err) require.Len(t, allTC, 2) + // verify that enabling/disabling correctly shows/hides CAs + tc.SetEnabled(false) + tc.SetRevision(gotTC.GetRevision()) + ca.SetRevision(gotCA.GetRevision()) + revision, err := trustService.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + _, err = trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + _, err = trustService.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + + tc.SetEnabled(true) + tc.SetRevision(revision) + ca.SetRevision(revision) + _, err = trustService.UpdateTrustedCluster(ctx, tc, []types.CertAuthority{ca}) + require.NoError(t, err) + + _, err = trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.NoError(t, err) + _, err = trustService.GetInactiveCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + // delete cluster - err = trustService.DeleteTrustedCluster(ctx, "foo") + err = trustService.DeleteTrustedClusterInternal(ctx, "foo", []types.CertAuthID{ca.GetID()}) require.NoError(t, err) // make sure it's really gone _, err = trustService.GetTrustedCluster(ctx, "foo") - require.Error(t, err) - require.ErrorIs(t, err, trace.NotFound(`key "/trustedclusters/foo" is not found`)) + require.True(t, trace.IsNotFound(err), "err=%v", err) + _, err = trustService.GetCertAuthority(ctx, ca.GetID(), true) + require.True(t, trace.IsNotFound(err), "err=%v", err) + + // make sure we can't create remote clusters with the same name as an extant trusted cluster + rc, err := types.NewRemoteCluster("bar") + require.NoError(t, err) + _, err = trustService.CreateRemoteCluster(ctx, rc) + require.True(t, trace.IsBadParameter(err), "err=%v", err) +} + +func newCertAuthority(t *testing.T, caType types.CertAuthType, domain string) types.CertAuthority { + t.Helper() + + ta := testauthority.New() + priv, pub, err := ta.GenerateKeyPair() + require.NoError(t, err) + + key, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: domain}, nil, time.Hour) + require.NoError(t, err) + + ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: caType, + ClusterName: domain, + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{{ + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }}, + TLS: []*types.TLSKeyPair{{ + Cert: cert, + Key: key, + }}, + JWT: []*types.JWTKeyPair{{ + PublicKey: pub, + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + }}, + }, + }) + require.NoError(t, err) + + return ca } diff --git a/lib/services/trust.go b/lib/services/trust.go index c7cbfe0229bce..63775ae5b52bb 100644 --- a/lib/services/trust.go +++ b/lib/services/trust.go @@ -83,6 +83,26 @@ type Trust interface { // auth server for some local operations. type TrustInternal interface { Trust + + // CreateTrustedCluster atomically creates a new trusted cluster along with associated resources. + CreateTrustedCluster(context.Context, types.TrustedCluster, []types.CertAuthority) (revision string, err error) + + // UpdateTrustedCluster atomically updates a trusted cluster along with associated resources. + UpdateTrustedCluster(context.Context, types.TrustedCluster, []types.CertAuthority) (revision string, err error) + + // DeleteTrustedClusterInternal atomically deletes a trusted cluster along with associated resources. + DeleteTrustedClusterInternal(context.Context, string, []types.CertAuthID) error + + // CreateRemoteCluster atomically creates a new remote cluster along with associated resources. + CreateRemoteClusterInternal(context.Context, types.RemoteCluster, []types.CertAuthority) (revision string, err error) + + // DeleteRemotClusterInternal atomically deletes a remote cluster along with associated resources. + DeleteRemoteClusterInternal(context.Context, string, []types.CertAuthID) error + + // GetInactiveCertAuthority returns inactive certificate authority by given id. Parameter loadSigningKeys + // controls if signing keys are loaded. + GetInactiveCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool) (types.CertAuthority, error) + // CreateCertAuthorities creates multiple cert authorities atomically. CreateCertAuthorities(context.Context, ...types.CertAuthority) (revision string, err error)