Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augment unified resource requests with login information #38559

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 177 additions & 38 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ limitations under the License.
package client

import (
"cmp"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"slices"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -1136,12 +1138,96 @@ func (c *Client) CreateResetPasswordToken(ctx context.Context, req *proto.Create

// GetAccessRequests retrieves a list of all access requests matching the provided filter.
func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) {
requests, err := c.ListAllAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Filter: &filter,
})
if err != nil {
return nil, trace.Wrap(err)
}

ireqs := make([]types.AccessRequest, 0, len(requests))
for _, r := range requests {
ireqs = append(ireqs, r)
}

return ireqs, nil
}

// ListAccessRequests is an access request getter with pagination and sorting options.
func (c *Client) ListAccessRequests(ctx context.Context, req *proto.ListAccessRequestsRequest) (*proto.ListAccessRequestsResponse, error) {
rsp, err := c.grpc.ListAccessRequests(ctx, req)
return rsp, trace.Wrap(err)
}

// ListAllAccessRequests aggregates all access requests via the ListAccessRequests api. This is equivalent to calling GetAccessRequests
// except that it supports custom sort order/indexes. Calling this method rather than ListAccessRequests also provides the advantage
// that it can fallback to calling the old GetAccessRequests grpc method if it encounters and outdated control plane. For that reason,
// implementations that don't actually *need* pagination are better served by calling this method.
func (c *Client) ListAllAccessRequests(ctx context.Context, req *proto.ListAccessRequestsRequest) ([]*types.AccessRequestV3, error) {
var requests []*types.AccessRequestV3
for {
rsp, err := c.ListAccessRequests(ctx, req)
if err != nil {
if trace.IsNotImplemented(err) {
return c.listAllAccessRequestsCompat(ctx, req)
}

return nil, trace.Wrap(err)
}

requests = append(requests, rsp.AccessRequests...)

req.StartKey = rsp.NextKey
if req.StartKey == "" {
break
}
}

return requests, nil
}

// listAllAccessRequestsCompat is a helper that simulates ListAllAccessRequests behavior via the old GetAccessRequests method.
func (c *Client) listAllAccessRequestsCompat(ctx context.Context, req *proto.ListAccessRequestsRequest) ([]*types.AccessRequestV3, error) {
var filter types.AccessRequestFilter
if req.Filter != nil {
filter = *req.Filter
}
requests, err := c.getAccessRequests(ctx, filter)
if err != nil {
return nil, trace.Wrap(err)
}

switch req.Sort {
case proto.AccessRequestSort_DEFAULT:
// no custom sort order needed
case proto.AccessRequestSort_CREATED:
slices.SortFunc(requests, func(a, b *types.AccessRequestV3) int {
return a.GetCreationTime().Compare(b.GetCreationTime())
})
case proto.AccessRequestSort_STATE:
slices.SortFunc(requests, func(a, b *types.AccessRequestV3) int {
return cmp.Compare(a.GetState().String(), b.GetState().String())
})
default:
return nil, trace.BadParameter("list access request compat fallback does not support sort order %q", req.Sort)
}

if req.Descending {
slices.Reverse(requests)
}

return requests, nil
}

