diff --git a/cursor.go b/cursor.go index 059c264..913c88c 100644 --- a/cursor.go +++ b/cursor.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "time" ) // Cursor is the interface of a row cursor. @@ -23,6 +24,13 @@ func (c cursor) Next() bool { return c.rows.Next() } +var timeType = reflect.TypeOf(time.Time{}) + +func isScanner(val reflect.Value) bool { + _, ok := val.Addr().Interface().(sql.Scanner) + return ok +} + func preparePointers(val reflect.Value, scans *[]interface{}) error { kind := val.Kind() switch kind { @@ -35,6 +43,10 @@ func preparePointers(val reflect.Value, scans *[]interface{}) error { *scans = append(*scans, addr.Interface()) } case reflect.Struct: + if canScan := val.Type() == timeType || isScanner(val); canScan { + *scans = append(*scans, val.Addr().Interface()) + return nil + } for j := 0; j < val.NumField(); j++ { field := val.Field(j) if field.Kind() == reflect.Interface { @@ -85,6 +97,19 @@ func parseBool(s []byte) (bool, error) { } func (c cursor) Scan(dest ...interface{}) error { + columns, err := c.rows.Columns() + if err != nil { + return err + } + values := make([]interface{}, len(columns)) + pointers := make([]interface{}, len(columns)) + for i := range columns { + pointers[i] = &values[i] + } + if err := c.rows.Scan(pointers...); err != nil { + return err + } + if len(dest) == 0 { // dry run return nil @@ -119,8 +144,7 @@ func (c cursor) Scan(dest ...interface{}) error { } } - err := c.rows.Scan(scans...) - if err != nil { + if err := c.rows.Scan(scans...); err != nil { return err } diff --git a/cursor_test.go b/cursor_test.go index c5a0149..f15b653 100644 --- a/cursor_test.go +++ b/cursor_test.go @@ -1,10 +1,12 @@ package sqlingo import ( + "database/sql" "database/sql/driver" "io" "strconv" "testing" + "time" ) type mockDriver struct{} @@ -30,7 +32,7 @@ type mockRows struct { } func (m mockRows) Columns() []string { - return []string{"a", "b", "c", "d", "e", "f", "g"}[:m.columnCount] + return []string{"a", "b", "c", "d", "e", "f", "g", "h"}[:m.columnCount] } func (m mockRows) Close() error { @@ -58,6 +60,8 @@ func (m *mockRows) Next(dest []driver.Value) error { dest[i] = dest[0] case 6: dest[i] = nil + case 7: + dest[i] = time.Now() } } return nil @@ -97,14 +101,15 @@ func TestCursor(t *testing.T) { } var f ****int // deep pointer var g *int // always null + var h string for i := 1; i <= 10; i++ { if !cursor.Next() { t.Error() } g = &i - if err := cursor.Scan(&a, &b, &cde, &f, &g); err != nil { - t.Errorf("%v", err) + if err := cursor.Scan(&a, &b, &cde, &f, &g, &h); err != nil { + t.Fatalf("%v", err) } if a != i || b != strconv.Itoa(i) || @@ -123,7 +128,8 @@ func TestCursor(t *testing.T) { var b ****bool var p *string var bs []byte - if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p); err != nil { + var u string + if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &u); err != nil { t.Error(err) } if ****b != (i%2 == 1) || @@ -141,6 +147,29 @@ func TestCursor(t *testing.T) { } +func TestScanTime(t *testing.T) { + db := newMockDatabase() + cursor, _ := db.Query("dummy sql") + defer cursor.Close() + + var row struct { + A sql.NullString + B []byte + C sql.NullInt32 + D sql.NullString + E sql.NullString + F sql.NullString + G sql.NullString + H time.Time + } + if !cursor.Next() { + t.Error() + } + if err := cursor.Scan(&row); err != nil { + t.Error(err) + } +} + func TestCursorMap(t *testing.T) { db := newMockDatabase() cursor, _ := db.Query("dummy sql") diff --git a/database_test.go b/database_test.go index 37bf183..eb67fd9 100644 --- a/database_test.go +++ b/database_test.go @@ -32,7 +32,7 @@ func (m *mockConn) Begin() (driver.Tx, error) { } var sharedMockConn = &mockConn{ - columnCount: 7, + columnCount: 8, rowCount: 10, }