Skip to content

Commit

Permalink
addressing some PR feedback and adding regular/sshserver tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktate committed Dec 4, 2024
1 parent 6c7483d commit b897de1
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 54 deletions.
1 change: 0 additions & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3205,7 +3205,6 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
Roles: req.checker.RoleNames(),
CertificateFormat: certificateFormat,
PermitPortForwarding: req.checker.CanPortForward(),
SSHPortForwardMode: req.checker.SSHPortForwardMode(),
PermitAgentForwarding: req.checker.CanForwardAgents(),
PermitX11Forwarding: req.checker.PermitX11Forwarding(),
RouteToCluster: req.routeToCluster,
Expand Down
10 changes: 5 additions & 5 deletions lib/services/access_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ func TestSSHPortForwarding(t *testing.T) {
})

allow := newRole(func(rv *types.RoleV6) {
rv.SetName("all-allow")
rv.SetName("allow")
rv.SetOptions(types.RoleOptions{
SSHPortForwarding: &types.SSHPortForwarding{
Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(true)},
Expand All @@ -604,7 +604,7 @@ func TestSSHPortForwarding(t *testing.T) {
})

deny := newRole(func(rv *types.RoleV6) {
rv.SetName("all-deny")
rv.SetName("deny")
rv.SetOptions(types.RoleOptions{
SSHPortForwarding: &types.SSHPortForwarding{
Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(false)},
Expand All @@ -623,15 +623,15 @@ func TestSSHPortForwarding(t *testing.T) {
})

legacyDeny := newRole(func(rv *types.RoleV6) {
rv.SetName("legacy-allow")
rv.SetName("legacy-deny")
rv.SetOptions(types.RoleOptions{
PortForwarding: types.NewBoolOption(false),
})
rv.SetNodeLabels(types.Allow, anyLabels)
})

remoteAllow := newRole(func(rv *types.RoleV6) {
rv.SetName("remote-deny")
rv.SetName("remote-allow")
rv.SetOptions(types.RoleOptions{
SSHPortForwarding: &types.SSHPortForwarding{
Remote: &types.SSHRemotePortForwarding{Enabled: types.NewBoolOption(true)},
Expand All @@ -651,7 +651,7 @@ func TestSSHPortForwarding(t *testing.T) {
})

localAllow := newRole(func(rv *types.RoleV6) {
rv.SetName("local-deny")
rv.SetName("local-allow")
rv.SetOptions(types.RoleOptions{
SSHPortForwarding: &types.SSHPortForwarding{
Local: &types.SSHLocalPortForwarding{Enabled: types.NewBoolOption(true)},
Expand Down
28 changes: 21 additions & 7 deletions lib/services/presets.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,16 @@ func NewPresetEditorRole() types.Role {
Options: types.RoleOptions{
CertificateFormat: constants.CertificateFormatStandard,
MaxSessionTTL: types.NewDuration(apidefaults.MaxCertDuration),
PortForwarding: types.NewBoolOption(true),
ForwardAgent: types.NewBool(true),
BPF: apidefaults.EnhancedEvents(),
SSHPortForwarding: &types.SSHPortForwarding{
Remote: &types.SSHRemotePortForwarding{
Enabled: types.NewBoolOption(true),
},
Local: &types.SSHLocalPortForwarding{
Enabled: types.NewBoolOption(true),
},
},
ForwardAgent: types.NewBool(true),
BPF: apidefaults.EnhancedEvents(),
RecordSession: &types.RecordSession{
Desktop: types.NewBoolOption(false),
},
Expand Down Expand Up @@ -208,10 +215,17 @@ func NewPresetAccessRole() types.Role {
Options: types.RoleOptions{
CertificateFormat: constants.CertificateFormatStandard,
MaxSessionTTL: types.NewDuration(apidefaults.MaxCertDuration),
PortForwarding: types.NewBoolOption(true),
ForwardAgent: types.NewBool(true),
BPF: apidefaults.EnhancedEvents(),
RecordSession: &types.RecordSession{Desktop: types.NewBoolOption(true)},
SSHPortForwarding: &types.SSHPortForwarding{
Remote: &types.SSHRemotePortForwarding{
Enabled: types.NewBoolOption(true),
},
Local: &types.SSHLocalPortForwarding{
Enabled: types.NewBoolOption(true),
},
},
ForwardAgent: types.NewBool(true),
BPF: apidefaults.EnhancedEvents(),
RecordSession: &types.RecordSession{Desktop: types.NewBoolOption(true)},
},
Allow: types.RoleConditions{
Namespaces: []string{apidefaults.Namespace},
Expand Down
17 changes: 9 additions & 8 deletions lib/services/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,17 @@ func withWarningReporter(f func(error)) validateRoleOption {
}
}

// ValidateRole parses validates the role, and sets default values.
// ValidateRole parses, validates, and sets default values on a role.
func ValidateRole(r types.Role, opts ...validateRoleOption) error {
options := defaultValidateRoleOptions()
for _, opt := range opts {
opt(&options)
}

if r.GetOptions().SSHPortForwarding != nil && r.GetOptions().PortForwarding != nil {
return trace.BadParameter("options define both 'port_forwarding' and 'ssh_port_forwarding', only one can be set")
}

if err := CheckAndSetDefaults(r); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2825,6 +2829,7 @@ func (set RoleSet) CanForwardAgents() bool {
return false
}

// SSHPortForwardMode enumerates the possible SSH port forwarding modes available at a given time.
type SSHPortForwardMode int

const (
Expand Down Expand Up @@ -2863,14 +2868,11 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode {
config := role.GetOptions().SSHPortForwarding
// only consider legacy allows when config isn't provided on the same role
if config == nil {
// TODO (eriktate): remove legacy check in v20
//nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward
legacy := role.GetOptions().PortForwarding
if legacy != nil {
//nolint:staticcheck // this field is preserved for backwards compatibility, but shouldn't be used going forward
if legacy := role.GetOptions().PortForwarding; legacy != nil {
if legacy.Value {
return SSHPortForwardModeOn
}

legacyDeny = true
}

Expand All @@ -2880,7 +2882,6 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode {
if config.Remote != nil && config.Remote.Enabled != nil {
if !config.Remote.Enabled.Value {
denyRemote = true

}

// an explicit legacy deny is only possible if no explicit SSHPortForwarding config has been provided
Expand Down Expand Up @@ -2912,7 +2913,7 @@ func (set RoleSet) SSHPortForwardMode() SSHPortForwardMode {
}
}

// CanPortForward returns true if a role in the RoleSet allows port forwarding.
// CanPortForward returns true if the RoleSet allows both local and remote port forwarding.
func (set RoleSet) CanPortForward() bool {
return set.SSHPortForwardMode() == SSHPortForwardModeOn
}
Expand Down
13 changes: 13 additions & 0 deletions lib/services/role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,19 @@ func TestValidateRole(t *testing.T) {
},
},
},
{
name: "invalid port forwarding config",
spec: types.RoleSpecV6{
Options: types.RoleOptions{
PortForwarding: types.NewBoolOption(true),
SSHPortForwarding: &types.SSHPortForwarding{},
},
Allow: types.RoleConditions{
Logins: []string{`{{external["http://schemas.microsoft.com/ws/2008/06/identity/claims/windowsaccountname"]}}`},
},
},
expectError: trace.BadParameter("options define both 'port_forwarding' and 'ssh_port_forwarding', only one can be set"),
},
{
name: "invalid role condition login syntax",
spec: types.RoleSpecV6{
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ func (h *AuthHandlers) CheckFileCopying(ctx *ServerContext) error {
}

// CheckPortForward checks if port forwarding is allowed for the users RoleSet.
func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext, mode services.SSHPortForwardMode) error {
checkedMode := ctx.Identity.AccessChecker.SSHPortForwardMode()
if checkedMode == services.SSHPortForwardModeOn {
func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext, requestedMode services.SSHPortForwardMode) error {
allowedMode := ctx.Identity.AccessChecker.SSHPortForwardMode()
if allowedMode == services.SSHPortForwardModeOn {
return nil
}

if checkedMode == services.SSHPortForwardModeOff || checkedMode != mode {
if allowedMode == services.SSHPortForwardModeOff || allowedMode != requestedMode {
systemErrorMessage := fmt.Sprintf("port forwarding not allowed by role set: %v", ctx.Identity.AccessChecker.RoleNames())
userErrorMessage := "port forwarding not allowed"

Expand Down
24 changes: 19 additions & 5 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,18 @@ func (s *Server) checkTCPIPForwardRequest(r *ssh.Request) error {
return err
}

// RBAC checks should only necessary when connecting to an agentless node
if s.targetServer != nil && s.targetServer.IsOpenSSHNode() {
_, scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext)
if err != nil {
return err
}

if err := s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeRemote); err != nil {
return trace.Wrap(err)
}
}

return nil
}

Expand Down Expand Up @@ -1084,11 +1096,13 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r

ch = scx.TrackActivity(ch)

// Check if the role allows port forwarding for this user.
err = s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeLocal)
if err != nil {
s.stderrWrite(ch, err.Error())
return
// RBAC checks should only necessary when connecting to an agentless node
if s.targetServer != nil && s.targetServer.IsOpenSSHNode() {
err = s.authHandlers.CheckPortForward(scx.DstAddr, scx, services.SSHPortForwardModeLocal)
if err != nil {
s.stderrWrite(ch, err.Error())
return
}
}

s.log.Debugf("Opening direct-tcpip channel from %v to %v in context %v.", scx.SrcAddr, scx.DstAddr, scx.ID())
Expand Down
26 changes: 25 additions & 1 deletion lib/srv/forward/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -190,6 +191,7 @@ func TestDirectTCPIP(t *testing.T) {
cases := []struct {
name string
login string
accessChecker services.AccessChecker
expectAccepted bool
expectRejected bool
}{
Expand All @@ -211,6 +213,19 @@ func TestDirectTCPIP(t *testing.T) {
// which return errors on accept.
expectRejected: true,
},
{
name: "port forwarding denied",
login: func() string {
u, err := user.Current()
require.NoError(t, err)
return u.Username
}(),
accessChecker: &fakePortForwardChecker{mode: services.SSHPortForwardModeOff},
expectAccepted: false,
// expectRejected is set to true because we are using mock channel
// which return errors on accept.
expectRejected: true,
},
}

for _, tt := range cases {
Expand All @@ -220,7 +235,7 @@ func TestDirectTCPIP(t *testing.T) {

s := Server{
log: utils.NewLoggerForTests().WithField(teleport.ComponentKey, "test"),
identityContext: srv.IdentityContext{Login: tt.login},
identityContext: srv.IdentityContext{Login: tt.login, AccessChecker: tt.accessChecker},
}

nch := &newChannelMock{channelType: teleport.ChanDirectTCPIP}
Expand Down Expand Up @@ -271,5 +286,14 @@ func TestCheckTCPIPForward(t *testing.T) {
}
}

type fakePortForwardChecker struct {
services.AccessChecker
mode services.SSHPortForwardMode
}

func (f *fakePortForwardChecker) SSHPortForwardMode() services.SSHPortForwardMode {
return f.mode
}

// TODO(atburke): Add test for handleForwardedTCPIPRequest once we have
// infrastructure for higher-level tests here.
2 changes: 1 addition & 1 deletion lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ func (s *Server) canPortForward(scx *srv.ServerContext, mode services.SSHPortFor
// Check if the role allows port forwarding for this user.
err := s.authHandlers.CheckPortForward(scx.DstAddr, scx, mode)
if err != nil {
return err
return trace.Wrap(err)
}

return nil
Expand Down
Loading

0 comments on commit b897de1

Please sign in to comment.