// getAccessRequests calls the old GetAccessRequests method. used by back-compat logic when interacting with
// an outdated control-plane that doesn't support ListAccessRequests.
func (c *Client) getAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]*types.AccessRequestV3, error) {
stream, err := c.grpc.GetAccessRequestsV2(ctx, &filter)
if err != nil {
return nil, trace.Wrap(err)
}

var reqs []types.AccessRequest
var reqs []*types.AccessRequestV3
for {
req, err := stream.Recv()
if errors.Is(err, io.EOF) {
Expand Down Expand Up @@ -3442,52 +3528,41 @@ type ResourcePage[T types.ResourceWithLabels] struct {
NextKey string
}

// getResourceFromProtoPage extracts the resource from the PaginatedResource returned
// from the rpc ListUnifiedResources
func getResourceFromProtoPage(resource *proto.PaginatedResource) (types.ResourceWithLabels, error) {
var out types.ResourceWithLabels
// convertEnrichedResource extracts the resource and any enriched information from the
// PaginatedResource returned from the rpc ListUnifiedResources.
func convertEnrichedResource(resource *proto.PaginatedResource) (*types.EnrichedResource, error) {
if r := resource.GetNode(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r, Logins: resource.Logins}, nil
} else if r := resource.GetDatabaseServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetDatabaseService(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetAppServerOrSAMLIdPServiceProvider(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetWindowsDesktop(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r, Logins: resource.Logins}, nil
} else if r := resource.GetWindowsDesktopService(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetKubeCluster(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetKubernetesServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetUserGroup(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetAppServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else {
return nil, trace.BadParameter("received unsupported resource %T", resource.Resource)
}
}

// ListUnifiedResourcePage is a helper for getting a single page of unified resources that match the provided request.
func ListUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient, req *proto.ListUnifiedResourcesRequest) (ResourcePage[types.ResourceWithLabels], error) {
var out ResourcePage[types.ResourceWithLabels]
// GetUnifiedResourcePage is a helper for getting a single page of unified resources that match the provided request.
func GetUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient, req *proto.ListUnifiedResourcesRequest) ([]*types.EnrichedResource, string, error) {
var out []*types.EnrichedResource

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit == 0 || req.Limit > int32(defaults.DefaultChunkSize) {
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

Expand All @@ -3499,24 +3574,88 @@ func ListUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, "", trace.Wrap(err, "resource is too large to retrieve")
}

continue
}

return out, trace.Wrap(err)
return nil, "", trace.Wrap(err)
}

for _, respResource := range resp.Resources {
resource, err := getResourceFromProtoPage(respResource)
resource, err := convertEnrichedResource(respResource)
if err != nil {
return out, trace.Wrap(err)
return nil, "", trace.Wrap(err)
}
out.Resources = append(out.Resources, resource)
out = append(out, resource)
}

return out, resp.NextKey, nil
}
}

// GetEnrichedResourcePage returns a page of resources matching the provided request
// that are enriched with additional metadata.
func GetEnrichedResourcePage(ctx context.Context, clt GetResourcesClient, req *proto.ListResourcesRequest) (ResourcePage[*types.EnrichedResource], error) {
var out ResourcePage[*types.EnrichedResource]

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

for {
resp, err := clt.GetResources(ctx, req)
if err != nil {
if trace.IsLimitExceeded(err) {
// Cut chunkSize in half if gRPC max message size is exceeded.
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(err, "resource is too large to retrieve")
}

continue
}

return out, trace.Wrap(err)
}

for _, respResource := range resp.Resources {
var resource types.ResourceWithLabels
switch req.ResourceType {
case types.KindDatabaseServer:
resource = respResource.GetDatabaseServer()
case types.KindDatabaseService:
resource = respResource.GetDatabaseService()
case types.KindAppServer:
resource = respResource.GetAppServer()
case types.KindNode:
resource = respResource.GetNode()
case types.KindWindowsDesktop:
resource = respResource.GetWindowsDesktop()
case types.KindWindowsDesktopService:
resource = respResource.GetWindowsDesktopService()
case types.KindKubernetesCluster:
resource = respResource.GetKubeCluster()
case types.KindKubeServer:
resource = respResource.GetKubernetesServer()
case types.KindUserGroup:
resource = respResource.GetUserGroup()
case types.KindAppOrSAMLIdPServiceProvider:
resource = respResource.GetAppServerOrSAMLIdPServiceProvider()
default:
out.Resources = nil
return out, trace.NotImplemented("resource type %s does not support pagination", req.ResourceType)
}

out.Resources = append(out.Resources, &types.EnrichedResource{ResourceWithLabels: resource, Logins: respResource.Logins})
}

out.NextKey = resp.NextKey
out.Total = int(resp.TotalCount)

return out, nil
}
Expand All @@ -3528,7 +3667,7 @@ func GetResourcePage[T types.ResourceWithLabels](ctx context.Context, clt GetRes

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit == 0 || req.Limit > int32(defaults.DefaultChunkSize) {
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

Expand All @@ -3540,7 +3679,7 @@ func GetResourcePage[T types.ResourceWithLabels](ctx context.Context, clt GetRes
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return out, trace.Wrap(err, "resource is too large to retrieve")
}

continue
Expand Down Expand Up @@ -3653,7 +3792,7 @@ func GetResourcesWithFilters(ctx context.Context, clt ListResourcesClient, req p
chunkSize = chunkSize / 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if chunkSize == 0 {
return nil, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, trace.Wrap(err, "resource is too large to retrieve")
}

continue
Expand Down Expand Up @@ -3704,7 +3843,7 @@ func GetKubernetesResourcesWithFilters(ctx context.Context, clt kubeproto.KubeSe
chunkSize = chunkSize / 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if chunkSize == 0 {
return nil, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, trace.Wrap(err, "resource is too large to retrieve")
}
continue
}
Expand Down
Loading