Skip to content

Commit

Permalink
add Scan support of sql.Scanner and time.Time
Browse files Browse the repository at this point in the history
  • Loading branch information
lqs committed Dec 27, 2023
1 parent b092004 commit d830480
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
28 changes: 26 additions & 2 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strconv"
"time"
)

// Cursor is the interface of a row cursor.
Expand All @@ -23,6 +24,13 @@ func (c cursor) Next() bool {
return c.rows.Next()
}

var timeType = reflect.TypeOf(time.Time{})

func isScanner(val reflect.Value) bool {
_, ok := val.Addr().Interface().(sql.Scanner)
return ok
}

func preparePointers(val reflect.Value, scans *[]interface{}) error {
kind := val.Kind()
switch kind {
Expand All @@ -35,6 +43,10 @@ func preparePointers(val reflect.Value, scans *[]interface{}) error {
*scans = append(*scans, addr.Interface())
}
case reflect.Struct:
if canScan := val.Type() == timeType || isScanner(val); canScan {
*scans = append(*scans, val.Addr().Interface())
return nil
}
for j := 0; j < val.NumField(); j++ {
field := val.Field(j)
if field.Kind() == reflect.Interface {
Expand Down Expand Up @@ -85,6 +97,19 @@ func parseBool(s []byte) (bool, error) {
}

func (c cursor) Scan(dest ...interface{}) error {
columns, err := c.rows.Columns()
if err != nil {
return err
}
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
for i := range columns {
pointers[i] = &values[i]
}
if err := c.rows.Scan(pointers...); err != nil {
return err
}

if len(dest) == 0 {
// dry run
return nil
Expand Down Expand Up @@ -119,8 +144,7 @@ func (c cursor) Scan(dest ...interface{}) error {
}
}

err := c.rows.Scan(scans...)
if err != nil {
if err := c.rows.Scan(scans...); err != nil {
return err
}

Expand Down
37 changes: 33 additions & 4 deletions cursor_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package sqlingo

import (
"database/sql"
"database/sql/driver"
"io"
"strconv"
"testing"
"time"
)

type mockDriver struct{}
Expand All @@ -30,7 +32,7 @@ type mockRows struct {
}

func (m mockRows) Columns() []string {
return []string{"a", "b", "c", "d", "e", "f", "g"}[:m.columnCount]
return []string{"a", "b", "c", "d", "e", "f", "g", "h"}[:m.columnCount]
}

func (m mockRows) Close() error {
Expand Down Expand Up @@ -58,6 +60,8 @@ func (m *mockRows) Next(dest []driver.Value) error {
dest[i] = dest[0]
case 6:
dest[i] = nil
case 7:
dest[i] = time.Now()
}
}
return nil
Expand Down Expand Up @@ -97,14 +101,15 @@ func TestCursor(t *testing.T) {
}
var f ****int // deep pointer
var g *int // always null
var h string

for i := 1; i <= 10; i++ {
if !cursor.Next() {
t.Error()
}
g = &i
if err := cursor.Scan(&a, &b, &cde, &f, &g); err != nil {
t.Errorf("%v", err)
if err := cursor.Scan(&a, &b, &cde, &f, &g, &h); err != nil {
t.Fatalf("%v", err)
}
if a != i ||
b != strconv.Itoa(i) ||
Expand All @@ -123,7 +128,8 @@ func TestCursor(t *testing.T) {
var b ****bool
var p *string
var bs []byte
if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p); err != nil {
var u string
if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &u); err != nil {
t.Error(err)
}
if ****b != (i%2 == 1) ||
Expand All @@ -141,6 +147,29 @@ func TestCursor(t *testing.T) {

}

func TestScanTime(t *testing.T) {
db := newMockDatabase()
cursor, _ := db.Query("dummy sql")
defer cursor.Close()

var row struct {
A sql.NullString
B []byte
C sql.NullInt32
D sql.NullString
E sql.NullString
F sql.NullString
G sql.NullString
H time.Time
}
if !cursor.Next() {
t.Error()
}
if err := cursor.Scan(&row); err != nil {
t.Error(err)
}
}

func TestCursorMap(t *testing.T) {
db := newMockDatabase()
cursor, _ := db.Query("dummy sql")
Expand Down
2 changes: 1 addition & 1 deletion database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mockConn) Begin() (driver.Tx, error) {
}

var sharedMockConn = &mockConn{
columnCount: 7,
columnCount: 8,
rowCount: 10,
}

Expand Down

0 comments on commit d830480

Please sign in to comment.