From b897de11d8905995441b38b64a66d5e5ec5a31d6 Mon Sep 17 00:00:00 2001 From: Erik Tate Date: Wed, 4 Dec 2024 09:04:15 -0500 Subject: [PATCH] addressing some PR feedback and adding regular/sshserver tests --- lib/auth/auth.go | 1 - lib/services/access_checker_test.go | 10 +-- lib/services/presets.go | 28 ++++-- lib/services/role.go | 17 ++-- lib/services/role_test.go | 13 +++ lib/srv/authhandlers.go | 8 +- lib/srv/forward/sshserver.go | 24 +++-- lib/srv/forward/sshserver_test.go | 26 +++++- lib/srv/regular/sshserver.go | 2 +- lib/srv/regular/sshserver_test.go | 131 +++++++++++++++++++++++----- 10 files changed, 206 insertions(+), 54 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index a520709e9447e..e43b1a986b8da 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3205,7 +3205,6 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. Roles: req.checker.RoleNames(), CertificateFormat: certificateFormat, PermitPortForwarding: req.checker.CanPortForward(), - SSHPortForwardMode: req.checker.SSHPortForwardMode(), PermitAgentForwarding: req.checker.CanForwardAgents(), PermitX11Forwarding: req.checker.PermitX11Forwarding(), RouteToCluster: req.routeToCluster, diff --git a/lib/services/access_checker_test.go b/lib/services/access_checker_test.go index 9362e2c97700f..86b49d07d5cb0 100644 --- a/lib/services/access_checker_test.go +++ b/lib/services/access_checker_test.go @@ -593,7 +593,7 @@ func TestSSHPortForwarding(t *testing.T) { }) allow := newRole(func(rv *types.RoleV6) { - rv.SetName("all-allow") + rv.SetName("allow") rv.SetOptions(types.RoleOptions{ SSHPortForwarding: &types.SSHPortForwarding{ Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(true)}, @@ -604,7 +604,7 @@ func TestSSHPortForwarding(t *testing.T) { }) deny := newRole(func(rv *types.RoleV6) { - rv.SetName("all-deny") + rv.SetName("deny") rv.SetOptions(types.RoleOptions{ SSHPortForwarding: &types.SSHPortForwarding{ Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(false)}, @@ -623,7 +623,7 @@ func TestSSHPortForwarding(t *testing.T) { }) legacyDeny := newRole(func(rv *types.RoleV6) { - rv.SetName("legacy-allow") + rv.SetName("legacy-deny") rv.SetOptions(types.RoleOptions{ PortForwarding: types.NewBoolOption(false), }) @@ -631,7 +631,7 @@ func TestSSHPortForwarding(t *testing.T) { }) remoteAllow := newRole(func(rv *types.RoleV6) { - rv.SetName("remote-deny") + rv.SetName("remote-allow") rv.SetOptions(types.RoleOptions{ SSHPortForwarding: &types.SSHPortForwarding{ Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(true)}, @@ -651,7 +651,7 @@ func TestSSHPortForwarding(t *testing.T) { }) localAllow := newRole(func(rv *types.RoleV6) { - rv.SetName("local-deny") + rv.SetName("local-allow") rv.SetOptions(types.RoleOptions{ SSHPortForwarding: &types.SSHPortForwarding{ Local: &types.SSHLocalPortForwarding{Enabled: types.NewBoolOption(true)}, diff --git a/lib/services/presets.go b/lib/services/presets.go index 2601730179647..4cf83c21ed3e2 100644 --- a/lib/services/presets.go +++ b/lib/services/presets.go @@ -117,9 +117,16 @@ func NewPresetEditorRole() types.Role { Options: types.RoleOptions{ CertificateFormat: constants.CertificateFormatStandard, MaxSessionTTL: types.NewDuration(apidefaults.MaxCertDuration), - PortForwarding: types.NewBoolOption(true), - ForwardAgent: types.NewBool(true), - BPF: apidefaults.EnhancedEvents(), + SSHPortForwarding: &types.SSHPortForwarding{ + Remote: &types.SSHRemotePortForwarding{ + Enabled: types.NewBoolOption(true), + }, + Local: &types.SSHLocalPortForwarding{ + Enabled: types.NewBoolOption(true), + }, + }, + ForwardAgent: types.NewBool(true), + BPF: apidefaults.EnhancedEvents(), RecordSession: &types.RecordSession{ Desktop: types.NewBoolOption(false), }, @@ -208,10 +215,17 @@ func NewPresetAccessRole() types.Role { Options: types.RoleOptions{ CertificateFormat: constants.CertificateFormatStandard, MaxSessionTTL: types.NewDuration(apidefaults.MaxCertDuration), - PortForwarding: types.NewBoolOption(true), - ForwardAgent: types.NewBool(true), - BPF: apidefaults.EnhancedEvents(), - RecordSession: &types.RecordSession{Desktop: types.NewBoolOption(true)}, + SSHPortForwarding: &types.SSHPortForwarding{ + Remote: &types.SSHRemotePortForwarding{ + Enabled: types.NewBoolOption(true), + }, + Local: &types.SSHLocalPortForwarding{ + Enabled: types.NewBoolOption(true), + }, + }, + ForwardAgent: types.NewBool(true), + BPF: apidefaults.EnhancedEvents(), + RecordSession: &types.RecordSession{Desktop: types.NewBoolOption(true)}, }, Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, diff --git a/lib/services/role.go b/lib/services/role.go index 3b36f121d1329..607d34133ae75 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -248,13 +248,17 @@ func withWarningReporter(f func(error)) validateRoleOption { } } -// ValidateRole parses validates the role, and sets default values. +// ValidateRole parses, validates, and sets default values on a role. func ValidateRole(r types.Role, opts ...validateRoleOption) error { options := defaultValidateRoleOptions() for _, opt := range opts { opt(&options) } + if r.GetOptions().SSHPortForwarding != nil && r.GetOptions().PortForwarding != nil { + return trace.BadParameter("options define both 'port_forwarding' and 'ssh_port_forwarding', only one can be set") + } + if err := CheckAndSetDefaults(r); err != nil { return trace.Wrap(err) } @@ -2825,6 +2829,7 @@ func (set RoleSet) CanForwardAgents() bool { return false } +// SSHPortForwardMode enumerates the possible SSH port forwarding modes available at a given time. type SSHPortForwardMode int const ( @@ -2863,14 +2868,11 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode { config := role.GetOptions().SSHPortForwarding // only consider legacy allows when config isn't provided on the same role if config == nil { - // TODO (eriktate): remove legacy check in v20 - //nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward - legacy := role.GetOptions().PortForwarding - if legacy != nil { + //nolint:staticcheck // this field is preserved for backwards compatibility, but shouldn't be used going forward + if legacy := role.GetOptions().PortForwarding; legacy != nil { if legacy.Value { return SSHPortForwardModeOn } - legacyDeny = true } @@ -2880,7 +2882,6 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode { if config.Remote != nil && config.Remote.Enabled != nil { if !config.Remote.Enabled.Value { denyRemote = true - } // an explicit legacy deny is only possible if no explicit SSHPortForwarding config has been provided @@ -2912,7 +2913,7 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode { } } -// CanPortForward returns true if a role in the RoleSet allows port forwarding. +// CanPortForward returns true if the RoleSet allows both local and remote port forwarding. func (set RoleSet) CanPortForward() bool { return set.SSHPortForwardMode() == SSHPortForwardModeOn } diff --git a/lib/services/role_test.go b/lib/services/role_test.go index e518c9de75b85..5aa150b3682a9 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -832,6 +832,19 @@ func TestValidateRole(t *testing.T) { }, }, }, + { + name: "invalid port forwarding config", + spec: types.RoleSpecV6{ + Options: types.RoleOptions{ + PortForwarding: types.NewBoolOption(true), + SSHPortForwarding: &types.SSHPortForwarding{}, + }, + Allow: types.RoleConditions{ + Logins: []string{`{{external["http://schemas.microsoft.com/ws/2008/06/identity/claims/windowsaccountname"]}}`}, + }, + }, + expectError: trace.BadParameter("options define both 'port_forwarding' and 'ssh_port_forwarding', only one can be set"), + }, { name: "invalid role condition login syntax", spec: types.RoleSpecV6{ diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 0a05f02180a55..eed3b37c0a53f 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -262,13 +262,13 @@ func (h *AuthHandlers) CheckFileCopying(ctx *ServerContext) error { } // CheckPortForward checks if port forwarding is allowed for the users RoleSet. -func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext, mode services.SSHPortForwardMode) error { - checkedMode := ctx.Identity.AccessChecker.SSHPortForwardMode() - if checkedMode == services.SSHPortForwardModeOn { +func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext, requestedMode services.SSHPortForwardMode) error { + allowedMode := ctx.Identity.AccessChecker.SSHPortForwardMode() + if allowedMode == services.SSHPortForwardModeOn { return nil } - if checkedMode == services.SSHPortForwardModeOff || checkedMode != mode { + if allowedMode == services.SSHPortForwardModeOff || allowedMode != requestedMode { systemErrorMessage := fmt.Sprintf("port forwarding not allowed by role set: %v", ctx.Identity.AccessChecker.RoleNames()) userErrorMessage := "port forwarding not allowed" diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 53282eb70c3dd..475d8a113017e 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -1017,6 +1017,18 @@ func (s *Server) checkTCPIPForwardRequest(r *ssh.Request) error { return err } + // RBAC checks should only necessary when connecting to an agentless node + if s.targetServer != nil && s.targetServer.IsOpenSSHNode() { + _, scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext) + if err != nil { + return err + } + + if err := s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeRemote); err != nil { + return trace.Wrap(err) + } + } + return nil } @@ -1084,11 +1096,13 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r ch = scx.TrackActivity(ch) - // Check if the role allows port forwarding for this user. - err = s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeLocal) - if err != nil { - s.stderrWrite(ch, err.Error()) - return + // RBAC checks should only necessary when connecting to an agentless node + if s.targetServer != nil && s.targetServer.IsOpenSSHNode() { + err = s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeLocal) + if err != nil { + s.stderrWrite(ch, err.Error()) + return + } } s.log.Debugf("Opening direct-tcpip channel from %v to %v in context %v.", scx.SrcAddr, scx.DstAddr, scx.ID()) diff --git a/lib/srv/forward/sshserver_test.go b/lib/srv/forward/sshserver_test.go index 98e409e63b692..9e189bd923b90 100644 --- a/lib/srv/forward/sshserver_test.go +++ b/lib/srv/forward/sshserver_test.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -190,6 +191,7 @@ func TestDirectTCPIP(t *testing.T) { cases := []struct { name string login string + accessChecker services.AccessChecker expectAccepted bool expectRejected bool }{ @@ -211,6 +213,19 @@ func TestDirectTCPIP(t *testing.T) { // which return errors on accept. expectRejected: true, }, + { + name: "port forwarding denied", + login: func() string { + u, err := user.Current() + require.NoError(t, err) + return u.Username + }(), + accessChecker: &fakePortForwardChecker{mode: services.SSHPortForwardModeOff}, + expectAccepted: false, + // expectRejected is set to true because we are using mock channel + // which return errors on accept. + expectRejected: true, + }, } for _, tt := range cases { @@ -220,7 +235,7 @@ func TestDirectTCPIP(t *testing.T) { s := Server{ log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"), - identityContext: srv.IdentityContext{Login: tt.login}, + identityContext: srv.IdentityContext{Login: tt.login, AccessChecker: tt.accessChecker}, } nch := &newChannelMock{channelType: teleport.ChanDirectTCPIP} @@ -271,5 +286,14 @@ func TestCheckTCPIPForward(t *testing.T) { } } +type fakePortForwardChecker struct { + services.AccessChecker + mode services.SSHPortForwardMode +} + +func (f *fakePortForwardChecker) SSHPortForwardMode() services.SSHPortForwardMode { + return f.mode +} + // TODO(atburke): Add test for handleForwardedTCPIPRequest once we have // infrastructure for higher-level tests here. diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 14c9ca2f2fedd..de400ac756989 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1429,7 +1429,7 @@ func (s *Server) canPortForward(scx *srv.ServerContext, mode services.SSHPortFor // Check if the role allows port forwarding for this user. err := s.authHandlers.CheckPortForward(scx.DstAddr, scx, mode) if err != nil { - return err + return trace.Wrap(err) } return nil diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index c2f5990339255..339c4ebcf654a 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -769,15 +769,43 @@ func TestLockInForce(t *testing.T) { require.NoError(t, err) } +func setPortForwarding(t *testing.T, ctx context.Context, f *sshTestFixture, legacy, remote, local *types.BoolOption) { + roleName := services.RoleNameForUser(f.user) + role, err := f.testSrv.Auth().GetRole(ctx, roleName) + require.NoError(t, err) + roleOptions := role.GetOptions() + roleOptions.PermitX11Forwarding = types.NewBool(true) + roleOptions.ForwardAgent = types.NewBool(true) + //nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward + roleOptions.PortForwarding = legacy + + if remote != nil || local != nil { + roleOptions.SSHPortForwarding = &types.SSHPortForwarding{ + Remote: &types.SSHRemotePortForwarding{ + Enabled: remote, + }, + Local: &types.SSHLocalPortForwarding{ + Enabled: local, + }, + } + } + + role.SetOptions(roleOptions) + _, err = f.testSrv.Auth().UpsertRole(ctx, role) + require.NoError(t, err) +} + // TestDirectTCPIP ensures that the server can create a "direct-tcpip" // channel to the target address. The "direct-tcpip" channel is what port // forwarding is built upon. func TestDirectTCPIP(t *testing.T) { + ctx := context.Background() t.Parallel() f := newFixtureWithoutDiskBasedLogging(t) // Startup a test server that will reply with "hello, world\n" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello, world") })) defer ts.Close() @@ -786,26 +814,58 @@ func TestDirectTCPIP(t *testing.T) { u, err := url.Parse(ts.URL) require.NoError(t, err) - // Build a http.Client that will dial through the server to establish the - // connection. That's why a custom dialer is used and the dialer uses - // s.clt.Dial (which performs the "direct-tcpip" request). - httpClient := http.Client{ - Transport: &http.Transport{ - Dial: func(network string, addr string) (net.Conn, error) { - return f.ssh.clt.DialContext(context.Background(), "tcp", u.Host) + t.Run("Local forwarding is successful", func(t *testing.T) { + // Build a http.Client that will dial through the server to establish the + // connection. That's why a custom dialer is used and the dialer uses + // s.clt.Dial (which performs the "direct-tcpip" request). + httpClient := http.Client{ + Transport: &http.Transport{ + Dial: func(network string, addr string) (net.Conn, error) { + return f.ssh.clt.DialContext(context.Background(), "tcp", u.Host) + }, }, - }, - } + } - // Perform a HTTP GET to the test HTTP server through a "direct-tcpip" request. - resp, err := httpClient.Get(ts.URL) - require.NoError(t, err) - defer resp.Body.Close() + // Perform a HTTP GET to the test HTTP server through a "direct-tcpip" request. + resp, err := httpClient.Get(ts.URL) + require.NoError(t, err) + defer resp.Body.Close() - // Make sure the response is what was expected. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, []byte("hello, world\n"), body) + // Make sure the response is what was expected. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, []byte("hello, world\n"), body) + }) + + t.Run("Local forwarding fails when access is denied", func(t *testing.T) { + httpClient := http.Client{ + Transport: &http.Transport{ + Dial: func(network string, addr string) (net.Conn, error) { + return f.ssh.clt.DialContext(context.Background(), "tcp", u.Host) + }, + }, + } + + setPortForwarding(t, ctx, f, nil, nil, types.NewBoolOption(false)) + // Perform a HTTP GET to the test HTTP server through a "direct-tcpip" request. + _, err := httpClient.Get(ts.URL) + require.Error(t, err) + }) + + t.Run("Local forwarding fails when access is denied by legacy config", func(t *testing.T) { + httpClient := http.Client{ + Transport: &http.Transport{ + Dial: func(network string, addr string) (net.Conn, error) { + return f.ssh.clt.DialContext(context.Background(), "tcp", u.Host) + }, + }, + } + + setPortForwarding(t, ctx, f, types.NewBoolOption(false), nil, nil) + // Perform a HTTP GET to the test HTTP server through a "direct-tcpip" request. + _, err := httpClient.Get(ts.URL) + require.Error(t, err) + }) t.Run("SessionJoinPrincipal cannot use direct-tcpip", func(t *testing.T) { // Ensure that ssh client using SessionJoinPrincipal as Login, cannot @@ -832,8 +892,12 @@ func TestTCPIPForward(t *testing.T) { hostname, err := os.Hostname() require.NoError(t, err) tests := []struct { - name string - listenAddr string + name string + listenAddr string + legacyAllow *types.BoolOption + remoteAllow *types.BoolOption + localAllow *types.BoolOption + expectErr bool }{ { name: "localhost", @@ -847,14 +911,37 @@ func TestTCPIPForward(t *testing.T) { name: "hostname", listenAddr: hostname + ":0", }, + { + name: "remote deny", + listenAddr: "localhost:0", + remoteAllow: types.NewBoolOption(false), + expectErr: true, + }, + { + name: "legacy deny", + listenAddr: "localhost:0", + legacyAllow: types.NewBoolOption(false), + expectErr: true, + }, + { + name: "local deny", + listenAddr: "localhost:0", + localAllow: types.NewBoolOption(false), + expectErr: false, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { f := newFixtureWithoutDiskBasedLogging(t) - + setPortForwarding(t, context.Background(), f, tc.legacyAllow, tc.remoteAllow, tc.localAllow) // Request a listener from the server. listener, err := f.ssh.clt.Listen("tcp", tc.listenAddr) - require.NoError(t, err) + if tc.expectErr { + require.Error(t, err) + return + } else { + require.NoError(t, err) + } // Start up a test server that uses the port forwarded listener. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -870,7 +957,7 @@ func TestTCPIPForward(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) require.NoError(t, err) resp, err := ts.Client().Do(req) - require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, resp.Body.Close()) })