Skip to content

Commit

Permalink
Add ForEachResource and CollectResources utils funtions
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky committed Nov 10, 2024
1 parent 82da7f4 commit 4866a7b
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 0 deletions.
110 changes: 110 additions & 0 deletions lib/utils/iterators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package utils

import (
"context"
"errors"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/utils/pagination"
)

// ErrStopIteration is value that signals to stop iteration from the caller injected function.
var ErrStopIteration = errors.New("stop iteration")

// ForEachOptions specifies options for ForEachResource.
type ForEachOptions struct {
// PageSize is the number of items to fetch per page.
PageSize int
}

// ForEachOptionFunc is a function that sets an option on ForEachOptions.
type ForEachOptionFunc func(*ForEachOptions)

// WithPageSize sets the page size option.
func WithPageSize(pageSize int) ForEachOptionFunc {
return func(opts *ForEachOptions) {
opts.PageSize = pageSize
}
}

// TokenLister is a function that lists resources with a page token.
type TokenLister[T any] func(context.Context, int, string) ([]T, string, error)

// ForEachResource iterates over resources.
// Example:
//
// count := 0
// err := ForEachResource(ctx, svc.ListAccessLists, func(acl accesslist.AccessList) error {
// count++
// return nil
// })
// if err != nil {
// return trace.Wrap(err)
// }
// fmt.Printf("Total access lists: %v", count)
func ForEachResource[T any](ctx context.Context, listFn TokenLister[T], fn func(T) error, opts ...ForEachOptionFunc) error {
var options ForEachOptions
for _, opt := range opts {
opt(&options)
}
pageToken := ""
for {
items, nextToken, err := listFn(ctx, options.PageSize, pageToken)
if err != nil {
return trace.Wrap(err)
}
for _, item := range items {
if err := fn(item); err != nil {
if errors.Is(err, ErrStopIteration) {
return nil
}
return trace.Wrap(err)
}
}
if nextToken == "" {
return nil
}
pageToken = nextToken
}
}

// ListerWithPageToken is a function that lists resources with a page token.
type ListerWithPageToken[T any] func(context.Context, int, *pagination.PageRequestToken) ([]T, pagination.NextPageToken, error)

// AdaptPageTokenLister adapts a listener with page token to a lister.
func AdaptPageTokenLister[T any](listFn ListerWithPageToken[T]) TokenLister[T] {
return func(ctx context.Context, pageSize int, pageToken string) ([]T, string, error) {
var pageRequestToken pagination.PageRequestToken
pageRequestToken.Update(pagination.NextPageToken(pageToken))
resources, nextPageToken, err := listFn(ctx, pageSize, &pageRequestToken)
if err != nil {
return nil, "", err
}
return resources, string(nextPageToken), nil
}
}

// CollectResources collects resources.
// Example usage:
//
// count := 0
// err := ForEachResource(ctx, svc.ListAccessLists, func(acl accesslist.AccessList) error {
// count++
// return nil
// })
// if err != nil {
// return trace.Wrap(err)
// }
// fmt.Printf("Total access lists: %v", count)
func CollectResources[T any](ctx context.Context, listFn TokenLister[T], opts ...ForEachOptionFunc) ([]T, error) {
var results []T
err := ForEachResource(ctx, listFn, func(item T) error {
results = append(results, item)
return nil
}, opts...)
if err != nil {
return nil, err
}
return results, nil
}
153 changes: 153 additions & 0 deletions lib/utils/iterators_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package utils

import (
"context"
"strconv"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils/pagination"
)

func TestCollect(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()
results, err := CollectResources(ctx, mock.List)
require.NoError(t, err)
require.Equal(t, []int{1, 2, 3, 4, 5}, results)
}

func TestCollect_WithAdaptedLister(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()
results, err := CollectResources(ctx, AdaptPageTokenLister(mock.ListWithPagination))
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3, 4, 5}, results)
}

func TestForEachResource(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, mock.List, func(item int) error {
count++
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Len(t, mock.items, 5)
}

func TestForEachResource_StopIteration(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, mock.List, func(item int) error {
count++
if item == 3 {
return ErrStopIteration
}
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Equal(t, 3, count)
}

func TestForEachResource_AdaptPageTokenLister(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, AdaptPageTokenLister(mock.ListWithPagination), func(item int) error {
count++
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Equal(t, 5, count)
}

func TestMockBackendLister_List(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()

expectedResults := [][]int{
{1}, {2}, {3}, {4}, {5},
}
pageToken := ""

for _, expected := range expectedResults {
results, nextToken, err := mock.List(ctx, 1, pageToken)
require.NoError(t, err)
require.Equal(t, expected, results)
pageToken = nextToken
}

require.Equal(t, "", pageToken)

pageToken = ""
results, nextToken, err := mock.List(ctx, 2, pageToken)
require.NoError(t, err)
require.Equal(t, []int{1, 2}, results)
require.NotEmpty(t, nextToken)

results, nextToken, err = mock.List(ctx, 2, nextToken)
require.NoError(t, err)
require.Equal(t, []int{3, 4}, results)
require.NotEmpty(t, nextToken)

results, nextToken, err = mock.List(ctx, 2, nextToken)
require.NoError(t, err)
require.Equal(t, []int{5}, results)
require.Equal(t, "", nextToken)
}

type mockBackendLister struct {
items []int
}

func (s *mockBackendLister) List(ctx context.Context, pageSize int, pageToken string) ([]int, string, error) {
if pageToken == "" {
pageToken = "0"
}
if pageSize <= 0 {
pageSize = 2
}
startIndex, err := strconv.Atoi(pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
endIndex := startIndex + pageSize
if endIndex > len(s.items) {
endIndex = len(s.items)
}
items := s.items[startIndex:endIndex]
if endIndex < len(s.items) {
return items, strconv.Itoa(endIndex), nil
}
return items, "", nil
}

func (s *mockBackendLister) ListWithPagination(ctx context.Context, pageSize int, page *pagination.PageRequestToken) ([]int, pagination.NextPageToken, error) {
if pageSize == 0 {
pageSize = 1
}
pageToken, err := page.Consume()
if err != nil {
return nil, "", trace.Wrap(err)
}
resp, nextPage, err := s.List(ctx, pageSize, pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
return resp, pagination.NextPageToken(nextPage), nil
}

0 comments on commit 4866a7b

Please sign in to comment.