diff --git a/lib/auth/auth.go b/lib/auth/auth.go index e8fc39d45141a..689c0688a8747 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -2771,8 +2771,8 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. CertificateType: events.CertificateTypeUser, Identity: &eventIdentity, ClientMetadata: apievents.ClientMetadata{ - //TODO(greedy52) currently only user-agent from GRPC clients are - //fetched. Need to propagate user-agent from HTTP calls. + // TODO(greedy52) currently only user-agent from GRPC clients are + // fetched. Need to propagate user-agent from HTTP calls. UserAgent: trimUserAgent(metadata.UserAgentFromContext(ctx)), }, }); err != nil { @@ -5566,6 +5566,21 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck return nil, trace.BadParameter("empty Login field") } + // state.MFARequired is "per-role", so if the user is joining + // a session, MFA is required no matter what node they are + // connecting to. We don't preform an RBAC check like we do + // below when users are starting a session to selectively + // require MFA because we don't know what session the user + // is joining, nor do we know what role allowed the session + // creator to start the session that is attempting to be joined. + // We need this info to be able to selectively skip MFA in + // this case. + if t.Node.Login == teleport.SSHSessionJoinPrincipal { + return &proto.IsMFARequiredResponse{ + Required: true, + }, nil + } + // Find the target node and check whether MFA is required. matches, err := client.GetResourcesWithFilters(ctx, a, proto.ListResourcesRequest{ ResourceType: types.KindNode, diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index f3df021b73c45..b7fd957b9e6db 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -20,10 +20,12 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" @@ -339,6 +341,203 @@ func TestCreateAuthenticateChallenge_WithRecoveryStartToken(t *testing.T) { } } +func TestCreateAuthenticateChallenge_mfaVerification(t *testing.T) { + t.Parallel() + + testServer := newTestTLSServer(t) + ctx := context.Background() + + adminClient, err := testServer.NewClient(TestBuiltin(types.RoleAdmin)) + require.NoError(t, err, "NewClient(types.RoleAdmin)") + + // Register a couple of SSH nodes. + registerNode := func(node, env string) error { + _, err := adminClient.UpsertNode(ctx, &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: uuid.NewString(), + Labels: map[string]string{ + "env": env, + }, + }, + Spec: types.ServerSpecV2{ + Hostname: node, + }, + }) + return err + } + const devNode = "node1" + const prodNode = "node2" + require.NoError(t, registerNode(devNode, "dev"), "registerNode(%q)", devNode) + require.NoError(t, registerNode(prodNode, "prod"), "registerNode(%q)", prodNode) + + // Create an MFA required role for "prod" nodes. + prodRole, err := types.NewRole("prod_access", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + Logins: []string{"{{internal.logins}}"}, + NodeLabels: types.Labels{ + "env": []string{"prod"}, + }, + }, + }) + require.NoError(t, err, "NewRole(prod)") + err = adminClient.UpsertRole(ctx, prodRole) + require.NoError(t, err, "UpsertRole(%q)", prodRole.GetName()) + + // Create a role that requires MFA when joining sessions + joinMFARole, err := types.NewRole("mfa_session_join", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + Logins: []string{"{{internal.logins}}"}, + NodeLabels: types.Labels{ + "env": []string{"*"}, + }, + JoinSessions: []*types.SessionJoinPolicy{ + { + Name: "session_join", + Roles: []string{"access"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }, + }, + }, + }) + require.NoError(t, err, "NewRole(joinMFA)") + err = adminClient.UpsertRole(ctx, joinMFARole) + require.NoError(t, err, "UpsertRole(%q)", joinMFARole.GetName()) + + // Create a role that doesn't require MFA when joining sessions + joinNoMFARole, err := types.NewRole("no_mfa_session_join", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{"{{internal.logins}}"}, + NodeLabels: types.Labels{ + "env": []string{"*"}, + }, + JoinSessions: []*types.SessionJoinPolicy{ + { + Name: "session_join", + Roles: []string{"access"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }, + }, + }, + }) + require.NoError(t, err, "NewRole(joinNoMFA)") + err = adminClient.UpsertRole(ctx, joinNoMFARole) + require.NoError(t, err, "UpsertRole(%q)", joinNoMFARole.GetName()) + + const normalLogin = "llama" + createUser := func(role types.Role) *Client { + // Create a user with MFA devices... + userCreds, err := createUserWithSecondFactors(testServer) + require.NoError(t, err, "createUserWithSecondFactors") + username := userCreds.username + + // ...and assign the user a sane unix login, plus the specified role. + user, err := adminClient.GetUser(username, false /* withSecrets */) + require.NoError(t, err, "GetUser(%q)", username) + + user.SetLogins(append(user.GetLogins(), normalLogin)) + user.AddRole(role.GetName()) + err = adminClient.UpdateUser(ctx, user.(*types.UserV2)) + require.NoError(t, err, "UpdateUser(%q)", username) + + userClient, err := testServer.NewClient(TestUser(username)) + require.NoError(t, err, "NewClient(%q)", username) + + return userClient + } + + prodAccessClient := createUser(prodRole) + joinMFAClient := createUser(joinMFARole) + joinNoMFAClient := createUser(joinNoMFARole) + + createReqForNode := func(node, login string) *proto.IsMFARequiredRequest { + return &proto.IsMFARequiredRequest{ + Target: &proto.IsMFARequiredRequest_Node{ + Node: &proto.NodeLogin{ + Node: node, + Login: login, + }, + }, + } + } + + tests := []struct { + name string + userClient *Client + req *proto.IsMFARequiredRequest + wantMFARequired bool + }{ + { + name: "MFA not required to start session, no challenges issued", + userClient: prodAccessClient, + req: createReqForNode(devNode, normalLogin), + wantMFARequired: false, + }, + { + name: "MFA required to start session", + userClient: prodAccessClient, + req: createReqForNode(prodNode, normalLogin), + wantMFARequired: true, + }, + { + name: "MFA required to join session on prod node (prod role)", + userClient: prodAccessClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: true, + }, + { + name: "MFA required to join session on dev node (prod role)", + userClient: prodAccessClient, + req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: true, + }, + { + name: "MFA required to join session on prod node (join MFA role)", + userClient: joinMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: true, + }, + { + name: "MFA required to join session dev node (join MFA role)", + userClient: joinMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: true, + }, + { + name: "MFA not required to join session, no challenges issued on dev node (join no MFA role)", + userClient: joinNoMFAClient, + req: createReqForNode(devNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: false, + }, + { + name: "MFA not required to join session, no challenges issued on prod node (join no MFA role)", + userClient: joinNoMFAClient, + req: createReqForNode(prodNode, teleport.SSHSessionJoinPrincipal), + wantMFARequired: false, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + resp, err := test.userClient.IsMFARequired(ctx, test.req) + require.NoError(t, err, "IsMFARequired") + + assert.Equal(t, test.wantMFARequired, resp.Required, "resp.Required mismatch") + }) + } +} + func TestCreateRegisterChallenge(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 3074c5749a309..2ec5d73938218 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -644,7 +644,7 @@ func (a *ahLoginChecker) canLoginWithRBAC(cert *ssh.Certificate, ca types.CertAu auth.RoleSupportsModeratedSessions(accessChecker.Roles()) { // allow joining if cluster wide MFA is not required - if state.MFARequired != services.MFARequiredAlways { + if state.MFARequired == services.MFARequiredNever { return nil } diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 40083d8bfe0b3..5e30294a4d2ed 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -26,7 +26,9 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/testauthority" @@ -34,13 +36,18 @@ import ( "github.com/gravitational/teleport/lib/services" ) -type mockCAGetter struct { +type mockCAandAuthPrefGetter struct { AccessPoint - cas map[types.CertAuthType]types.CertAuthority + authPref types.AuthPreference + cas map[types.CertAuthType]types.CertAuthority } -func (m mockCAGetter) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) { +func (m mockCAandAuthPrefGetter) GetAuthPreference(s_12345678 context.Context) (types.AuthPreference, error) { + return m.authPref, nil +} + +func (m mockCAandAuthPrefGetter) GetCertAuthorities(_ context.Context, caType types.CertAuthType, _ bool) ([]types.CertAuthority, error) { ca, ok := m.cas[caType] if !ok { return nil, trace.NotFound("CA not found") @@ -158,8 +165,9 @@ func TestRBAC(t *testing.T) { err = server.auth.SetClusterName(clusterName) require.NoError(t, err) - accessPoint := mockCAGetter{ + accessPoint := mockCAandAuthPrefGetter{ AccessPoint: server.auth, + authPref: types.DefaultAuthPreference(), cas: map[types.CertAuthType]types.CertAuthority{ types.UserCA: userCA, }, @@ -223,3 +231,178 @@ func TestRBAC(t *testing.T) { }) } } + +// TestRBACJoinMFA tests that MFA is enforced correctly when joining +// sessions depending on the cluster auth preference and roles presented. +func TestRBACJoinMFA(t *testing.T) { + t.Parallel() + + const clusterName = "localhost" + const username = "testuser" + + // create User CA + userTA := testauthority.New() + userCAPriv, err := userTA.GeneratePrivateKey() + require.NoError(t, err) + userCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.UserCA, + ClusterName: clusterName, + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{ + { + PublicKey: userCAPriv.MarshalSSHPublicKey(), + PrivateKey: userCAPriv.PrivateKeyPEM(), + PrivateKeyType: types.PrivateKeyType_RAW, + }, + }, + }, + }) + require.NoError(t, err) + + // create mock SSH server and add a cluster name + server := newMockServer(t) + cn, err := types.NewClusterName(types.ClusterNameSpecV2{ + ClusterName: clusterName, + ClusterID: "cluster_id", + }) + require.NoError(t, err) + err = server.auth.SetClusterName(cn) + require.NoError(t, err) + ctx := context.Background() + + accessPoint := &mockCAandAuthPrefGetter{ + AccessPoint: server.auth, + cas: map[types.CertAuthType]types.CertAuthority{ + types.UserCA: userCA, + }, + } + + // create auth handler and dummy node + config := &AuthHandlerConfig{ + Server: server, + Emitter: &eventstest.MockRecorderEmitter{}, + AccessPoint: accessPoint, + } + ah, err := NewAuthHandlers(config) + require.NoError(t, err) + + node, err := types.NewServer("testie_node", types.KindNode, types.ServerSpecV2{ + Addr: "1.2.3.4:22", + Hostname: "testie", + Version: types.V2, + }) + require.NoError(t, err) + + mfaAuthPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + SecondFactor: constants.SecondFactorOTP, + RequireMFAType: types.RequireMFAType_HARDWARE_KEY_TOUCH, + }) + require.NoError(t, err) + + noMFAAuthPref, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + SecondFactor: constants.SecondFactorOTP, + RequireMFAType: types.RequireMFAType_OFF, + }) + require.NoError(t, err) + + // create roles + joinMFARole, err := types.NewRole("joinMFA", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) + require.NoError(t, err) + err = server.auth.CreateRole(ctx, joinMFARole) + require.NoError(t, err) + + joinRole, err := types.NewRole("join", types.RoleSpecV6{ + Allow: types.RoleConditions{ + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) + require.NoError(t, err) + err = server.auth.CreateRole(ctx, joinRole) + require.NoError(t, err) + + tests := []struct { + name string + authPref types.AuthPreference + role string + testError func(t *testing.T, err error) + }{ + { + name: "MFA cluster auth, MFA role", + authPref: mfaAuthPref, + role: joinMFARole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "MFA cluster auth, no MFA role", + authPref: mfaAuthPref, + role: joinRole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "no MFA cluster auth, MFA role", + authPref: noMFAAuthPref, + role: joinMFARole.GetName(), + testError: func(t *testing.T, err error) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "no MFA cluster auth, no MFA role", + authPref: noMFAAuthPref, + role: joinRole.GetName(), + testError: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + accessPoint.authPref = tt.authPref + + // create SSH certificate + caSigner, err := ssh.NewSignerFromKey(userCAPriv) + require.NoError(t, err) + keygen := testauthority.New() + privateKey, err := native.GeneratePrivateKey() + require.NoError(t, err) + + c, err := keygen.GenerateUserCert(services.UserCertParams{ + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), + Username: username, + AllowedLogins: []string{username}, + Traits: wrappers.Traits{ + teleport.TraitInternalPrefix: []string{""}, + }, + Roles: []string{tt.role}, + CertificateFormat: constants.CertificateFormatStandard, + }) + require.NoError(t, err) + + cert, err := sshutils.ParseCertificate(c) + require.NoError(t, err) + + err = ah.canLoginWithRBAC(cert, userCA, clusterName, node, username, teleport.SSHSessionJoinPrincipal) + tt.testError(t, err) + }) + } +}