From dbfdf1ecaf58b5c11a95582aed5de5af6be6e301 Mon Sep 17 00:00:00 2001 From: Pawel Kopiczko Date: Wed, 20 Nov 2024 12:10:29 +0000 Subject: [PATCH] Fix broken auth Access Request creation tests This got exposed while working on Access Request reason required PR: https://github.com/gravitational/teleport/pull/49124 --- lib/auth/auth_with_roles_test.go | 54 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index ea0068c141e66..0cade14983748 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -8052,7 +8052,7 @@ func TestCreateAccessRequest(t *testing.T) { clock := srv.Clock() alice, bob, admin := createSessionTestUsers(t, srv.Auth()) - searchRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + searchRole, err := types.NewRole("searchRole", types.RoleSpecV6{ Allow: types.RoleConditions{ Request: &types.AccessRequestConditions{ Roles: []string{"requestRole"}, @@ -8062,11 +8062,32 @@ func TestCreateAccessRequest(t *testing.T) { }) require.NoError(t, err) - requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{}) + requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + Allow: types.RoleConditions{ + GroupLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) require.NoError(t, err) - srv.Auth().CreateRole(ctx, searchRole) - srv.Auth().CreateRole(ctx, requestRole) + nodeAllowedByRequestRole, err := types.NewServerWithLabels( + "test-node", + types.KindNode, + types.ServerSpecV2{}, + map[string]string{"any-key": "any-val"}, + ) + require.NoError(t, err) + + _, err = srv.Auth().UpsertNode(ctx, nodeAllowedByRequestRole) + require.NoError(t, err) + _, err = srv.Auth().CreateRole(ctx, requestRole) + require.NoError(t, err) + _, err = srv.Auth().CreateRole(ctx, searchRole) + require.NoError(t, err) user, err := srv.Auth().GetUser(ctx, alice, true) require.NoError(t, err) @@ -8109,12 +8130,12 @@ func TestCreateAccessRequest(t *testing.T) { user: alice, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), }), errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), }), }, { @@ -8122,12 +8143,15 @@ func TestCreateAccessRequest(t *testing.T) { user: admin, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), }), errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[0]), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[1]), + mustResourceID(srv.ClusterName(), types.KindApp, userGroup1.GetApplications()[2]), }), }, { @@ -8135,7 +8159,7 @@ func TestCreateAccessRequest(t *testing.T) { user: bob, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), }), errAssertionFunc: require.Error, }, @@ -8144,7 +8168,7 @@ func TestCreateAccessRequest(t *testing.T) { user: alice, accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), mustResourceID(srv.ClusterName(), types.KindApp, "app1"), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), @@ -8153,7 +8177,7 @@ func TestCreateAccessRequest(t *testing.T) { errAssertionFunc: require.NoError, expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), []string{requestRole.GetName()}, []types.ResourceID{ - mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), nodeAllowedByRequestRole.GetKind(), nodeAllowedByRequestRole.GetName()), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), mustResourceID(srv.ClusterName(), types.KindApp, "app1"), mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), @@ -8388,9 +8412,13 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { require.NoError(t, err) paymentsServer.SetStaticLabels(map[string]string{"service": "payments"}) - idServer, err := types.NewServer("server-identity", types.KindNode, types.ServerSpecV2{}) + idServer, err := types.NewServerWithLabels( + "server-identity", + types.KindNode, + types.ServerSpecV2{}, + map[string]string{"service": "identity"}, + ) require.NoError(t, err) - idServer.SetStaticLabels(map[string]string{"service": "payments"}) ctx := context.Background() srv := newTestTLSServer(t)