Skip to content

Commit

Permalink
Convert lib/srv/regular to use slog
Browse files Browse the repository at this point in the history
This starts the transition of our SSH servers to use slog. The
regular.Server has been converted to use slog, though there
are still various logs emitted via the srv.ServerContext from the
server which use logrus.
  • Loading branch information
rosstimothy committed Dec 5, 2024
1 parent 0fef894 commit 2d814c6
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 218 deletions.
32 changes: 21 additions & 11 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net"
"os"
"strconv"
Expand Down Expand Up @@ -406,14 +407,14 @@ type ServerContext struct {
// the ServerContext is closed. The ctx parameter should be a child of the ctx
// associated with the scope of the parent ConnectionContext to ensure that
// cancellation of the ConnectionContext propagates to the ServerContext.
func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (context.Context, *ServerContext, error) {
func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) {
netConfig, err := srv.GetAccessPoint().GetClusterNetworkingConfig(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}
recConfig, err := srv.GetAccessPoint().GetSessionRecordingConfig(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

cancelContext, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -452,7 +453,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
authPref, err := srv.GetAccessPoint().GetAuthPreference(ctx)
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

child.disconnectExpiredCert = getDisconnectExpiredCertFromIdentityContext(
Expand All @@ -474,7 +475,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
clusterName, err := srv.GetAccessPoint().GetClusterName()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

monitorConfig := MonitorConfig{
Expand All @@ -500,14 +501,14 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
err = StartMonitor(monitorConfig)
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

// Create pipe used to send command to child process.
child.cmdr, child.cmdw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.cmdr)
child.AddCloser(child.cmdw)
Expand All @@ -516,7 +517,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
child.contr, child.contw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.contr)
child.AddCloser(child.contw)
Expand All @@ -525,20 +526,20 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
child.readyr, child.readyw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.readyr)
child.AddCloser(child.readyw)

child.killShellr, child.killShellw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.killShellr)
child.AddCloser(child.killShellw)

return ctx, child, nil
return child, nil
}

// Parent grants access to the connection-level context of which this
Expand Down Expand Up @@ -945,6 +946,15 @@ func (c *ServerContext) String() string {
return fmt.Sprintf("ServerContext(%v->%v, user=%v, id=%v)", c.ServerConn.RemoteAddr(), c.ServerConn.LocalAddr(), c.ServerConn.User(), c.id)
}

func (c *ServerContext) LogValue() slog.Value {
return slog.GroupValue(
slog.String("remote_addr", c.ServerConn.RemoteAddr().String()),
slog.String("local_addr", c.ServerConn.LocalAddr().String()),
slog.String("user", c.ServerConn.User()),
slog.Int("id", c.id),
)
}

