Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CollectExactlyOneRow function #1720

Merged
merged 2 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@ func (ident Identifier) Sanitize() string {
return strings.Join(parts, ".")
}

// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
var (
// ErrNoRows occurs when rows are expected but none are returned.
ErrNoRows = errors.New("no rows in result set")
// ErrTooManyRows occurs when more rows than expected are returned.
ErrTooManyRows = errors.New("too many rows in result set")
)

var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
Expand Down
35 changes: 34 additions & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type Rows interface {
// Callers should check rows.Err() after rows.Next() returns false to detect
// whether result-set reading ended prematurely due to an error. See
// Conn.Query for details.
//
//
// For simpler error handling, consider using the higher-level pgx v5
// CollectRows() and ForEachRow() helpers instead.
Next() bool
Expand Down Expand Up @@ -465,6 +465,39 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
return value, rows.Err()
}

// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()

var (
err error
value T
)

if !rows.Next() {
if err = rows.Err(); err != nil {
return value, err
}

return value, ErrNoRows
}

value, err = fn(rows)
if err != nil {
return value, err
}

if rows.Next() {
var zero T

return zero, ErrTooManyRows
}

return value, rows.Err()
}

// RowTo returns a T scanned from row.
func RowTo[T any](row CollectableRow) (T, error) {
var value T
Expand Down
39 changes: 39 additions & 0 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,45 @@ func TestCollectOneRowPrefersPostgreSQLErrorOverErrNoRows(t *testing.T) {
})
}

func TestCollectExactlyOneRow(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select 42`)
n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) {
var n int32
err := row.Scan(&n)
return n, err
})
assert.NoError(t, err)
assert.Equal(t, int32(42), n)
})
}

func TestCollectExactlyOneRowNotFound(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select 42 where false`)
n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) {
var n int32
err := row.Scan(&n)
return n, err
})
assert.ErrorIs(t, err, pgx.ErrNoRows)
assert.Equal(t, int32(0), n)
})
}

func TestCollectExactlyOneRowExtraRows(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`)
n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) {
var n int32
err := row.Scan(&n)
return n, err
})
assert.ErrorIs(t, err, pgx.ErrTooManyRows)
assert.Equal(t, int32(0), n)
})
}

func TestRowTo(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
Expand Down