From c37d45237a6af93c2934b3a40d8f884e9232f091 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Tue, 17 Dec 2024 16:25:19 -0500 Subject: [PATCH] Convert lib/reversetunnel to use slog --- lib/reversetunnel/agent.go | 67 ++++++----- lib/reversetunnel/agent_dialer.go | 10 +- lib/reversetunnel/agent_dialer_test.go | 3 +- lib/reversetunnel/agentpool.go | 64 +++++----- lib/reversetunnel/conn.go | 27 +++-- lib/reversetunnel/localsite.go | 102 +++++++++------- lib/reversetunnel/localsite_test.go | 8 +- lib/reversetunnel/peer.go | 13 +-- lib/reversetunnel/rc_manager.go | 18 +-- lib/reversetunnel/remotesite.go | 130 +++++++++++---------- lib/reversetunnel/srv.go | 135 ++++++++++------------ lib/reversetunnel/srv_test.go | 7 +- lib/reversetunnel/transport.go | 70 +++++++---- lib/reversetunnelclient/api_with_roles.go | 4 +- lib/service/service.go | 9 +- 15 files changed, 356 insertions(+), 311 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index da41578411067..4bd870c3418d6 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -27,13 +27,13 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "strings" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/constants" @@ -42,6 +42,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) type AgentState string @@ -113,8 +114,8 @@ type agentConfig struct { // clock is use to get the current time. Mock clocks can be used for // testing. clock clockwork.Clock - // log is an optional logger. - log logrus.FieldLogger + // logger is an optional logger. + logger *slog.Logger // localAuthAddresses is a list of auth servers to use when dialing back to // the local cluster. localAuthAddresses []string @@ -145,12 +146,13 @@ func (c *agentConfig) checkAndSetDefaults() error { if c.clock == nil { c.clock = clockwork.NewRealClock() } - if c.log == nil { - c.log = logrus.New() + if c.logger == nil { + c.logger = slog.Default() } - c.log = c.log. - WithField("leaseID", c.lease.ID()). - WithField("target", c.addr.String()) + c.logger = c.logger.With( + "lease_id", c.lease.ID(), + "target", c.addr.String(), + ) return nil } @@ -284,7 +286,10 @@ func (a *agent) updateState(state AgentState) (AgentState, error) { prevState := a.state a.state = state - a.log.Debugf("Changing state %s -> %s.", prevState, state) + a.logger.DebugContext(a.ctx, "Agent state updated", + "previous_state", prevState, + "current_state", state, + ) if a.agentConfig.stateCallback != nil { go a.agentConfig.stateCallback(a.state) @@ -296,7 +301,7 @@ func (a *agent) updateState(state AgentState) (AgentState, error) { // Start starts an agent returning after successfully connecting and sending // the first heartbeat. func (a *agent) Start(ctx context.Context) error { - a.log.Debugf("Starting agent %v", a.addr) + a.logger.DebugContext(ctx, "Starting agent", "addr", a.addr.FullAddress()) var err error defer func() { @@ -325,7 +330,7 @@ func (a *agent) Start(ctx context.Context) error { a.wg.Add(1) go func() { if err := a.handleGlobalRequests(a.ctx, a.client.GlobalRequests()); err != nil { - a.log.WithError(err).Debug("Failed to handle global requests.") + a.logger.DebugContext(a.ctx, "Failed to handle global requests", "error", err) } a.wg.Done() a.Stop() @@ -336,7 +341,7 @@ func (a *agent) Start(ctx context.Context) error { a.wg.Add(1) go func() { if err := a.handleDrainChannels(); err != nil { - a.log.WithError(err).Debug("Failed to handle drainable channels.") + a.logger.DebugContext(a.ctx, "Failed to handle drainable channels", "error", err) } a.wg.Done() a.Stop() @@ -345,7 +350,7 @@ func (a *agent) Start(ctx context.Context) error { a.wg.Add(1) go func() { if err := a.handleChannels(); err != nil { - a.log.WithError(err).Debug("Failed to handle channels.") + a.logger.DebugContext(a.ctx, "Failed to handle channels", "error", err) } a.wg.Done() a.Stop() @@ -460,23 +465,23 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R case versionRequest: version, err := a.versionGetter.getVersion(ctx) if err != nil { - a.log.WithError(err).Warnf("Failed to retrieve auth version in response to %v request.", r.Type) + a.logger.WarnContext(ctx, "Failed to retrieve auth version in response to x-teleport-version request", "error", err) if err := a.client.Reply(r, false, []byte("Failed to retrieve auth version")); err != nil { - a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err) continue } } if err := a.client.Reply(r, true, []byte(version)); err != nil { - a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err) continue } case reconnectRequest: - a.log.Debugf("Received reconnect advisory request from proxy.") + a.logger.DebugContext(ctx, "Received reconnect advisory request from proxy") if r.WantReply { err := a.client.Reply(r, true, nil) if err != nil { - a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.logger.DebugContext(ctx, "Failed to reply to reconnect@goteleport.com request", "error", err) } } @@ -487,7 +492,7 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R // This handles keep-alive messages and matches the behavior of OpenSSH. err := a.client.Reply(r, false, nil) if err != nil { - a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err) + a.logger.DebugContext(ctx, "Failed to reply to global request", "request_type", r.Type, "error", err) continue } } @@ -555,10 +560,10 @@ func (a *agent) handleDrainChannels() error { bytes, _ := a.clock.Now().UTC().MarshalText() _, err := a.hbChannel.SendRequest(a.ctx, "ping", false, bytes) if err != nil { - a.log.Error(err) + a.logger.ErrorContext(a.ctx, "failed to send ping request", "error", err) return trace.Wrap(err) } - a.log.Debugf("Ping -> %v.", a.client.RemoteAddr()) + a.logger.DebugContext(a.ctx, "Sent ping request", "target_addr", logutils.StringerAttr(a.client.RemoteAddr())) // Handle transport requests. case nch := <-a.transportC: if nch == nil { @@ -567,15 +572,15 @@ func (a *agent) handleDrainChannels() error { if a.isDraining() { err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining") if err != nil { - a.log.WithError(err).Warningf("Failed to reject transport channel.") + a.logger.WarnContext(a.ctx, "Failed to reject transport channel", "error", err) } continue } - a.log.Debugf("Transport request: %v.", nch.ChannelType()) + a.logger.DebugContext(a.ctx, "Received trransport request", "channel_type", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Warningf("Failed to accept transport request: %v.", err) + a.logger.WarnContext(a.ctx, "Failed to accept transport request", "error", err) continue } @@ -601,10 +606,10 @@ func (a *agent) handleChannels() error { if nch == nil { continue } - a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType()) + a.logger.DebugContext(a.ctx, "Discovery request channel opened", "channel_type", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Warningf("Failed to accept discovery channel request: %v.", err) + a.logger.WarnContext(a.ctx, "Failed to accept discovery channel request", "error", err) continue } @@ -624,11 +629,11 @@ func (a *agent) handleChannels() error { // ch : SSH channel which received "teleport-transport" out-of-band request // reqC : request payload func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { - a.log.Debugf("handleDiscovery requests channel.") + a.logger.DebugContext(a.ctx, "handleDiscovery requests channel") sshutils.DiscardChannelData(ch) defer func() { if err := ch.Close(); err != nil { - a.log.Warnf("Failed to close discovery channel: %v", err) + a.logger.WarnContext(a.ctx, "Failed to close discovery channel", "error", err) } }() @@ -639,17 +644,17 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { return case req = <-reqC: if req == nil { - a.log.Infof("Connection closed, returning") + a.logger.InfoContext(a.ctx, "Connection closed, returning") return } var r discoveryRequest if err := json.Unmarshal(req.Payload, &r); err != nil { - a.log.WithError(err).Warn("Bad payload") + a.logger.WarnContext(a.ctx, "Received discovery request with bad payload", "error", err) return } - a.log.Debugf("Received discovery request: %s", &r) + a.logger.DebugContext(a.ctx, "Received discovery request", "discovery_request", logutils.StringerAttr(&r)) a.tracker.TrackExpected(r.TrackProxies()...) } } diff --git a/lib/reversetunnel/agent_dialer.go b/lib/reversetunnel/agent_dialer.go index 56c710733a343..01f79397c3dff 100644 --- a/lib/reversetunnel/agent_dialer.go +++ b/lib/reversetunnel/agent_dialer.go @@ -20,10 +20,10 @@ package reversetunnel import ( "context" + "log/slog" "strings" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -55,7 +55,7 @@ type agentDialer struct { authMethods []ssh.AuthMethod fips bool options []proxy.DialerOptionFunc - log logrus.FieldLogger + logger *slog.Logger isClaimed func(principals ...string) bool } @@ -65,7 +65,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...) pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout) if err != nil { - d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr) + d.logger.DebugContext(ctx, "Failed to dial", "error", err, "target_addr", addr.Addr) return nil, trace.Wrap(err) } @@ -75,7 +75,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC GetHostCheckers: d.hostCheckerFunc(ctx), OnCheckCert: func(c *ssh.Certificate) error { if d.isClaimed != nil && d.isClaimed(c.ValidPrincipals...) { - d.log.Debugf("Aborting SSH handshake because the proxy %q is already claimed by some other agent.", c.ValidPrincipals[0]) + d.logger.DebugContext(ctx, "Aborting SSH handshake because the proxy is already claimed by some other agent.", "proxy_id", c.ValidPrincipals[0]) // the error message must end with // [proxyAlreadyClaimedError] to be recognized by // [isProxyAlreadyClaimed] @@ -88,7 +88,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC FIPS: d.fips, }) if err != nil { - d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err) + d.logger.DebugContext(ctx, "Failed to create host key callback", "target_addr", addr.Addr, "error", err) return nil, trace.Wrap(err) } diff --git a/lib/reversetunnel/agent_dialer_test.go b/lib/reversetunnel/agent_dialer_test.go index 2293c7b7a2620..ec912a51fa690 100644 --- a/lib/reversetunnel/agent_dialer_test.go +++ b/lib/reversetunnel/agent_dialer_test.go @@ -23,7 +23,6 @@ import ( "testing" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -90,7 +89,7 @@ func TestAgentCertChecker(t *testing.T) { dialer := agentDialer{ client: &fakeClient{caKey: ca.PublicKey()}, authMethods: []ssh.AuthMethod{ssh.PublicKeys(signer)}, - log: logrus.New(), + logger: utils.NewSlogLoggerForTests(), } _, err = dialer.DialContext(context.Background(), *utils.MustParseAddr(sshServer.Addr())) diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 25a59cc1cdebe..c4ea000758570 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -25,13 +25,13 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -94,7 +94,7 @@ type AgentPool struct { // backoff limits the rate at which new agents are created. backoff retryutils.Retry - log logrus.FieldLogger + logger *slog.Logger } // AgentPoolConfig holds configuration parameters for the agent pool @@ -201,13 +201,11 @@ func NewAgentPool(ctx context.Context, config AgentPoolConfig) (*AgentPool, erro active: newAgentStore(), events: make(chan Agent), backoff: retry, - log: logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: teleport.ComponentReverseTunnelAgent, - teleport.ComponentFields: logrus.Fields{ - "targetCluster": config.Cluster, - "localCluster": config.LocalCluster, - }, - }), + logger: slog.With( + teleport.ComponentKey, teleport.ComponentReverseTunnelAgent, + "target_cluster", config.Cluster, + "local_cluster", config.LocalCluster, + ), runtimeConfig: newAgentPoolRuntimeConfig(), } @@ -239,7 +237,7 @@ func (p *AgentPool) updateConnectedProxies() { } proxies := p.active.proxyIDs() - p.log.Debugf("Updating connected proxies: %v", proxies) + p.logger.DebugContext(p.ctx, "Updating connected proxies", "proxies", proxies) p.AgentPoolConfig.ConnectedProxyGetter.setProxyIDs(proxies) } @@ -250,12 +248,15 @@ func (p *AgentPool) Count() int { // Start starts the agent pool in the background. func (p *AgentPool) Start() error { - p.log.Debugf("Starting agent pool %s.%s...", p.HostUUID, p.Cluster) + p.logger.DebugContext(p.ctx, "Starting agent pool", + "host_uuid", p.HostUUID, + "cluster", p.Cluster, + ) p.wg.Add(1) go func() { if err := p.run(); err != nil { - p.log.WithError(err).Warn("Agent pool exited.") + p.logger.WarnContext(p.ctx, "Agent pool exited", "error", err) } p.cancel() @@ -274,9 +275,9 @@ func (p *AgentPool) run() error { } else if isProxyAlreadyClaimed(err) { // "proxy already claimed" is a fairly benign error, we should not // spam the log with stack traces for it - p.log.Debugf("Failed to connect agent: %v.", err) + p.logger.DebugContext(p.ctx, "Failed to connect agent", "error", err) } else { - p.log.WithError(err).Debugf("Failed to connect agent.") + p.logger.DebugContext(p.ctx, "Failed to connect agent", "error", err) } } else { p.wg.Add(1) @@ -288,7 +289,7 @@ func (p *AgentPool) run() error { if p.ctx.Err() != nil { return nil } else if err != nil { - p.log.WithError(err).Debugf("Failed to wait for backoff.") + p.logger.DebugContext(p.ctx, "Failed to wait for backoff", "error", err) } } } @@ -337,7 +338,10 @@ func (p *AgentPool) updateRuntimeConfig(ctx context.Context) error { restrictConnectionCount := p.runtimeConfig.restrictConnectionCount() connectionCount := p.runtimeConfig.getConnectionCount() - p.log.Debugf("Runtime config: restrict_connection_count: %v connection_count: %v", restrictConnectionCount, connectionCount) + p.logger.DebugContext(ctx, "Runtime config updated", + "restrict_connection_count", restrictConnectionCount, + "connection_count", connectionCount, + ) if restrictConnectionCount { p.tracker.SetConnectionCount(connectionCount) @@ -420,7 +424,7 @@ func (p *AgentPool) handleEvent(ctx context.Context, agent Agent) { } } p.updateConnectedProxies() - p.log.Debugf("Active agent count: %d", p.active.len()) + p.logger.DebugContext(ctx, "Processed agent event", "active_agent_count", p.active.len()) } // stateCallback adds events to the queue for each agent state change. @@ -444,7 +448,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease err = p.runtimeConfig.updateRemote(ctx, addr) if err != nil { - p.log.WithError(err).Debugf("Failed to update remote config.") + p.logger.DebugContext(ctx, "Failed to update remote config", "error", err) } options := []proxy.DialerOptionFunc{proxy.WithInsecureSkipTLSVerify(lib.IsInsecureDevMode())} @@ -458,7 +462,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease authMethods: p.AuthMethods, options: options, username: p.HostUUID, - log: p.log, + logger: p.logger, isClaimed: p.tracker.IsClaimed, } @@ -471,7 +475,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease tracker: tracker, lease: lease, clock: p.Clock, - log: p.log, + logger: p.logger, localAuthAddresses: p.LocalAuthAddresses, proxySigner: p.PROXYSigner, }) @@ -536,7 +540,7 @@ func (p *AgentPool) handleTransport(ctx context.Context, channel ssh.Channel, re sconn: conn, channel: channel, requestCh: requests, - log: p.log, + logger: p.logger, authServers: p.LocalAuthAddresses, proxySigner: p.PROXYSigner, forwardClientAddress: true, @@ -566,7 +570,7 @@ func (p *AgentPool) handleLocalTransport(ctx context.Context, channel ssh.Channe return case <-time.After(apidefaults.DefaultIOTimeout): go ssh.DiscardRequests(reqC) - p.log.Warn("Timed out waiting for transport dial request.") + p.logger.WarnContext(ctx, "Timed out waiting for transport dial request") return case r, ok := <-reqC: if !ok { @@ -579,14 +583,14 @@ func (p *AgentPool) handleLocalTransport(ctx context.Context, channel ssh.Channe // sconn should never be nil, but it's sourced from the agent state and // starts as nil, and the original transport code checked against it if sconn == nil || p.Server == nil { - p.log.Error("Missing client or server (this is a bug).") + p.logger.ErrorContext(ctx, "Missing client or server (this is a bug)") fmt.Fprintf(channel.Stderr(), "internal server error") req.Reply(false, nil) return } if err := req.Reply(true, nil); err != nil { - p.log.Errorf("Failed to respond to dial request: %v.", err) + p.logger.ErrorContext(ctx, "Failed to respond to dial request", "error", err) return } @@ -596,8 +600,9 @@ func (p *AgentPool) handleLocalTransport(ctx context.Context, channel ssh.Channe switch dialReq.Address { case reversetunnelclient.LocalNode, reversetunnelclient.LocalKubernetes, reversetunnelclient.LocalWindowsDesktop: default: - p.log.WithField("address", dialReq.Address). - Warn("Received dial request for unexpected address, routing to the local service anyway.") + p.logger.WarnContext(ctx, "Received dial request for unexpected address, routing to the local service anyway", + "dial_addr", dialReq.Address, + ) } if src, err := utils.ParseAddr(dialReq.ClientSrcAddr); err == nil { conn = utils.NewConnWithSrcAddr(conn, getTCPAddr(src)) @@ -768,7 +773,10 @@ func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.N c.remoteTLSRoutingEnabled = tlsRoutingEnabled if c.remoteTLSRoutingEnabled { c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(ctx, addr.Addr, lib.IsInsecureDevMode()) - logrus.Debugf("ALPN upgrade required for remote %v: %v", addr.Addr, c.tlsRoutingConnUpgradeRequired) + slog.DebugContext(ctx, "ALPN upgrade required for remote cluster", + "remot_addr", addr.Addr, + "conn_upgrade_required", c.tlsRoutingConnUpgradeRequired, + ) } return nil } @@ -802,7 +810,7 @@ func (c *agentPoolRuntimeConfig) update(ctx context.Context, netConfig types.Clu if err == nil { c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(ctx, addr.Addr, lib.IsInsecureDevMode()) } else { - logrus.WithError(err).Warnf("Failed to resolve addr.") + slog.WarnContext(ctx, "Failed to resolve addr", "error", err) } } } diff --git a/lib/reversetunnel/conn.go b/lib/reversetunnel/conn.go index b78f1b97d9378..8a678690f97cd 100644 --- a/lib/reversetunnel/conn.go +++ b/lib/reversetunnel/conn.go @@ -19,8 +19,10 @@ package reversetunnel import ( + "context" "encoding/json" "fmt" + "log/slog" "net" "sync" "sync/atomic" @@ -28,12 +30,12 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/sshutils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) // connKey is a key used to identify tunnel connections. It contains the UUID @@ -54,8 +56,8 @@ type remoteConn struct { lastHeartbeat atomic.Int64 *connConfig - mu sync.Mutex - log *logrus.Entry + mu sync.Mutex + logger *slog.Logger // discoveryCh is the SSH channel over which discovery requests are sent. discoveryCh ssh.Channel @@ -109,9 +111,7 @@ type connConfig struct { func newRemoteConn(cfg *connConfig) *remoteConn { c := &remoteConn{ - log: logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: "discovery", - }), + logger: slog.With(teleport.ComponentKey, "discovery"), connConfig: cfg, clock: clockwork.NewRealClock(), newProxiesC: make(chan []types.Server, 100), @@ -181,7 +181,11 @@ func (c *remoteConn) markInvalid(err error) { c.lastError = err c.invalid.Store(true) - c.log.Warnf("Unhealthy connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err) + c.logger.WarnContext(context.Background(), "Unhealthy reverse tunnel connection", + "cluster", c.clusterName, + "remote_addr", logutils.StringerAttr(c.conn.RemoteAddr()), + "error", err, + ) } func (c *remoteConn) markValid() { @@ -256,7 +260,7 @@ func (c *remoteConn) updateProxies(proxies []types.Server) { default: // Missing proxies update is no longer critical with more permissive // discovery protocol that tolerates conflicting, stale or missing updates - c.log.Warnf("Discovery channel overflow at %v.", len(c.newProxiesC)) + c.logger.WarnContext(context.Background(), "Discovery channel overflow", "new_proxy_count", len(c.newProxiesC)) } } @@ -267,7 +271,7 @@ func (c *remoteConn) adviseReconnect() error { // sendDiscoveryRequest sends a discovery request with up to date // list of connected proxies -func (c *remoteConn) sendDiscoveryRequest(req discoveryRequest) error { +func (c *remoteConn) sendDiscoveryRequest(ctx context.Context, req discoveryRequest) error { discoveryCh, err := c.openDiscoveryChannel() if err != nil { return trace.Wrap(err) @@ -282,7 +286,10 @@ func (c *remoteConn) sendDiscoveryRequest(req discoveryRequest) error { // Log the discovery request being sent. Useful for debugging to know what // proxies the tunnel server thinks exist. - c.log.Debugf("Sending discovery request with proxies %v to %v.", req.ProxyNames(), c.sconn.RemoteAddr()) + c.logger.DebugContext(ctx, "Sending discovery request", + "proxies", req.ProxyNames(), + "target_addr", logutils.StringerAttr(c.sconn.RemoteAddr()), + ) if _, err := discoveryCh.SendRequest(chanDiscoveryReq, false, payload); err != nil { c.markInvalid(err) diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 7c89ea25273b0..3446a882cec23 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -21,6 +21,7 @@ package reversetunnel import ( "context" "fmt" + "log/slog" "net" "slices" "sync" @@ -29,7 +30,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -48,6 +48,7 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" proxyutils "github.com/gravitational/teleport/lib/utils/proxy" ) @@ -102,12 +103,10 @@ func newLocalSite(srv *server, domainName string, authServers []string, opts ... authServers: authServers, remoteConns: make(map[connKey][]*remoteConn), clock: srv.Clock, - log: log.WithFields(log.Fields{ - teleport.ComponentKey: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: map[string]string{ - "cluster": domainName, - }, - }), + logger: slog.With( + teleport.ComponentKey, teleport.ComponentReverseTunnelServer, + "cluster", domainName, + ), offlineThreshold: srv.offlineThreshold, peerClient: srv.PeerClient, periodicFunctionInterval: periodicFunctionInterval, @@ -129,7 +128,7 @@ func newLocalSite(srv *server, domainName string, authServers []string, opts ... // // it implements RemoteSite interface type localSite struct { - log log.FieldLogger + logger *slog.Logger domainName string authServers []string srv *server @@ -292,13 +291,15 @@ func (s *localSite) maybeSendSignedPROXYHeader(params reversetunnelclient.DialPa // TODO(awly): unit test this func (s *localSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, error) { - s.log.Debugf("Dialing %v.", params) + ctx := s.srv.ctx + logger := s.logger.With("dial_params", logutils.StringerAttr(params)) + logger.DebugContext(ctx, "Initiating dia request") conn, useTunnel, err := s.getConn(params) if err != nil { return nil, trace.Wrap(err) } - s.log.Debugf("Succeeded dialing %v.", params) + logger.DebugContext(ctx, "Succeeded dialing") if err := s.maybeSendSignedPROXYHeader(params, conn, useTunnel); err != nil { return nil, trace.Wrap(err) @@ -320,12 +321,12 @@ func (s *localSite) adviseReconnect(ctx context.Context) { s.remoteConnsMtx.Lock() for _, conns := range s.remoteConns { for _, conn := range conns { - s.log.Debugf("Sending reconnect: %s", conn.nodeID) + s.logger.DebugContext(ctx, "Sending reconnect to server ", "server_id", conn.nodeID) wg.Add(1) go func(conn *remoteConn) { if err := conn.adviseReconnect(); err != nil { - s.log.WithError(err).Warn("Failed sending reconnect advisory") + s.logger.WarnContext(ctx, "Failed sending reconnect advisory", "error", err) } wg.Done() }(conn) @@ -346,10 +347,15 @@ func (s *localSite) adviseReconnect(ctx context.Context) { } func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { + ctx := s.srv.ctx + if params.GetUserAgent == nil && !params.IsAgentlessNode { return nil, trace.BadParameter("agentless node require an agent getter") } - s.log.Debugf("Dialing and forwarding from %v to %v.", params.From, params.To) + s.logger.DebugContext(ctx, "Initiating dial and forwarding request", + "source_addr", logutils.StringerAttr(params.From), + "target_addr", logutils.StringerAttr(params.To), + ) // request user agent connection if a SSH user agent is set var userAgent teleagent.Agent @@ -378,7 +384,7 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net } // Get a host certificate for the forwarding node from the cache. - hostCertificate, err := s.certificateCache.getHostCertificate(context.TODO(), params.Address, params.Principals) + hostCertificate, err := s.certificateCache.getHostCertificate(ctx, params.Address, params.Principals) if err != nil { return nil, trace.Wrap(err) } @@ -438,7 +444,10 @@ func (s *localSite) dialTunnel(dreq *sshutils.DialReq) (net.Conn, error) { return nil, trace.NotFound("no tunnel connection found: %v", err) } - s.log.Debugf("Tunnel dialing to %v, client source %v", dreq.ServerID, dreq.ClientSrcAddr) + s.logger.DebugContext(s.srv.ctx, "Tunnel dialing to host", + "target_host_id", dreq.ServerID, + "src_addr", dreq.ClientSrcAddr, + ) conn, err := s.chanTransportConn(rconn, dreq) if err != nil { @@ -607,7 +616,7 @@ func (s *localSite) getConn(params reversetunnelclient.DialParams) (conn net.Con peeringEnabled := s.tryProxyPeering(params) if peeringEnabled { - s.log.Info("Dialing over peer proxy") + s.logger.InfoContext(s.srv.ctx, "Dialing over peer proxy") conn, peerErr = s.peerClient.DialNode( params.ProxyIDs, params.ServerID, params.From, params.To, params.ConnType, ) @@ -645,7 +654,7 @@ func (s *localSite) getConn(params reversetunnelclient.DialParams) (conn net.Con dialTimeout := apidefaults.DefaultIOTimeout if cnc, err := s.accessPoint.GetClusterNetworkingConfig(s.srv.Context); err != nil { - s.log.WithError(err).Warn("Failed to get cluster networking config - using default dial timeout") + s.logger.WarnContext(s.srv.ctx, "Failed to get cluster networking config - using default dial timeout", "error", err) } else { dialTimeout = cnc.GetSSHDialTimeout() } @@ -653,7 +662,12 @@ func (s *localSite) getConn(params reversetunnelclient.DialParams) (conn net.Con conn, directErr = dialer.DialTimeout(s.srv.Context, params.To.Network(), params.To.String(), dialTimeout) if directErr != nil { directMsg := getTunnelErrorMessage(params, "direct dial", directErr) - s.log.WithField("address", params.To.String()).Debugf("All attempted dial methods failed. tunnel=%q, peer=%q, direct=%q", tunnelErr, peerErr, directErr) + s.logger.DebugContext(s.srv.ctx, "All attempted dial methods failed", + "target_addr", logutils.StringerAttr(params.To), + "tunnel_error", tunnelErr, + "peer_error", peerErr, + "direct_error", directErr, + ) aggregateErr := trace.NewAggregate(tunnelErr, peerErr, directErr) return nil, false, trace.ConnectionProblem(aggregateErr, directMsg) } @@ -701,29 +715,29 @@ func (s *localSite) fanOutProxies(proxies []types.Server) { // handleHeartbeat receives heartbeat messages from the connected agent // if the agent has missed several heartbeats in a row, Proxy marks // the connection as invalid. -func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { +func (s *localSite) handleHeartbeat(ctx context.Context, rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { sshutils.DiscardChannelData(ch) if ch != nil { defer func() { if err := ch.Close(); err != nil { - s.log.Warnf("Failed to close heartbeat channel: %v", err) + s.logger.WarnContext(ctx, "Failed to close heartbeat channel", "error", err) } }() } - logger := s.log.WithFields(log.Fields{ - "serverID": rconn.nodeID, - "addr": rconn.conn.RemoteAddr().String(), - }) + logger := s.logger.With( + "server_id", rconn.nodeID, + "addr", logutils.StringerAttr(rconn.conn.RemoteAddr()), + ) firstHeartbeat := true proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval) defer func() { proxyResyncTicker.Stop() - logger.Warn("Closing remote connection to agent.") + logger.WarnContext(ctx, "Closing remote connection to agent") s.removeRemoteConn(rconn) if err := rconn.Close(); err != nil && !utils.IsOKNetworkError(err) { - logger.WithError(err).Warn("Failed to close remote connection") + logger.WarnContext(ctx, "Failed to close remote connection", "error", err) } if !firstHeartbeat { reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Dec() @@ -735,18 +749,18 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch for { select { case <-s.srv.ctx.Done(): - logger.Info("Closing") + logger.InfoContext(ctx, "Closing") return case <-proxyResyncTicker.Chan(): var req discoveryRequest - proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + proxies, err := s.srv.proxyWatcher.CurrentResources(ctx) if err != nil { - logger.WithError(err).Warn("Failed to get proxy set") + logger.WarnContext(ctx, "Failed to get proxy set", "error", err) } req.SetProxies(proxies) - if err := rconn.sendDiscoveryRequest(req); err != nil { - logger.WithError(err).Debug("Marking connection invalid on error") + if err := rconn.sendDiscoveryRequest(ctx, req); err != nil { + logger.DebugContext(ctx, "Marking connection invalid on error", "error", err) rconn.markInvalid(err) return } @@ -754,14 +768,14 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch var req discoveryRequest req.SetProxies(proxies) - if err := rconn.sendDiscoveryRequest(req); err != nil { - logger.WithError(err).Debug("Failed to send discovery request to agent") + if err := rconn.sendDiscoveryRequest(ctx, req); err != nil { + logger.DebugContext(ctx, "Failed to send discovery request to agent", "error", err) rconn.markInvalid(err) return } case req := <-reqC: if req == nil { - logger.Debug("Agent disconnected.") + logger.DebugContext(ctx, "Agent disconnected") rconn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } @@ -770,7 +784,7 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch // send it the list of current proxies back proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) if err != nil { - logger.WithError(err).Warn("Failed to get proxy set") + logger.WarnContext(ctx, "Failed to get proxy set", "error", err) } if len(proxies) > 0 { rconn.updateProxies(proxies) @@ -788,9 +802,9 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch log := logger if roundtrip != 0 { - log = logger.WithField("latency", roundtrip.String()) + log = logger.With("latency", logutils.StringerAttr(roundtrip)) } - log.Debugf("Ping <- %v", rconn.conn.RemoteAddr()) + log.DebugContext(ctx, "Received ping request", "remote_addr", logutils.StringerAttr(rconn.conn.RemoteAddr())) rconn.setLastHeartbeat(s.clock.Now().UTC()) rconn.markValid() @@ -799,10 +813,10 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch // terminate and remove the connection if offline, otherwise warn and wait for the next heartbeat if rconn.isOffline(t, s.offlineThreshold*missedHeartBeatThreshold) { - logger.Errorf("Closing unhealthy and idle connection. Heartbeat last received at %s", rconn.getLastHeartbeat()) + logger.ErrorContext(ctx, "Closing unhealthy and idle connection", "last_heartbeat", rconn.getLastHeartbeat()) return } - logger.Warnf("Deferring closure of unhealthy connection due to %d active connections", rconn.activeSessions()) + logger.WarnContext(ctx, "Deferring closure of unhealthy connection due to active connections", "active_conn_count", rconn.activeSessions()) offlineThresholdTimer.Reset(s.offlineThreshold) continue @@ -878,7 +892,7 @@ func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) { } func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq) (net.Conn, error) { - s.log.Debugf("Connecting to %v through tunnel.", rconn.conn.RemoteAddr()) + s.logger.DebugContext(s.srv.ctx, "Connecting to target through tunnel", "target_addr", logutils.StringerAttr(rconn.conn.RemoteAddr())) conn, markInvalid, err := sshutils.ConnectProxyTransport(rconn.sconn, dreq, false) if err != nil { @@ -934,7 +948,7 @@ func (s *localSite) periodicFunctions() { return case <-ticker.Chan(): if err := s.sshTunnelStats(); err != nil { - s.log.Warningf("Failed to report SSH tunnel statistics for: %v: %v.", s.domainName, err) + s.logger.WarnContext(s.srv.ctx, "Failed to report SSH tunnel statistics ", "cluster", s.domainName, "error", err) } } } @@ -988,7 +1002,11 @@ func (s *localSite) sshTunnelStats() error { if n > 10 { n = 10 } - s.log.Debugf("Cluster %v is missing %v tunnels. A small number of missing tunnels is normal, for example, a node could have just been shut down, the proxy restarted, etc. However, if this error persists with an elevated number of missing tunnels, it often indicates nodes can not discover all registered proxies. Check that all of your proxies are behind a load balancer and the load balancer is using a round robin strategy. Some of the missing hosts: %v.", s.domainName, len(missing), missing[:n]) + s.logger.DebugContext(s.srv.ctx, "Cluster is missing some tunnels. A small number of missing tunnels is normal, for example, a node could have just been shut down, the proxy restarted, etc. However, if this error persists with an elevated number of missing tunnels, it often indicates nodes can not discover all registered proxies. Check that all of your proxies are behind a load balancer and the load balancer is using a round robin strategy", + "cluster", s.domainName, + "missing_count", len(missing), + "missing", missing[:n], + ) } return nil } diff --git a/lib/reversetunnel/localsite_test.go b/lib/reversetunnel/localsite_test.go index 195a1e76510c2..543ecfd894c2e 100644 --- a/lib/reversetunnel/localsite_test.go +++ b/lib/reversetunnel/localsite_test.go @@ -77,7 +77,7 @@ func TestRemoteConnCleanup(t *testing.T) { ctx: ctx, Config: Config{Clock: clock}, localAuthClient: &mockLocalSiteClient{}, - log: utils.NewLoggerForTests(), + logger: utils.NewSlogLoggerForTests(), offlineThreshold: time.Second, proxyWatcher: watcher, } @@ -102,7 +102,7 @@ func TestRemoteConnCleanup(t *testing.T) { // terminated by too many missed heartbeats go func() { - site.handleHeartbeat(conn1, nil, reqs) + site.handleHeartbeat(ctx, conn1, nil, reqs) cancel() }() @@ -273,7 +273,7 @@ func TestProxyResync(t *testing.T) { ctx: ctx, Config: Config{Clock: clock}, localAuthClient: &mockLocalSiteClient{}, - log: utils.NewLoggerForTests(), + logger: utils.NewSlogLoggerForTests(), offlineThreshold: 24 * time.Hour, proxyWatcher: watcher, } @@ -312,7 +312,7 @@ func TestProxyResync(t *testing.T) { // terminated by canceled context go func() { - site.handleHeartbeat(conn1, nil, reqs) + site.handleHeartbeat(ctx, conn1, nil, reqs) }() expected := []types.Server{proxy1, proxy2} diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 570be5edf4bbe..675ad71e4522a 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -26,7 +26,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -155,14 +154,8 @@ func (p *clusterPeers) Close() error { return nil } // newClusterPeer returns new cluster peer func newClusterPeer(srv *server, connInfo types.TunnelConnection, offlineThreshold time.Duration) (*clusterPeer, error) { clusterPeer := &clusterPeer{ - srv: srv, - connInfo: connInfo, - log: log.WithFields(log.Fields{ - teleport.ComponentKey: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: map[string]string{ - "cluster": connInfo.GetClusterName(), - }, - }), + srv: srv, + connInfo: connInfo, clock: clockwork.NewRealClock(), offlineThreshold: offlineThreshold, } @@ -173,8 +166,6 @@ func newClusterPeer(srv *server, connInfo types.TunnelConnection, offlineThresho // clusterPeer is a remote cluster that has established // a tunnel to the peers type clusterPeer struct { - log *log.Entry - mu sync.Mutex connInfo types.TunnelConnection srv *server diff --git a/lib/reversetunnel/rc_manager.go b/lib/reversetunnel/rc_manager.go index 03db3f13f613f..f1e539ac3bf8b 100644 --- a/lib/reversetunnel/rc_manager.go +++ b/lib/reversetunnel/rc_manager.go @@ -20,12 +20,12 @@ package reversetunnel import ( "context" + "log/slog" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -81,8 +81,8 @@ type RemoteClusterTunnelManagerConfig struct { KubeDialAddr utils.NetAddr // FIPS indicates if Teleport was started in FIPS mode. FIPS bool - // Log is the logger - Log logrus.FieldLogger + // Logger is the logger + Logger *slog.Logger // LocalAuthAddresses is a list of auth servers to use when dialing back to // the local cluster. LocalAuthAddresses []string @@ -109,8 +109,8 @@ func (c *RemoteClusterTunnelManagerConfig) CheckAndSetDefaults() error { if c.Clock == nil { c.Clock = clockwork.NewRealClock() } - if c.Log == nil { - c.Log = logrus.New() + if c.Logger == nil { + c.Logger = slog.Default() } return nil @@ -153,7 +153,7 @@ func (w *RemoteClusterTunnelManager) Run(ctx context.Context) { w.mu.Unlock() if err := w.Sync(ctx); err != nil { - w.cfg.Log.Warningf("Failed to sync reverse tunnels: %v.", err) + w.cfg.Logger.WarnContext(ctx, "Failed to sync reverse tunnels", "error", err) } ticker := time.NewTicker(defaults.ResyncInterval) @@ -162,11 +162,11 @@ func (w *RemoteClusterTunnelManager) Run(ctx context.Context) { for { select { case <-ctx.Done(): - w.cfg.Log.Debugf("Closing.") + w.cfg.Logger.DebugContext(ctx, "Closing") return case <-ticker.C: if err := w.Sync(ctx); err != nil { - w.cfg.Log.Warningf("Failed to sync reverse tunnels: %v.", err) + w.cfg.Logger.WarnContext(ctx, "Failed to sync reverse tunnels", "error", err) continue } } @@ -247,7 +247,7 @@ func realNewAgentPool(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, } if err := pool.Start(); err != nil { - cfg.Log.WithError(err).Error("Failed to start agent pool") + cfg.Logger.ErrorContext(ctx, "Failed to start agent pool", "error", err) } return pool, nil diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index f9617f33b87d5..bfb3fa91412b4 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -23,13 +23,13 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "log/slog" "net" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) // remoteSite is a remote site that established the inbound connection to @@ -54,7 +55,7 @@ import ( type remoteSite struct { sync.RWMutex - logger *log.Entry + logger *slog.Logger domainName string connections []*remoteConn lastUsed int @@ -115,7 +116,7 @@ func (s *remoteSite) getRemoteClient() (authclient.ClientI, bool, error) { // The fact that cluster has keys to remote CA means that the key exchange // has completed. - s.logger.Debug("Using TLS client to remote cluster.") + s.logger.DebugContext(s.ctx, "Using TLS client to remote cluster") tlsConfig := utils.TLSConfig(s.srv.ClientTLSCipherSuites) // encode the name of this cluster to identify this cluster, // connecting to the remote one (it is used to find the right certificate @@ -272,7 +273,7 @@ func (s *remoteSite) removeInvalidConns() { } else { go func(conn *remoteConn) { if err := conn.Close(); err != nil { - s.logger.WithError(err).Warn("Failed to close invalid connection") + s.logger.WarnContext(s.ctx, "Failed to close invalid connection", "error", err) } }(s.connections[i]) } @@ -305,12 +306,12 @@ func (s *remoteSite) adviseReconnect(ctx context.Context) { s.RLock() for _, conn := range s.connections { - s.logger.Debugf("Sending reconnect: %s", conn.nodeID) + s.logger.DebugContext(ctx, "Sending reconnect to server", "server_id", conn.nodeID) wg.Add(1) go func(conn *remoteConn) { if err := conn.adviseReconnect(); err != nil { - s.logger.WithError(err).Warn("Failed to send reconnection advisory") + s.logger.WarnContext(ctx, "Failed to send reconnection advisory", "error", err) } wg.Done() }(conn) @@ -365,7 +366,7 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { s.setLastConnInfo(connInfo) err := s.localAccessPoint.UpsertTunnelConnection(connInfo) if err != nil { - s.logger.WithError(err).Warn("Failed to register heartbeat") + s.logger.WarnContext(s.ctx, "Failed to register heartbeat", "error", err) } } @@ -373,7 +374,7 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { // that this node lost the connection and needs to be discovered func (s *remoteSite) deleteConnectionRecord() { if err := s.localAccessPoint.DeleteTunnelConnection(s.connInfo.GetClusterName(), s.connInfo.GetName()); err != nil { - s.logger.WithError(err).Warn("Failed to delete tunnel connection") + s.logger.WarnContext(s.ctx, "Failed to delete tunnel connection", "error", err) } } @@ -391,17 +392,17 @@ func (s *remoteSite) fanOutProxies(proxies []types.Server) { // handleHeartbeat receives heartbeat messages from the connected agent // if the agent has missed several heartbeats in a row, Proxy marks // the connection as invalid. -func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { - logger := s.logger.WithFields(log.Fields{ - "serverID": conn.nodeID, - "addr": conn.conn.RemoteAddr().String(), - }) +func (s *remoteSite) handleHeartbeat(ctx context.Context, conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { + logger := s.logger.With( + "server_id", conn.nodeID, + "addr", logutils.StringerAttr(conn.conn.RemoteAddr()), + ) sshutils.DiscardChannelData(ch) if ch != nil { defer func() { if err := ch.Close(); err != nil { - logger.Warnf("Failed to close heartbeat channel: %v", err) + logger.WarnContext(ctx, "Failed to close heartbeat channel", "error", err) } }() } @@ -410,14 +411,14 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval) defer func() { proxyResyncTicker.Stop() - logger.Info("Cluster connection closed.") + logger.InfoContext(ctx, "Cluster connection closed") if err := conn.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { - logger.WithError(err).Warnf("Failed to close remote connection for remote site") + logger.WarnContext(ctx, "Failed to close remote connection for remote site", "error", err) } if err := s.srv.onSiteTunnelClose(s); err != nil { - logger.WithError(err).Warn("Failed to close remote site") + logger.WarnContext(ctx, "Failed to close remote site", "error", err) } }() @@ -426,18 +427,18 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch for { select { case <-s.ctx.Done(): - logger.Infof("closing") + logger.InfoContext(ctx, "closing") return case <-proxyResyncTicker.Chan(): var req discoveryRequest proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) if err != nil { - logger.WithError(err).Warn("Failed to get proxy set") + logger.WarnContext(ctx, "Failed to get proxy set", "error", err) } req.SetProxies(proxies) - if err := conn.sendDiscoveryRequest(req); err != nil { - logger.WithError(err).Debug("Marking connection invalid on error") + if err := conn.sendDiscoveryRequest(ctx, req); err != nil { + logger.DebugContext(ctx, "Marking connection invalid on error", "error", err) conn.markInvalid(err) return } @@ -445,17 +446,17 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch var req discoveryRequest req.SetProxies(proxies) - if err := conn.sendDiscoveryRequest(req); err != nil { - logger.WithError(err).Debug("Marking connection invalid on error") + if err := conn.sendDiscoveryRequest(ctx, req); err != nil { + logger.DebugContext(ctx, "Marking connection invalid on error", "error", err) conn.markInvalid(err) return } case req := <-reqC: if req == nil { - logger.Info("Cluster agent disconnected.") + logger.InfoContext(ctx, "Cluster agent disconnected") conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) if !s.HasValidConnections() { - logger.Debug("Deleting connection record.") + logger.DebugContext(ctx, "Deleting connection record") s.deleteConnectionRecord() } return @@ -463,9 +464,9 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + proxies, err := s.srv.proxyWatcher.CurrentResources(ctx) if err != nil { - logger.WithError(err).Warn("Failed to get proxy set") + logger.WarnContext(ctx, "Failed to get proxy set", "error", err) } if len(proxies) > 0 { conn.updateProxies(proxies) @@ -482,9 +483,9 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch pinglog := logger if roundtrip != 0 { - pinglog = pinglog.WithField("latency", roundtrip) + pinglog = pinglog.With("latency", roundtrip) } - pinglog.Debugf("Ping <- %v", conn.conn.RemoteAddr()) + pinglog.DebugContext(ctx, "Received ping request", "remote_addr", logutils.StringerAttr(conn.conn.RemoteAddr())) tm := s.clock.Now().UTC() conn.setLastHeartbeat(tm) @@ -498,11 +499,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch if t.After(hb.Add(s.offlineThreshold * missedHeartBeatThreshold)) { count := conn.activeSessions() if count == 0 { - logger.Errorf("Closing unhealthy and idle connection. Heartbeat last received at %s", hb) + logger.ErrorContext(ctx, "Closing unhealthy and idle connection", "last_heartbeat", hb) return } - logger.Warnf("Deferring closure of unhealthy connection due to %d active connections", count) + logger.WarnContext(ctx, "Deferring closure of unhealthy connection due to active connections", "active_conn_count", count) } offlineThresholdTimer.Reset(s.offlineThreshold) @@ -554,24 +555,24 @@ func (s *remoteSite) updateCertAuthorities(retry retryutils.Retry, remoteWatcher if err != nil { switch { case trace.IsNotFound(err): - s.logger.Debug("Remote cluster does not support cert authorities rotation yet.") + s.logger.DebugContext(s.ctx, "Remote cluster does not support cert authorities rotation yet") case trace.IsCompareFailed(err): - s.logger.Info("Remote cluster has updated certificate authorities, going to force reconnect.") + s.logger.InfoContext(s.ctx, "Remote cluster has updated certificate authorities, going to force reconnect") if err := s.srv.onSiteTunnelClose(&alwaysClose{RemoteSite: s}); err != nil { - s.logger.WithError(err).Warn("Failed to close remote site") + s.logger.WarnContext(s.ctx, "Failed to close remote site", "error", err) } return case trace.IsConnectionProblem(err): - s.logger.Debug("Remote cluster is offline.") + s.logger.DebugContext(s.ctx, "Remote cluster is offline") default: - s.logger.Warnf("Could not perform cert authorities update: %v.", trace.DebugReport(err)) + s.logger.WarnContext(s.ctx, "Could not perform cert authorities update", "error", err) } } startedWaiting := s.clock.Now() select { case t := <-retry.After(): - s.logger.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + s.logger.DebugContext(s.ctx, "Initiating new cert authority watch after applying backoff", "backoff_duration", t.Sub(startedWaiting)) retry.Inc() case <-s.ctx.Done(): return @@ -592,7 +593,7 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW } defer func() { if err := localWatch.Close(); err != nil { - s.logger.WithError(err).Warn("Failed to close local ca watcher subscription.") + s.logger.WarnContext(s.ctx, "Failed to close local ca watcher subscription", "error", err) } }() @@ -607,7 +608,7 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW } defer func() { if err := remoteWatch.Close(); err != nil { - s.logger.WithError(err).Warn("Failed to close remote ca watcher subscription.") + s.logger.WarnContext(s.ctx, "Failed to close remote ca watcher subscription", "error", err) } }() @@ -624,7 +625,7 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, ca); err != nil { return trace.Wrap(err, "failed to push local cert authority") } - s.logger.Debugf("Pushed local cert authority %v", caID.String()) + s.logger.DebugContext(s.ctx, "Pushed local cert authority", "cert_authority", logutils.StringerAttr(caID)) localCAs[caType] = ca } @@ -650,7 +651,7 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW // if CA is changed or does not exist, update backend if err != nil || !services.CertAuthoritiesEquivalent(oldRemoteCA, remoteCA) { - s.logger.Debugf("Ingesting remote cert authority %v", remoteCA.GetID()) + s.logger.DebugContext(s.ctx, "Ingesting remote cert authority", "cert_authority", logutils.StringerAttr(remoteCA.GetID())) if err := s.localClient.UpsertCertAuthority(s.ctx, remoteCA); err != nil { return trace.Wrap(err) } @@ -668,17 +669,16 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW return trace.Wrap(err) } - s.logger.Debugf("Watching for cert authority changes.") + s.logger.DebugContext(s.ctx, "Watching for cert authority changes") for { select { case <-s.ctx.Done(): - s.logger.WithError(s.ctx.Err()).Debug("Context is closing.") return trace.Wrap(s.ctx.Err()) case <-localWatch.Done(): - s.logger.Warn("Local CertAuthority watcher subscription has closed") + s.logger.WarnContext(s.ctx, "Local CertAuthority watcher subscription has closed") return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) case <-remoteWatch.Done(): - s.logger.Warn("Remote CertAuthority watcher subscription has closed") + s.logger.WarnContext(s.ctx, "Remote CertAuthority watcher subscription has closed") return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) case evt := <-localWatch.Events(): switch evt.Type { @@ -699,7 +699,7 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW // TODO(espadolini): figure out who should be responsible for validating the CA *once* newCA = newCA.Clone() if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, newCA); err != nil { - log.WithError(err).Warn("Failed to rotate external ca") + s.logger.WarnContext(s.ctx, "Failed to rotate external ca", "error", err) return trace.Wrap(err) } @@ -724,13 +724,13 @@ func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityW } func (s *remoteSite) updateLocks(retry retryutils.Retry) { - s.logger.Debugf("Watching for remote lock changes.") + s.logger.DebugContext(s.ctx, "Watching for remote lock changes") for { startedWaiting := s.clock.Now() select { case t := <-retry.After(): - s.logger.Debugf("Initiating new lock watch after waiting %v.", t.Sub(startedWaiting)) + s.logger.DebugContext(s.ctx, "Initiating new lock watch after applying backoff", "backoff_duration", t.Sub(startedWaiting)) retry.Inc() case <-s.ctx.Done(): return @@ -739,11 +739,11 @@ func (s *remoteSite) updateLocks(retry retryutils.Retry) { if err := s.watchLocks(); err != nil { switch { case trace.IsNotImplemented(err): - s.logger.Debugf("Remote cluster %v does not support locks yet.", s.domainName) + s.logger.DebugContext(s.ctx, "Remote cluster does not support locks yet", "cluster", s.domainName) case trace.IsConnectionProblem(err): - s.logger.Debugf("Remote cluster %v is offline.", s.domainName) + s.logger.DebugContext(s.ctx, "Remote cluster is offline", "cluster", s.domainName) default: - s.logger.WithError(err).Warn("Could not update remote locks.") + s.logger.WarnContext(s.ctx, "Could not update remote locks", "error", err) } } } @@ -752,22 +752,21 @@ func (s *remoteSite) updateLocks(retry retryutils.Retry) { func (s *remoteSite) watchLocks() error { watcher, err := s.srv.LockWatcher.Subscribe(s.ctx) if err != nil { - s.logger.WithError(err).Error("Failed to subscribe to LockWatcher") + s.logger.ErrorContext(s.ctx, "Failed to subscribe to LockWatcher", "error", err) return err } defer func() { if err := watcher.Close(); err != nil { - s.logger.WithError(err).Warn("Failed to close lock watcher subscription.") + s.logger.WarnContext(s.ctx, "Failed to close lock watcher subscription", "error", err) } }() for { select { case <-watcher.Done(): - s.logger.WithError(watcher.Error()).Warn("Lock watcher subscription has closed") + s.logger.WarnContext(s.ctx, "Lock watcher subscription has closed", "error", watcher.Error()) return trace.Wrap(watcher.Error()) case <-s.ctx.Done(): - s.logger.WithError(s.ctx.Err()).Debug("Context is closing.") return trace.Wrap(s.ctx.Err()) case evt := <-watcher.Events(): switch evt.Type { @@ -822,7 +821,10 @@ func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, erro } func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, error) { - s.logger.Debugf("Dialing from %v to %v.", params.From, params.To) + s.logger.DebugContext(s.ctx, "Initiating dial request", + "source_addr", logutils.StringerAttr(params.From), + "target_addr", logutils.StringerAttr(params.To), + ) conn, err := s.connThroughTunnel(&sshutils.DialReq{ Address: params.To.String(), @@ -843,7 +845,10 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne if params.GetUserAgent == nil && !params.IsAgentlessNode { return nil, trace.BadParameter("user agent getter is required for teleport nodes") } - s.logger.Debugf("Dialing and forwarding from %v to %v.", params.From, params.To) + s.logger.DebugContext(s.ctx, "Initiating dial and forward request", + "source_addr", logutils.StringerAttr(params.From), + "target_addr", logutils.StringerAttr(params.To), + ) // request user agent connection if a SSH user agent is set var userAgent teleagent.Agent @@ -930,7 +935,7 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne // UseTunnel makes a channel request asking for the type of connection. If // the other side does not respond (older cluster) or takes to long to // respond, be on the safe side and assume it's not a tunnel connection. -func UseTunnel(logger *log.Entry, c *sshutils.ChConn) bool { +func UseTunnel(logger *slog.Logger, c *sshutils.ChConn) bool { responseCh := make(chan bool, 1) go func() { @@ -946,13 +951,16 @@ func UseTunnel(logger *log.Entry, c *sshutils.ChConn) bool { case response := <-responseCh: return response case <-time.After(1 * time.Second): - logger.Debugf("Timed out waiting for response: returning false.") + logger.DebugContext(context.Background(), "Timed out waiting for response: returning false") return false } } func (s *remoteSite) connThroughTunnel(req *sshutils.DialReq) (*sshutils.ChConn, error) { - s.logger.Debugf("Requesting connection to %v [%v] in remote cluster.", req.Address, req.ServerID) + s.logger.DebugContext(s.ctx, "Requesting connection in remote cluster.", + "target_address", req.Address, + "target_server_id", req.ServerID, + ) // Loop through existing remote connections and try and establish a // connection over the "reverse tunnel". @@ -963,7 +971,7 @@ func (s *remoteSite) connThroughTunnel(req *sshutils.DialReq) (*sshutils.ChConn, if err == nil { return conn, nil } - s.logger.WithError(err).Warn("Request for connection to remote site failed") + s.logger.WarnContext(s.ctx, "Request for connection to remote site failed", "error", err) } // Didn't connect and no error? This means we didn't have any connected diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 10591e2042bdd..eb7483eec6477 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -55,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) var ( @@ -111,8 +111,8 @@ type server struct { // ctx is a context used for signaling and broadcast ctx context.Context - // log specifies the logger - log log.FieldLogger + // logger specifies the logger + logger *slog.Logger // proxyWatcher monitors changes to the proxies // and broadcasts updates @@ -186,10 +186,6 @@ type Config struct { // Component is a component used in logs Component string - // Log specifies the logger - // TODO(tross): remove this once Logger is used everywhere - Log log.FieldLogger - // Logger specifies the logger Logger *slog.Logger @@ -265,10 +261,6 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.Component == "" { cfg.Component = teleport.Component(teleport.ComponentProxy, teleport.ComponentServer) } - if cfg.Log == nil { - cfg.Log = log.StandardLogger() - } - cfg.Log = cfg.Log.WithField(teleport.ComponentKey, cfg.Component) if cfg.Logger == nil { cfg.Logger = slog.Default() @@ -331,7 +323,7 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) { cancel: cancel, proxyWatcher: proxyWatcher, clusterPeers: make(map[string]*clusterPeers), - log: cfg.Log, + logger: cfg.Logger, offlineThreshold: offlineThreshold, proxySigner: cfg.PROXYSigner, } @@ -384,9 +376,9 @@ func remoteClustersMap(rc []types.RemoteCluster) map[string]types.RemoteCluster func (s *server) disconnectClusters(connectedRemoteClusters []*remoteSite, remoteMap map[string]types.RemoteCluster) error { for _, cluster := range connectedRemoteClusters { if _, ok := remoteMap[cluster.GetName()]; !ok { - s.log.Infof("Remote cluster %q has been deleted. Disconnecting it from the proxy.", cluster.GetName()) + s.logger.InfoContext(s.ctx, "Remote cluster has been deleted, disconnecting it from the proxy", "remote_cluster", cluster.GetName()) if err := s.onSiteTunnelClose(&alwaysClose{RemoteSite: cluster}); err != nil { - s.log.Debugf("Failure closing cluster %q: %v.", cluster.GetName(), err) + s.logger.DebugContext(s.ctx, "Failure closing cluster", "remote_cluster", cluster.GetName(), "error", err) } remoteClustersStats.DeleteLabelValues(cluster.GetName()) } @@ -399,36 +391,36 @@ func (s *server) periodicFunctions() { defer ticker.Stop() if err := s.fetchClusterPeers(); err != nil { - s.log.Warningf("Failed to fetch cluster peers: %v.", err) + s.logger.WarnContext(s.Context, "Failed to fetch cluster peers", "error", err) } for { select { case <-s.ctx.Done(): - s.log.Debugf("Closing.") + s.logger.DebugContext(s.ctx, "Closing") return // Proxies have been updated, notify connected agents about the update. case proxies := <-s.proxyWatcher.ResourcesC: s.fanOutProxies(proxies) case <-ticker.C: if err := s.fetchClusterPeers(); err != nil { - s.log.WithError(err).Warn("Failed to fetch cluster peers") + s.logger.WarnContext(s.ctx, "Failed to fetch cluster peers", "error", err) } connectedRemoteClusters := s.getRemoteClusters() remoteClusters, err := s.localAccessPoint.GetRemoteClusters(s.ctx) if err != nil { - s.log.WithError(err).Warn("Failed to get remote clusters") + s.logger.WarnContext(s.ctx, "Failed to get remote clusters", "error", err) } remoteMap := remoteClustersMap(remoteClusters) if err := s.disconnectClusters(connectedRemoteClusters, remoteMap); err != nil { - s.log.Warningf("Failed to disconnect clusters: %v.", err) + s.logger.WarnContext(s.ctx, "Failed to disconnect clusters", "error", err) } if err := s.reportClusterStats(connectedRemoteClusters, remoteMap); err != nil { - s.log.Warningf("Failed to report cluster stats: %v.", err) + s.logger.WarnContext(s.ctx, "Failed to report cluster stats", "error", err) } } } @@ -555,11 +547,11 @@ func (s *server) removeClusterPeers(conns []types.TunnelConnection) { for _, conn := range conns { peers, ok := s.clusterPeers[conn.GetClusterName()] if !ok { - s.log.Warningf("failed to remove cluster peer, not found peers for %v.", conn) + s.logger.WarnContext(s.ctx, "failed to remove missing cluster peer", "tunnel_connection", logutils.StringerAttr(conn)) continue } peers.removePeer(conn) - s.log.Debugf("Removed cluster peer %v.", conn) + s.logger.DebugContext(s.ctx, "Removed cluster peer", "tunnel_connection", logutils.StringerAttr(conn)) } } @@ -620,11 +612,11 @@ func (s *server) DrainConnections(ctx context.Context) error { // Ensure listener is closed before sending reconnects. err := s.srv.Close() s.RLock() - s.log.Debugf("Advising reconnect to local site: %s", s.localSite.GetName()) + s.logger.DebugContext(ctx, "Advising reconnect to local site", "local_site", s.localSite.GetName()) go s.localSite.adviseReconnect(ctx) for _, site := range s.remoteSites { - s.log.Debugf("Advising reconnect to remote site: %s", site.GetName()) + s.logger.DebugContext(ctx, "Advising reconnect to remote site", "remote_site", site.GetName()) go site.adviseReconnect(ctx) } s.RUnlock() @@ -650,7 +642,7 @@ func (s *server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont switch channelType { // Heartbeats can come from nodes or proxies. case chanHeartbeat: - s.handleHeartbeat(conn, sconn, nch) + s.handleHeartbeat(ctx, conn, sconn, nch) // Transport requests come from nodes requesting a connection to the Auth // Server through the reverse tunnel. case constants.ChanTransport: @@ -665,19 +657,20 @@ func (s *server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont if channelType == "session" { msg = "Cannot open new SSH session on reverse tunnel. Are you connecting to the right port?" } - s.log.Warn(msg) + //nolint:sloglint // message should be a constant but in this case we are creating it at runtime. + s.logger.WarnContext(ctx, msg) s.rejectRequest(nch, ssh.ConnectionFailed, msg) return } } func (s *server) handleTransport(sconn *ssh.ServerConn, nch ssh.NewChannel) { - s.log.Debug("Received transport request.") + s.logger.DebugContext(s.ctx, "Received transport request") channel, requestC, err := nch.Accept() if err != nil { sconn.Close() // avoid WithError to reduce log spam on network errors - s.log.Warnf("Failed to accept request: %v.", err) + s.logger.WarnContext(s.ctx, "Failed to accept request", "error", err) return } @@ -696,7 +689,7 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r return case <-time.After(apidefaults.DefaultIOTimeout): go ssh.DiscardRequests(reqC) - s.log.Warn("Timed out waiting for transport dial request.") + s.logger.WarnContext(s.ctx, "Timed out waiting for transport dial request") return case r, ok := <-reqC: if !ok { @@ -708,13 +701,12 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r dialReq := parseDialReq(req.Payload) if dialReq.Address != constants.RemoteAuthServer { - s.log.WithField("address", dialReq.Address). - Warn("Received dial request for unexpected address, routing to the auth server anyway.") + s.logger.WarnContext(s.ctx, "Received dial request for unexpected address, routing to the auth server anyway", "address", dialReq.Address) } authAddress := utils.ChooseRandomString(s.LocalAuthAddresses) if authAddress == "" { - s.log.Error("No auth servers configured.") + s.logger.ErrorContext(s.ctx, "No auth servers configured") fmt.Fprint(ch.Stderr(), "internal server error") req.Reply(false, nil) return @@ -726,7 +718,7 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r if s.proxySigner != nil && clientSrcAddr != nil && clientDstAddr != nil { h, err := s.proxySigner.SignPROXYHeader(clientSrcAddr, clientDstAddr) if err != nil { - s.log.WithError(err).Error("Failed to create signed PROXY header.") + s.logger.ErrorContext(s.ctx, "Failed to create signed PROXY header", "error", err) fmt.Fprint(ch.Stderr(), "internal server error") req.Reply(false, nil) } @@ -736,7 +728,7 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r d := net.Dialer{Timeout: apidefaults.DefaultIOTimeout} conn, err := d.DialContext(s.ctx, "tcp", authAddress) if err != nil { - s.log.Errorf("Failed to dial auth: %v.", err) + s.logger.ErrorContext(s.ctx, "Failed to dial auth", "error", err) fmt.Fprint(ch.Stderr(), "failed to dial auth server") req.Reply(false, nil) return @@ -745,7 +737,7 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r _ = conn.SetWriteDeadline(time.Now().Add(apidefaults.DefaultIOTimeout)) if _, err := conn.Write(proxyHeader); err != nil { - s.log.Errorf("Failed to send PROXY header: %v.", err) + s.logger.ErrorContext(s.ctx, "Failed to send PROXY header", "error", err) fmt.Fprint(ch.Stderr(), "failed to dial auth server") req.Reply(false, nil) return @@ -753,7 +745,7 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r _ = conn.SetWriteDeadline(time.Time{}) if err := req.Reply(true, nil); err != nil { - s.log.Errorf("Failed to respond to dial request: %v.", err) + s.logger.ErrorContext(s.ctx, "Failed to respond to dial request", "error", err) return } @@ -761,10 +753,10 @@ func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, r } // TODO(awly): unit test this -func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { - s.log.Debugf("New tunnel from %v.", sconn.RemoteAddr()) +func (s *server) handleHeartbeat(ctx context.Context, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { + s.logger.DebugContext(ctx, "New tunnel established", "remote_addr", logutils.StringerAttr(sconn.RemoteAddr())) if sconn.Permissions.Extensions[utils.ExtIntCertType] != utils.ExtIntCertTypeHost { - s.log.Error(trace.BadParameter("can't retrieve certificate type in certType")) + s.logger.ErrorContext(ctx, "can't retrieve certificate type in certtype@teleport extension") return } @@ -772,7 +764,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N // nodes it's a node dialing back. val, ok := sconn.Permissions.Extensions[extCertRole] if !ok { - s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole) + s.logger.ErrorContext(ctx, "Failed to accept connection, missing role extension") s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role") return } @@ -781,64 +773,64 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N switch role { // Node is dialing back. case types.RoleNode: - s.handleNewService(role, conn, sconn, nch, types.NodeTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.NodeTunnel) // App is dialing back. case types.RoleApp: - s.handleNewService(role, conn, sconn, nch, types.AppTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.AppTunnel) // Kubernetes service is dialing back. case types.RoleKube: - s.handleNewService(role, conn, sconn, nch, types.KubeTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.KubeTunnel) // Database proxy is dialing back. case types.RoleDatabase: - s.handleNewService(role, conn, sconn, nch, types.DatabaseTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.DatabaseTunnel) // Proxy is dialing back. case types.RoleProxy: - s.handleNewCluster(conn, sconn, nch) + s.handleNewCluster(ctx, conn, sconn, nch) case types.RoleWindowsDesktop: - s.handleNewService(role, conn, sconn, nch, types.WindowsDesktopTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.WindowsDesktopTunnel) case types.RoleOkta: - s.handleNewService(role, conn, sconn, nch, types.OktaTunnel) + s.handleNewService(ctx, role, conn, sconn, nch, types.OktaTunnel) // Unknown role. default: - s.log.Errorf("Unsupported role attempting to connect: %v", val) + s.logger.ErrorContext(ctx, "Unsupported role attempting to connect", "role", val) s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val)) } } -func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) { +func (s *server) handleNewService(ctx context.Context, role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) { cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType) if err != nil { - s.log.Errorf("Failed to upsert %s: %v.", role, err) + s.logger.ErrorContext(ctx, "Failed to upsert service connection", "role", role, "error", err) sconn.Close() return } ch, req, err := nch.Accept() if err != nil { - s.log.Errorf("Failed to accept on channel: %v.", err) + s.logger.ErrorContext(ctx, "Failed to accept on channel", "error", err) sconn.Close() return } - go cluster.handleHeartbeat(rconn, ch, req) + go cluster.handleHeartbeat(ctx, rconn, ch, req) } -func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ssh.NewChannel) { +func (s *server) handleNewCluster(ctx context.Context, conn net.Conn, sshConn *ssh.ServerConn, nch ssh.NewChannel) { // add the incoming site (cluster) to the list of active connections: site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn) if err != nil { - s.log.Error(trace.Wrap(err)) + s.logger.ErrorContext(ctx, "failed to upsert remote cluster connection", "error", err) s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection") return } // accept the request and start the heartbeat on it: ch, req, err := nch.Accept() if err != nil { - s.log.Error(trace.Wrap(err)) + s.logger.ErrorContext(ctx, "Failed to accept on channel", "error", err) sshConn.Close() return } - go site.handleHeartbeat(remoteConn, ch, req) + go site.handleHeartbeat(ctx, remoteConn, ch, req) } func (s *server) requireLocalAgentForConn(sconn *ssh.ServerConn, connType types.TunnelType) error { @@ -864,15 +856,15 @@ func (s *server) getTrustedCAKeysByID(id types.CertAuthID) ([]ssh.PublicKey, err } func (s *server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (perm *ssh.Permissions, err error) { - logger := s.log.WithFields(log.Fields{ - "remote": conn.RemoteAddr(), - "user": conn.User(), - }) + logger := s.logger.With( + "remote_addr", logutils.StringerAttr(conn.RemoteAddr()), + "user", conn.User(), + ) // The crypto/x/ssh package won't log the returned error for us, do it // manually. defer func() { if err != nil { - logger.Warnf("Failed to authenticate client, err: %v.", err) + logger.WarnContext(s.ctx, "Failed to authenticate client", "error", err) } }() @@ -920,7 +912,7 @@ func (s *server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (perm *ssh.Pe return nil, trace.BadParameter("unsupported cert type: %v.", cert.CertType) } - if err := s.checkClientCert(logger, conn.User(), clusterName, cert, caType); err != nil { + if err := s.checkClientCert(conn.User(), clusterName, cert, caType); err != nil { return nil, trace.Wrap(err) } return &ssh.Permissions{ @@ -935,7 +927,7 @@ func (s *server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (perm *ssh.Pe // checkClientCert verifies that client certificate is signed by the recognized // certificate authority. -func (s *server) checkClientCert(logger *log.Entry, user string, clusterName string, cert *ssh.Certificate, caType types.CertAuthType) error { +func (s *server) checkClientCert(user string, clusterName string, cert *ssh.Certificate, caType types.CertAuthType) error { // fetch keys of the certificate authority to check // if there is a match keys, err := s.getTrustedCAKeysByID(types.CertAuthID{ @@ -1024,7 +1016,10 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r } s.remoteSites = append(s.remoteSites, site) } - site.logger.Infof("Connection <- %v, clusters: %d.", conn.RemoteAddr(), len(s.remoteSites)) + site.logger.InfoContext(s.ctx, "Processed inbound connection from remote cluster", + "source_addr", logutils.StringerAttr(conn.RemoteAddr()), + "tunnel_count", len(s.remoteSites), + ) // treat first connection as a registered heartbeat, // otherwise the connection information will appear after initial // heartbeat delay @@ -1146,7 +1141,7 @@ func (s *server) fanOutProxies(proxies []types.Server) { func (s *server) rejectRequest(ch ssh.NewChannel, reason ssh.RejectionReason, msg string) { if err := ch.Reject(reason, msg); err != nil { - s.log.Warnf("Failed rejecting new channel request: %v", err) + s.logger.WarnContext(s.ctx, "Failed rejecting new channel request", "error", err) } } @@ -1182,12 +1177,10 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, srv: srv, domainName: domainName, connInfo: connInfo, - logger: log.WithFields(log.Fields{ - teleport.ComponentKey: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: log.Fields{ - "cluster": domainName, - }, - }), + logger: slog.With( + teleport.ComponentKey, teleport.ComponentReverseTunnelServer, + "cluster", domainName, + ), ctx: closeContext, cancel: cancel, clock: srv.Clock, diff --git a/lib/reversetunnel/srv_test.go b/lib/reversetunnel/srv_test.go index 327b194ac6ae7..2477739df359a 100644 --- a/lib/reversetunnel/srv_test.go +++ b/lib/reversetunnel/srv_test.go @@ -29,7 +29,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -64,7 +63,7 @@ func TestServerKeyAuth(t *testing.T) { require.NoError(t, err) s := &server{ - log: utils.NewLoggerForTests(), + logger: utils.NewSlogLoggerForTests(), Config: Config{Clock: clockwork.NewRealClock()}, localAccessPoint: mockAccessPoint{ ca: ca, @@ -204,8 +203,8 @@ func TestOnlyAuthDial(t *testing.T) { badListenerAddr := acceptAndCloseListener(t, true) srv := &server{ - log: logrus.StandardLogger(), - ctx: ctx, + logger: utils.NewSlogLoggerForTests(), + ctx: ctx, Config: Config{ LocalAuthAddresses: []string{goodListenerAddr}, }, diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 2f3d5cfc697f4..b3d338e9bb62b 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -23,12 +23,12 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net" "net/netip" "time" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -61,7 +61,7 @@ func parseDialReq(payload []byte) *sshutils.DialReq { // transport is used to build a connection to the target host. type transport struct { component string - log logrus.FieldLogger + logger *slog.Logger closeContext context.Context authClient authclient.ProxyAccessPoint authServers []string @@ -126,7 +126,7 @@ func (p *transport) start() { return } case <-time.After(apidefaults.DefaultIOTimeout): - p.log.Warnf("Transport request failed: timed out waiting for request.") + p.logger.WarnContext(p.closeContext, "Transport request failed: timed out waiting for request") return } @@ -140,8 +140,12 @@ func (p *transport) start() { if !p.forwardClientAddress { // This shouldn't happen in normal operation. Either malicious user or misconfigured client. if dreq.ClientSrcAddr != "" || dreq.ClientDstAddr != "" { - p.log.Warnf("Received unexpected dial request with client source address %q, "+ - "client destination address %q, when they should be empty.", dreq.ClientSrcAddr, dreq.ClientDstAddr) + const msg = "Received unexpected dial request with client source address and " + + "client destination address populated, when they should be empty." + p.logger.WarnContext(p.closeContext, msg, + "client_src_addr", dreq.ClientSrcAddr, + "client_dest_addr", dreq.ClientDstAddr, + ) } // Make sure address fields are overwritten. @@ -154,7 +158,11 @@ func (p *transport) start() { } } - p.log.Debugf("Received out-of-band proxy transport request for %v [%v], from %v.", dreq.Address, dreq.ServerID, dreq.ClientSrcAddr) + p.logger.DebugContext(p.closeContext, "Received out-of-band proxy transport request", + "target_address", dreq.Address, + "taget_server_id", dreq.ServerID, + "client_addr", dreq.ClientSrcAddr, + ) // directAddress will hold the address of the node to dial to, if we don't // have a tunnel for it. @@ -165,7 +173,7 @@ func (p *transport) start() { // Connect to an Auth Server. case reversetunnelclient.RemoteAuthServer: if len(p.authServers) == 0 { - p.log.Errorf("connection rejected: no auth servers configured") + p.logger.ErrorContext(p.closeContext, "connection rejected: no auth servers configured") p.reply(req, false, []byte("no auth servers configured")) return @@ -190,11 +198,14 @@ func (p *transport) start() { return } if err := req.Reply(true, []byte("Connected.")); err != nil { - p.log.Errorf("Failed responding OK to %q request: %v", req.Type, err) + p.logger.ErrorContext(p.closeContext, "Failed responding OK to request", + "request_type", req.Type, + "error", err, + ) return } - p.log.Debug("Handing off connection to a local kubernetes service") + p.logger.DebugContext(p.closeContext, "Handing off connection to a local kubernetes service") // If dreq has ClientSrcAddr we wrap connection var clientConn net.Conn = sshutils.NewChConn(p.sconn, p.channel) @@ -211,7 +222,7 @@ func (p *transport) start() { p.reply(req, false, []byte("connection rejected: configure kubernetes proxy for this cluster.")) return } - p.log.Debugf("Forwarding connection to %q", p.kubeDialAddr.Addr) + p.logger.DebugContext(p.closeContext, "Forwarding connection to kubernetes proxy", "kube_proxy_addr", p.kubeDialAddr.Addr) directAddress = p.kubeDialAddr.Addr } @@ -227,17 +238,20 @@ func (p *transport) start() { if p.server != nil { if p.sconn == nil { - p.log.Debug("Connection rejected: server connection missing") + p.logger.DebugContext(p.closeContext, "Connection rejected: server connection missing") p.reply(req, false, []byte("connection rejected: server connection missing")) return } if err := req.Reply(true, []byte("Connected.")); err != nil { - p.log.Errorf("Failed responding OK to %q request: %v", req.Type, err) + p.logger.ErrorContext(p.closeContext, "Failed responding OK to request", + "request_type", req.Type, + "error", err, + ) return } - p.log.Debugf("Handing off connection to a local %q service.", dreq.ConnType) + p.logger.DebugContext(p.closeContext, "Handing off connection to a local service.", "conn_type", dreq.ConnType) // If dreq has ClientSrcAddr we wrap connection var clientConn net.Conn = sshutils.NewChConn(p.sconn, p.channel) @@ -294,13 +308,19 @@ func (p *transport) start() { // Dial was successful. if err := req.Reply(true, []byte("Connected.")); err != nil { - p.log.Errorf("Failed responding OK to %q request: %v", req.Type, err) + p.logger.ErrorContext(p.closeContext, "Failed responding OK to request", + "request_type", req.Type, + "error", err, + ) if err := conn.Close(); err != nil { - p.log.Errorf("Failed closing connection: %v", err) + p.logger.ErrorContext(p.closeContext, "Failed closing connection", "error", err) } return } - p.log.Debugf("Successfully dialed to %v %q, start proxying.", dreq.Address, dreq.ServerID) + p.logger.DebugContext(p.closeContext, "Successfully dialed to target, starting to proxy", + "target_addr", dreq.Address, + "target_server_id", dreq.ServerID, + ) // Start processing channel requests. Pass in a context that wraps the passed // in context with a context that closes when this function returns to @@ -314,9 +334,9 @@ func (p *transport) start() { if len(signedHeader) > 0 { _, err = conn.Write(signedHeader) if err != nil { - p.log.Errorf("Could not write PROXY header to the connection: %v", err) + p.logger.ErrorContext(p.closeContext, "Could not write PROXY header to the connection", "error", err) if err := conn.Close(); err != nil { - p.log.Errorf("Failed closing connection: %v", err) + p.logger.ErrorContext(p.closeContext, "Failed closing connection", "error", err) } return } @@ -342,7 +362,7 @@ func (p *transport) start() { select { case <-errorCh: case <-p.closeContext.Done(): - p.log.Warnf("Proxy transport failed: closing context.") + p.logger.WarnContext(p.closeContext, "Proxy transport failed: closing context") return } } @@ -375,7 +395,7 @@ func (p *transport) handleChannelRequests(closeContext context.Context, useTunne func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, error) { // This function doesn't attempt to dial if a host with one of the // search names is not registered. It's a fast check. - p.log.Debugf("Attempting to dial through tunnel with server ID %q.", r.ServerID) + p.logger.DebugContext(p.closeContext, "Attempting to dial server through tunnel", "target_server_id", r.ServerID) conn, err := p.tunnelDial(r) if err != nil { if !trace.IsNotFound(err) { @@ -394,13 +414,13 @@ func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, e } errTun := err - p.log.Debugf("Attempting to dial directly %q.", addr) + p.logger.DebugContext(p.closeContext, "Attempting to dial server directly", "taget_addr", addr) conn, err = p.directDial(addr) if err != nil { return nil, false, trace.ConnectionProblem(err, "failed dialing through tunnel (%v) or directly (%v)", errTun, err) } - p.log.Debugf("Returning direct dialed connection to %q.", addr) + p.logger.DebugContext(p.closeContext, "Returning direct dialed connection", "target_addr", addr) // Requests to get a connection to the remote auth server do not provide a ConnType, // and since an empty ConnType is converted to [types.NodeTunnel] in CheckAndSetDefaults, @@ -413,7 +433,7 @@ func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, e return conn, false, nil } - p.log.Debugf("Returning connection dialed through tunnel with server ID %v.", r.ServerID) + p.logger.DebugContext(p.closeContext, "Returning connection to server dialed through tunnel", "target_server_id", r.ServerID) if r.ConnType == types.NodeTunnel { return proxy.NewProxiedMetricConn(conn), true, nil @@ -449,10 +469,10 @@ func (p *transport) tunnelDial(r *sshutils.DialReq) (net.Conn, error) { func (p *transport) reply(req *ssh.Request, ok bool, msg []byte) { if !ok { - p.log.Debugf("Non-ok reply to %q request: %s", req.Type, msg) + p.logger.DebugContext(p.closeContext, "Non-ok reply to request", "request_type", req.Type, "error", string(msg)) } if err := req.Reply(ok, msg); err != nil { - p.log.Warnf("Failed sending reply to %q request on SSH channel: %v", req.Type, err) + p.logger.WarnContext(p.closeContext, "Failed sending reply to request", "request_type", req.Type, "error", err) } } diff --git a/lib/reversetunnelclient/api_with_roles.go b/lib/reversetunnelclient/api_with_roles.go index 4b4eeff886871..ddeb7ff090f50 100644 --- a/lib/reversetunnelclient/api_with_roles.go +++ b/lib/reversetunnelclient/api_with_roles.go @@ -20,9 +20,9 @@ package reversetunnelclient import ( "context" + "log/slog" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" @@ -75,7 +75,7 @@ func (t *TunnelWithRoles) GetSites() ([]RemoteSite, error) { if !trace.IsNotFound(err) { return nil, trace.Wrap(err) } - logrus.Warningf("Skipping dangling cluster %q, no remote cluster resource found.", cluster.GetName()) + slog.WarnContext(ctx, "Skipping dangling cluster, no remote cluster resource found", "cluster", cluster.GetName()) continue } if err := t.accessChecker.CheckAccessToRemoteCluster(rc); err != nil { diff --git a/lib/service/service.go b/lib/service/service.go index fb2b70e19934c..39666c2aa1d91 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4443,7 +4443,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { PollingPeriod: process.Config.PollingPeriod, FIPS: cfg.FIPS, Emitter: streamEmitter, - Log: process.log, Logger: process.logger, LockWatcher: lockWatcher, PeerClient: peerClient, @@ -5019,9 +5018,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return nil }) - rcWatchLog := logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: teleport.Component(teleport.ComponentReverseTunnelAgent, process.id), - }) + rcWatchLog := process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelAgent, process.id)) // Create and register reverse tunnel AgentPool. rcWatcher, err := reversetunnel.NewRemoteClusterTunnelManager(reversetunnel.RemoteClusterTunnelManagerConfig{ @@ -5033,7 +5030,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { KubeDialAddr: utils.DialAddrFromListenAddr(kubeDialAddr(cfg.Proxy, clusterNetworkConfig.GetProxyListenerMode())), ReverseTunnelServer: tsrv, FIPS: process.Config.FIPS, - Log: rcWatchLog, + Logger: rcWatchLog, LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServerAddresses()), PROXYSigner: proxySigner, }) @@ -5042,7 +5039,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } process.RegisterCriticalFunc("proxy.reversetunnel.watcher", func() error { - rcWatchLog.Infof("Starting reverse tunnel agent pool.") + rcWatchLog.InfoContext(process.ExitContext(), "Starting reverse tunnel agent pool") done := make(chan struct{}) go func() { defer close(done)