Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lib/reversetunnel to use slog #50362

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/constants"
Expand All @@ -42,6 +42,7 @@ import (
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnel/track"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

type AgentState string
Expand Down Expand Up @@ -113,8 +114,8 @@ type agentConfig struct {
// clock is use to get the current time. Mock clocks can be used for
// testing.
clock clockwork.Clock
// log is an optional logger.
log logrus.FieldLogger
// logger is an optional logger.
logger *slog.Logger
// localAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
localAuthAddresses []string
Expand Down Expand Up @@ -145,12 +146,13 @@ func (c *agentConfig) checkAndSetDefaults() error {
if c.clock == nil {
c.clock = clockwork.NewRealClock()
}
if c.log == nil {
c.log = logrus.New()
if c.logger == nil {
c.logger = slog.Default()
}
c.log = c.log.
WithField("leaseID", c.lease.ID()).
WithField("target", c.addr.String())
c.logger = c.logger.With(
"lease_id", c.lease.ID(),
"target", c.addr.String(),
)

return nil
}
Expand Down Expand Up @@ -284,7 +286,10 @@ func (a *agent) updateState(state AgentState) (AgentState, error) {

prevState := a.state
a.state = state
a.log.Debugf("Changing state %s -> %s.", prevState, state)
a.logger.DebugContext(a.ctx, "Agent state updated",
"previous_state", prevState,
"current_state", state,
)

if a.agentConfig.stateCallback != nil {
go a.agentConfig.stateCallback(a.state)
Expand All @@ -296,7 +301,7 @@ func (a *agent) updateState(state AgentState) (AgentState, error) {
// Start starts an agent returning after successfully connecting and sending
// the first heartbeat.
func (a *agent) Start(ctx context.Context) error {
a.log.Debugf("Starting agent %v", a.addr)
a.logger.DebugContext(ctx, "Starting agent", "addr", a.addr.FullAddress())

var err error
defer func() {
Expand Down Expand Up @@ -325,7 +330,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleGlobalRequests(a.ctx, a.client.GlobalRequests()); err != nil {
a.log.WithError(err).Debug("Failed to handle global requests.")
a.logger.DebugContext(a.ctx, "Failed to handle global requests", "error", err)
}
a.wg.Done()
a.Stop()
Expand All @@ -336,7 +341,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleDrainChannels(); err != nil {
a.log.WithError(err).Debug("Failed to handle drainable channels.")
a.logger.DebugContext(a.ctx, "Failed to handle drainable channels", "error", err)
}
a.wg.Done()
a.Stop()
Expand All @@ -345,7 +350,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleChannels(); err != nil {
a.log.WithError(err).Debug("Failed to handle channels.")
a.logger.DebugContext(a.ctx, "Failed to handle channels", "error", err)
}
a.wg.Done()
a.Stop()
Expand Down Expand Up @@ -460,23 +465,23 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
case versionRequest:
version, err := a.versionGetter.getVersion(ctx)
if err != nil {
a.log.WithError(err).Warnf("Failed to retrieve auth version in response to %v request.", r.Type)
a.logger.WarnContext(ctx, "Failed to retrieve auth version in response to x-teleport-version request", "error", err)
if err := a.client.Reply(r, false, []byte("Failed to retrieve auth version")); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err)
continue
}
}

if err := a.client.Reply(r, true, []byte(version)); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err)
continue
}
case reconnectRequest:
a.log.Debugf("Received reconnect advisory request from proxy.")
a.logger.DebugContext(ctx, "Received reconnect advisory request from proxy")
if r.WantReply {
err := a.client.Reply(r, true, nil)
if err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to [email protected] request", "error", err)
}
}

Expand All @@ -487,7 +492,7 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
// This handles keep-alive messages and matches the behavior of OpenSSH.
err := a.client.Reply(r, false, nil)
if err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to global request", "request_type", r.Type, "error", err)
continue
}
}
Expand Down Expand Up @@ -555,10 +560,10 @@ func (a *agent) handleDrainChannels() error {
bytes, _ := a.clock.Now().UTC().MarshalText()
_, err := a.hbChannel.SendRequest(a.ctx, "ping", false, bytes)
if err != nil {
a.log.Error(err)
a.logger.ErrorContext(a.ctx, "failed to send ping request", "error", err)
return trace.Wrap(err)
}
a.log.Debugf("Ping -> %v.", a.client.RemoteAddr())
a.logger.DebugContext(a.ctx, "Sent ping request", "target_addr", logutils.StringerAttr(a.client.RemoteAddr()))
// Handle transport requests.
case nch := <-a.transportC:
if nch == nil {
Expand All @@ -567,15 +572,15 @@ func (a *agent) handleDrainChannels() error {
if a.isDraining() {
err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining")
if err != nil {
a.log.WithError(err).Warningf("Failed to reject transport channel.")
a.logger.WarnContext(a.ctx, "Failed to reject transport channel", "error", err)
}
continue
}

a.log.Debugf("Transport request: %v.", nch.ChannelType())
a.logger.DebugContext(a.ctx, "Received trransport request", "channel_type", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Warningf("Failed to accept transport request: %v.", err)
a.logger.WarnContext(a.ctx, "Failed to accept transport request", "error", err)
continue
}

Expand All @@ -601,10 +606,10 @@ func (a *agent) handleChannels() error {
if nch == nil {
continue
}
a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType())
a.logger.DebugContext(a.ctx, "Discovery request channel opened", "channel_type", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Warningf("Failed to accept discovery channel request: %v.", err)
a.logger.WarnContext(a.ctx, "Failed to accept discovery channel request", "error", err)
continue
}

Expand All @@ -624,11 +629,11 @@ func (a *agent) handleChannels() error {
// ch : SSH channel which received "teleport-transport" out-of-band request
// reqC : request payload
func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
a.log.Debugf("handleDiscovery requests channel.")
a.logger.DebugContext(a.ctx, "handleDiscovery requests channel")
sshutils.DiscardChannelData(ch)
defer func() {
if err := ch.Close(); err != nil {
a.log.Warnf("Failed to close discovery channel: %v", err)
a.logger.WarnContext(a.ctx, "Failed to close discovery channel", "error", err)
}
}()

Expand All @@ -639,17 +644,17 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
return
case req = <-reqC:
if req == nil {
a.log.Infof("Connection closed, returning")
a.logger.InfoContext(a.ctx, "Connection closed, returning")
return
}

var r discoveryRequest
if err := json.Unmarshal(req.Payload, &r); err != nil {
a.log.WithError(err).Warn("Bad payload")
a.logger.WarnContext(a.ctx, "Received discovery request with bad payload", "error", err)
return
}

a.log.Debugf("Received discovery request: %s", &r)
a.logger.DebugContext(a.ctx, "Received discovery request", "discovery_request", logutils.StringerAttr(&r))
a.tracker.TrackExpected(r.TrackProxies()...)
}
}
Expand Down
10 changes: 5 additions & 5 deletions lib/reversetunnel/agent_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package reversetunnel

import (
"context"
"log/slog"
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

apidefaults "github.com/gravitational/teleport/api/defaults"
Expand Down Expand Up @@ -55,7 +55,7 @@ type agentDialer struct {
authMethods []ssh.AuthMethod
fips bool
options []proxy.DialerOptionFunc
log logrus.FieldLogger
logger *slog.Logger
isClaimed func(principals ...string) bool
}

Expand All @@ -65,7 +65,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
d.logger.DebugContext(ctx, "Failed to dial", "error", err, "target_addr", addr.Addr)
return nil, trace.Wrap(err)
}

Expand All @@ -75,7 +75,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) error {
if d.isClaimed != nil && d.isClaimed(c.ValidPrincipals...) {
d.log.Debugf("Aborting SSH handshake because the proxy %q is already claimed by some other agent.", c.ValidPrincipals[0])
d.logger.DebugContext(ctx, "Aborting SSH handshake because the proxy is already claimed by some other agent.", "proxy_id", c.ValidPrincipals[0])
// the error message must end with
// [proxyAlreadyClaimedError] to be recognized by
// [isProxyAlreadyClaimed]
Expand All @@ -88,7 +88,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
FIPS: d.fips,
})
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
d.logger.DebugContext(ctx, "Failed to create host key callback", "target_addr", addr.Addr, "error", err)
return nil, trace.Wrap(err)
}

Expand Down
3 changes: 1 addition & 2 deletions lib/reversetunnel/agent_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"testing"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

Expand Down Expand Up @@ -90,7 +89,7 @@ func TestAgentCertChecker(t *testing.T) {
dialer := agentDialer{
client: &fakeClient{caKey: ca.PublicKey()},
authMethods: []ssh.AuthMethod{ssh.PublicKeys(signer)},
log: logrus.New(),
logger: utils.NewSlogLoggerForTests(),
}

_, err = dialer.DialContext(context.Background(), *utils.MustParseAddr(sshServer.Addr()))
Expand Down
Loading
Loading