Skip to content

Commit

Permalink
Implement Go 1.23 iterator function for iterating rows
Browse files Browse the repository at this point in the history
  • Loading branch information
hlubek committed Sep 19, 2024
1 parent 5a0678b commit b096a25
Show file tree
Hide file tree
Showing 11 changed files with 573 additions and 74 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: '1.20'
go-version: '1.21'
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
Expand All @@ -26,7 +26,7 @@ jobs:
test:
strategy:
matrix:
go-version: [ '1.20', '1.21' ]
go-version: [ '1.21', '1.22', '1.23' ]
platform: [ 'ubuntu-latest' ]
runs-on: ${{ matrix.platform }}
steps:
Expand All @@ -48,7 +48,7 @@ jobs:
if: success()
uses: actions/setup-go@v4
with:
go-version: '1.20'
go-version: '1.23'
- name: Calc coverage
run: |
go test -cover ./... -coverpkg=github.com/networkteam/construct/v2/... -coverprofile=coverage.txt
Expand Down
10 changes: 6 additions & 4 deletions constructsql/constructsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
)

// CollectRows collects all rows to the given target type from a ExecutiveQueryBuilder.Query result.
func CollectRows[T any](rows Rows, err error) ([]T, error) {
if err != nil {
return nil, err
func CollectRows[T any](rows Rows, queryErr error) (result []T, err error) {
if queryErr != nil {
return nil, queryErr
}

defer rows.Close()
defer func() {
err = errors.Join(err, rows.Close())
}()

slice := []T{}

Expand Down
55 changes: 55 additions & 0 deletions constructsql/constructsql_iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//go:build go1.23

package constructsql

import (
"errors"
"iter"
"log"
)

// IterateRows returns an iterator over the rows of a database query and scans them to T.
// It returns a single-use iterator.
func IterateRows[T any](rows Rows, err error) iter.Seq2[T, error] {
if err != nil {
return func(yield func(T, error) bool) {
var result T
yield(result, err)
}
}

return func(yield func(T, error) bool) {
var err error
var iteratorClosed bool

defer func() {
closeErr := rows.Close()
if iteratorClosed && closeErr != nil {
log.Printf("constructsql: Error closing rows after function for loop body returned false: %v", closeErr)
return
}

err = errors.Join(err, closeErr)
if err != nil {
var result T
yield(result, err)
}
}()

var value T
for rows.Next() {
value, err = scanRow[T](rows)
if err != nil {
return
}
if !yield(value, nil) {
iteratorClosed = true
return
}
}

if err = rows.Err(); err != nil {
return
}
}
}
258 changes: 258 additions & 0 deletions constructsql/constructsql_iterator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
//go:build go1.23

package constructsql_test

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/networkteam/construct/v2/constructsql"
)

func TestIterateRows(t *testing.T) {
t.Run("iterate rows without error", func(t *testing.T) {
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
},
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
count++
}
assert.Equal(t, 1, count)

assert.True(t, rows.closed)
})

t.Run("iterate empty rows without error", func(t *testing.T) {
rows := mockRows{
rows: []mockRow{},
}

count := 0
for range constructsql.IterateRows[user](&rows, nil) {
count++
}
assert.Equal(t, 0, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with early break", func(t *testing.T) {
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
{
scanJSON: []byte(`{"id": 2, "name": "test"}`),
},
},
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else {
break
}
count++
}
assert.Equal(t, 1, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with early break and close error", func(t *testing.T) {
closeErr := errors.New("some error on close")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
{
scanJSON: []byte(`{"id": 2, "name": "test"}`),
},
},
closeErr: closeErr,
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else {
break
}
count++
}
assert.Equal(t, 1, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with initial error", func(t *testing.T) {
rows := mockRows{}
initialErr := errors.New("some initial error")
count := 0
for _, err := range constructsql.IterateRows[user](&rows, initialErr) {
assert.ErrorIs(t, err, initialErr)
count++
}
assert.Equal(t, 1, count)

assert.False(t, rows.closed)
})

t.Run("iterate rows with scan error", func(t *testing.T) {
scanErr := errors.New("some scan error")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
{
scanErr: scanErr,
},
},
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else if count == 1 {
assert.ErrorIs(t, err, scanErr)
}
count++
}
assert.Equal(t, 2, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with iterate error", func(t *testing.T) {
iterateErr := errors.New("some iterate error")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
},
iterateErr: iterateErr,
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else if count == 1 {
assert.ErrorIs(t, err, iterateErr)
}
count++
}
assert.Equal(t, 2, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with close error", func(t *testing.T) {
closeErr := errors.New("some error on close")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
},
closeErr: closeErr,
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else if count == 1 {
assert.ErrorIs(t, err, closeErr)
}
count++
}
assert.Equal(t, 2, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with scan and close error", func(t *testing.T) {
scanErr := errors.New("some scan error")
closeErr := errors.New("some error on close")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
{
scanErr: scanErr,
},
},
closeErr: closeErr,
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else if count == 1 {
assert.ErrorIs(t, err, scanErr)
assert.ErrorIs(t, err, closeErr)
}
count++
}
assert.Equal(t, 2, count)

assert.True(t, rows.closed)
})

t.Run("iterate rows with iterate and close error", func(t *testing.T) {
iterateErr := errors.New("some iterate error")
closeErr := errors.New("some error on close")
rows := mockRows{
rows: []mockRow{
{
scanJSON: []byte(`{"id": 1, "name": "test"}`),
},
},
iterateErr: iterateErr,
closeErr: closeErr,
}

count := 0
for record, err := range constructsql.IterateRows[user](&rows, nil) {
if count == 0 {
require.NoError(t, err)
assert.Equal(t, user{ID: 1, Name: "test"}, record)
} else if count == 1 {
assert.ErrorIs(t, err, iterateErr)
assert.ErrorIs(t, err, closeErr)
}
count++
}
assert.Equal(t, 2, count)

assert.True(t, rows.closed)
})
}
Loading

0 comments on commit b096a25

Please sign in to comment.