diff --git a/docs/pages/admin-guides/api/rbac.mdx b/docs/pages/admin-guides/api/rbac.mdx index 83b9ae16edc0b..f292d19a1b63e 100644 --- a/docs/pages/admin-guides/api/rbac.mdx +++ b/docs/pages/admin-guides/api/rbac.mdx @@ -859,7 +859,11 @@ spec: enabled: true max_session_ttl: 30h0m0s pin_source_ip: false - port_forwarding: true + ssh_port_forwarding: + remote: + enabled: true + local: + enabled: true record_session: default: best_effort desktop: true @@ -906,7 +910,11 @@ spec: enabled: true max_session_ttl: 30h0m0s pin_source_ip: false - port_forwarding: true + ssh_port_forwarding: + remote: + enabled: true + local: + enabled: true record_session: default: best_effort desktop: true diff --git a/docs/pages/enroll-resources/server-access/rbac.mdx b/docs/pages/enroll-resources/server-access/rbac.mdx index 6cc43636c5d27..036a91e7aa642 100644 --- a/docs/pages/enroll-resources/server-access/rbac.mdx +++ b/docs/pages/enroll-resources/server-access/rbac.mdx @@ -135,8 +135,18 @@ spec: create_host_user_mode: keep # forward_agent controls whether SSH agent forwarding is allowed forward_agent: true - # port_forwarding controls whether TCP port forwarding is allowed for SSH - port_forwarding: true + # ssh_port_forwarding controls which TCP port forwarding modes are allowed over SSH. This replaces + # the deprecated port_forwarding field, which did not differentiate between remote and local + # port forwarding modes. If you have any existing roles that allow forwarding by enabling the + # legacy port_forwarding field then the forwarding controls configured in ssh_port_forwarding will be + # ignored. + ssh_port_forwarding: + # configures remote port forwarding behavior + remote: + enabled: true + # configures local port forwarding behavior + local: + enabled: true # ssh_file_copy controls whether file copying (SCP/SFTP) is allowed. # Defaults to true. ssh_file_copy: false diff --git a/docs/pages/includes/role-spec.mdx b/docs/pages/includes/role-spec.mdx index ef780959cd30b..1a0f2cfeece0c 100644 --- a/docs/pages/includes/role-spec.mdx +++ b/docs/pages/includes/role-spec.mdx @@ -13,8 +13,18 @@ spec: max_session_ttl: 8h # forward_agent controls whether SSH agent forwarding is allowed forward_agent: true - # port_forwarding controls whether TCP port forwarding is allowed for SSH - port_forwarding: true + # ssh_port_forwarding controls which TCP port forwarding modes are allowed over SSH. This replaces + # the deprecated port_forwarding field, which did not differentiate between remote and local + # port forwarding modes. If you have any existing roles that allow forwarding by enabling the + # legacy port_forwarding field then the forwarding controls configured in ssh_port_forwarding will be + # ignored. + ssh_port_forwarding: + # configures remote port forwarding behavior + remote: + enabled: true + # configures local port forwarding behavior + local: + enabled: true # ssh_file_copy controls whether file copying (SCP/SFTP) is allowed. # Defaults to true. ssh_file_copy: false diff --git a/docs/pages/reference/access-controls/roles.mdx b/docs/pages/reference/access-controls/roles.mdx index 86029bff5012b..c67dd234b8642 100644 --- a/docs/pages/reference/access-controls/roles.mdx +++ b/docs/pages/reference/access-controls/roles.mdx @@ -52,7 +52,7 @@ user: | - | - | - | | `max_session_ttl` | Max. time to live (TTL) of a user's SSH certificates | The shortest TTL wins | | `forward_agent` | Allow SSH agent forwarding | Logical "OR" i.e. if any role allows agent forwarding, it's allowed | -| `port_forwarding` | Allow TCP port forwarding | Logical "OR" i.e. if any role allows port forwarding, it's allowed | +| `ssh_port_forwarding` | Allow TCP port forwarding | Logical "AND" i.e. if any role denies port forwarding, it's denied | | `ssh_file_copy` | Allow SCP/SFTP | Logical "AND" i.e. if all roles allows file copying, it's allowed | | `client_idle_timeout` | Forcefully terminate active sessions after an idle interval | The shortest timeout value wins, i.e. the most restrictive value is selected | | `disconnect_expired_cert` | Forcefully terminate active sessions when a client certificate expires | Logical "OR" i.e. evaluates to "yes" if at least one role requires session termination | diff --git a/docs/pages/reference/terraform-provider/resources/role.mdx b/docs/pages/reference/terraform-provider/resources/role.mdx index 3d573fa65646b..70d9c3edc0f1e 100644 --- a/docs/pages/reference/terraform-provider/resources/role.mdx +++ b/docs/pages/reference/terraform-provider/resources/role.mdx @@ -27,9 +27,17 @@ resource "teleport_role" "example" { spec = { options = { - forward_agent = false - max_session_ttl = "7m" - port_forwarding = false + forward_agent = false + max_session_ttl = "7m" + ssh_port_forwarding = { + remote = { + enabled = false + } + + local = { + enabled = false + } + } client_idle_timeout = "1h" disconnect_expired_cert = true permit_x11_forwarding = false diff --git a/examples/resources/admin.yaml b/examples/resources/admin.yaml index 2c8427a632a4a..acb170f290e18 100644 --- a/examples/resources/admin.yaml +++ b/examples/resources/admin.yaml @@ -28,5 +28,9 @@ spec: - network forward_agent: true max_session_ttl: 30h0m0s - port_forwarding: true + ssh_port_forwarding: + remote: + enabled: true + local: + enabled: true version: v3 diff --git a/examples/resources/user.yaml b/examples/resources/user.yaml index 07ab839b3286a..8f47afd3fc0e9 100644 --- a/examples/resources/user.yaml +++ b/examples/resources/user.yaml @@ -56,5 +56,9 @@ spec: - network forward_agent: true max_session_ttl: 30h0m0s - port_forwarding: true + ssh_port_forwarding: + remote: + enabled: true + local: + enabled: true version: v3 diff --git a/integrations/terraform/examples/resources/teleport_role/resource.tf b/integrations/terraform/examples/resources/teleport_role/resource.tf index c5ac6c920e5d9..c76e71248e962 100644 --- a/integrations/terraform/examples/resources/teleport_role/resource.tf +++ b/integrations/terraform/examples/resources/teleport_role/resource.tf @@ -13,9 +13,17 @@ resource "teleport_role" "example" { spec = { options = { - forward_agent = false - max_session_ttl = "7m" - port_forwarding = false + forward_agent = false + max_session_ttl = "7m" + ssh_port_forwarding = { + remote = { + enabled = false + } + + local = { + enabled = false + } + } client_idle_timeout = "1h" disconnect_expired_cert = true permit_x11_forwarding = false diff --git a/integrations/terraform/reference.mdx b/integrations/terraform/reference.mdx index 078a1bb73ca08..a5df9cc060e46 100755 --- a/integrations/terraform/reference.mdx +++ b/integrations/terraform/reference.mdx @@ -2051,7 +2051,8 @@ Options is for OpenSSH options like agent forwarding. | max_sessions | number | | MaxSessions defines the maximum number of concurrent sessions per connection. | | permit_x11_forwarding | bool | | PermitX11Forwarding authorizes use of X11 forwarding. | | pin_source_ip | bool | | PinSourceIP forces the same client IP for certificate generation and usage | -| port_forwarding | bool | | | +| ssh_port_forwarding | object | | SSHPortForwarding configures what types of SSH port forwarding are allowed by a role. | +| port_forwarding | bool | | Deprecated: Use SSHPortForwarding instead. | | record_session | object | | RecordDesktopSession indicates whether desktop access sessions should be recorded. It defaults to true unless explicitly set to false. | | request_access | string | | RequestAccess defines the access request strategy (optional|note|always) where optional is the default. | | request_prompt | string | | RequestPrompt is an optional message which tells users what they aught to request. | @@ -2085,6 +2086,31 @@ SAML are options related to the Teleport SAML IdP. |---------|------|----------|-------------| | enabled | bool | | | +##### spec.options.ssh_port_forwarding + +SSHPortForwarding configures what types of SSH port forwarding are allowed by a role. + +| Name | Type | Required | Description | +|--------|--------|----------|-----------------------------------------------------------| +| remote | object | | remote contains options related to remote port forwarding | +| local | object | | local contains options related to local port forwarding | + +###### spec.options.ssh_port_forwarding.remote + +remote contains options related to remote port forwarding + +| Name | Type | Required | Description | +|---------|------|----------|-------------| +| enabled | bool | | | + +###### spec.options.ssh_port_forwarding.local + +local contains options related to local port forwarding + +| Name | Type | Required | Description | +|---------|------|----------|-------------| +| enabled | bool | | | + ##### spec.options.record_session RecordDesktopSession indicates whether desktop access sessions should be recorded. It defaults to true unless explicitly set to false. @@ -2114,11 +2140,19 @@ resource "teleport_role" "example" { options = { forward_agent = false max_session_ttl = "7m" - port_forwarding = false client_idle_timeout = "1h" disconnect_expired_cert = true permit_x11_forwarding = false request_access = "denied" + ssh_port_forwarding = { + remote = { + enabled = false + } + + local = { + enabled = false + } + } } allow = { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index c240ad6fc585f..4a88c5e083603 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1508,12 +1508,11 @@ func (a *Server) runPeriodicOperations() { heartbeatsMissedByAuth.Inc() } + if srv.GetSubKind() != types.SubKindOpenSSHNode { + return false, nil + } // TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then. if !validServerHostname(srv.GetHostname()) { - if srv.GetSubKind() != types.SubKindOpenSSHNode { - return false, nil - } - logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname()) logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname") @@ -1527,6 +1526,17 @@ func (a *Server) runPeriodicOperations() { if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err) } + } else if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) { + // If the hostname has been replaced by a sanitized version, revert it back to the original + // if the original is valid under the most recent rules. + logger := a.logger.With("server", srv.GetName(), "old_hostname", oldHostname, "sanitized_hostname", srv.GetHostname()) + if err := restoreSanitizedHostname(srv); err != nil { + logger.WarnContext(a.closeCtx, "failed to restore sanitized static SSH server hostname", "error", err) + return false, nil + } + if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { + log.Warnf("Failed to update node hostname: %v", err) + } } return false, nil @@ -5650,7 +5660,7 @@ func (a *Server) KeepAliveServer(ctx context.Context, h types.KeepAlive) error { const ( serverHostnameMaxLen = 256 - serverHostnameRegexPattern = `^[a-zA-Z0-9]([\.-]?[a-zA-Z0-9]+)*$` + serverHostnameRegexPattern = `^[a-zA-Z0-9]+[a-zA-Z0-9\.-]*$` replacedHostnameLabel = types.TeleportInternalLabelPrefix + "invalid-hostname" ) @@ -5658,7 +5668,7 @@ var serverHostnameRegex = regexp.MustCompile(serverHostnameRegexPattern) // validServerHostname returns false if the hostname is longer than 256 characters or // does not entirely consist of alphanumeric characters as well as '-' and '.'. A valid hostname also -// cannot begin with a symbol, and a symbol cannot be followed immediately by another symbol. +// cannot begin with a symbol. func validServerHostname(hostname string) bool { return len(hostname) <= serverHostnameMaxLen && serverHostnameRegex.MatchString(hostname) } @@ -5697,6 +5707,26 @@ func sanitizeHostname(server types.Server) error { return nil } +// restoreSanitizedHostname restores the original hostname of a server and removes the label. +func restoreSanitizedHostname(server types.Server) error { + oldHostname, ok := server.GetLabels()[replacedHostnameLabel] + // if the label is not present or the hostname is invalid under the most recent rules, do nothing. + if !ok || !validServerHostname(oldHostname) { + return nil + } + + switch s := server.(type) { + case *types.ServerV2: + // restore the original hostname and remove the label. + s.Spec.Hostname = oldHostname + delete(s.Metadata.Labels, replacedHostnameLabel) + default: + return trace.BadParameter("invalid server provided") + } + + return nil +} + // UpsertNode implements [services.Presence] by delegating to [Server.Services] // and potentially emitting a [usagereporter] event. func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) { diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index e4978e32e358a..8f535a1727588 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -4478,6 +4478,10 @@ func TestServerHostnameSanitization(t *testing.T) { name: "uuid dns hostname", hostname: uuid.NewString() + ".example.com", }, + { + name: "valid dns hostname with multi-dots", + hostname: "llama..example.com", + }, { name: "empty hostname", hostname: "", @@ -4488,11 +4492,6 @@ func TestServerHostnameSanitization(t *testing.T) { hostname: strings.Repeat("a", serverHostnameMaxLen*2), invalidHostname: true, }, - { - name: "invalid dns hostname", - hostname: "llama..example.com", - invalidHostname: true, - }, { name: "spaces in hostname", hostname: "the quick brown fox jumps over the lazy dog", @@ -4562,3 +4561,74 @@ func TestServerHostnameSanitization(t *testing.T) { }) } } + +func TestValidServerHostname(t *testing.T) { + t.Parallel() + tests := []struct { + name string + hostname string + want bool + }{ + { + name: "valid dns hostname", + hostname: "llama.example.com", + want: true, + }, + { + name: "valid friendly hostname", + hostname: "llama", + want: true, + }, + { + name: "uuid hostname", + hostname: uuid.NewString(), + want: true, + }, + { + name: "valid hostname with multi-dashes", + hostname: "llama--example.com", + want: true, + }, + { + name: "valid hostname with multi-dots", + hostname: "llama..example.com", + want: true, + }, + { + name: "valid hostname with numbers", + hostname: "llama9", + want: true, + }, + { + name: "hostname with invalid characters", + hostname: "llama?!$", + want: false, + }, + { + name: "super long hostname", + hostname: strings.Repeat("a", serverHostnameMaxLen*2), + want: false, + }, + { + name: "hostname with spaces", + hostname: "the quick brown fox jumps over the lazy dog", + want: false, + }, + { + name: "hostname with ;", + hostname: "llama;example.com", + want: false, + }, + { + name: "empty hostname", + hostname: "", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := validServerHostname(tt.hostname) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/lib/client/clientcache/clientcache.go b/lib/client/clientcache/clientcache.go index 5a9c4df29e7de..f5e8f44aafdf9 100644 --- a/lib/client/clientcache/clientcache.go +++ b/lib/client/clientcache/clientcache.go @@ -18,11 +18,11 @@ package clientcache import ( "context" + "log/slog" "slices" "sync" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "golang.org/x/sync/singleflight" "github.com/gravitational/teleport" @@ -53,7 +53,7 @@ type RetryWithReloginFunc func(ctx context.Context, tc *client.TeleportClient, f type Config struct { NewClientFunc NewClientFunc RetryWithReloginFunc RetryWithReloginFunc - Log logrus.FieldLogger + Logger *slog.Logger } func (c *Config) checkAndSetDefaults() error { @@ -63,8 +63,8 @@ func (c *Config) checkAndSetDefaults() error { if c.RetryWithReloginFunc == nil { return trace.BadParameter("RetryWithReloginFunc is required") } - if c.Log == nil { - c.Log = logrus.WithField(teleport.ComponentKey, "clientcache") + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "clientcache") } return nil } @@ -99,7 +99,7 @@ func (c *Cache) Get(ctx context.Context, profileName, leafClusterName string) (* k := key{profile: profileName, leafCluster: leafClusterName} groupClt, err, _ := c.group.Do(k.String(), func() (any, error) { if fromCache := c.getFromCache(k); fromCache != nil { - c.cfg.Log.WithField("cluster", k).Debug("Retrieved client from cache.") + c.cfg.Logger.DebugContext(ctx, "Retrieved client from cache", "cluster", k) return fromCache, nil } @@ -123,7 +123,7 @@ func (c *Cache) Get(ctx context.Context, profileName, leafClusterName string) (* // Save the client in the cache, so we don't have to build a new connection next time. c.addToCache(k, newClient) - c.cfg.Log.WithField("cluster", k).Info("Added client to cache.") + c.cfg.Logger.InfoContext(ctx, "Added client to cache", "cluster", k) return newClient, nil }) @@ -159,9 +159,10 @@ func (c *Cache) ClearForRoot(profileName string) error { } } - c.cfg.Log.WithFields( - logrus.Fields{"cluster": profileName, "clients": deleted}, - ).Info("Invalidated cached clients for root cluster.") + c.cfg.Logger.InfoContext(context.Background(), "Invalidated cached clients for root cluster", + "cluster", profileName, + "clients", deleted, + ) return trace.NewAggregate(errors...) diff --git a/lib/client/db/dbcmd/dbcmd.go b/lib/client/db/dbcmd/dbcmd.go index 603a284ea7ec9..376435b260d40 100644 --- a/lib/client/db/dbcmd/dbcmd.go +++ b/lib/client/db/dbcmd/dbcmd.go @@ -21,6 +21,7 @@ package dbcmd import ( "context" "fmt" + "log/slog" "net/url" "os" "os/exec" @@ -30,7 +31,6 @@ import ( "strings" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "github.com/gravitational/teleport/api/constants" @@ -143,8 +143,8 @@ func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus, host, port = tc.DatabaseProxyHostPort(db) } - if options.log == nil { - options.log = logrus.NewEntry(logrus.StandardLogger()) + if options.logger == nil { + options.logger = slog.Default() } if options.exe == nil { @@ -256,8 +256,11 @@ func (c *CLICommandBuilder) getPostgresCommand() *exec.Cmd { func (c *CLICommandBuilder) getCockroachCommand() *exec.Cmd { // If cockroach CLI client is not available, fallback to psql. if _, err := c.options.exe.LookPath(cockroachBin); err != nil { - c.options.log.Debugf("Couldn't find %q client in PATH, falling back to %q: %v.", - cockroachBin, postgresBin, err) + c.options.logger.DebugContext(context.Background(), "Couldn't find cockroach client in PATH, falling back to postgres client", + "cockroach_client", cockroachBin, + "postgres_client", postgresBin, + "error", err, + ) return c.getPostgresCommand() } return exec.Command(cockroachBin, "sql", "--url", c.getPostgresConnString()) @@ -560,7 +563,10 @@ func (c *CLICommandBuilder) getMongoAddress() string { // force a different timeout for debugging purpose or extreme situations. serverSelectionTimeoutMS := "5000" if envValue := os.Getenv(envVarMongoServerSelectionTimeoutMS); envValue != "" { - c.options.log.Infof("Using environment variable %s=%s.", envVarMongoServerSelectionTimeoutMS, envValue) + c.options.logger.InfoContext(context.Background(), "Using server selection timeout value from environment variable", + "environment_variable", envVarMongoServerSelectionTimeoutMS, + "server_selection_timeout", envValue, + ) serverSelectionTimeoutMS = envValue } query.Set("serverSelectionTimeoutMS", serverSelectionTimeoutMS) @@ -905,7 +911,7 @@ type connectionCommandOpts struct { noTLS bool printFormat bool tolerateMissingCLIClient bool - log *logrus.Entry + logger *slog.Logger exe Execer password string gcp types.GCPCloudSQL @@ -969,9 +975,9 @@ func WithPrintFormat() ConnectCommandFunc { // WithLogger is the connect command option that allows the caller to pass a logger that will be // used by CLICommandBuilder. -func WithLogger(log *logrus.Entry) ConnectCommandFunc { +func WithLogger(log *slog.Logger) ConnectCommandFunc { return func(opts *connectionCommandOpts) { - opts.log = log + opts.logger = log } } diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index 98775ec69040c..f241f6d36ddb8 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -41,7 +41,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/observability/tracing" tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" - "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/utils" ) @@ -155,23 +154,6 @@ func MakeStdHandlerWithErrorWriter(fn StdHandlerFunc, errWriter ErrorWriter) htt } } -// WithCSRFProtection ensures that request to unauthenticated API is checked against CSRF attacks -func WithCSRFProtection(fn HandlerFunc) httprouter.Handle { - handlerFn := MakeHandler(fn) - return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - if r.Method != http.MethodGet && r.Method != http.MethodHead { - errHeader := csrf.VerifyHTTPHeader(r) - errForm := csrf.VerifyFormField(r) - if errForm != nil && errHeader != nil { - slog.WarnContext(r.Context(), "unable to validate CSRF token", "header_error", errHeader, "form_error", errForm) - trace.WriteError(w, trace.AccessDenied("access denied")) - return - } - } - handlerFn(w, r, p) - } -} - // ReadJSON reads HTTP json request and unmarshals it // into passed any obj. A reasonable maximum size is enforced // to mitigate resource exhaustion attacks. @@ -188,6 +170,7 @@ func ReadResourceJSON(r *http.Request, val any) error { func readJSON(r *http.Request, val any, maxSize int64) error { // Check content type to mitigate CSRF attack. + // (Form POST requests don't support application/json payloads.) contentType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { slog.WarnContext(r.Context(), "Error parsing media type for reading JSON", "error", err) diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index 8ea733c950dc6..4bdbd2f596c41 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/inventory/internal/delay" + "github.com/gravitational/teleport/lib/services" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" @@ -118,11 +119,13 @@ const ( keepAliveKubeTick = "keep-alive-kube-tick" ) -// instanceHBStepSize is the step size used for the variable instance heartbeat duration. This value is -// basically arbitrary. It was selected because it produces a scaling curve that makes a fairly reasonable -// tradeoff between heartbeat availability and load scaling. See test coverage in the 'interval' package -// for a demonstration of the relationship between step sizes and interval/duration scaling. -const instanceHBStepSize = 1024 +// heartbeatStepSize is the step size used for the variable heartbeat intervals. +// This value is basically arbitrary. It was selected because it produces a +// scaling curve that makes a fairly reasonable tradeoff between heartbeat +// availability and load scaling. See test coverage in the 'interval' package +// for a demonstration of the relationship between step sizes and +// interval/duration scaling. +const heartbeatStepSize = 1024 type controllerOptions struct { serverKeepAlive time.Duration @@ -233,6 +236,10 @@ type Controller struct { instanceTTL time.Duration instanceHBEnabled bool instanceHBVariableDuration *interval.VariableDuration + sshHBVariableDuration *interval.VariableDuration + appHBVariableDuration *interval.VariableDuration + dbHBVariableDuration *interval.VariableDuration + kubeHBVariableDuration *interval.VariableDuration maxKeepAliveErrs int usageReporter usagereporter.UsageReporter testEvents chan testEvent @@ -254,18 +261,55 @@ func NewController(auth Auth, usageReporter usagereporter.UsageReporter, opts .. instanceHBVariableDuration := interval.NewVariableDuration(interval.VariableDurationConfig{ MinDuration: options.instanceHBInterval, MaxDuration: apidefaults.MaxInstanceHeartbeatInterval, - Step: instanceHBStepSize, + Step: heartbeatStepSize, }) + var ( + sshHBVariableDuration *interval.VariableDuration + appHBVariableDuration *interval.VariableDuration + dbHBVariableDuration *interval.VariableDuration + kubeHBVariableDuration *interval.VariableDuration + ) + serverTTL := apidefaults.ServerAnnounceTTL + if !variableRateHeartbeatsDisabledEnv() { + // by default, heartbeats will scale from 1.5 to 6 minutes, and will + // have a TTL of 15 minutes + serverTTL = apidefaults.ServerAnnounceTTL * 3 / 2 + sshHBVariableDuration = interval.NewVariableDuration(interval.VariableDurationConfig{ + MinDuration: options.serverKeepAlive, + MaxDuration: options.serverKeepAlive * 4, + Step: heartbeatStepSize, + }) + appHBVariableDuration = interval.NewVariableDuration(interval.VariableDurationConfig{ + MinDuration: options.serverKeepAlive, + MaxDuration: options.serverKeepAlive * 4, + Step: heartbeatStepSize, + }) + dbHBVariableDuration = interval.NewVariableDuration(interval.VariableDurationConfig{ + MinDuration: options.serverKeepAlive, + MaxDuration: options.serverKeepAlive * 4, + Step: heartbeatStepSize, + }) + kubeHBVariableDuration = interval.NewVariableDuration(interval.VariableDurationConfig{ + MinDuration: options.serverKeepAlive, + MaxDuration: options.serverKeepAlive * 4, + Step: heartbeatStepSize, + }) + } + ctx, cancel := context.WithCancel(context.Background()) return &Controller{ store: NewStore(), serviceCounter: &serviceCounter{}, serverKeepAlive: options.serverKeepAlive, - serverTTL: apidefaults.ServerAnnounceTTL, + serverTTL: serverTTL, instanceTTL: apidefaults.InstanceHeartbeatTTL, instanceHBEnabled: !instanceHeartbeatsDisabledEnv(), instanceHBVariableDuration: instanceHBVariableDuration, + sshHBVariableDuration: sshHBVariableDuration, + appHBVariableDuration: appHBVariableDuration, + dbHBVariableDuration: dbHBVariableDuration, + kubeHBVariableDuration: kubeHBVariableDuration, maxKeepAliveErrs: options.maxKeepAliveErrs, auth: auth, authID: options.authID, @@ -417,23 +461,36 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { if handle.sshServer != nil { c.onDisconnectFunc(constants.KeepAliveNode, 1) + if c.sshHBVariableDuration != nil { + c.sshHBVariableDuration.Dec() + } + handle.sshServer = nil } if len(handle.appServers) > 0 { c.onDisconnectFunc(constants.KeepAliveApp, len(handle.appServers)) + if c.appHBVariableDuration != nil { + c.appHBVariableDuration.Add(-len(handle.appServers)) + } + clear(handle.appServers) } if len(handle.databaseServers) > 0 { c.onDisconnectFunc(constants.KeepAliveDatabase, len(handle.databaseServers)) + if c.dbHBVariableDuration != nil { + c.dbHBVariableDuration.Add(-len(handle.databaseServers)) + } + clear(handle.databaseServers) } if len(handle.kubernetesServers) > 0 { c.onDisconnectFunc(constants.KeepAliveKube, len(handle.kubernetesServers)) + if c.kubeHBVariableDuration != nil { + c.kubeHBVariableDuration.Add(-len(handle.kubernetesServers)) + } + clear(handle.kubernetesServers) } - clear(handle.appServers) - clear(handle.databaseServers) - clear(handle.kubernetesServers) c.testEvent(handlerClose) }() @@ -448,40 +505,60 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { case proto.UpstreamInventoryAgentMetadata: c.handleAgentMetadata(handle, m) case proto.InventoryHeartbeat: - if err := c.handleHeartbeatMsg(handle, m); err != nil { - handle.CloseWithError(err) - return + // XXX: when adding new services to the heartbeat logic, make + // sure to also update the 'icsServiceToMetricName' mapping in + // auth/grpcserver.go in order to ensure that metrics start + // counting the control stream as a registered keepalive stream + // for that service. + + if m.SSHServer != nil { + // we initialize sshKeepAliveDelay before calling + // handleSSHServerHB unlike the other heartbeat types + // because handleSSHServerHB needs the delay to reset it + // after an announce, including the first one + if sshKeepAliveDelay == nil { + sshKeepAliveDelay = c.createKeepAliveDelay(c.sshHBVariableDuration) + } + + if err := c.handleSSHServerHB(handle, m.SSHServer, sshKeepAliveDelay); err != nil { + handle.CloseWithError(trace.Wrap(err)) + return + } } - // we initialize delays lazily here, depending on the protocol - if sshKeepAliveDelay == nil && m.SSHServer != nil { - sshKeepAliveDelay = delay.New(delay.Params{ - FirstInterval: retryutils.HalfJitter(c.serverKeepAlive), - FixedInterval: c.serverKeepAlive, - Jitter: retryutils.SeventhJitter, - }) - } - if appKeepAliveDelay == nil && m.AppServer != nil { - appKeepAliveDelay = delay.New(delay.Params{ - FirstInterval: retryutils.HalfJitter(c.serverKeepAlive), - FixedInterval: c.serverKeepAlive, - Jitter: retryutils.SeventhJitter, - }) + if m.AppServer != nil { + if err := c.handleAppServerHB(handle, m.AppServer); err != nil { + handle.CloseWithError(err) + return + } + + if appKeepAliveDelay == nil { + appKeepAliveDelay = c.createKeepAliveDelay(c.appHBVariableDuration) + } } - if dbKeepAliveDelay == nil && m.DatabaseServer != nil { - dbKeepAliveDelay = delay.New(delay.Params{ - FirstInterval: retryutils.HalfJitter(c.serverKeepAlive), - FixedInterval: c.serverKeepAlive, - Jitter: retryutils.SeventhJitter, - }) + + if m.DatabaseServer != nil { + if err := c.handleDatabaseServerHB(handle, m.DatabaseServer); err != nil { + handle.CloseWithError(err) + return + } + + if dbKeepAliveDelay == nil { + dbKeepAliveDelay = c.createKeepAliveDelay(c.dbHBVariableDuration) + } } - if kubeKeepAliveDelay == nil && m.KubernetesServer != nil { - kubeKeepAliveDelay = delay.New(delay.Params{ - FirstInterval: retryutils.HalfJitter(c.serverKeepAlive), - FixedInterval: c.serverKeepAlive, - Jitter: retryutils.SeventhJitter, - }) + + if m.KubernetesServer != nil { + if err := c.handleKubernetesServerHB(handle, m.KubernetesServer); err != nil { + handle.CloseWithError(err) + return + } + + if kubeKeepAliveDelay == nil { + kubeKeepAliveDelay = c.createKeepAliveDelay(c.kubeHBVariableDuration) + } } + case proto.UpstreamInventoryPong: c.handlePong(handle, m) case proto.UpstreamInventoryGoodbye: @@ -570,6 +647,12 @@ func instanceHeartbeatsDisabledEnv() bool { return os.Getenv("TELEPORT_UNSTABLE_DISABLE_INSTANCE_HB") == "yes" } +// variableRateHeartbeatsDisabledEnv checks if variable rate heartbeats have +// been explicitly disabled via environment variable. +func variableRateHeartbeatsDisabledEnv() bool { + return os.Getenv("TELEPORT_UNSTABLE_DISABLE_VARIABLE_RATE_HEARTBEATS") == "yes" +} + func (c *Controller) heartbeatInstanceState(handle *upstreamHandle, now time.Time) error { if !c.instanceHBEnabled { return nil @@ -673,39 +756,7 @@ func (c *Controller) handlePingRequest(handle *upstreamHandle, req pingRequest) return nil } -func (c *Controller) handleHeartbeatMsg(handle *upstreamHandle, hb proto.InventoryHeartbeat) error { - // XXX: when adding new services to the heartbeat logic, make sure to also update the - // 'icsServiceToMetricName' mapping in auth/grpcserver.go in order to ensure that metrics - // start counting the control stream as a registered keepalive stream for that service. - - if hb.SSHServer != nil { - if err := c.handleSSHServerHB(handle, hb.SSHServer); err != nil { - return trace.Wrap(err) - } - } - - if hb.AppServer != nil { - if err := c.handleAppServerHB(handle, hb.AppServer); err != nil { - return trace.Wrap(err) - } - } - - if hb.DatabaseServer != nil { - if err := c.handleDatabaseServerHB(handle, hb.DatabaseServer); err != nil { - return trace.Wrap(err) - } - } - - if hb.KubernetesServer != nil { - if err := c.handleKubernetesServerHB(handle, hb.KubernetesServer); err != nil { - return trace.Wrap(err) - } - } - - return nil -} - -func (c *Controller) handleSSHServerHB(handle *upstreamHandle, sshServer *types.ServerV2) error { +func (c *Controller) handleSSHServerHB(handle *upstreamHandle, sshServer *types.ServerV2, sshDelay *delay.Delay) error { // the auth layer verifies that a stream's hello message matches the identity and capabilities of the // client cert. after that point it is our responsibility to ensure that heartbeated information is // consistent with the identity and capabilities claimed in the initial hello. @@ -722,31 +773,48 @@ func (c *Controller) handleSSHServerHB(handle *upstreamHandle, sshServer *types. sshServer.SetAddr(utils.ReplaceLocalhost(sshServer.GetAddr(), handle.PeerAddr())) } + sshServer.SetExpiry(time.Now().Add(c.serverTTL).UTC()) + if handle.sshServer == nil { c.onConnectFunc(constants.KeepAliveNode) - handle.sshServer = &heartBeatInfo[*types.ServerV2]{} + if c.sshHBVariableDuration != nil { + c.sshHBVariableDuration.Inc() + } + handle.sshServer = &heartBeatInfo[*types.ServerV2]{ + resource: sshServer, + } + } else if handle.sshServer.keepAliveErrs == 0 && services.CompareServers(handle.sshServer.resource, sshServer) < services.Different { + // if we have successfully upserted this exact server the last time + // (except for the expiry), we don't need to upsert it again right now + return nil + } else { + handle.sshServer.resource = sshServer } - now := c.clock.Now() - - sshServer.SetExpiry(now.Add(c.serverTTL).UTC()) - - lease, err := c.auth.UpsertNode(c.closeContext, sshServer) - if err == nil { + if _, err := c.auth.UpsertNode(c.closeContext, handle.sshServer.resource); err == nil { c.testEvent(sshUpsertOk) - // store the new lease and reset retry state - handle.sshServer.lease = lease + // reset the error status + handle.sshServer.keepAliveErrs = 0 handle.sshServer.retryUpsert = false + + sshDelay.Reset() } else { c.testEvent(sshUpsertErr) - slog.WarnContext(c.closeContext, "Failed to upsert ssh server on heartbeat", + slog.WarnContext(c.closeContext, "Failed to announce SSH server", "server_id", handle.Hello().ServerID, "error", err, ) - // blank old lease if any and set retry state. next time handleKeepAlive is called - // we will attempt to upsert the server again. - handle.sshServer.lease = nil + // we use keepAliveErrs as a general upsert error count for SSH, + // retryUpsert as a flag to signify that we MUST succeed the very next + // upsert: if we're here it means that we have a new resource to upsert + // and we have failed to do so once, so if we fail again we are going to + // fall too far behind and we should let the instance go and connect to + // a healthier auth server + handle.sshServer.keepAliveErrs++ + if handle.sshServer.retryUpsert || handle.sshServer.keepAliveErrs > c.maxKeepAliveErrs { + return trace.Wrap(err, "failed to announce SSH server") + } handle.sshServer.retryUpsert = true } handle.sshServer.resource = sshServer @@ -772,6 +840,9 @@ func (c *Controller) handleAppServerHB(handle *upstreamHandle, appServer *types. if _, ok := handle.appServers[appKey]; !ok { c.onConnectFunc(constants.KeepAliveApp) + if c.appHBVariableDuration != nil { + c.appHBVariableDuration.Inc() + } handle.appServers[appKey] = &heartBeatInfo[*types.AppServerV3]{} } @@ -823,6 +894,9 @@ func (c *Controller) handleDatabaseServerHB(handle *upstreamHandle, databaseServ if _, ok := handle.databaseServers[dbKey]; !ok { c.onConnectFunc(constants.KeepAliveDatabase) + if c.dbHBVariableDuration != nil { + c.dbHBVariableDuration.Inc() + } handle.databaseServers[dbKey] = &heartBeatInfo[*types.DatabaseServerV3]{} } @@ -874,6 +948,9 @@ func (c *Controller) handleKubernetesServerHB(handle *upstreamHandle, kubernetes if _, ok := handle.kubernetesServers[kubeKey]; !ok { c.onConnectFunc(constants.KeepAliveKube) + if c.kubeHBVariableDuration != nil { + c.kubeHBVariableDuration.Inc() + } handle.kubernetesServers[kubeKey] = &heartBeatInfo[*types.KubernetesServerV3]{} } @@ -951,6 +1028,9 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e if shouldRemove { c.testEvent(appKeepAliveDel) c.onDisconnectFunc(constants.KeepAliveApp, 1) + if c.appHBVariableDuration != nil { + c.appHBVariableDuration.Dec() + } delete(handle.appServers, name) } } else { @@ -1002,6 +1082,9 @@ func (c *Controller) keepAliveDatabaseServer(handle *upstreamHandle, now time.Ti if shouldRemove { c.testEvent(dbKeepAliveDel) c.onDisconnectFunc(constants.KeepAliveDatabase, 1) + if c.dbHBVariableDuration != nil { + c.dbHBVariableDuration.Dec() + } delete(handle.databaseServers, name) } } else { @@ -1053,6 +1136,9 @@ func (c *Controller) keepAliveKubernetesServer(handle *upstreamHandle, now time. if shouldRemove { c.testEvent(kubeKeepAliveDel) c.onDisconnectFunc(constants.KeepAliveKube, 1) + if c.kubeHBVariableDuration != nil { + c.kubeHBVariableDuration.Dec() + } delete(handle.kubernetesServers, name) } } else { @@ -1088,50 +1174,54 @@ func (c *Controller) keepAliveSSHServer(handle *upstreamHandle, now time.Time) e return nil } - if handle.sshServer.lease != nil { - lease := *handle.sshServer.lease - lease.Expires = now.Add(c.serverTTL).UTC() - if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { - c.testEvent(sshKeepAliveErr) - handle.sshServer.keepAliveErrs++ - shouldClose := handle.sshServer.keepAliveErrs > c.maxKeepAliveErrs - - slog.WarnContext(c.closeContext, "Failed to keep alive ssh server", - "server_id", handle.Hello().ServerID, - "error", err, - "error_count", handle.sshServer.keepAliveErrs, - "should_remove", shouldClose, - ) - - if shouldClose { - return trace.Errorf("failed to keep alive ssh server: %v", err) - } + if _, err := c.auth.UpsertNode(c.closeContext, handle.sshServer.resource); err == nil { + if handle.sshServer.retryUpsert { + c.testEvent(sshUpsertRetryOk) } else { - handle.sshServer.keepAliveErrs = 0 c.testEvent(sshKeepAliveOk) } - } else if handle.sshServer.retryUpsert { - handle.sshServer.resource.SetExpiry(c.clock.Now().Add(c.serverTTL).UTC()) - lease, err := c.auth.UpsertNode(c.closeContext, handle.sshServer.resource) - if err != nil { + handle.sshServer.keepAliveErrs = 0 + handle.sshServer.retryUpsert = false + } else { + if handle.sshServer.retryUpsert { c.testEvent(sshUpsertRetryErr) - slog.WarnContext(c.closeContext, "Failed to upsert ssh server on retry", + slog.WarnContext(c.closeContext, "Failed to upsert SSH server on retry", "server_id", handle.Hello().ServerID, "error", err, ) - // since this is retry-specific logic, an error here means that upsert failed twice in - // a row. Missing upserts is more problematic than missing keepalives so we don'resource bother - // attempting a third time. - return trace.Errorf("failed to upsert ssh server on retry: %v", err) + // retryUpsert is set when we get a new resource and we fail to + // upsert it; if we're here it means that we have failed to upsert + // it _again_, so we have fallen quite far behind + return trace.Wrap(err, "failed to upsert SSH server on retry") + } + + c.testEvent(sshKeepAliveErr) + handle.sshServer.keepAliveErrs++ + closing := handle.sshServer.keepAliveErrs > c.maxKeepAliveErrs + slog.WarnContext(c.closeContext, "Failed to upsert SSH server on keepalive", + "server_id", handle.Hello().ServerID, + "error", err, + "count", handle.sshServer.keepAliveErrs, + "closing", closing, + ) + + if closing { + return trace.Wrap(err, "failed to keep alive SSH server") } - c.testEvent(sshUpsertRetryOk) - handle.sshServer.lease = lease - handle.sshServer.retryUpsert = false } return nil } +func (c *Controller) createKeepAliveDelay(variableDuration *interval.VariableDuration) *delay.Delay { + return delay.New(delay.Params{ + FirstInterval: retryutils.HalfJitter(c.serverKeepAlive), + FixedInterval: c.serverKeepAlive, + VariableInterval: variableDuration, + Jitter: retryutils.SeventhJitter, + }) +} + // Close terminates all control streams registered with this controller. Control streams // registered after Close() is called are closed immediately. func (c *Controller) Close() error { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index 9ec509f725293..badc1e6920d97 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -247,14 +247,13 @@ func TestSSHServerBasics(t *testing.T) { // set up to induce some failures, but not enough to cause the control // stream to be closed. auth.mu.Lock() - auth.failUpserts = 1 - auth.failKeepAlives = 2 + auth.failUpserts = 2 auth.mu.Unlock() // keepalive should fail twice, but since the upsert is already known // to have succeeded, we should not see an upsert failure yet. awaitEvents(t, events, - expect(sshKeepAliveErr, sshKeepAliveErr), + expect(sshKeepAliveErr, sshKeepAliveErr, sshKeepAliveOk), deny(sshUpsertErr, handlerClose), ) @@ -270,6 +269,32 @@ func TestSSHServerBasics(t *testing.T) { }) require.NoError(t, err) + // this explicit upsert will not happen since the server is the same, but + // keepalives should work + awaitEvents(t, events, + expect(sshKeepAliveOk), + deny(sshKeepAliveErr, sshUpsertErr, sshUpsertRetryOk, handlerClose), + ) + + err = downstream.Send(ctx, proto.InventoryHeartbeat{ + SSHServer: &types.ServerV2{ + Metadata: types.Metadata{ + Name: serverID, + Labels: map[string]string{ + "changed": "changed", + }, + }, + Spec: types.ServerSpecV2{ + Addr: zeroAddr, + }, + }, + }) + require.NoError(t, err) + + auth.mu.Lock() + auth.failUpserts = 1 + auth.mu.Unlock() + // we should now see an upsert failure, but no additional // keepalive failures, and the upsert should succeed on retry. awaitEvents(t, events, diff --git a/lib/inventory/internal/delay/delay.go b/lib/inventory/internal/delay/delay.go index 7a5ac8a06d74a..bb94478daf875 100644 --- a/lib/inventory/internal/delay/delay.go +++ b/lib/inventory/internal/delay/delay.go @@ -74,7 +74,7 @@ type Delay struct { } // Elapsed returns the channel on which the ticks are delivered. This method can -// be called on a nil delay, resulting in a nil channel. The [*Delay.Advance] +// be called on a nil delay, resulting in a nil channel. The [Delay.Advance] // method must be called after receiving a tick from the channel. // // select { @@ -102,7 +102,7 @@ func (i *Delay) interval() time.Duration { } // Advance sets up the next tick of the delay. Must be called after receiving -// from the [*Delay.Elapsed] channel; specifically, to maintain compatibility +// from the [Delay.Elapsed] channel; specifically, to maintain compatibility // with [clockwork.Clock], it must only be called with a drained timer channel. // For consistency, the value passed to Advance should be the value received // from the Elapsed channel (passing the current time will also work, but will @@ -111,8 +111,20 @@ func (i *Delay) Advance(now time.Time) { i.timer.Reset(i.interval() - i.clock.Since(now)) } -// Stop stops the delay. Only needed for [clockwork.Clock] compatibility. Can be -// called on a nil delay, as a no-op. The delay should not be used afterwards. +// Reset restarts the ticker from the current time. Must only be called while +// the timer is running (i.e. it must not be called between receiving from +// [Delay.Elapsed] and calling [Delay.Advance]). +func (i *Delay) Reset() { + // the drain is for Go earlier than 1.23 and for [clockwork.Clock] + if !i.timer.Stop() { + <-i.timer.Chan() + } + i.timer.Reset(i.interval()) +} + +// Stop stops the delay. Only needed for Go 1.22 and [clockwork.Clock] +// compatibility. Can be called on a nil delay, as a no-op. The delay should not +// be used afterwards. func (i *Delay) Stop() { if i == nil { return diff --git a/lib/teleterm/apiserver/apiserver.go b/lib/teleterm/apiserver/apiserver.go index f622fe5614437..42916c94b571a 100644 --- a/lib/teleterm/apiserver/apiserver.go +++ b/lib/teleterm/apiserver/apiserver.go @@ -19,11 +19,12 @@ package apiserver import ( + "context" "fmt" + "log/slog" "net" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" @@ -70,7 +71,7 @@ func New(cfg Config) (*APIServer, error) { } grpcServer := grpc.NewServer(cfg.TshdServerCreds, - grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log)), + grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Logger)), grpc.MaxConcurrentStreams(defaults.GRPCMaxConcurrentStreams), ) @@ -96,7 +97,7 @@ func (s *APIServer) Stop() { // immediate. Closing the VNet service before the gRPC server gives some time for the VNet admin // process to notice that the client is gone and shut down as well. if err := s.vnetService.Close(); err != nil { - log.WithError(err).Error("Error while closing VNet service") + slog.ErrorContext(context.Background(), "Error while closing VNet service", "error", err) } s.grpcServer.GracefulStop() @@ -120,7 +121,7 @@ func newListener(hostAddr string, listeningC chan<- utils.NetAddr) (net.Listener listeningC <- addr } - log.Infof("tsh daemon is listening on %v.", addr.FullAddress()) + slog.InfoContext(context.Background(), "tsh daemon listener created", "listen_addr", addr.FullAddress()) return lis, nil } diff --git a/lib/teleterm/apiserver/config.go b/lib/teleterm/apiserver/config.go index 76495b8a181b2..37e233a33f706 100644 --- a/lib/teleterm/apiserver/config.go +++ b/lib/teleterm/apiserver/config.go @@ -19,9 +19,10 @@ package apiserver import ( + "log/slog" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "google.golang.org/grpc" "github.com/gravitational/teleport" @@ -39,8 +40,8 @@ type Config struct { Daemon *daemon.Service ClusterIDCache *clusteridcache.Cache InstallationID string - // Log is a component logger - Log logrus.FieldLogger + // Logger is a component logger + Logger *slog.Logger TshdServerCreds grpc.ServerOption Clock clockwork.Clock // ListeningC propagates the address on which the gRPC server listens. Mostly useful in tests, as @@ -66,8 +67,8 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing TshdServerCreds") } - if c.Log == nil { - c.Log = logrus.WithField(teleport.ComponentKey, "conn:apiserver") + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "conn:apiserver") } if c.InstallationID == "" { diff --git a/lib/teleterm/apiserver/middleware.go b/lib/teleterm/apiserver/middleware.go index 7f8bb787a4038..520b97bb76565 100644 --- a/lib/teleterm/apiserver/middleware.go +++ b/lib/teleterm/apiserver/middleware.go @@ -20,14 +20,14 @@ package apiserver import ( "context" + "log/slog" "github.com/gravitational/trace/trail" - "github.com/sirupsen/logrus" "google.golang.org/grpc" ) // withErrorHandling is gRPC middleware that maps internal errors to proper gRPC error codes -func withErrorHandling(log logrus.FieldLogger) grpc.UnaryServerInterceptor { +func withErrorHandling(log *slog.Logger) grpc.UnaryServerInterceptor { return func( ctx context.Context, req interface{}, @@ -36,7 +36,7 @@ func withErrorHandling(log logrus.FieldLogger) grpc.UnaryServerInterceptor { ) (interface{}, error) { resp, err := handler(ctx, req) if err != nil { - log.WithError(err).Error("Request failed.") + log.ErrorContext(ctx, "Request failed", "error", err) return resp, trail.ToGRPC(err) } diff --git a/lib/teleterm/clusters/cluster.go b/lib/teleterm/clusters/cluster.go index 3899dc64fadff..ef8f22e461c86 100644 --- a/lib/teleterm/clusters/cluster.go +++ b/lib/teleterm/clusters/cluster.go @@ -20,10 +20,10 @@ package clusters import ( "context" + "log/slog" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -49,8 +49,8 @@ type Cluster struct { Name string // ProfileName is the name of the tsh profile ProfileName string - // Log is a component logger - Log *logrus.Entry + // Logger is a component logger + Logger *slog.Logger // dir is the directory where cluster certificates are stored dir string // Status is the cluster status @@ -192,9 +192,7 @@ func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.Clie return roles, nil }) if err != nil { - c.Log. - WithError(err). - Warn("Failed to calculate trusted device requirement") + c.Logger.WarnContext(ctx, "Failed to calculate trusted device requirement", "error", err) } roleSet := services.NewRoleSet(roles...) diff --git a/lib/teleterm/clusters/cluster_auth.go b/lib/teleterm/clusters/cluster_auth.go index c8b8b4ebe1a40..e50256410ee48 100644 --- a/lib/teleterm/clusters/cluster_auth.go +++ b/lib/teleterm/clusters/cluster_auth.go @@ -22,10 +22,10 @@ import ( "context" "encoding/json" "errors" + "log/slog" "sort" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" @@ -47,9 +47,9 @@ func (c *Cluster) SyncAuthPreference(ctx context.Context) (*webclient.WebConfigA } pingResponseJSON, err := json.Marshal(pingResponse) if err != nil { - c.Log.WithError(err).Debugln("Could not marshal ping response to JSON") + c.Logger.DebugContext(ctx, "Could not marshal ping response to JSON", "error", err) } else { - c.Log.WithField("response", string(pingResponseJSON)).Debugln("Got ping response") + c.Logger.DebugContext(ctx, "Got ping response", "response", string(pingResponseJSON)) } if err := c.clusterClient.SaveProfile(false); err != nil { @@ -227,7 +227,7 @@ func (c *Cluster) passwordlessLogin(stream api.TerminalService_LoginPasswordless response, err := client.SSHAgentPasswordlessLogin(ctx, client.SSHLoginPasswordless{ SSHLogin: sshLogin, AuthenticatorAttachment: c.clusterClient.AuthenticatorAttachment, - CustomPrompt: newPwdlessLoginPrompt(ctx, c.Log, stream), + CustomPrompt: newPwdlessLoginPrompt(ctx, c.Logger, stream), WebauthnLogin: c.clusterClient.WebauthnLogin, }) if err != nil { @@ -239,11 +239,11 @@ func (c *Cluster) passwordlessLogin(stream api.TerminalService_LoginPasswordless // pwdlessLoginPrompt is a implementation for wancli.LoginPrompt for teleterm passwordless logins. type pwdlessLoginPrompt struct { - log *logrus.Entry + log *slog.Logger Stream api.TerminalService_LoginPasswordlessServer } -func newPwdlessLoginPrompt(ctx context.Context, log *logrus.Entry, stream api.TerminalService_LoginPasswordlessServer) *pwdlessLoginPrompt { +func newPwdlessLoginPrompt(ctx context.Context, log *slog.Logger, stream api.TerminalService_LoginPasswordlessServer) *pwdlessLoginPrompt { return &pwdlessLoginPrompt{ log: log, Stream: stream, @@ -283,7 +283,7 @@ func (p *pwdlessLoginPrompt) ackTouch() error { // The current gRPC message type switch in teleterm client code will reject // any new message types, making this difficult to add without breaking // older clients. - p.log.Debug("Detected security key tap") + p.log.DebugContext(context.Background(), "Detected security key tap") return nil } diff --git a/lib/teleterm/clusters/cluster_auth_test.go b/lib/teleterm/clusters/cluster_auth_test.go index f9c7cd8e2c4ea..36165f053c1eb 100644 --- a/lib/teleterm/clusters/cluster_auth_test.go +++ b/lib/teleterm/clusters/cluster_auth_test.go @@ -20,20 +20,17 @@ package clusters import ( "context" + "log/slog" "testing" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "google.golang.org/grpc" - "github.com/gravitational/teleport" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" ) -var log = logrus.WithField(teleport.ComponentKey, "cluster_auth_test") - func TestPwdlessLoginPrompt_PromptPIN(t *testing.T) { stream := &mockLoginPwdlessStream{} @@ -49,7 +46,7 @@ func TestPwdlessLoginPrompt_PromptPIN(t *testing.T) { }}, nil } - prompt := newPwdlessLoginPrompt(context.Background(), log, stream) + prompt := newPwdlessLoginPrompt(context.Background(), slog.Default(), stream) pin, err := prompt.PromptPIN() require.NoError(t, err) require.Equal(t, "1234", pin) @@ -74,7 +71,7 @@ func TestPwdlessLoginPrompt_PromptTouch(t *testing.T) { return nil } - prompt := newPwdlessLoginPrompt(context.Background(), log, stream) + prompt := newPwdlessLoginPrompt(context.Background(), slog.Default(), stream) ackTouch, err := prompt.PromptTouch() require.NoError(t, err) require.NoError(t, ackTouch()) @@ -110,7 +107,7 @@ func TestPwdlessLoginPrompt_PromptCredential(t *testing.T) { }}, nil } - prompt := newPwdlessLoginPrompt(context.Background(), log, stream) + prompt := newPwdlessLoginPrompt(context.Background(), slog.Default(), stream) cred, err := prompt.PromptCredential(unsortedCreds) require.NoError(t, err) require.Equal(t, "foo", cred.User.Name) diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index 590fa27611f21..64577c35cf7dd 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -105,7 +105,7 @@ func (c *Cluster) createDBGateway(ctx context.Context, params CreateGatewayParam Cert: cert, Insecure: c.clusterClient.InsecureSkipVerify, WebProxyAddr: c.clusterClient.WebProxyAddr, - Log: c.Log, + Logger: c.Logger, TCPPortAllocator: params.TCPPortAllocator, OnExpiredCert: params.OnExpiredCert, Clock: c.clock, @@ -145,7 +145,7 @@ func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayPar Cert: cert, Insecure: c.clusterClient.InsecureSkipVerify, WebProxyAddr: c.clusterClient.WebProxyAddr, - Log: c.Log, + Logger: c.Logger, TCPPortAllocator: params.TCPPortAllocator, OnExpiredCert: params.OnExpiredCert, Clock: c.clock, @@ -187,7 +187,7 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara Protocol: app.GetProtocol(), Insecure: c.clusterClient.InsecureSkipVerify, WebProxyAddr: c.clusterClient.WebProxyAddr, - Log: c.Log, + Logger: c.Logger, TCPPortAllocator: params.TCPPortAllocator, OnExpiredCert: params.OnExpiredCert, Clock: c.clock, diff --git a/lib/teleterm/clusters/config.go b/lib/teleterm/clusters/config.go index 6af0ad1bbfad3..ff94f4fdb533a 100644 --- a/lib/teleterm/clusters/config.go +++ b/lib/teleterm/clusters/config.go @@ -19,9 +19,10 @@ package clusters import ( + "log/slog" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/utils/keys" @@ -37,8 +38,8 @@ type Config struct { Clock clockwork.Clock // InsecureSkipVerify is an option to skip TLS cert check InsecureSkipVerify bool - // Log is a component logger - Log *logrus.Entry + // Logger is a component logger + Logger *slog.Logger // WebauthnLogin allows tests to override the Webauthn Login func. // Defaults to wancli.Login. WebauthnLogin client.WebauthnLoginFunc @@ -63,8 +64,8 @@ func (c *Config) CheckAndSetDefaults() error { c.Clock = clockwork.NewRealClock() } - if c.Log == nil { - c.Log = logrus.WithField(teleport.ComponentKey, "conn:storage") + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "conn:storage") } if c.AddKeysToAgent == "" { diff --git a/lib/teleterm/clusters/storage.go b/lib/teleterm/clusters/storage.go index f00adfc73c15c..7d5b1292aa25e 100644 --- a/lib/teleterm/clusters/storage.go +++ b/lib/teleterm/clusters/storage.go @@ -179,13 +179,13 @@ func (s *Storage) addCluster(ctx context.Context, dir, webProxyAddress string) ( return nil, nil, trace.Wrap(err) } - clusterLog := s.Log.WithField("cluster", clusterURI) + clusterLog := s.Logger.With("cluster", clusterURI) pingResponseJSON, err := json.Marshal(pingResponse) if err != nil { - clusterLog.WithError(err).Debugln("Could not marshal ping response to JSON") + clusterLog.DebugContext(ctx, "Could not marshal ping response to JSON", "error", err) } else { - clusterLog.WithField("response", string(pingResponseJSON)).Debugln("Got ping response") + clusterLog.DebugContext(ctx, "Got ping response", "response", string(pingResponseJSON)) } if err := clusterClient.SaveProfile(false); err != nil { @@ -201,7 +201,7 @@ func (s *Storage) addCluster(ctx context.Context, dir, webProxyAddress string) ( clusterClient: clusterClient, dir: s.Dir, clock: s.Clock, - Log: clusterLog, + Logger: clusterLog, }, clusterClient, nil } @@ -241,7 +241,7 @@ func (s *Storage) fromProfile(profileName, leafClusterName string) (*Cluster, *c dir: s.Dir, clock: s.Clock, statusError: err, - Log: s.Log.WithField("cluster", clusterURI), + Logger: s.Logger.With("cluster", clusterURI), } if status != nil { cluster.status = *status @@ -258,7 +258,7 @@ func (s *Storage) loadProfileStatusAndClusterKey(clusterClient *client.TeleportC _, err := clusterClient.LocalAgent().GetKeyRing(clusterNameForKey) if err != nil { if trace.IsNotFound(err) { - s.Log.Infof("No keys found for cluster %v.", clusterNameForKey) + s.Logger.InfoContext(context.Background(), "No keys found for cluster", "cluster", clusterNameForKey) } else { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/cmd/db_test.go b/lib/teleterm/cmd/db_test.go index cd165b850cdc4..68f7a82fdcec4 100644 --- a/lib/teleterm/cmd/db_test.go +++ b/lib/teleterm/cmd/db_test.go @@ -21,12 +21,12 @@ package cmd import ( "context" "fmt" + "log/slog" "os/exec" "path/filepath" "testing" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" @@ -66,7 +66,7 @@ func (m fakeDatabaseGateway) TargetName() string { return m.targetURI func (m fakeDatabaseGateway) TargetUser() string { return "alice" } func (m fakeDatabaseGateway) TargetSubresourceName() string { return m.subresourceName } func (m fakeDatabaseGateway) Protocol() string { return m.protocol } -func (m fakeDatabaseGateway) Log() *logrus.Entry { return nil } +func (m fakeDatabaseGateway) Log() *slog.Logger { return nil } func (m fakeDatabaseGateway) LocalAddress() string { return "localhost" } func (m fakeDatabaseGateway) LocalPortInt() int { return 8888 } func (m fakeDatabaseGateway) LocalPort() string { return "8888" } diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index 80cc79d081946..3646b78f05e2b 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -20,10 +20,10 @@ package daemon import ( "context" + "log/slog" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "google.golang.org/grpc" "github.com/gravitational/teleport" @@ -52,8 +52,8 @@ type Config struct { Clock clockwork.Clock // Storage is a storage service that reads/writes to tsh profiles Storage Storage - // Log is a component logger - Log *logrus.Entry + // Logger is a component logger + Logger *slog.Logger // PrehogAddr is the URL where prehog events should be submitted. PrehogAddr string // KubeconfigsDir is the directory containing kubeconfigs for Kubernetes @@ -121,8 +121,8 @@ func (c *Config) CheckAndSetDefaults() error { c.GatewayCreator = clusters.NewGatewayCreator(c.Storage) } - if c.Log == nil { - c.Log = logrus.NewEntry(logrus.StandardLogger()).WithField(teleport.ComponentKey, "daemon") + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "daemon") } if c.ConnectMyComputerRoleSetup == nil { @@ -172,7 +172,7 @@ func (c *Config) CheckAndSetDefaults() error { return clusters.AddMetadataToRetryableError(ctx, fn) } return clientcache.New(clientcache.Config{ - Log: c.Log, + Logger: c.Logger, NewClientFunc: newClientFunc, RetryWithReloginFunc: clientcache.RetryWithReloginFunc(retryWithRelogin), }) diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 13f12f4dfa253..d3528793a4b99 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -357,7 +357,7 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) go func() { if err := gateway.Serve(); err != nil { - gateway.Log().WithError(err).Warn("Failed to handle a gateway connection.") + gateway.Log().WarnContext(ctx, "Failed to handle a gateway connection", "error", err) } }() @@ -416,7 +416,7 @@ func (s *Service) reissueGatewayCerts(ctx context.Context, g gateway.Gateway) (t }, }) if notifyErr != nil { - s.cfg.Log.WithError(notifyErr).Error("Failed to send a notification for an error encountered during gateway cert reissue") + s.cfg.Logger.ErrorContext(ctx, "Failed to send a notification for an error encountered during gateway cert reissue", "error", notifyErr) } // Return the error to the alpn.LocalProxy's middleware. @@ -559,9 +559,9 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (gateway.Gat // Rather than continuing in presence of the race condition, let's attempt to close the new // gateway (since it shouldn't be used anyway) and return the error. if newGatewayCloseErr := newGateway.Close(); newGatewayCloseErr != nil { - newGateway.Log().Warnf( - "Failed to close the new gateway after failing to close the old gateway: %v", - newGatewayCloseErr, + newGateway.Log().WarnContext(s.closeContext, + "Failed to close the new gateway after failing to close the old gateway", + "error", newGatewayCloseErr, ) } return nil, trace.Wrap(err) @@ -571,7 +571,7 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (gateway.Gat go func() { if err := newGateway.Serve(); err != nil { - newGateway.Log().WithError(err).Warn("Failed to handle a gateway connection.") + newGateway.Log().WarnContext(s.closeContext, "Failed to handle a gateway connection", "error", err) } }() @@ -842,7 +842,7 @@ func (s *Service) Stop() { s.mu.RLock() defer s.mu.RUnlock() - s.cfg.Log.Info("Stopping") + s.cfg.Logger.InfoContext(s.closeContext, "Stopping") for _, gateway := range s.gateways { gateway.Close() @@ -851,14 +851,14 @@ func (s *Service) Stop() { s.StopHeadlessWatchers() if err := s.clientCache.Clear(); err != nil { - s.cfg.Log.WithError(err).Error("Failed to close remote clients") + s.cfg.Logger.ErrorContext(s.closeContext, "Failed to close remote clients", "error", err) } timeoutCtx, cancel := context.WithTimeout(s.closeContext, time.Second*10) defer cancel() if err := s.usageReporter.GracefulStop(timeoutCtx); err != nil { - s.cfg.Log.WithError(err).Error("Gracefully stopping usage reporter failed") + s.cfg.Logger.ErrorContext(timeoutCtx, "Gracefully stopping usage reporter failed", "error", err) } // s.closeContext is used for the tshd events client which might make requests as long as any of diff --git a/lib/teleterm/daemon/daemon_headless.go b/lib/teleterm/daemon/daemon_headless.go index 9ddf02a52405c..310b853d229c3 100644 --- a/lib/teleterm/daemon/daemon_headless.go +++ b/lib/teleterm/daemon/daemon_headless.go @@ -30,6 +30,7 @@ import ( api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/teleterm/clusters" + logutils "github.com/gravitational/teleport/lib/utils/log" ) // UpdateHeadlessAuthenticationState updates a headless authentication state. @@ -96,7 +97,7 @@ func (s *Service) startHeadlessWatcher(rootCluster *clusters.Cluster, waitInit b watchCtx, watchCancel := context.WithCancel(s.closeContext) s.headlessWatcherClosers[rootCluster.URI.String()] = watchCancel - log := s.cfg.Log.WithField("cluster", rootCluster.URI.String()) + log := s.cfg.Logger.With("cluster", logutils.StringerAttr(rootCluster.URI)) pendingRequests := make(map[string]context.CancelFunc) pendingRequestsMu := sync.Mutex{} @@ -180,7 +181,7 @@ func (s *Service) startHeadlessWatcher(rootCluster *clusters.Cluster, waitInit b defer cancelSend() if err := s.sendPendingHeadlessAuthentication(sendCtx, ha, rootCluster.URI.String()); err != nil { if !strings.Contains(err.Error(), context.Canceled.Error()) && !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) { - log.WithError(err).Debug("sendPendingHeadlessAuthentication resulted in unexpected error.") + log.DebugContext(sendCtx, "sendPendingHeadlessAuthentication resulted in unexpected error", "error", err) } } }() @@ -210,7 +211,7 @@ func (s *Service) startHeadlessWatcher(rootCluster *clusters.Cluster, waitInit b } } - log.Debugf("Starting headless watch loop.") + log.DebugContext(watchCtx, "Starting headless watch loop") go func() { defer func() { s.headlessWatcherClosersMu.Lock() @@ -222,31 +223,36 @@ func (s *Service) startHeadlessWatcher(rootCluster *clusters.Cluster, waitInit b default: // watcher closed due to error or cluster disconnect. if err := s.stopHeadlessWatcher(rootCluster.URI.String()); err != nil { - log.WithError(err).Debug("Failed to remove headless watcher.") + log.DebugContext(watchCtx, "Failed to remove headless watcher", "error", err) } } }() for { if !rootCluster.Connected() { - log.Debugf("Not connected to cluster. Returning from headless watch loop.") + log.DebugContext(watchCtx, "Not connected to cluster, terminating headless watch loop") return } err := watch() if trace.IsNotImplemented(err) { // Don't retry watch if we are connecting to an old Auth Server. - log.WithError(err).Debug("Headless watcher not supported.") + log.DebugContext(watchCtx, "Headless watcher not supported", "error", err) return } startedWaiting := s.cfg.Clock.Now() select { case t := <-retry.After(): - log.WithError(err).Debugf("Restarting watch on error after waiting %v.", t.Sub(startedWaiting)) + log.DebugContext(watchCtx, "Restarting watch on error", + "backoff", t.Sub(startedWaiting), + "error", err, + ) retry.Inc() case <-watchCtx.Done(): - log.WithError(watchCtx.Err()).Debugf("Context closed with err. Returning from headless watch loop.") + log.DebugContext(watchCtx, "Context closed with error, ending headless watch loop", + "error", watchCtx.Err(), + ) return } } @@ -295,7 +301,10 @@ func (s *Service) StopHeadlessWatchers() { for uri := range s.headlessWatcherClosers { if err := s.stopHeadlessWatcher(uri); err != nil { - s.cfg.Log.WithField("cluster", uri).WithError(err).Debug("Encountered unexpected error closing headless watcher") + s.cfg.Logger.DebugContext(s.closeContext, "Encountered unexpected error closing headless watcher", + "error", err, + "cluster", uri, + ) } } } diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 9c46057f7ac30..6f5670d61fe57 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -21,6 +21,7 @@ package daemon import ( "context" "errors" + "log/slog" "net" "net/http" "net/http/httptest" @@ -30,7 +31,6 @@ import ( "time" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -726,7 +726,7 @@ func (m fakeGateway) TargetName() string { return m.targetURI.GetDbNa func (m fakeGateway) TargetUser() string { return "alice" } func (m fakeGateway) TargetSubresourceName() string { return m.subresourceName } func (m fakeGateway) Protocol() string { return defaults.ProtocolSQLServer } -func (m fakeGateway) Log() *logrus.Entry { return nil } +func (m fakeGateway) Log() *slog.Logger { return nil } func (m fakeGateway) LocalAddress() string { return "localhost" } func (m fakeGateway) LocalPortInt() int { return 8888 } func (m fakeGateway) LocalPort() string { return "8888" } diff --git a/lib/teleterm/gateway/app.go b/lib/teleterm/gateway/app.go index 603d640a05a9c..110d36604aeff 100644 --- a/lib/teleterm/gateway/app.go +++ b/lib/teleterm/gateway/app.go @@ -56,7 +56,7 @@ func makeAppGateway(cfg Config) (Gateway, error) { } middleware := &appMiddleware{ - log: a.cfg.Log, + logger: a.cfg.Logger, onExpiredCert: func(ctx context.Context) (tls.Certificate, error) { cert, err := a.cfg.OnExpiredCert(ctx, a) return cert, trace.Wrap(err) diff --git a/lib/teleterm/gateway/app_middleware.go b/lib/teleterm/gateway/app_middleware.go index 8af69271ade03..9b58de8624016 100644 --- a/lib/teleterm/gateway/app_middleware.go +++ b/lib/teleterm/gateway/app_middleware.go @@ -21,16 +21,16 @@ import ( "crypto/tls" "crypto/x509" "errors" + "log/slog" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" ) type appMiddleware struct { onExpiredCert func(context.Context) (tls.Certificate, error) - log *logrus.Entry + logger *slog.Logger } // OnNewConnection calls m.onExpiredCert to get a fresh cert if the cert has expired and then sets @@ -48,7 +48,7 @@ func (m *appMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy return trace.Wrap(err) } - m.log.WithError(err).Debug("Gateway certificates have expired") + m.logger.DebugContext(ctx, "Gateway certificates have expired", "error", err) cert, err := m.onExpiredCert(ctx) if err != nil { diff --git a/lib/teleterm/gateway/base.go b/lib/teleterm/gateway/base.go index e0a9a33cc4d86..3a8b076307c60 100644 --- a/lib/teleterm/gateway/base.go +++ b/lib/teleterm/gateway/base.go @@ -21,11 +21,11 @@ package gateway import ( "context" "fmt" + "log/slog" "net" "strconv" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/teleterm/api/uri" @@ -107,8 +107,8 @@ func (b *base) Close() error { // Serve starts the underlying ALPN proxy. Blocks until closeContext is canceled. func (b *base) Serve() error { - b.cfg.Log.Info("Gateway is open.") - defer b.cfg.Log.Info("Gateway has closed.") + b.cfg.Logger.InfoContext(b.closeContext, "Gateway is open") + defer b.cfg.Logger.InfoContext(b.closeContext, "Gateway has closed") if b.forwardProxy != nil { return trace.Wrap(b.serveWithForwardProxy()) @@ -165,8 +165,8 @@ func (b *base) SetTargetSubresourceName(value string) { b.cfg.TargetSubresourceName = value } -func (b *base) Log() *logrus.Entry { - return b.cfg.Log +func (b *base) Log() *slog.Logger { + return b.cfg.Logger } func (b *base) LocalAddress() string { diff --git a/lib/teleterm/gateway/config.go b/lib/teleterm/gateway/config.go index c870df9075728..67768d05900db 100644 --- a/lib/teleterm/gateway/config.go +++ b/lib/teleterm/gateway/config.go @@ -22,13 +22,13 @@ import ( "context" "crypto/tls" "crypto/x509" + "log/slog" "net" "runtime" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/defaults" @@ -69,8 +69,8 @@ type Config struct { Username string // WebProxyAddr WebProxyAddr string - // Log is a component logger - Log *logrus.Entry + // Logger is a component logger + Logger *slog.Logger // TCPPortAllocator creates listeners on the given ports. This interface lets us avoid occupying // hardcoded ports in tests. TCPPortAllocator TCPPortAllocator @@ -125,8 +125,8 @@ func (c *Config) CheckAndSetDefaults() error { c.LocalPort = "0" } - if c.Log == nil { - c.Log = logrus.NewEntry(logrus.StandardLogger()) + if c.Logger == nil { + c.Logger = slog.Default() } if c.TargetName == "" { @@ -154,10 +154,10 @@ func (c *Config) CheckAndSetDefaults() error { } } - c.Log = c.Log.WithFields(logrus.Fields{ - "resource": c.TargetURI.String(), - "gateway": c.URI.String(), - }) + c.Logger = c.Logger.With( + "resource", c.TargetURI.String(), + "gateway", c.URI.String(), + ) return nil } diff --git a/lib/teleterm/gateway/db.go b/lib/teleterm/gateway/db.go index a6b25e685b9c8..b1602d0648b08 100644 --- a/lib/teleterm/gateway/db.go +++ b/lib/teleterm/gateway/db.go @@ -54,7 +54,7 @@ func makeDatabaseGateway(cfg Config) (Database, error) { } middleware := &dbMiddleware{ - log: d.cfg.Log, + logger: d.cfg.Logger, onExpiredCert: func(ctx context.Context) (tls.Certificate, error) { cert, err := d.cfg.OnExpiredCert(ctx, d) return cert, trace.Wrap(err) diff --git a/lib/teleterm/gateway/db_middleware.go b/lib/teleterm/gateway/db_middleware.go index cd189fff048a0..110f6969d41a8 100644 --- a/lib/teleterm/gateway/db_middleware.go +++ b/lib/teleterm/gateway/db_middleware.go @@ -23,9 +23,9 @@ import ( "crypto/tls" "crypto/x509" "errors" + "log/slog" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" @@ -33,7 +33,7 @@ import ( type dbMiddleware struct { onExpiredCert func(context.Context) (tls.Certificate, error) - log *logrus.Entry + logger *slog.Logger dbRoute tlsca.RouteToDatabase } @@ -54,7 +54,7 @@ func (m *dbMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy) return trace.Wrap(err) } - m.log.WithError(err).Debug("Gateway certificates have expired") + m.logger.DebugContext(ctx, "Gateway certificates have expired", "error", err) cert, err := m.onExpiredCert(ctx) if err != nil { diff --git a/lib/teleterm/gateway/db_middleware_test.go b/lib/teleterm/gateway/db_middleware_test.go index cecf12306de2e..1f786a5d8226c 100644 --- a/lib/teleterm/gateway/db_middleware_test.go +++ b/lib/teleterm/gateway/db_middleware_test.go @@ -21,11 +21,11 @@ package gateway import ( "context" "crypto/tls" + "log/slog" "testing" "time" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" @@ -112,7 +112,7 @@ func TestDBMiddleware_OnNewConnection(t *testing.T) { hasCalledOnExpiredCert = true return tls.Certificate{}, nil }, - log: logrus.WithField(teleport.ComponentKey, "middleware"), + logger: slog.With(teleport.ComponentKey, "middleware"), dbRoute: tt.dbRoute, } diff --git a/lib/teleterm/gateway/interfaces.go b/lib/teleterm/gateway/interfaces.go index 4cedf02e7ffd7..27bc6735a2b9d 100644 --- a/lib/teleterm/gateway/interfaces.go +++ b/lib/teleterm/gateway/interfaces.go @@ -19,8 +19,9 @@ package gateway import ( + "log/slog" + "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/tlsca" @@ -41,7 +42,7 @@ type Gateway interface { TargetUser() string TargetSubresourceName() string SetTargetSubresourceName(value string) - Log() *logrus.Entry + Log() *slog.Logger LocalAddress() string LocalPort() string LocalPortInt() int diff --git a/lib/teleterm/grpccredentials.go b/lib/teleterm/grpccredentials.go index f0c7c7562927f..8228f2e7ae631 100644 --- a/lib/teleterm/grpccredentials.go +++ b/lib/teleterm/grpccredentials.go @@ -21,11 +21,11 @@ package teleterm import ( "crypto/tls" "crypto/x509" + "log/slog" "os" "path/filepath" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -65,21 +65,21 @@ func createServerCredentials(serverKeyPair tls.Certificate, clientCertPaths []st Certificates: []tls.Certificate{serverKeyPair}, } - config.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { certPool := x509.NewCertPool() for _, clientCertPath := range clientCertPaths { - log := log.WithField("cert_path", clientCertPath) + log := slog.With("cert_path", clientCertPath) clientCert, err := os.ReadFile(clientCertPath) if err != nil { - log.WithError(err).Error("Failed to read the client cert file") + log.ErrorContext(info.Context(), "Failed to read the client cert file", "error", err) // Fall back to the default config. return nil, nil } if !certPool.AppendCertsFromPEM(clientCert) { - log.Error("Failed to add the client cert to the pool") + log.ErrorContext(info.Context(), "Failed to add the client cert to the pool") // Fall back to the default config. return nil, nil } diff --git a/lib/teleterm/services/connectmycomputer/connectmycomputer.go b/lib/teleterm/services/connectmycomputer/connectmycomputer.go index 26ecc8aafe8d9..d26e17621c9b5 100644 --- a/lib/teleterm/services/connectmycomputer/connectmycomputer.go +++ b/lib/teleterm/services/connectmycomputer/connectmycomputer.go @@ -21,6 +21,7 @@ package connectmycomputer import ( "context" "fmt" + "log/slog" "os" "os/user" "path/filepath" @@ -31,7 +32,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -105,8 +105,9 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity reloadCerts := false + logger := s.cfg.Logger.With("role", roleName) if !doesRoleExist { - s.cfg.Log.Infof("Creating the role %v.", roleName) + logger.InfoContext(ctx, "Creating the role") role, err := types.NewRole(roleName, types.RoleSpecV6{ Allow: types.RoleConditions{ @@ -123,7 +124,7 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity return noCertsReloaded, trace.Wrap(err, "creating role %v", roleName) } } else { - s.cfg.Log.Infof("The role %v already exists", roleName) + logger.InfoContext(ctx, "The role already exists") isRoleDirty := false // Ensure that the current system username is in the role. @@ -134,7 +135,9 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity allowedLogins := existingRole.GetLogins(types.Allow) if !slices.Contains(allowedLogins, systemUser.Username) { - s.cfg.Log.Infof("Adding %v to the logins of the role %v.", systemUser.Username, roleName) + logger.InfoContext(ctx, "Adding username to the logins of the role", + "username", systemUser.Username, + ) existingRole.SetLogins(types.Allow, append(allowedLogins, systemUser.Username)) isRoleDirty = true @@ -156,7 +159,7 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity expectedOwnerNodeLabelValue := []string{clusterUser.GetName()} if !slices.Equal(ownerNodeLabelValue, expectedOwnerNodeLabelValue) { - s.cfg.Log.Infof("Overwriting the owner node label in the role %v.", roleName) + logger.InfoContext(ctx, "Overwriting the owner node label in the role") allowedNodeLabels[types.ConnectMyComputerNodeOwnerLabel] = expectedOwnerNodeLabelValue isRoleDirty = true @@ -178,9 +181,9 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity hasCMCRole := slices.Contains(clusterUser.GetRoles(), roleName) if hasCMCRole { - s.cfg.Log.Infof("The user %v already has the role %v.", clusterUser.GetName(), roleName) + logger.InfoContext(ctx, "The user already has the role", "user", clusterUser.GetName()) } else { - s.cfg.Log.Infof("Adding the role %v to the user %v.", roleName, clusterUser.GetName()) + logger.InfoContext(ctx, "Adding the role to the user", "user", clusterUser.GetName()) clusterUser.AddRole(roleName) timeoutCtx, cancel := context.WithTimeout(ctx, resourceUpdateTimeout) defer cancel() @@ -197,7 +200,7 @@ func (s *RoleSetup) Run(ctx context.Context, accessAndIdentity AccessAndIdentity } if reloadCerts { - s.cfg.Log.Info("Reissuing certs.") + s.cfg.Logger.InfoContext(ctx, "Reissuing certs") // ReissueUserCerts called with CertCacheDrop and a bogus access request ID in DropAccessRequests // allows us to refresh the role list in the certs without forcing the user to relogin. // @@ -273,12 +276,12 @@ type CertManager interface { } type RoleSetupConfig struct { - Log *logrus.Entry + Logger *slog.Logger } func (c *RoleSetupConfig) CheckAndSetDefaults() error { - if c.Log == nil { - c.Log = logrus.NewEntry(logrus.StandardLogger()).WithField(teleport.ComponentKey, "CMC role") + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "CMC role") } return nil diff --git a/lib/teleterm/teleterm.go b/lib/teleterm/teleterm.go index c8bc92b4e0e8a..0c6cde8efd031 100644 --- a/lib/teleterm/teleterm.go +++ b/lib/teleterm/teleterm.go @@ -20,6 +20,7 @@ package teleterm import ( "context" + "log/slog" "os" "os/signal" "path/filepath" @@ -28,7 +29,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -112,9 +112,9 @@ func Serve(ctx context.Context, cfg Config) error { select { case <-ctx.Done(): - log.Info("Context closed, stopping service.") + slog.InfoContext(ctx, "Context closed, stopping service") case sig := <-c: - log.Infof("Captured %s, stopping service.", sig) + slog.InfoContext(ctx, "Captured signal, stopping service", "signal", sig) } daemonService.Stop() diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a21d165e382de..b88f4c0102edf 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -780,7 +780,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/sessions/app", h.WithAuth(h.createAppSession)) // Web sessions - h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.WithLimiterHandlerFunc(h.createWebSession))) + h.POST("/webapi/sessions/web", h.WithLimiter(h.createWebSession)) h.DELETE("/webapi/sessions/web", h.WithAuth(h.deleteWebSession)) h.POST("/webapi/sessions/web/renew", h.WithAuth(h.renewWebSession)) h.POST("/webapi/users", h.WithAuth(h.createUserHandle)) @@ -793,7 +793,7 @@ func (h *Handler) bindDefaultEndpoints() { // h.GET("/webapi/users/password/token/:token", h.WithLimiter(h.getResetPasswordTokenHandle)) h.GET("/webapi/users/*wildcard", h.handleGetUserOrResetToken) - h.PUT("/webapi/users/password/token", httplib.WithCSRFProtection(h.changeUserAuthentication)) + h.PUT("/webapi/users/password/token", h.WithLimiter(h.changeUserAuthentication)) h.PUT("/webapi/users/password", h.WithAuth(h.changePassword)) h.POST("/webapi/users/password/token", h.WithAuth(h.createResetPasswordToken)) h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle)) @@ -1994,7 +1994,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr } response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{ - CSRFToken: req.CSRFToken, ConnectorID: req.ConnectorID, CreateWebSession: true, ClientRedirectURL: req.ClientRedirectURL, @@ -2004,7 +2003,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr if err != nil { logger.WithError(err).Error("Error creating auth request.") return client.LoginFailedRedirectURL - } return response.RedirectURL @@ -4705,21 +4703,6 @@ func (h *Handler) WithSession(fn ContextHandler) httprouter.Handle { }) } -// WithAuthCookieAndCSRF ensures that a request is authenticated -// for plain old non-AJAX requests (does not check the Bearer header). -// It enforces CSRF checks (except for "safe" methods). -func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle { - f := func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { - sctx, err := h.AuthenticateRequest(w, r, false) - if err != nil { - return nil, trace.Wrap(err) - } - return fn(w, r, p, sctx) - } - - return httplib.WithCSRFProtection(f) -} - // WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests. // This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if // you're certain that no authenticated requests will be made. @@ -5054,8 +5037,6 @@ type SSORequestParams struct { // ConnectorID identifies the SSO connector to use to log in, from // the connector_id query parameter. ConnectorID string - // CSRFToken is the token in the CSRF cookie header. - CSRFToken string } // ParseSSORequestParams extracts the SSO request parameters from an http.Request, @@ -5088,15 +5069,9 @@ func ParseSSORequestParams(r *http.Request) (*SSORequestParams, error) { return nil, trace.BadParameter("missing connector_id query parameter") } - csrfToken, err := csrf.ExtractTokenFromCookie(r) - if err != nil { - return nil, trace.Wrap(err) - } - return &SSORequestParams{ ClientRedirectURL: clientRedirectURL, ConnectorID: connectorID, - CSRFToken: csrfToken, }, nil } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 962d52514cad9..48a9ff61179a5 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -117,7 +117,6 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/inventory" kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" "github.com/gravitational/teleport/lib/limiter" @@ -947,10 +946,6 @@ func TestWebSessionsCRUD(t *testing.T) { func TestCSRF(t *testing.T) { t.Parallel() s := newWebSuite(t) - type input struct { - reqToken string - cookieToken string - } // create a valid user user := "csrfuser" @@ -958,39 +953,25 @@ func TestCSRF(t *testing.T) { otpSecret := newOTPSharedSecret() s.createUser(t, user, user, pass, otpSecret) - encodedToken1 := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - encodedToken2 := "bf355921bbf3ef3672a03e410d4194077dfa5fe863c652521763b3e7f81e7b11" - invalid := []input{ - {reqToken: encodedToken2, cookieToken: encodedToken1}, - {reqToken: "", cookieToken: encodedToken1}, - {reqToken: "", cookieToken: ""}, - {reqToken: encodedToken1, cookieToken: ""}, - } - clt := s.client(t) ctx := context.Background() // valid validReq := loginWebOTPParams{ - webClient: clt, - clock: s.clock, - user: user, - password: pass, - otpSecret: otpSecret, - cookieCSRF: &encodedToken1, - headerCSRF: &encodedToken1, + webClient: clt, + clock: s.clock, + user: user, + password: pass, + otpSecret: otpSecret, } loginWebOTP(t, ctx, validReq) - // invalid - for i := range invalid { - req := validReq - req.cookieCSRF = &invalid[i].cookieToken - req.headerCSRF = &invalid[i].reqToken - httpResp, _, err := rawLoginWebOTP(ctx, req) - require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly") - assert.Equal(t, http.StatusForbidden, httpResp.StatusCode, "HTTP status code mismatch") - } + // invalid - wrong content-type header + invalidReq := validReq + invalidReq.overrideContentType = "multipart/form-data" + httpResp, _, err := rawLoginWebOTP(ctx, invalidReq) + require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly") + require.Equal(t, http.StatusBadRequest, httpResp.StatusCode, "HTTP status code mismatch") } func TestPasswordChange(t *testing.T) { @@ -5953,13 +5934,9 @@ func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) { httpReqData, err := json.Marshal(req) require.NoError(t, err) - // CSRF protected endpoint. - csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" httpReq, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(httpReqData)) require.NoError(t, err) - addCSRFCookieToReq(httpReq, csrfToken) httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set(csrf.HeaderName, csrfToken) httpRes, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { return clt.HTTPClient().Do(httpReq) })) @@ -6104,10 +6081,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body)) require.NoError(t, err) - csrfToken, err := csrf.GenerateToken() - require.NoError(t, err) - addCSRFCookieToReq(req, csrfToken) - req.Header.Set(csrf.HeaderName, csrfToken) req.Header.Set("Content-Type", "application/json") re, err := clt.Client.RoundTrip(func() (*http.Response, error) { @@ -6129,8 +6102,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing func TestParseSSORequestParams(t *testing.T) { t.Parallel() - token := "someMeaninglessTokenString" - tests := []struct { name, url string wantErr bool @@ -6142,7 +6113,6 @@ func TestParseSSORequestParams(t *testing.T) { expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", ConnectorID: "oidc", - CSRFToken: token, }, }, { @@ -6151,7 +6121,6 @@ func TestParseSSORequestParams(t *testing.T) { expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", ConnectorID: "github", - CSRFToken: token, }, }, { @@ -6160,7 +6129,6 @@ func TestParseSSORequestParams(t *testing.T) { expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", ConnectorID: "saml", - CSRFToken: token, }, }, { @@ -6179,7 +6147,6 @@ func TestParseSSORequestParams(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req, err := http.NewRequest("", tc.url, nil) require.NoError(t, err) - addCSRFCookieToReq(req, token) params, err := ParseSSORequestParams(req) @@ -7932,15 +7899,6 @@ func (s *WebSuite) url() *url.URL { return u } -func addCSRFCookieToReq(req *http.Request, token string) { - cookie := &http.Cookie{ - Name: csrf.CookieName, - Value: token, - } - - req.AddCookie(cookie) -} - func removeSpace(in string) string { for _, c := range []string{"\n", "\r", "\t"} { in = strings.Replace(in, c, " ", -1) diff --git a/lib/web/login_helper_test.go b/lib/web/login_helper_test.go index ae7f2ddbd93c8..2829a2d1400d7 100644 --- a/lib/web/login_helper_test.go +++ b/lib/web/login_helper_test.go @@ -18,6 +18,7 @@ package web import ( "bytes" + "cmp" "context" "encoding/base32" "encoding/json" @@ -34,7 +35,6 @@ import ( "github.com/gravitational/teleport/lib/auth/mocku2f" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/httplib/csrf" ) // newOTPSharedSecret returns an OTP shared secret, encoded as a base32 string. @@ -54,9 +54,8 @@ type loginWebOTPParams struct { // If empty then no OTP is sent in the request. otpSecret string - userAgent string // Optional. - - cookieCSRF, headerCSRF *string // Explicit CSRF tokens. Optional. + userAgent string // Optional. + overrideContentType string // Optional. } // DrainedHTTPResponse mimics an http.Response, but without a body. @@ -124,24 +123,11 @@ func rawLoginWebOTP(ctx context.Context, params loginWebOTPParams) (resp *Draine } // Set assorted headers. - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", cmp.Or(params.overrideContentType, "application/json")) if params.userAgent != "" { req.Header.Set("User-Agent", params.userAgent) } - // Set CSRF cookie and header. - const defaultCSRFToken = "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - cookieCSRF := defaultCSRFToken - if params.cookieCSRF != nil { - cookieCSRF = *params.cookieCSRF - } - addCSRFCookieToReq(req, cookieCSRF) - headerCSRF := defaultCSRFToken - if params.headerCSRF != nil { - headerCSRF = *params.headerCSRF - } - req.Header.Set(csrf.HeaderName, headerCSRF) - httpResp, err := webClient.HTTPClient().Do(req) if err != nil { return nil, nil, trace.Wrap(err, "do HTTP request") diff --git a/rfd/0007-rbac-oss.md b/rfd/0007-rbac-oss.md index 33848f7780105..23d33cb443f8d 100644 --- a/rfd/0007-rbac-oss.md +++ b/rfd/0007-rbac-oss.md @@ -90,7 +90,11 @@ role: name: user spec: options: - port_forwarding: true + ssh_port_forwarding: + remote: + enabled: true + local: + enabled: true max_session_ttl: 30h forward_agent: true enhanced_recording: ['command', 'network'] diff --git a/rfd/0008-application-access.md b/rfd/0008-application-access.md index 89ec8b837a3fd..adf15b34e8ede 100644 --- a/rfd/0008-application-access.md +++ b/rfd/0008-application-access.md @@ -303,7 +303,11 @@ version: v3 spec: options: forward_agent: true - port_forwarding: false + ssh_port_forwarding: + remote: + enabled: false + local: + enabled: false allow: logins: ["rjones"] # Application labels define labels that an application must match for this diff --git a/tool/tbot/kube.go b/tool/tbot/kube.go index 2314ec83bc984..997b668087e11 100644 --- a/tool/tbot/kube.go +++ b/tool/tbot/kube.go @@ -73,6 +73,10 @@ func onKubeCredentialsCommand( return trace.Wrap(err) } + if err = destination.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + idData, err := destination.Read(ctx, config.IdentityFilePath) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index d26a215a8b4e4..baf222ab69f90 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -539,7 +539,7 @@ func onDatabaseConfig(cf *CLIConf) error { case dbFormatCommand: cmd, err := dbcmd.NewCmdBuilder(tc, profile, *database, rootCluster, dbcmd.WithPrintFormat(), - dbcmd.WithLogger(log), + dbcmd.WithLogger(logger), dbcmd.WithGetDatabaseFunc(getDatabase), ).GetConnectCommand(cf.Context) if err != nil { @@ -779,7 +779,7 @@ func onDatabaseConnect(cf *CLIConf) error { return trace.Wrap(err) } opts = append(opts, - dbcmd.WithLogger(log), + dbcmd.WithLogger(logger), dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), ) diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 4f0f1fee92135..489c3f483846f 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -233,7 +233,7 @@ func onProxyCommandDB(cf *CLIConf) error { opts := []dbcmd.ConnectCommandFunc{ dbcmd.WithLocalProxy("localhost", addr.Port(0), ""), dbcmd.WithNoTLS(), - dbcmd.WithLogger(log), + dbcmd.WithLogger(logger), dbcmd.WithPrintFormat(), dbcmd.WithTolerateMissingCLIClient(), dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), diff --git a/web/packages/teleport/src/lib/tdp/playerClient.ts b/web/packages/teleport/src/lib/tdp/playerClient.ts index dd32996b5bd79..9adb15ef10038 100644 --- a/web/packages/teleport/src/lib/tdp/playerClient.ts +++ b/web/packages/teleport/src/lib/tdp/playerClient.ts @@ -112,6 +112,11 @@ export class PlayerClient extends Client { } this.lastUpdateTime = Date.now(); + this.send(JSON.stringify({ action: Action.TOGGLE_PLAY_PAUSE })); + + if (this.paused) { + return; + } if (this.isSeekingForward()) { const next = Math.max(this.skipTimeUpdatesUntil, this.lastTimestamp); @@ -119,8 +124,6 @@ export class PlayerClient extends Client { } else { this.scheduleNextUpdate(this.lastTimestamp); } - - this.send(JSON.stringify({ action: Action.TOGGLE_PLAY_PAUSE })); } // setPlaySpeed sets the playback speed of the recording.