Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] pagination & sorting for access requests (#38633) #39868

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ limitations under the License.
package client

import (
"cmp"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"slices"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -1187,12 +1189,96 @@ func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) {

// GetAccessRequests retrieves a list of all access requests matching the provided filter.
func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) {
requests, err := c.ListAllAccessRequests(ctx, &proto.ListAccessRequestsRequest{
Filter: &filter,
})
if err != nil {
return nil, trace.Wrap(err)
}

ireqs := make([]types.AccessRequest, 0, len(requests))
for _, r := range requests {
ireqs = append(ireqs, r)
}

return ireqs, nil
}

// ListAccessRequests is an access request getter with pagination and sorting options.
func (c *Client) ListAccessRequests(ctx context.Context, req *proto.ListAccessRequestsRequest) (*proto.ListAccessRequestsResponse, error) {
rsp, err := c.grpc.ListAccessRequests(ctx, req)
return rsp, trace.Wrap(err)
}

// ListAllAccessRequests aggregates all access requests via the ListAccessRequests api. This is equivalent to calling GetAccessRequests
// except that it supports custom sort order/indexes. Calling this method rather than ListAccessRequests also provides the advantage
// that it can fallback to calling the old GetAccessRequests grpc method if it encounters and outdated control plane. For that reason,
// implementations that don't actually *need* pagination are better served by calling this method.
func (c *Client) ListAllAccessRequests(ctx context.Context, req *proto.ListAccessRequestsRequest) ([]*types.AccessRequestV3, error) {
var requests []*types.AccessRequestV3
for {
rsp, err := c.ListAccessRequests(ctx, req)
if err != nil {
if trace.IsNotImplemented(err) {
return c.listAllAccessRequestsCompat(ctx, req)
}

return nil, trace.Wrap(err)
}

requests = append(requests, rsp.AccessRequests...)

req.StartKey = rsp.NextKey
if req.StartKey == "" {
break
}
}

return requests, nil
}

// listAllAccessRequestsCompat is a helper that simulates ListAllAccessRequests behavior via the old GetAccessRequests method.
func (c *Client) listAllAccessRequestsCompat(ctx context.Context, req *proto.ListAccessRequestsRequest) ([]*types.AccessRequestV3, error) {
var filter types.AccessRequestFilter
if req.Filter != nil {
filter = *req.Filter
}
requests, err := c.getAccessRequests(ctx, filter)
if err != nil {
return nil, trace.Wrap(err)
}

switch req.Sort {
case proto.AccessRequestSort_DEFAULT:
// no custom sort order needed
case proto.AccessRequestSort_CREATED:
slices.SortFunc(requests, func(a, b *types.AccessRequestV3) int {
return a.GetCreationTime().Compare(b.GetCreationTime())
})
case proto.AccessRequestSort_STATE:
slices.SortFunc(requests, func(a, b *types.AccessRequestV3) int {
return cmp.Compare(a.GetState().String(), b.GetState().String())
})
default:
return nil, trace.BadParameter("list access request compat fallback does not support sort order %q", req.Sort)
}

if req.Descending {
slices.Reverse(requests)
}

return requests, nil
}

// getAccessRequests calls the old GetAccessRequests method. used by back-compat logic when interacting with
// an outdated control-plane that doesn't support ListAccessRequests.
func (c *Client) getAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]*types.AccessRequestV3, error) {
stream, err := c.grpc.GetAccessRequestsV2(ctx, &filter)
if err != nil {
return nil, trace.Wrap(err)
}

var reqs []types.AccessRequest
var reqs []*types.AccessRequestV3
for {
req, err := stream.Recv()
if errors.Is(err, io.EOF) {
Expand Down
Loading
Loading