diff --git a/api/types/authentication.go b/api/types/authentication.go index 504e44ff6e203..64a8d972f4347 100644 --- a/api/types/authentication.go +++ b/api/types/authentication.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/tlsutils" ) @@ -171,6 +172,9 @@ type AuthPreference interface { // String represents a human readable version of authentication settings. String() string + + // Clone makes a deep copy of the AuthPreference. + Clone() AuthPreference } // NewAuthPreference is a convenience method to to create AuthPreferenceV2. @@ -759,6 +763,11 @@ func (c *AuthPreferenceV2) String() string { return fmt.Sprintf("AuthPreference(Type=%q,SecondFactor=%q)", c.Spec.Type, c.Spec.SecondFactor) } +// Clone returns a copy of the AuthPreference resource. +func (c *AuthPreferenceV2) Clone() AuthPreference { + return utils.CloneProtoMsg(c) +} + func (u *U2F) Check() error { if u.AppID == "" { return trace.BadParameter("u2f configuration missing app_id") diff --git a/api/types/sessionrecording.go b/api/types/sessionrecording.go index 9c237b1f61f11..2fde26ae74349 100644 --- a/api/types/sessionrecording.go +++ b/api/types/sessionrecording.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils" ) // SessionRecordingConfig defines session recording configuration. This is @@ -40,6 +42,9 @@ type SessionRecordingConfig interface { // SetProxyChecksHostKeys sets if the proxy will check host keys. SetProxyChecksHostKeys(bool) + + // Clone returns a copy of the resource. + Clone() SessionRecordingConfig } // NewSessionRecordingConfigFromConfigFile is a convenience method to create @@ -168,6 +173,11 @@ func (c *SessionRecordingConfigV2) SetProxyChecksHostKeys(t bool) { c.Spec.ProxyChecksHostKeys = NewBoolOption(t) } +// Clone returns a copy of the resource. +func (c *SessionRecordingConfigV2) Clone() SessionRecordingConfig { + return utils.CloneProtoMsg(c) +} + // setStaticFields sets static resource header and metadata fields. func (c *SessionRecordingConfigV2) setStaticFields() { c.Kind = KindSessionRecordingConfig diff --git a/integration/helpers/helpers.go b/integration/helpers/helpers.go index 7101897bd6de3..63525f86b319f 100644 --- a/integration/helpers/helpers.go +++ b/integration/helpers/helpers.go @@ -483,8 +483,11 @@ func UpsertAuthPrefAndWaitForCache( _, err := srv.UpsertAuthPreference(ctx, pref) require.NoError(t, err) require.EventuallyWithT(t, func(t *assert.CollectT) { - p, err := srv.GetAuthPreference(ctx) + // we need to wait for the in-memory copy of auth pref to be updated, which + // takes a bit longer than standard cache propagation. + rp, err := srv.GetReadOnlyAuthPreference(ctx) require.NoError(t, err) + p := rp.Clone() assert.Empty(t, cmp.Diff(&pref, &p)) }, 5*time.Second, 100*time.Millisecond) } diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 45477c62638cd..86eeedb1a7269 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -105,6 +105,7 @@ import ( "github.com/gravitational/teleport/lib/resourceusage" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/spacelift" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/sshca" @@ -483,6 +484,20 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { return nil, trace.Wrap(err) } + _, cacheEnabled := as.getCache() + + // cluster config ttl cache *must* be set up after `opts` has been applied to the server because + // the Cache field starts off as a pointer to the local backend services and is only switched + // over to being a proper cache during option processing. + as.ReadOnlyCache, err = readonly.NewCache(readonly.CacheConfig{ + Upstream: as.Cache, + Disabled: !cacheEnabled, + ReloadOnErr: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if as.ghaIDTokenValidator == nil { as.ghaIDTokenValidator = githubactions.NewIDTokenValidator( githubactions.IDTokenValidatorConfig{ @@ -798,6 +813,10 @@ var ( // successfully authenticated. An example would be creating objects based on the user. type LoginHook func(context.Context, types.User) error +// ReadOnlyCache is a type alias used to assist with embedding [readonly.Cache] in places +// where it would have a naming conflict with other types named Cache. +type ReadOnlyCache = readonly.Cache + // Server keeps the cluster together. It acts as a certificate authority (CA) for // a cluster and: // - generates the keypair for the node it's running on @@ -849,6 +868,11 @@ type Server struct { // method on Services instead. authclient.Cache + // ReadOnlyCache is a specialized cache that provides read-only shared references + // in certain performance-critical paths where deserialization/cloning may be too + // expensive at scale. + *ReadOnlyCache + // privateKey is used in tests to use pre-generated private keys privateKey []byte @@ -1837,7 +1861,8 @@ func (a *Server) GenerateHostCert(ctx context.Context, hostPublicKey []byte, hos func (a *Server) generateHostCert( ctx context.Context, p services.HostCertParams, ) ([]byte, error) { - authPref, err := a.GetAuthPreference(ctx) + + readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -1860,7 +1885,7 @@ func (a *Server) generateHostCert( default: locks = []types.LockTarget{{ServerID: p.HostID}, {ServerID: HostFQDN(p.HostID, p.ClusterName)}} } - if lockErr := a.checkLockInForce(authPref.GetLockingMode(), + if lockErr := a.checkLockInForce(readOnlyAuthPref.GetLockingMode(), locks, ); lockErr != nil { return nil, trace.Wrap(lockErr) @@ -2028,11 +2053,11 @@ func (a *Server) GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCert return nil, trace.BadParameter("public key is empty") } if req.TTL == 0 { - cap, err := a.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.BadParameter("cert request does not specify a TTL and the cluster_auth_preference is not available: %v", err) } - req.TTL = proto.Duration(cap.GetDefaultSessionTTL()) + req.TTL = proto.Duration(readOnlyAuthPref.GetDefaultSessionTTL()) } if req.TTL < 0 { return nil, trace.BadParameter("TTL must be positive") @@ -2458,13 +2483,13 @@ func (a *Server) AugmentContextUserCertificates( } // Verify locks right before we re-issue any certificates. - authPref, err := a.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } if err := a.verifyLocksForUserCerts(verifyLocksForUserCertsReq{ checker: authCtx.Checker, - defaultMode: authPref.GetLockingMode(), + defaultMode: readOnlyAuthPref.GetLockingMode(), username: identity.Username, mfaVerified: identity.MFAVerified, activeAccessRequests: identity.ActiveRequests, @@ -2597,13 +2622,13 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. } // Reject the cert request if there is a matching lock in force. - authPref, err := a.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } if err := a.verifyLocksForUserCerts(verifyLocksForUserCertsReq{ checker: req.checker, - defaultMode: authPref.GetLockingMode(), + defaultMode: readOnlyAuthPref.GetLockingMode(), username: req.user.GetName(), mfaVerified: req.mfaVerified, activeAccessRequests: req.activeRequests.AccessRequests, @@ -2632,7 +2657,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. var allowedLogins []string if req.ttl == 0 { - req.ttl = time.Duration(authPref.GetDefaultSessionTTL()) + req.ttl = time.Duration(readOnlyAuthPref.GetDefaultSessionTTL()) } // If the role TTL is ignored, do not restrict session TTL and allowed logins. @@ -2662,7 +2687,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. } attestedKeyPolicy := keys.PrivateKeyPolicyNone - requiredKeyPolicy, err := req.checker.PrivateKeyPolicy(authPref.GetPrivateKeyPolicy()) + requiredKeyPolicy, err := req.checker.PrivateKeyPolicy(readOnlyAuthPref.GetPrivateKeyPolicy()) if err != nil { return nil, trace.Wrap(err) } @@ -2684,7 +2709,7 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. } var validateSerialNumber bool - hksnv, err := authPref.GetHardwareKeySerialNumberValidation() + hksnv, err := readOnlyAuthPref.GetHardwareKeySerialNumberValidation() if err == nil { validateSerialNumber = hksnv.Enabled } @@ -3501,7 +3526,7 @@ func (a *Server) deleteMFADeviceSafely(ctx context.Context, user, deviceName str return nil, trace.Wrap(err) } - authPref, err := a.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -3549,7 +3574,7 @@ func (a *Server) deleteMFADeviceSafely(ctx context.Context, user, deviceName str // Prevent users from deleting their last device for clusters that require second factors. const minDevices = 1 - switch sf := authPref.GetSecondFactor(); sf { + switch sf := readOnlyAuthPref.GetSecondFactor(); sf { case constants.SecondFactorOff, constants.SecondFactorOptional: // MFA is not required, allow deletion case constants.SecondFactorOn: if knownDevices <= minDevices { @@ -3570,7 +3595,7 @@ func (a *Server) deleteMFADeviceSafely(ctx context.Context, user, deviceName str // It checks whether the credential to delete is a last passkey and whether // the user has other valid local credentials. canDeleteLastPasskey := func() (bool, error) { - if !authPref.GetAllowPasswordless() || numResidentKeys > 1 || !isResidentKey(deviceToDelete) { + if !readOnlyAuthPref.GetAllowPasswordless() || numResidentKeys > 1 || !isResidentKey(deviceToDelete) { return true, nil } @@ -3592,7 +3617,7 @@ func (a *Server) deleteMFADeviceSafely(ctx context.Context, user, deviceName str // Whether we take TOTPs into consideration or not depends on whether it's // enabled. - switch sf := authPref.GetSecondFactor(); sf { + switch sf := readOnlyAuthPref.GetSecondFactor(); sf { case constants.SecondFactorOTP, constants.SecondFactorOn, constants.SecondFactorOptional: if sfToCount[constants.SecondFactorOTP] >= 1 { return true, nil diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 103561dcc5ce3..bd277a94ee1e1 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1550,7 +1550,7 @@ func (a *ServerWithRoles) GetSSHTargets(ctx context.Context, req *proto.GetSSHTa // try to detect case-insensitive routing setting, but default to false if we can't load // networking config (equivalent to proxy routing behavior). var caseInsensitiveRouting bool - if cfg, err := a.authServer.GetClusterNetworkingConfig(ctx); err == nil { + if cfg, err := a.authServer.GetReadOnlyClusterNetworkingConfig(ctx); err == nil { caseInsensitiveRouting = cfg.GetCaseInsensitiveRouting() } @@ -2911,11 +2911,11 @@ func getBotName(user types.User) string { func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserCertsRequest, opts ...certRequestOption) (*proto.Certs, error) { // Device trust: authorize device before issuing certificates. - authPref, err := a.authServer.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.authServer.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } - if err := a.verifyUserDeviceForCertIssuance(req.Usage, authPref.GetDeviceTrust()); err != nil { + if err := a.verifyUserDeviceForCertIssuance(req.Usage, readOnlyAuthPref.GetDeviceTrust()); err != nil { return nil, trace.Wrap(err) } @@ -3013,7 +3013,7 @@ func (a *ServerWithRoles) generateUserCerts(ctx context.Context, req proto.UserC if err != nil { return nil, trace.Wrap(err) } - sessionTTL := roleSet.AdjustSessionTTL(authPref.GetDefaultSessionTTL().Duration()) + sessionTTL := roleSet.AdjustSessionTTL(readOnlyAuthPref.GetDefaultSessionTTL().Duration()) req.Expires = a.authServer.GetClock().Now().UTC().Add(sessionTTL) } else if req.Expires.After(sessionExpires) { // Standard user impersonation has an expiry limited to the expiry @@ -4373,8 +4373,11 @@ func (a *ServerWithRoles) GetAuthPreference(ctx context.Context) (types.AuthPref if err := a.action(apidefaults.Namespace, types.KindClusterAuthPreference, types.VerbRead); err != nil { return nil, trace.Wrap(err) } - - return a.authServer.GetAuthPreference(ctx) + cfg, err := a.authServer.GetReadOnlyAuthPreference(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return cfg.Clone(), nil } func (a *ServerWithRoles) GetUIConfig(ctx context.Context) (types.UIConfig, error) { @@ -4552,7 +4555,11 @@ func (a *ServerWithRoles) GetClusterNetworkingConfig(ctx context.Context) (types return nil, trace.Wrap(err) } } - return a.authServer.GetClusterNetworkingConfig(ctx) + cfg, err := a.authServer.GetReadOnlyClusterNetworkingConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return cfg.Clone(), nil } // SetClusterNetworkingConfig sets cluster networking configuration. @@ -6644,12 +6651,12 @@ func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *prot // enforceGlobalModeTrustedDevice is used to enforce global device trust requirements // for key endpoints. func (a *ServerWithRoles) enforceGlobalModeTrustedDevice(ctx context.Context) error { - authPref, err := a.GetAuthPreference(ctx) + readOnlyAuthPref, err := a.authServer.GetReadOnlyAuthPreference(ctx) if err != nil { return trace.Wrap(err) } - err = dtauthz.VerifyTLSUser(authPref.GetDeviceTrust(), a.context.Identity.GetIdentity()) + err = dtauthz.VerifyTLSUser(readOnlyAuthPref.GetDeviceTrust(), a.context.Identity.GetIdentity()) return trace.Wrap(err) } diff --git a/lib/auth/clusterconfig/clusterconfigv1/service.go b/lib/auth/clusterconfig/clusterconfigv1/service.go index 3c55616867006..2a86a05a3777b 100644 --- a/lib/auth/clusterconfig/clusterconfigv1/service.go +++ b/lib/auth/clusterconfig/clusterconfigv1/service.go @@ -29,6 +29,7 @@ import ( dtconfig "github.com/gravitational/teleport/lib/devicetrust/config" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/services/readonly" ) // Cache is used by the [Service] to query cluster config resources. @@ -38,6 +39,13 @@ type Cache interface { GetSessionRecordingConfig(ctx context.Context) (types.SessionRecordingConfig, error) } +// ReadOnlyCache abstracts over the required methods of [readonly.Cache]. +type ReadOnlyCache interface { + GetReadOnlyAuthPreference(context.Context) (readonly.AuthPreference, error) + GetReadOnlyClusterNetworkingConfig(ctx context.Context) (readonly.ClusterNetworkingConfig, error) + GetReadOnlySessionRecordingConfig(ctx context.Context) (readonly.SessionRecordingConfig, error) +} + // Backend is used by the [Service] to mutate cluster config resources. type Backend interface { CreateAuthPreference(ctx context.Context, preference types.AuthPreference) (types.AuthPreference, error) @@ -55,11 +63,12 @@ type Backend interface { // ServiceConfig contain dependencies required to create a [Service]. type ServiceConfig struct { - Cache Cache - Backend Backend - Authorizer authz.Authorizer - Emitter apievents.Emitter - AccessGraph AccessGraphConfig + Cache Cache + Backend Backend + Authorizer authz.Authorizer + Emitter apievents.Emitter + AccessGraph AccessGraphConfig + ReadOnlyCache ReadOnlyCache } // AccessGraphConfig contains the configuration about the access graph service @@ -81,11 +90,12 @@ type AccessGraphConfig struct { type Service struct { clusterconfigpb.UnimplementedClusterConfigServiceServer - cache Cache - backend Backend - authorizer authz.Authorizer - emitter apievents.Emitter - accessGraph AccessGraphConfig + cache Cache + backend Backend + authorizer authz.Authorizer + emitter apievents.Emitter + accessGraph AccessGraphConfig + readOnlyCache ReadOnlyCache } // NewService validates the provided configuration and returns a [Service]. @@ -101,7 +111,17 @@ func NewService(cfg ServiceConfig) (*Service, error) { return nil, trace.BadParameter("emitter is required") } - return &Service{cache: cfg.Cache, backend: cfg.Backend, authorizer: cfg.Authorizer, emitter: cfg.Emitter, accessGraph: cfg.AccessGraph}, nil + if cfg.ReadOnlyCache == nil { + readOnlyCache, err := readonly.NewCache(readonly.CacheConfig{ + Upstream: cfg.Cache, + }) + if err != nil { + return nil, trace.Wrap(err) + } + cfg.ReadOnlyCache = readOnlyCache + } + + return &Service{cache: cfg.Cache, backend: cfg.Backend, authorizer: cfg.Authorizer, emitter: cfg.Emitter, accessGraph: cfg.AccessGraph, readOnlyCache: cfg.ReadOnlyCache}, nil } // GetAuthPreference returns the locally cached auth preference. @@ -115,12 +135,12 @@ func (s *Service) GetAuthPreference(ctx context.Context, _ *clusterconfigpb.GetA return nil, trace.Wrap(err) } - pref, err := s.cache.GetAuthPreference(ctx) + pref, err := s.readOnlyCache.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } - authPrefV2, ok := pref.(*types.AuthPreferenceV2) + authPrefV2, ok := pref.Clone().(*types.AuthPreferenceV2) if !ok { return nil, trace.Wrap(trace.BadParameter("unexpected auth preference type %T (expected %T)", pref, authPrefV2)) } @@ -358,12 +378,12 @@ func (s *Service) GetClusterNetworkingConfig(ctx context.Context, _ *clusterconf } } - netConfig, err := s.cache.GetClusterNetworkingConfig(ctx) + netConfig, err := s.readOnlyCache.GetReadOnlyClusterNetworkingConfig(ctx) if err != nil { return nil, trace.Wrap(err) } - cfgV2, ok := netConfig.(*types.ClusterNetworkingConfigV2) + cfgV2, ok := netConfig.Clone().(*types.ClusterNetworkingConfigV2) if !ok { return nil, trace.Wrap(trace.BadParameter("unexpected cluster networking config type %T (expected %T)", netConfig, cfgV2)) } @@ -658,12 +678,12 @@ func (s *Service) GetSessionRecordingConfig(ctx context.Context, _ *clusterconfi } } - netConfig, err := s.cache.GetSessionRecordingConfig(ctx) + netConfig, err := s.readOnlyCache.GetReadOnlySessionRecordingConfig(ctx) if err != nil { return nil, trace.Wrap(err) } - cfgV2, ok := netConfig.(*types.SessionRecordingConfigV2) + cfgV2, ok := netConfig.Clone().(*types.SessionRecordingConfigV2) if !ok { return nil, trace.Wrap(trace.BadParameter("unexpected session recording config type %T (expected %T)", netConfig, cfgV2)) } diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 2adde2520ca50..ddb190da0e40f 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5974,6 +5974,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { Address: cfg.APIConfig.AccessGraph.Address, Insecure: cfg.APIConfig.AccessGraph.Insecure, }, + ReadOnlyCache: cfg.AuthServer.ReadOnlyCache, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 54570e5747801..41795c6538f9e 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -451,9 +451,10 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { srv.AuthServer.SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher) srv.Authorizer, err = authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: srv.ClusterName, - AccessPoint: srv.AuthServer, - LockWatcher: srv.LockWatcher, + ClusterName: srv.ClusterName, + AccessPoint: srv.AuthServer, + ReadOnlyAccessPoint: srv.AuthServer.ReadOnlyCache, + LockWatcher: srv.LockWatcher, // AuthServer does explicit device authorization checks. DeviceAuthorization: authz.DeviceAuthorizationOpts{ DisableGlobalMode: true, diff --git a/lib/auth/sessions.go b/lib/auth/sessions.go index 7098b76d6154a..ebcbdb38fd663 100644 --- a/lib/auth/sessions.go +++ b/lib/auth/sessions.go @@ -122,7 +122,7 @@ func (a *Server) NewWebSession(ctx context.Context, req NewWebSessionRequest) (t return nil, trace.Wrap(err) } - netCfg, err := a.GetClusterNetworkingConfig(ctx) + idleTimeout, err := a.getWebIdleTimeout(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -187,7 +187,7 @@ func (a *Server) NewWebSession(ctx context.Context, req NewWebSessionRequest) (t BearerToken: bearerToken, BearerTokenExpires: startTime.UTC().Add(bearerTokenTTL), LoginTime: req.LoginTime, - IdleTimeout: types.Duration(netCfg.GetWebIdleTimeout()), + IdleTimeout: types.Duration(idleTimeout), } UserLoginCount.Inc() @@ -198,6 +198,14 @@ func (a *Server) NewWebSession(ctx context.Context, req NewWebSessionRequest) (t return sess, nil } +func (a *Server) getWebIdleTimeout(ctx context.Context) (time.Duration, error) { + netCfg, err := a.GetReadOnlyClusterNetworkingConfig(ctx) + if err != nil { + return 0, trace.Wrap(err) + } + return netCfg.GetWebIdleTimeout(), nil +} + func (a *Server) upsertWebSession(ctx context.Context, session types.WebSession) error { if err := a.WebSessions().Upsert(ctx, session); err != nil { return trace.Wrap(err) diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index b8a157361e4cd..d3d73a15a48d9 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/api/utils/keys" dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" ) @@ -74,11 +75,12 @@ type DeviceAuthorizationOpts struct { // AuthorizerOpts holds creation options for [NewAuthorizer]. type AuthorizerOpts struct { - ClusterName string - AccessPoint AuthorizerAccessPoint - MFAAuthenticator MFAAuthenticator - LockWatcher *services.LockWatcher - Logger logrus.FieldLogger + ClusterName string + AccessPoint AuthorizerAccessPoint + ReadOnlyAccessPoint ReadOnlyAuthorizerAccessPoint + MFAAuthenticator MFAAuthenticator + LockWatcher *services.LockWatcher + Logger logrus.FieldLogger // DeviceAuthorization holds Device Trust authorization options. // @@ -86,6 +88,9 @@ type AuthorizerOpts struct { // support device trust to disable it. // Most services should not set this field. DeviceAuthorization DeviceAuthorizationOpts + // PermitCaching opts into the authorizer setting up its own internal + // caching when ReadOnlyAccessPoint is not provided. + PermitCaching bool } // NewAuthorizer returns new authorizer using backends @@ -101,9 +106,25 @@ func NewAuthorizer(opts AuthorizerOpts) (Authorizer, error) { logger = logrus.WithFields(logrus.Fields{teleport.ComponentKey: "authorizer"}) } + if opts.ReadOnlyAccessPoint == nil { + // we create the read-only access point if not provided in order to keep our + // code paths simpler, but the it will not perform ttl-caching unless opts.PermitCaching + // was set. This is necessary because the vast majority of our test coverage + // cannot handle caching, and will fail if caching is enabled. + var err error + opts.ReadOnlyAccessPoint, err = readonly.NewCache(readonly.CacheConfig{ + Upstream: opts.AccessPoint, + Disabled: !opts.PermitCaching, + }) + if err != nil { + return nil, trace.Wrap(err) + } + } + return &authorizer{ clusterName: opts.ClusterName, accessPoint: opts.AccessPoint, + readOnlyAccessPoint: opts.ReadOnlyAccessPoint, mfaAuthenticator: opts.MFAAuthenticator, lockWatcher: opts.LockWatcher, logger: logger, @@ -156,6 +177,20 @@ type AuthorizerAccessPoint interface { GetSessionRecordingConfig(ctx context.Context) (types.SessionRecordingConfig, error) } +// ReadOnlyAuthorizerAccessPoint is an additional optional access point interface that permits +// optimized access-control checks by sharing references to frequently accessed configuration +// objects across goroutines. +type ReadOnlyAuthorizerAccessPoint interface { + // GetReadOnlyAuthPreference returns the cluster authentication configuration. + GetReadOnlyAuthPreference(ctx context.Context) (readonly.AuthPreference, error) + + // GetReadOnlyClusterNetworkingConfig returns cluster networking configuration. + GetReadOnlyClusterNetworkingConfig(ctx context.Context) (readonly.ClusterNetworkingConfig, error) + + // GetReadOnlySessionRecordingConfig returns session recording configuration. + GetReadOnlySessionRecordingConfig(ctx context.Context) (readonly.SessionRecordingConfig, error) +} + // MFAAuthenticator authenticates MFA responses. type MFAAuthenticator interface { // ValidateMFAAuthResponse validates an MFA challenge response. @@ -175,11 +210,12 @@ type MFAAuthData struct { // authorizer creates new local authorizer type authorizer struct { - clusterName string - accessPoint AuthorizerAccessPoint - mfaAuthenticator MFAAuthenticator - lockWatcher *services.LockWatcher - logger logrus.FieldLogger + clusterName string + accessPoint AuthorizerAccessPoint + readOnlyAccessPoint ReadOnlyAuthorizerAccessPoint + mfaAuthenticator MFAAuthenticator + lockWatcher *services.LockWatcher + logger logrus.FieldLogger disableGlobalDeviceMode bool disableRoleDeviceMode bool @@ -308,7 +344,7 @@ func (c *Context) WithExtraRoles(access services.RoleGetter, clusterName string, // GetAccessState returns the AccessState based on the underlying // [services.AccessChecker] and [tlsca.Identity]. -func (c *Context) GetAccessState(authPref types.AuthPreference) services.AccessState { +func (c *Context) GetAccessState(authPref readonly.AuthPreference) services.AccessState { state := c.Checker.GetAccessState(authPref) identity := c.Identity.GetIdentity() @@ -327,7 +363,7 @@ func (c *Context) GetAccessState(authPref types.AuthPreference) services.AccessS // based on whether a connection is set to disconnect on cert expiry, and whether // the cert is a short-lived (<1m) one issued for an MFA verified session. If the session // doesn't need to be disconnected on cert expiry, it will return a zero [time.Time]. -func (c *Context) GetDisconnectCertExpiry(authPref types.AuthPreference) time.Time { +func (c *Context) GetDisconnectCertExpiry(authPref readonly.AuthPreference) time.Time { // In the case where both disconnect_expired_cert and require_session_mfa are enabled, // the PreviousIdentityExpires value of the certificate will be used, which is the // expiry of the certificate used to issue the short-lived MFA verified certificate. @@ -380,7 +416,7 @@ func (a *authorizer) Authorize(ctx context.Context) (authCtx *Context, err error } // Enforce applicable locks. - authPref, err := a.accessPoint.GetAuthPreference(ctx) + authPref, err := a.readOnlyAccessPoint.GetReadOnlyAuthPreference(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -409,7 +445,7 @@ func (a *authorizer) Authorize(ctx context.Context) (authCtx *Context, err error return authContext, nil } -func (a *authorizer) enforcePrivateKeyPolicy(ctx context.Context, authContext *Context, authPref types.AuthPreference) error { +func (a *authorizer) enforcePrivateKeyPolicy(ctx context.Context, authContext *Context, authPref readonly.AuthPreference) error { switch authContext.Identity.(type) { case BuiltinRole, RemoteBuiltinRole: // built in roles do not need to pass private key policies @@ -481,7 +517,7 @@ func (a *authorizer) isAdminActionAuthorizationRequired(ctx context.Context, aut return false, nil } - authpref, err := a.accessPoint.GetAuthPreference(ctx) + authpref, err := a.readOnlyAccessPoint.GetReadOnlyAuthPreference(ctx) if err != nil { return false, trace.Wrap(err) } @@ -730,7 +766,7 @@ func (a *authorizer) authorizeRemoteUser(ctx context.Context, u RemoteUser) (*Co // authorizeBuiltinRole authorizes builtin role func (a *authorizer) authorizeBuiltinRole(ctx context.Context, r BuiltinRole) (*Context, error) { - recConfig, err := a.accessPoint.GetSessionRecordingConfig(ctx) + recConfig, err := a.readOnlyAccessPoint.GetReadOnlySessionRecordingConfig(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -904,7 +940,7 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { } // RoleSetForBuiltinRoles returns RoleSet for embedded builtin role -func RoleSetForBuiltinRoles(clusterName string, recConfig types.SessionRecordingConfig, roles ...types.SystemRole) (services.RoleSet, error) { +func RoleSetForBuiltinRoles(clusterName string, recConfig readonly.SessionRecordingConfig, roles ...types.SystemRole) (services.RoleSet, error) { var definitions []types.Role for _, role := range roles { rd, err := definitionForBuiltinRole(clusterName, recConfig, role) @@ -917,7 +953,7 @@ func RoleSetForBuiltinRoles(clusterName string, recConfig types.SessionRecording } // definitionForBuiltinRole constructs the appropriate role definition for a given builtin role. -func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordingConfig, role types.SystemRole) (types.Role, error) { +func definitionForBuiltinRole(clusterName string, recConfig readonly.SessionRecordingConfig, role types.SystemRole) (types.Role, error) { switch role { case types.RoleAuth: return services.RoleFromSpec( @@ -1231,7 +1267,7 @@ func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordi } // ContextForBuiltinRole returns a context with the builtin role information embedded. -func ContextForBuiltinRole(r BuiltinRole, recConfig types.SessionRecordingConfig) (*Context, error) { +func ContextForBuiltinRole(r BuiltinRole, recConfig readonly.SessionRecordingConfig) (*Context, error) { var systemRoles []types.SystemRole if r.Role == types.RoleInstance { // instance certs encode multiple system roles in a separate field diff --git a/lib/authz/permissions_test.go b/lib/authz/permissions_test.go index a1d559fb65aa6..0ce6edeca953f 100644 --- a/lib/authz/permissions_test.go +++ b/lib/authz/permissions_test.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -1110,7 +1111,7 @@ type fakeCtxChecker struct { state services.AccessState } -func (c *fakeCtxChecker) GetAccessState(_ types.AuthPreference) services.AccessState { +func (c *fakeCtxChecker) GetAccessState(_ readonly.AuthPreference) services.AccessState { return c.state } diff --git a/lib/service/service.go b/lib/service/service.go index 80f69e7297150..75771fd9ffa10 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2162,11 +2162,12 @@ func (process *TeleportProcess) initAuthService() error { // each serving requests for a "role" which is assigned to every connected // client based on their certificate (user, server, admin, etc) authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: clusterName, - AccessPoint: authServer, - MFAAuthenticator: authServer, - LockWatcher: lockWatcher, - Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentAuth, process.id)), + ClusterName: clusterName, + AccessPoint: authServer, + ReadOnlyAccessPoint: authServer, + MFAAuthenticator: authServer, + LockWatcher: lockWatcher, + Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentAuth, process.id)), // Auth Server does explicit device authorization. // Various Auth APIs must allow access to unauthorized devices, otherwise it // is not possible to acquire device-aware certificates in the first place. @@ -4371,10 +4372,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: cn.GetClusterName(), - AccessPoint: accessPoint, - LockWatcher: lockWatcher, - Logger: process.log, + ClusterName: cn.GetClusterName(), + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + Logger: process.log, + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return trace.Wrap(err) @@ -4635,10 +4637,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: clusterName, - AccessPoint: accessPoint, - LockWatcher: lockWatcher, - Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + ClusterName: clusterName, + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return trace.Wrap(err) @@ -4799,10 +4802,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { var kubeServer *kubeproxy.TLSServer if listeners.kube != nil && !process.Config.Proxy.DisableReverseTunnel { authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: clusterName, - AccessPoint: accessPoint, - LockWatcher: lockWatcher, - Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + ClusterName: clusterName, + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return trace.Wrap(err) @@ -4900,10 +4904,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // framework. if (!listeners.db.Empty() || alpnRouter != nil) && !process.Config.Proxy.DisableReverseTunnel { authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ - ClusterName: clusterName, - AccessPoint: accessPoint, - LockWatcher: lockWatcher, - Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + ClusterName: clusterName, + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + Logger: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return trace.Wrap(err) @@ -5757,6 +5762,7 @@ func (process *TeleportProcess) initApps() { // settings to be applied. DisableGlobalMode: true, }, + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return trace.Wrap(err) @@ -6424,6 +6430,7 @@ func (process *TeleportProcess) initSecureGRPCServer(cfg initSecureGRPCServerCfg Logger: process.log.WithFields(logrus.Fields{ teleport.ComponentKey: teleport.Component(teleport.ComponentProxySecureGRPC, process.id), }), + PermitCaching: process.Config.CachePolicy.Enabled, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/services/access_checker.go b/lib/services/access_checker.go index fc27954e87f6a..4e12ff8e53fb2 100644 --- a/lib/services/access_checker.go +++ b/lib/services/access_checker.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -82,7 +83,7 @@ type AccessChecker interface { // CheckAccessToSAMLIdP checks access to the SAML IdP. // //nolint:revive // Because we want this to be IdP. - CheckAccessToSAMLIdP(types.AuthPreference) error + CheckAccessToSAMLIdP(readonly.AuthPreference) error // AdjustSessionTTL will reduce the requested ttl to lowest max allowed TTL // for this role set, otherwise it returns ttl unchanged @@ -224,7 +225,7 @@ type AccessChecker interface { // GetAccessState returns the AccessState for the user given their roles, the // cluster auth preference, and whether MFA and the user's device were // verified. - GetAccessState(authPref types.AuthPreference) AccessState + GetAccessState(authPref readonly.AuthPreference) AccessState // PrivateKeyPolicy returns the enforced private key policy for this role set, // or the provided defaultPolicy - whichever is stricter. PrivateKeyPolicy(defaultPolicy keys.PrivateKeyPolicy) (keys.PrivateKeyPolicy, error) diff --git a/lib/services/readonly/cache.go b/lib/services/readonly/cache.go new file mode 100644 index 0000000000000..503018f75b553 --- /dev/null +++ b/lib/services/readonly/cache.go @@ -0,0 +1,125 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package readonly + +import ( + "context" + "time" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" +) + +// Upstream represents the upstream data source that the cache will fetch data from. +type Upstream interface { + GetAuthPreference(ctx context.Context) (types.AuthPreference, error) + GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) + GetSessionRecordingConfig(ctx context.Context) (types.SessionRecordingConfig, error) +} + +// Cache provides simple ttl-based in-memory caching for select resources that are frequently accessed +// on hot paths. All resources are returned as read-only shared references. +type Cache struct { + cfg CacheConfig + ttlCache *utils.FnCache +} + +// CacheConfig holds configuration options for the cache. +type CacheConfig struct { + // Upstream is the upstream data source that the cache will fetch data from. + Upstream Upstream + // TTL is the time-to-live for each cache entry. + TTL time.Duration + // Disabled is a flag that can be used to disable ttl-caching. Useful in tests that + // don't play nicely with stale data. + Disabled bool + // ReloadOnErr controls wether or not the underlying ttl cache will hold onto error + // entries for the full TTL, or reload error entries immediately. As a general rule, + // this value aught to be true on auth servers and false on agents, though in practice + // the difference is small unless an unusually long TTL is used. + ReloadOnErr bool +} + +// NewCache sets up a new cache instance with the provided configuration. +func NewCache(cfg CacheConfig) (*Cache, error) { + if cfg.Upstream == nil { + return nil, trace.BadParameter("missing upstream data source for readonly cache") + } + if cfg.TTL == 0 { + cfg.TTL = time.Millisecond * 1600 + } + + ttlCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: cfg.TTL, + ReloadOnErr: cfg.ReloadOnErr, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Cache{ + cfg: cfg, + ttlCache: ttlCache, + }, nil +} + +type ttlCacheKey struct { + kind string +} + +// GetReadOnlyAuthPreference returns a read-only shared reference to the auth preference resource. +func (c *Cache) GetReadOnlyAuthPreference(ctx context.Context) (AuthPreference, error) { + if c.cfg.Disabled { + cfg, err := c.cfg.Upstream.GetAuthPreference(ctx) + return sealAuthPreference(cfg), trace.Wrap(err) + } + cfg, err := utils.FnCacheGet(ctx, c.ttlCache, ttlCacheKey{kind: types.KindClusterAuthPreference}, func(ctx context.Context) (AuthPreference, error) { + cfg, err := c.cfg.Upstream.GetAuthPreference(ctx) + return sealAuthPreference(cfg), trace.Wrap(err) + }) + return cfg, trace.Wrap(err) +} + +// GetReadOnlyClusterNetworkingConfig returns a read-only shared reference to the cluster networking config resource. +func (c *Cache) GetReadOnlyClusterNetworkingConfig(ctx context.Context) (ClusterNetworkingConfig, error) { + if c.cfg.Disabled { + cfg, err := c.cfg.Upstream.GetClusterNetworkingConfig(ctx) + return sealClusterNetworkingConfig(cfg), trace.Wrap(err) + } + cfg, err := utils.FnCacheGet(ctx, c.ttlCache, ttlCacheKey{kind: types.KindClusterNetworkingConfig}, func(ctx context.Context) (ClusterNetworkingConfig, error) { + cfg, err := c.cfg.Upstream.GetClusterNetworkingConfig(ctx) + return sealClusterNetworkingConfig(cfg), trace.Wrap(err) + }) + return cfg, trace.Wrap(err) +} + +// GetReadOnlySessionRecordingConfig returns a read-only shared reference to the session recording config resource. +func (c *Cache) GetReadOnlySessionRecordingConfig(ctx context.Context) (SessionRecordingConfig, error) { + if c.cfg.Disabled { + cfg, err := c.cfg.Upstream.GetSessionRecordingConfig(ctx) + return sealSessionRecordingConfig(cfg), trace.Wrap(err) + } + cfg, err := utils.FnCacheGet(ctx, c.ttlCache, ttlCacheKey{kind: types.KindSessionRecordingConfig}, func(ctx context.Context) (SessionRecordingConfig, error) { + cfg, err := c.cfg.Upstream.GetSessionRecordingConfig(ctx) + return sealSessionRecordingConfig(cfg), trace.Wrap(err) + }) + return cfg, trace.Wrap(err) +} diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go new file mode 100644 index 0000000000000..71462dd9f01b7 --- /dev/null +++ b/lib/services/readonly/readonly.go @@ -0,0 +1,105 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package readonly + +import ( + "time" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" +) + +// NOTE: is best to avoid importing anything from lib other than lib/utils in this package in +// order to ensure that we can import it anywhere api/types is being used. + +// AuthPreference is a read-only subset of types.AuthPreference used on certain hot paths +// to ensure that we do not modify the underlying AuthPreference as it may be shared across +// multiple goroutines. +type AuthPreference interface { + GetSecondFactor() constants.SecondFactorType + GetDisconnectExpiredCert() bool + GetLockingMode() constants.LockingMode + GetDeviceTrust() *types.DeviceTrust + GetPrivateKeyPolicy() keys.PrivateKeyPolicy + IsAdminActionMFAEnforced() bool + GetRequireMFAType() types.RequireMFAType + IsSAMLIdPEnabled() bool + GetDefaultSessionTTL() types.Duration + GetHardwareKeySerialNumberValidation() (*types.HardwareKeySerialNumberValidation, error) + GetAllowPasswordless() bool + Clone() types.AuthPreference +} + +type sealedAuthPreference struct { + AuthPreference +} + +// sealAuthPreference returns a read-only version of the AuthPreference. +func sealAuthPreference(p types.AuthPreference) AuthPreference { + if p == nil { + // preserving nils simplifies error flow-control + return nil + } + return sealedAuthPreference{AuthPreference: p} +} + +// ClusterNetworkingConfig is a read-only subset of types.ClusterNetworkingConfig used on certain hot paths +// to ensure that we do not modify the underlying ClusterNetworkingConfig as it may be shared across +// multiple goroutines. +type ClusterNetworkingConfig interface { + GetCaseInsensitiveRouting() bool + GetWebIdleTimeout() time.Duration + Clone() types.ClusterNetworkingConfig +} + +type sealedClusterNetworkingConfig struct { + ClusterNetworkingConfig +} + +// sealClusterNetworkingConfig returns a read-only version of the ClusterNetworkingConfig. +func sealClusterNetworkingConfig(c ClusterNetworkingConfig) ClusterNetworkingConfig { + if c == nil { + // preserving nils simplifies error flow-control + return nil + } + return sealedClusterNetworkingConfig{ClusterNetworkingConfig: c} +} + +// SessionRecordingConfig is a read-only subset of types.SessionRecordingConfig used on certain hot paths +// to ensure that we do not modify the underlying SessionRecordingConfig as it may be shared across +// multiple goroutines. +type SessionRecordingConfig interface { + GetMode() string + GetProxyChecksHostKeys() bool + Clone() types.SessionRecordingConfig +} + +type sealedSessionRecordingConfig struct { + SessionRecordingConfig +} + +// sealSessionRecordingConfig returns a read-only version of the SessionRecordingConfig. +func sealSessionRecordingConfig(c SessionRecordingConfig) SessionRecordingConfig { + if c == nil { + // preserving nils simplifies error flow-control + return nil + } + return sealedSessionRecordingConfig{SessionRecordingConfig: c} +} diff --git a/lib/services/readonly/readonly_test.go b/lib/services/readonly/readonly_test.go new file mode 100644 index 0000000000000..8876f27a7b64b --- /dev/null +++ b/lib/services/readonly/readonly_test.go @@ -0,0 +1,148 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package readonly + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +type testUpstream struct { + auth types.AuthPreference + networking types.ClusterNetworkingConfig + recording types.SessionRecordingConfig +} + +func (u *testUpstream) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return u.auth.Clone(), nil +} + +func (u *testUpstream) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) { + return u.networking.Clone(), nil +} + +func (u *testUpstream) GetSessionRecordingConfig(ctx context.Context) (types.SessionRecordingConfig, error) { + return u.recording.Clone(), nil +} + +// TestAuthPreference tests the GetReadOnlyAuthPreference method and verifies the read-only protections +// on the returned resource. +func TestAuthPreference(t *testing.T) { + upstreamCfg, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{}) + require.NoError(t, err) + + // Create a new cache instance. + cache, err := NewCache(CacheConfig{ + Upstream: &testUpstream{ + auth: upstreamCfg, + }, + TTL: time.Hour, + }) + require.NoError(t, err) + + // Get the auth preference resource. + authPref, err := cache.GetReadOnlyAuthPreference(context.Background()) + require.NoError(t, err) + + // Verify that the auth preference resource cannot be cast back to a write-supporting interface. + _, ok := authPref.(types.AuthPreference) + require.False(t, ok) + + authPref2, err := cache.GetReadOnlyAuthPreference(context.Background()) + require.NoError(t, err) + + // verify pointer equality (i.e. that subsequent reads return the same shared resource). + require.True(t, pointersEqual(authPref, authPref2)) +} + +func TestClusterNetworkingConfig(t *testing.T) { + // Create a new cache instance. + cache, err := NewCache(CacheConfig{ + Upstream: &testUpstream{ + networking: types.DefaultClusterNetworkingConfig(), + }, + TTL: time.Hour, + }) + require.NoError(t, err) + + // Get the cluster networking config resource. + networking, err := cache.GetReadOnlyClusterNetworkingConfig(context.Background()) + require.NoError(t, err) + + // Verify that the cluster networking config resource cannot be cast back to a write-supporting interface. + _, ok := networking.(types.ClusterNetworkingConfig) + require.False(t, ok) + + networking2, err := cache.GetReadOnlyClusterNetworkingConfig(context.Background()) + require.NoError(t, err) + + // verify pointer equality (i.e. that subsequent reads return the same shared resource). + require.True(t, pointersEqual(networking, networking2)) +} + +func TestSessionRecordingConfig(t *testing.T) { + // Create a new cache instance. + cache, err := NewCache(CacheConfig{ + Upstream: &testUpstream{ + recording: types.DefaultSessionRecordingConfig(), + }, + TTL: time.Hour, + }) + require.NoError(t, err) + + // Get the session recording config resource. + recording, err := cache.GetReadOnlySessionRecordingConfig(context.Background()) + require.NoError(t, err) + + // Verify that the session recording config resource cannot be cast back to a write-supporting interface. + _, ok := recording.(types.SessionRecordingConfig) + require.False(t, ok) + + recording2, err := cache.GetReadOnlySessionRecordingConfig(context.Background()) + require.NoError(t, err) + + // verify pointer equality (i.e. that subsequent reads return the same shared resource). + require.True(t, pointersEqual(recording, recording2)) +} + +// TestCloneBreaksEquality tests that cloning a resource breaks equality with the original resource +// (this is a sanity-check to make sure that the other tests in this package work since they rely upon +// cloned resources being distinct from the original in terms of interface equality). +func TestCloneBreaksEquality(t *testing.T) { + authPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{}) + require.NoError(t, err) + require.False(t, pointersEqual(authPref, authPref.Clone())) + + networking := types.DefaultClusterNetworkingConfig() + require.False(t, pointersEqual(networking, networking.Clone())) + + recording := types.DefaultSessionRecordingConfig() + require.False(t, pointersEqual(recording, recording.Clone())) +} + +// pointersEqual is a helper function that compares two pointers for equality. used to improve readability +// and avoid incorrect lints. +func pointersEqual(a, b interface{}) bool { + return a == b +} diff --git a/lib/services/role.go b/lib/services/role.go index 20abf5e11be0a..83d67a09e547c 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -46,6 +46,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -1197,7 +1198,7 @@ func (set RoleSet) PinSourceIP() bool { // GetAccessState returns the AccessState, setting [AccessState.MFARequired] // according to the user's roles and cluster auth preference. -func (set RoleSet) GetAccessState(authPref types.AuthPreference) AccessState { +func (set RoleSet) GetAccessState(authPref readonly.AuthPreference) AccessState { return AccessState{ MFARequired: set.getMFARequired(authPref.GetRequireMFAType()), // We don't set EnableDeviceVerification here, as both it and DeviceVerified @@ -1524,7 +1525,7 @@ func (set RoleSet) CheckGCPServiceAccounts(ttl time.Duration, overrideTTL bool) // CheckAccessToSAMLIdP checks access to the SAML IdP. // //nolint:revive // Because we want this to be IdP. -func (set RoleSet) CheckAccessToSAMLIdP(authPref types.AuthPreference) error { +func (set RoleSet) CheckAccessToSAMLIdP(authPref readonly.AuthPreference) error { if authPref != nil { if !authPref.IsSAMLIdPEnabled() { return trace.AccessDenied("SAML IdP is disabled at the cluster level") diff --git a/lib/srv/db/common/session.go b/lib/srv/db/common/session.go index b70b0f4c9b135..f3593f8a2f313 100644 --- a/lib/srv/db/common/session.go +++ b/lib/srv/db/common/session.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/lib/authz" dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" ) @@ -73,7 +74,7 @@ func (c *Session) String() string { // GetAccessState returns the AccessState based on the underlying // [services.AccessChecker] and [tlsca.Identity]. -func (c *Session) GetAccessState(authPref types.AuthPreference) services.AccessState { +func (c *Session) GetAccessState(authPref readonly.AuthPreference) services.AccessState { state := c.Checker.GetAccessState(authPref) state.MFAVerified = c.Identity.IsMFAVerified() state.EnableDeviceVerification = true diff --git a/lib/srv/db/common/session_test.go b/lib/srv/db/common/session_test.go index ccd5045a24353..d5c42baf78b52 100644 --- a/lib/srv/db/common/session_test.go +++ b/lib/srv/db/common/session_test.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/tlsca" ) @@ -94,6 +95,6 @@ type fakeAccessChecker struct { services.AccessChecker } -func (c *fakeAccessChecker) GetAccessState(authPref types.AuthPreference) services.AccessState { +func (c *fakeAccessChecker) GetAccessState(authPref readonly.AuthPreference) services.AccessState { return services.AccessState{} } diff --git a/lib/srv/db/sqlserver/engine_test.go b/lib/srv/db/sqlserver/engine_test.go index ba45897c00b76..3e9b7b42cf5dc 100644 --- a/lib/srv/db/sqlserver/engine_test.go +++ b/lib/srv/db/sqlserver/engine_test.go @@ -37,6 +37,7 @@ import ( libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol" "github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol/fixtures" @@ -425,7 +426,7 @@ func (m *mockChecker) CheckAccess(r services.AccessCheckable, state services.Acc return nil } -func (m *mockChecker) GetAccessState(authPref types.AuthPreference) services.AccessState { +func (m *mockChecker) GetAccessState(authPref readonly.AuthPreference) services.AccessState { if authPref.GetRequireMFAType().IsSessionMFARequired() { return services.AccessState{ MFARequired: services.MFARequiredAlways,