func getPAMConfig(c *ServerContext) (*PAMConfig, error) {
// PAM should be disabled.
if c.srv.Component() != teleport.ComponentNode {
Expand Down
6 changes: 3 additions & 3 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha

// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
if err := nch.Reject(ssh.ConnectionFailed, "failed to open server context"); err != nil {
s.log.WithError(err).Errorf("Error rejecting %s channel.", teleport.ChanForwardedTCPIP)
Expand Down Expand Up @@ -1070,7 +1070,7 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) {
func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
s.log.Errorf("Unable to create connection context: %v.", err)
s.stderrWrite(ch, "Unable to create connection context.")
Expand Down Expand Up @@ -1121,7 +1121,7 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) {
// There is no need for the forwarding server to initiate disconnects,
// based on teleport business logic, because this logic is already
// done on the server's terminating side.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
s.log.Warnf("Server context setup failed: %v", err)
if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("server context setup failed: %v", err)); err != nil {
Expand Down
6 changes: 5 additions & 1 deletion lib/srv/regular/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
package regular

import (
"context"
"log/slog"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/srv"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

func FuzzParseProxySubsys(f *testing.F) {
Expand All @@ -40,12 +43,13 @@ func FuzzParseProxySubsys(f *testing.F) {
server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
}

ctx := &srv.ServerContext{}

require.NotPanics(t, func() {
parseProxySubsys(request, server, ctx)
server.parseProxySubsys(context.Background(), request, ctx)
})
})
}
65 changes: 30 additions & 35 deletions lib/srv/regular/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package regular
import (
"context"
"fmt"
"log/slog"
"net"
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -53,7 +53,7 @@ type proxySubsys struct {
proxySubsysRequest
router *proxy.Router
ctx *srv.ServerContext
log *logrus.Entry
logger *slog.Logger
closeC chan error
proxySigner PROXYHeaderSigner
localCluster string
Expand All @@ -68,8 +68,8 @@ type proxySubsys struct {
// "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername'
// "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername'
// "proxy:host:22@namespace@clustername"
func parseProxySubsysRequest(request string) (proxySubsysRequest, error) {
log.Debugf("parse_proxy_subsys(%q)", request)
func (s *Server) parseProxySubsysRequest(ctx context.Context, request string) (proxySubsysRequest, error) {
s.logger.DebugContext(ctx, "parsing proxy subsystem request", "request", request)
var (
clusterName string
targetHost string
Expand Down Expand Up @@ -124,12 +124,12 @@ func parseProxySubsysRequest(request string) (proxySubsysRequest, error) {

// parseProxySubsys decodes a proxy subsystem request and sets up a proxy subsystem instance.
// See parseProxySubsysRequest for details on the request format.
func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*proxySubsys, error) {
req, err := parseProxySubsysRequest(request)
func (s *Server) parseProxySubsys(ctx context.Context, request string, serverContext *srv.ServerContext) (*proxySubsys, error) {
req, err := s.parseProxySubsysRequest(ctx, request)
if err != nil {
return nil, trace.Wrap(err)
}
subsys, err := newProxySubsys(ctx, srv, req)
subsys, err := newProxySubsys(ctx, serverContext, s, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -163,16 +163,17 @@ func (p *proxySubsysRequest) SetDefaults() {
// newProxySubsys is a helper that creates a proxy subsystem from
// a port forwarding request, used to implement ProxyJump feature in proxy
// and reuse the code
func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) {
func newProxySubsys(ctx context.Context, serverContext *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) {
req.SetDefaults()
if req.clusterName == "" && ctx.Identity.RouteToCluster != "" {
log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.",
ctx.Identity.TeleportUser, ctx.Identity.RouteToCluster,
if req.clusterName == "" && serverContext.Identity.RouteToCluster != "" {
srv.logger.DebugContext(ctx, "Proxy subsystem: routing user to cluster based on the route to cluster extension",
"user", serverContext.Identity.TeleportUser,
"cluster", serverContext.Identity.RouteToCluster,
)
req.clusterName = ctx.Identity.RouteToCluster
req.clusterName = serverContext.Identity.RouteToCluster
}
if req.clusterName != "" && srv.proxyTun != nil {
checker, err := srv.tunnelWithAccessChecker(ctx)
checker, err := srv.tunnelWithAccessChecker(serverContext)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -181,18 +182,15 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest)
return nil, trace.BadParameter("invalid format for proxy request: unknown cluster %q", req.clusterName)
}
}
log.Debugf("newProxySubsys(%v).", req)
srv.logger.DebugContext(ctx, "successfully created proxy subsystem request", "request", &req)
return &proxySubsys{
proxySubsysRequest: req,
ctx: ctx,
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentSubsystemProxy,
teleport.ComponentFields: map[string]string{},
}),
closeC: make(chan error),
router: srv.router,
proxySigner: srv.proxySigner,
localCluster: ctx.ClusterName,
ctx: serverContext,
logger: slog.With(teleport.ComponentKey, teleport.ComponentSubsystemProxy),
closeC: make(chan error),
router: srv.router,
proxySigner: srv.proxySigner,
localCluster: serverContext.ClusterName,
}, nil
}

Expand All @@ -205,15 +203,12 @@ func (t *proxySubsys) String() string {
// a mapping connection between a client & remote node we're proxying to)
func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Request, serverContext *srv.ServerContext) error {
// once we start the connection, update logger to include component fields
t.log = logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentSubsystemProxy,
teleport.ComponentFields: map[string]string{
"src": sconn.RemoteAddr().String(),
"dst": sconn.LocalAddr().String(),
"subsystem": t.String(),
},
})
t.log.Debugf("Starting subsystem")
t.logger = t.logger.With(
"src", sconn.RemoteAddr().String(),
"dst", sconn.LocalAddr().String(),
"subsystem", t.String(),
)
t.logger.DebugContext(ctx, "Starting subsystem")

clientAddr := sconn.RemoteAddr()

Expand All @@ -229,13 +224,13 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C
// proxyToSite establishes a proxy connection from the connected SSH client to the
// auth server of the requested remote site
func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterName string, clientSrcAddr, clientDstAddr net.Addr) error {
t.log.Debugf("Connecting from cluster %q to site: %q", t.localCluster, clusterName)
t.logger.DebugContext(ctx, "attempting to proxy connection to auth server", "local_cluster", t.localCluster, "proxied_cluster", clusterName)

conn, err := t.router.DialSite(ctx, clusterName, clientSrcAddr, clientDstAddr)
if err != nil {
return trace.Wrap(err)
}
t.log.Infof("Connected to cluster %v at %v", clusterName, conn.RemoteAddr())
t.logger.InfoContext(ctx, "Connected to cluster", "cluster", clusterName, "address", conn.RemoteAddr())

go func() {
t.close(utils.ProxyConn(ctx, ch, conn))
Expand All @@ -246,7 +241,7 @@ func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterNa
// proxyToHost establishes a proxy connection from the connected SSH client to the
// requested remote node (t.host:t.port) via the given site
func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, clientSrcAddr, clientDstAddr net.Addr) error {
t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort())
t.logger.DebugContext(ctx, "proxying connection to target host", "host", t.host, "port", t.port, "exact_port", t.SpecifiedPort())

authClient, err := t.router.GetSiteClient(ctx, t.localCluster)
if err != nil {
Expand Down
14 changes: 12 additions & 2 deletions lib/srv/regular/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
package regular

import (
"context"
"log/slog"
"testing"

"github.com/stretchr/testify/require"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/lib/srv"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

func TestParseProxyRequest(t *testing.T) {
Expand Down Expand Up @@ -86,6 +89,12 @@ func TestParseProxyRequest(t *testing.T) {
},
}

server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
}

for i, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
if tt.expected.namespace == "" {
Expand All @@ -94,7 +103,7 @@ func TestParseProxyRequest(t *testing.T) {
// never actually be empty.
tt.expected.namespace = apidefaults.Namespace
}
req, err := parseProxySubsysRequest(tt.req)
req, err := server.parseProxySubsysRequest(context.Background(), tt.req)
require.NoError(t, err, "Test case %d: req=%s, expected=%+v", i, tt.req, tt.expected)
require.Equal(t, tt.expected, req, "Test case %d: req=%s, expected=%+v", i, tt.req, tt.expected)
})
Expand All @@ -107,6 +116,7 @@ func TestParseBadRequests(t *testing.T) {
server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
}

ctx := &srv.ServerContext{}
Expand All @@ -122,7 +132,7 @@ func TestParseBadRequests(t *testing.T) {
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
subsystem, err := parseProxySubsys(tt.input, server, ctx)
subsystem, err := server.parseProxySubsys(context.Background(), tt.input, ctx)
require.Error(t, err, "test case: %q", tt.input)
require.Nil(t, subsystem, "test case: %q", tt.input)
})
Expand Down
Loading

0 comments on commit 2d814c6

Please sign in to comment.