Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lib/srv/regular to use slog #49792

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,14 @@ type ServerContext struct {
// the ServerContext is closed. The ctx parameter should be a child of the ctx
// associated with the scope of the parent ConnectionContext to ensure that
// cancellation of the ConnectionContext propagates to the ServerContext.
func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (context.Context, *ServerContext, error) {
func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) {
netConfig, err := srv.GetAccessPoint().GetClusterNetworkingConfig(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}
recConfig, err := srv.GetAccessPoint().GetSessionRecordingConfig(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, trace.Wrap(err)
}

cancelContext, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -453,7 +453,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
authPref, err := srv.GetAccessPoint().GetAuthPreference(ctx)
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

child.disconnectExpiredCert = getDisconnectExpiredCertFromIdentityContext(
Expand All @@ -475,7 +475,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
clusterName, err := srv.GetAccessPoint().GetClusterName()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

monitorConfig := MonitorConfig{
Expand Down Expand Up @@ -503,14 +503,14 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
err = StartMonitor(monitorConfig)
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}

// Create pipe used to send command to child process.
child.cmdr, child.cmdw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.cmdr)
child.AddCloser(child.cmdw)
Expand All @@ -519,7 +519,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
child.contr, child.contw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.contr)
child.AddCloser(child.contw)
Expand All @@ -528,20 +528,20 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
child.readyr, child.readyw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.readyr)
child.AddCloser(child.readyw)

child.killShellr, child.killShellw, err = os.Pipe()
if err != nil {
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
return nil, trace.NewAggregate(err, childErr)
}
child.AddCloser(child.killShellr)
child.AddCloser(child.killShellw)

return ctx, child, nil
return child, nil
}

// Parent grants access to the connection-level context of which this
Expand Down Expand Up @@ -948,6 +948,15 @@ func (c *ServerContext) String() string {
return fmt.Sprintf("ServerContext(%v->%v, user=%v, id=%v)", c.ServerConn.RemoteAddr(), c.ServerConn.LocalAddr(), c.ServerConn.User(), c.id)
}

func (c *ServerContext) LogValue() slog.Value {
return slog.GroupValue(
slog.String("remote_addr", c.ServerConn.RemoteAddr().String()),
slog.String("local_addr", c.ServerConn.LocalAddr().String()),
slog.String("user", c.ServerConn.User()),
slog.Int("id", c.id),
)
}

