Skip to content

Commit

Permalink
Add ForEachResource and CollectResources utils funtions (#48729)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky authored Nov 11, 2024
1 parent 8cf2a73 commit d1fdd85
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 0 deletions.
121 changes: 121 additions & 0 deletions lib/utils/iterators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package utils

import (
"context"
"errors"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/utils/pagination"
)

// ErrStopIteration is value that signals to stop iteration from the caller injected function.
var ErrStopIteration = errors.New("stop iteration")

// ForEachOptions specifies options for ForEachResource.
type ForEachOptions struct {
// PageSize is the number of items to fetch per page.
PageSize int
}

// ForEachOptionFunc is a function that sets an option on ForEachOptions.
type ForEachOptionFunc func(*ForEachOptions)

// WithPageSize sets the page size option.
func WithPageSize(pageSize int) ForEachOptionFunc {
return func(opts *ForEachOptions) {
opts.PageSize = pageSize
}
}

// TokenLister is a function that lists resources with a page token.
type TokenLister[T any] func(context.Context, int, string) ([]T, string, error)

// ForEachResource iterates over resources.
// Example:
//
// count := 0
// err := ForEachResource(ctx, svc.ListAccessLists, func(acl accesslist.AccessList) error {
// count++
// return nil
// })
// if err != nil {
// return trace.Wrap(err)
// }
// fmt.Printf("Total access lists: %v", count)
func ForEachResource[T any](ctx context.Context, listFn TokenLister[T], fn func(T) error, opts ...ForEachOptionFunc) error {
var options ForEachOptions
for _, opt := range opts {
opt(&options)
}
pageToken := ""
for {
items, nextToken, err := listFn(ctx, options.PageSize, pageToken)
if err != nil {
return trace.Wrap(err)
}
for _, item := range items {
if err := fn(item); err != nil {
if errors.Is(err, ErrStopIteration) {
return nil
}
return trace.Wrap(err)
}
}
if nextToken == "" {
return nil
}
pageToken = nextToken
}
}

// ListerWithPageToken is a function that lists resources with a page token.
type ListerWithPageToken[T any] func(context.Context, int, *pagination.PageRequestToken) ([]T, pagination.NextPageToken, error)

// AdaptPageTokenLister adapts a listener with page token to a lister.
func AdaptPageTokenLister[T any](listFn ListerWithPageToken[T]) TokenLister[T] {
return func(ctx context.Context, pageSize int, pageToken string) ([]T, string, error) {
var pageRequestToken pagination.PageRequestToken
pageRequestToken.Update(pagination.NextPageToken(pageToken))
resources, nextPageToken, err := listFn(ctx, pageSize, &pageRequestToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
return resources, string(nextPageToken), nil
}
}

// CollectResources collects resources.
// Example usage:
//
// accessLists err := ForEachResource(ctx, svc.ListAccessLists)
// fmt.Printf("Total access lists: %v", len(accessLists))
func CollectResources[T any](ctx context.Context, listFn TokenLister[T], opts ...ForEachOptionFunc) ([]T, error) {
var results []T
err := ForEachResource(ctx, listFn, func(item T) error {
results = append(results, item)
return nil
}, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return results, nil
}
171 changes: 171 additions & 0 deletions lib/utils/iterators_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package utils

import (
"context"
"strconv"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils/pagination"
)

func TestCollect(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()
results, err := CollectResources(ctx, mock.List)
require.NoError(t, err)
require.Equal(t, []int{1, 2, 3, 4, 5}, results)
}

func TestCollect_WithAdaptedLister(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()
results, err := CollectResources(ctx, AdaptPageTokenLister(mock.ListWithPagination))
require.NoError(t, err)
assert.Equal(t, []int{1, 2, 3, 4, 5}, results)
}

func TestForEachResource(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, mock.List, func(item int) error {
count++
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Len(t, mock.items, 5)
}

func TestForEachResource_StopIteration(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, mock.List, func(item int) error {
count++
if item == 3 {
return ErrStopIteration
}
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Equal(t, 3, count)
}

func TestForEachResource_AdaptPageTokenLister(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}

ctx := context.Background()
var count int

err := ForEachResource(ctx, AdaptPageTokenLister(mock.ListWithPagination), func(item int) error {
count++
return nil
}, WithPageSize(2))

require.NoError(t, err)
require.Equal(t, 5, count)
}

func TestMockBackendLister_List(t *testing.T) {
mock := &mockBackendLister{items: []int{1, 2, 3, 4, 5}}
ctx := context.Background()

expectedResults := [][]int{
{1}, {2}, {3}, {4}, {5},
}
pageToken := ""

for _, expected := range expectedResults {
results, nextToken, err := mock.List(ctx, 1, pageToken)
require.NoError(t, err)
require.Equal(t, expected, results)
pageToken = nextToken
}

require.Equal(t, "", pageToken)

pageToken = ""
results, nextToken, err := mock.List(ctx, 2, pageToken)
require.NoError(t, err)
require.Equal(t, []int{1, 2}, results)
require.NotEmpty(t, nextToken)

results, nextToken, err = mock.List(ctx, 2, nextToken)
require.NoError(t, err)
require.Equal(t, []int{3, 4}, results)
require.NotEmpty(t, nextToken)

results, nextToken, err = mock.List(ctx, 2, nextToken)
require.NoError(t, err)
require.Equal(t, []int{5}, results)
require.Equal(t, "", nextToken)
}

type mockBackendLister struct {
items []int
}

func (s *mockBackendLister) List(ctx context.Context, pageSize int, pageToken string) ([]int, string, error) {
if pageToken == "" {
pageToken = "0"
}
if pageSize <= 0 {
pageSize = 2
}
startIndex, err := strconv.Atoi(pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
endIndex := startIndex + pageSize
if endIndex > len(s.items) {
endIndex = len(s.items)
}
items := s.items[startIndex:endIndex]
if endIndex < len(s.items) {
return items, strconv.Itoa(endIndex), nil
}
return items, "", nil
}

func (s *mockBackendLister) ListWithPagination(ctx context.Context, pageSize int, page *pagination.PageRequestToken) ([]int, pagination.NextPageToken, error) {
if pageSize == 0 {
pageSize = 1
}
pageToken, err := page.Consume()
if err != nil {
return nil, "", trace.Wrap(err)
}
resp, nextPage, err := s.List(ctx, pageSize, pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
return resp, pagination.NextPageToken(nextPage), nil
}

0 comments on commit d1fdd85

Please sign in to comment.