func getPAMConfig(c *ServerContext) (*PAMConfig, error) {
// PAM should be disabled.
if c.srv.Component() != teleport.ComponentNode {
Expand Down
6 changes: 3 additions & 3 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha

// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
if err := nch.Reject(ssh.ConnectionFailed, "failed to open server context"); err != nil {
s.log.WithError(err).Errorf("Error rejecting %s channel.", teleport.ChanForwardedTCPIP)
Expand Down Expand Up @@ -1070,7 +1070,7 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) {
func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
s.log.Errorf("Unable to create connection context: %v.", err)
s.stderrWrite(ch, "Unable to create connection context.")
Expand Down Expand Up @@ -1121,7 +1121,7 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) {
// There is no need for the forwarding server to initiate disconnects,
// based on teleport business logic, because this logic is already
// done on the server's terminating side.
ctx, scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext)
if err != nil {
s.log.Warnf("Server context setup failed: %v", err)
if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("server context setup failed: %v", err)); err != nil {
Expand Down
6 changes: 5 additions & 1 deletion lib/srv/regular/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
package regular

import (
"context"
"log/slog"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/srv"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

func FuzzParseProxySubsys(f *testing.F) {
Expand All @@ -40,12 +43,13 @@ func FuzzParseProxySubsys(f *testing.F) {
server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
}

ctx := &srv.ServerContext{}

require.NotPanics(t, func() {
parseProxySubsys(request, server, ctx)
server.parseProxySubsys(context.Background(), request, ctx)
})
})
}
65 changes: 30 additions & 35 deletions lib/srv/regular/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package regular
import (
"context"
"fmt"
"log/slog"
"net"
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -53,7 +53,7 @@ type proxySubsys struct {
proxySubsysRequest
router *proxy.Router
ctx *srv.ServerContext
log *logrus.Entry
logger *slog.Logger
closeC chan error
proxySigner PROXYHeaderSigner
localCluster string
Expand All @@ -68,8 +68,8 @@ type proxySubsys struct {
// "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername'
// "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername'
// "proxy:host:22@namespace@clustername"
func parseProxySubsysRequest(request string) (proxySubsysRequest, error) {
log.Debugf("parse_proxy_subsys(%q)", request)
func (s *Server) parseProxySubsysRequest(ctx context.Context, request string) (proxySubsysRequest, error) {
s.logger.DebugContext(ctx, "parsing proxy subsystem request", "request", request)
var (
clusterName string
targetHost string
Expand Down Expand Up @@ -124,12 +124,12 @@ func parseProxySubsysRequest(request string) (proxySubsysRequest, error) {

// parseProxySubsys decodes a proxy subsystem request and sets up a proxy subsystem instance.
// See parseProxySubsysRequest for details on the request format.
func parseProxySubsys(request string, srv *Server, ctx *srv.ServerContext) (*proxySubsys, error) {
req, err := parseProxySubsysRequest(request)
func (s *Server) parseProxySubsys(ctx context.Context, request string, serverContext *srv.ServerContext) (*proxySubsys, error) {
req, err := s.parseProxySubsysRequest(ctx, request)
if err != nil {
return nil, trace.Wrap(err)
}
subsys, err := newProxySubsys(ctx, srv, req)
subsys, err := newProxySubsys(ctx, serverContext, s, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -163,16 +163,17 @@ func (p *proxySubsysRequest) SetDefaults() {
// newProxySubsys is a helper that creates a proxy subsystem from
// a port forwarding request, used to implement ProxyJump feature in proxy
// and reuse the code
func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) {
func newProxySubsys(ctx context.Context, serverContext *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) {
req.SetDefaults()
if req.clusterName == "" && ctx.Identity.RouteToCluster != "" {
log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.",
ctx.Identity.TeleportUser, ctx.Identity.RouteToCluster,
if req.clusterName == "" && serverContext.Identity.RouteToCluster != "" {
srv.logger.DebugContext(ctx, "Proxy subsystem: routing user to cluster based on the route to cluster extension",
"user", serverContext.Identity.TeleportUser,
"cluster", serverContext.Identity.RouteToCluster,
)
req.clusterName = ctx.Identity.RouteToCluster
req.clusterName = serverContext.Identity.RouteToCluster
}
if req.clusterName != "" && srv.proxyTun != nil {
checker, err := srv.tunnelWithAccessChecker(ctx)
checker, err := srv.tunnelWithAccessChecker(serverContext)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -181,18 +182,15 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest)
return nil, trace.BadParameter("invalid format for proxy request: unknown cluster %q", req.clusterName)
}
}
log.Debugf("newProxySubsys(%v).", req)
srv.logger.DebugContext(ctx, "successfully created proxy subsystem request", "request", &req)
return &proxySubsys{
proxySubsysRequest: req,
ctx: ctx,
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentSubsystemProxy,
teleport.ComponentFields: map[string]string{},
}),
closeC: make(chan error),
router: srv.router,
proxySigner: srv.proxySigner,
localCluster: ctx.ClusterName,
ctx: serverContext,
logger: slog.With(teleport.ComponentKey, teleport.ComponentSubsystemProxy),
closeC: make(chan error),
router: srv.router,
proxySigner: srv.proxySigner,
localCluster: serverContext.ClusterName,
}, nil
}

Expand All @@ -205,15 +203,12 @@ func (t *proxySubsys) String() string {
// a mapping connection between a client & remote node we're proxying to)
func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Request, serverContext *srv.ServerContext) error {
// once we start the connection, update logger to include component fields
t.log = logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentSubsystemProxy,
teleport.ComponentFields: map[string]string{
"src": sconn.RemoteAddr().String(),
"dst": sconn.LocalAddr().String(),
"subsystem": t.String(),
},
})
t.log.Debugf("Starting subsystem")
t.logger = t.logger.With(
"src", sconn.RemoteAddr().String(),
"dst", sconn.LocalAddr().String(),
"subsystem", t.String(),
)
t.logger.DebugContext(ctx, "Starting subsystem")

clientAddr := sconn.RemoteAddr()

Expand All @@ -229,13 +224,13 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C
// proxyToSite establishes a proxy connection from the connected SSH client to the
// auth server of the requested remote site
func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterName string, clientSrcAddr, clientDstAddr net.Addr) error {
t.log.Debugf("Connecting from cluster %q to site: %q", t.localCluster, clusterName)
t.logger.DebugContext(ctx, "attempting to proxy connection to auth server", "local_cluster", t.localCluster, "proxied_cluster", clusterName)

conn, err := t.router.DialSite(ctx, clusterName, clientSrcAddr, clientDstAddr)
if err != nil {
return trace.Wrap(err)
}
t.log.Infof("Connected to cluster %v at %v", clusterName, conn.RemoteAddr())
t.logger.InfoContext(ctx, "Connected to cluster", "cluster", clusterName, "address", conn.RemoteAddr())

go func() {
t.close(utils.ProxyConn(ctx, ch, conn))
Expand All @@ -246,7 +241,7 @@ func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterNa
// proxyToHost establishes a proxy connection from the connected SSH client to the
// requested remote node (t.host:t.port) via the given site
func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, clientSrcAddr, clientDstAddr net.Addr) error {
t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort())
t.logger.DebugContext(ctx, "proxying connection to target host", "host", t.host, "port", t.port, "exact_port", t.SpecifiedPort())

authClient, err := t.router.GetSiteClient(ctx, t.localCluster)
if err != nil {
Expand Down
14 changes: 12 additions & 2 deletions lib/srv/regular/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
package regular

import (
"context"
"log/slog"
"testing"

"github.com/stretchr/testify/require"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/lib/srv"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

func TestParseProxyRequest(t *testing.T) {
Expand Down Expand Up @@ -86,6 +89,12 @@ func TestParseProxyRequest(t *testing.T) {
},
}

server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a test logger so that it gets the logs when we run the tests in verbose mode (-v)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that the logs would provide any additional/meaningful information for this test.

}

for i, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
if tt.expected.namespace == "" {
Expand All @@ -94,7 +103,7 @@ func TestParseProxyRequest(t *testing.T) {
// never actually be empty.
tt.expected.namespace = apidefaults.Namespace
}
req, err := parseProxySubsysRequest(tt.req)
req, err := server.parseProxySubsysRequest(context.Background(), tt.req)
require.NoError(t, err, "Test case %d: req=%s, expected=%+v", i, tt.req, tt.expected)
require.Equal(t, tt.expected, req, "Test case %d: req=%s, expected=%+v", i, tt.req, tt.expected)
})
Expand All @@ -107,6 +116,7 @@ func TestParseBadRequests(t *testing.T) {
server := &Server{
hostname: "redhorse",
proxyMode: true,
logger: slog.New(logutils.DiscardHandler{}),
}

ctx := &srv.ServerContext{}
Expand All @@ -122,7 +132,7 @@ func TestParseBadRequests(t *testing.T) {
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
subsystem, err := parseProxySubsys(tt.input, server, ctx)
subsystem, err := server.parseProxySubsys(context.Background(), tt.input, ctx)
require.Error(t, err, "test case: %q", tt.input)
require.Nil(t, subsystem, "test case: %q", tt.input)
})
Expand Down
Loading
Loading