Skip to content

Commit

Permalink
Merge pull request #18 from VarusHsu/time_optimize
Browse files Browse the repository at this point in the history
Time optimize
  • Loading branch information
lqs authored Dec 28, 2023
2 parents d830480 + c4d4a07 commit 1327f63
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 17 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* Context support
* Transaction support
* Interceptor support
* Golang time.Time is supported now, but you can still use the string type by adding `-timeAsString` when generating the model

## Database Support Status
| Database | Status |
Expand Down Expand Up @@ -104,3 +105,5 @@ func main() {
Execute()
}
```
## TODO
* Millisecond time is not currently supported
57 changes: 53 additions & 4 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ func preparePointers(val reflect.Value, scans *[]interface{}) error {
reflect.Float32, reflect.Float64,
reflect.String:
*scans = append(*scans, val.Addr().Interface())
case reflect.Struct:
if toType == reflect.TypeOf(time.Time{}) {
*scans = append(*scans, val.Addr().Interface())
} else {
to := reflect.New(toType).Elem()
val.Set(to.Addr())
err := preparePointers(to, scans)
if err != nil {
return nil
}
}
default:
to := reflect.New(toType).Elem()
val.Set(to.Addr())
Expand Down Expand Up @@ -131,16 +142,27 @@ func (c cursor) Scan(dest ...interface{}) error {

pbs := make(map[int]*bool)
ppbs := make(map[int]**bool)
pts := make(map[int]*time.Time)
ppts := make(map[int]**time.Time)

for i, scan := range scans {
if pb, ok := scan.(*bool); ok {
switch scan.(type) {
case *bool:
var s []uint8
scans[i] = &s
pbs[i] = pb
} else if ppb, ok := scan.(**bool); ok {
pbs[i] = scan.(*bool)
case **bool:
var s *[]uint8
scans[i] = &s
ppbs[i] = ppb
ppbs[i] = scan.(**bool)
case *time.Time:
var s string
scans[i] = &s
pts[i] = scan.(*time.Time)
case **time.Time:
var s sql.NullString
scans[i] = &s
ppts[i] = scan.(**time.Time)
}
}

Expand Down Expand Up @@ -169,6 +191,33 @@ func (c cursor) Scan(dest ...interface{}) error {
*ppb = &b
}
}
for i := range pts {
s := scans[i].(*string)
if s == nil {
return fmt.Errorf("field %d is null", i)
}
t, err := time.Parse("2006-01-02 15:04:05", *s)
if err != nil {
return err
}
*pts[i] = t

}
for i := range ppts {
nullString := scans[i].(*sql.NullString)
if nullString == nil {
return fmt.Errorf("field %d is null", i)
}
if !nullString.Valid {
*ppts[i] = nil
} else {
t, err := time.Parse("2006-01-02 15:04:05", nullString.String)
if err != nil {
return err
}
*ppts[i] = &t
}
}

return err
}
Expand Down
26 changes: 22 additions & 4 deletions cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type mockRows struct {
}

func (m mockRows) Columns() []string {
return []string{"a", "b", "c", "d", "e", "f", "g", "h"}[:m.columnCount]
return []string{"a", "b", "c", "d", "e", "f", "g", "h", "j", "k", "l"}[:m.columnCount]
}

func (m mockRows) Close() error {
Expand Down Expand Up @@ -62,6 +62,12 @@ func (m *mockRows) Next(dest []driver.Value) error {
dest[i] = nil
case 7:
dest[i] = time.Now()
case 8:
dest[i] = "2023-09-06 18:37:46.828"
case 9:
dest[i] = "2023-09-06 18:37:46"
case 10:
dest[i] = "2023-09-06 18:37:46"
}
}
return nil
Expand Down Expand Up @@ -103,21 +109,33 @@ func TestCursor(t *testing.T) {
var g *int // always null
var h string

Check failure on line 110 in cursor_test.go

View workflow job for this annotation

GitHub Actions / build

other declaration of h

var h *time.Time

Check failure on line 112 in cursor_test.go

View workflow job for this annotation

GitHub Actions / build

h redeclared in this block
var j time.Time
var k *time.Time
var l time.Time
tmh, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
tmj, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
tmk, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
tml, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
for i := 1; i <= 10; i++ {
if !cursor.Next() {
t.Error()
}
g = &i
if err := cursor.Scan(&a, &b, &cde, &f, &g, &h); err != nil {
t.Fatalf("%v", err)
if err := cursor.Scan(&a, &b, &cde, &f, &g, &h, &j, &k, &l); err != nil {
t.Errorf("%v", err)
}
if a != i ||
b != strconv.Itoa(i) ||
cde.C != float32(i) ||
cde.DE.D != (i%2 == 1) ||
cde.DE.E != cde.DE.D ||
****f != i ||
g != nil {
g != nil ||
*h != tmh ||

Check failure on line 135 in cursor_test.go

View workflow job for this annotation

GitHub Actions / build

invalid operation: cannot indirect h (variable of type string)
j != tmj ||
*k != tmk ||
l != tml {
t.Error(a, b, cde.C, cde.DE.D, cde.DE.E, ****f, g)
}
if err := cursor.Scan(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mockConn) Begin() (driver.Tx, error) {
}

var sharedMockConn = &mockConn{
columnCount: 8,
columnCount: 11,
rowCount: 10,
}

Expand Down
16 changes: 16 additions & 0 deletions expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"
"strconv"
"strings"
"time"
"unsafe"
)

Expand Down Expand Up @@ -90,6 +91,10 @@ type ArrayExpression interface {
Expression
}

type DateExpression interface {
Expression
}

// UnknownExpression is the interface of an SQL expression with unknown value.
type UnknownExpression interface {
Expression
Expand Down Expand Up @@ -310,6 +315,17 @@ func getSQL(scope scope, value interface{}) (sql string, priority priority, err
sql = value.(Table).GetSQL(scope)
case CaseExpression:
sql, err = value.(CaseExpression).End().GetSQL(scope)
case time.Time:
tmStr := value.(time.Time).Format("2006-01-02 15:04:05")
sql = quoteString(tmStr)
case *time.Time:
tm := value.(*time.Time)
if tm == nil {
sql = "NULL"
} else {
tmStr := tm.Format("2006-01-02 15:04:05")
sql = quoteString(tmStr)
}
default:
v := reflect.ValueOf(value)
sql, priority, err = getSQLFromReflectValue(scope, v)
Expand Down
10 changes: 10 additions & 0 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type ArrayField interface {
GetTable() Table
}

type DateField interface {
DateExpression
GetTable() Table
}

type actualField struct {
expression
table Table
Expand Down Expand Up @@ -84,6 +89,11 @@ func NewStringField(table Table, fieldName string) StringField {
return newField(table, fieldName)
}

// NewDateField creates a reference to a time.Time field. It should only be called from generated code.
func NewDateField(table Table, fieldName string) DateField {
return newField(table, fieldName)
}

type fieldList []Field

func (fields fieldList) GetSQL(scope scope) (string, error) {
Expand Down
2 changes: 2 additions & 0 deletions generator/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ func parseArgs(exampleDataSourceName string) (options options) {
printUsageAndExit(exampleDataSourceName)
}
parseForceCases = true
case "timeAsString":
timeAsString = true
default:
printUsageAndExit(exampleDataSourceName)
}
Expand Down
2 changes: 2 additions & 0 deletions generator/fetcher_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strconv"
)

var timeAsString = false

type mysqlSchemaFetcher struct {
db *sql.DB
}
Expand Down
49 changes: 41 additions & 8 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string,
case "float", "double", "decimal", "real":
goType = "float64"
fieldClass = "NumberField"
case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "datetime", "date", "time", "timestamp", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid":
case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "date", "time", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid":
goType = "string"
fieldClass = "StringField"
case "year":
Expand All @@ -103,6 +103,14 @@ func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string,
// TODO: Switch to specific type instead of interface.
goType = "[]interface{}"
fieldClass = "ArrayField"
case "datetime", "timestamp":
if !timeAsString {
goType = "time.Time"
fieldClass = "DateField"
} else {
goType = "string"
fieldClass = "StringField"
}
case "geometry", "point", "linestring", "polygon", "multipoint", "multilinestring", "multipolygon", "geometrycollection":
goType = "sqlingo.WellKnownBinary"
fieldClass = "WellKnownBinaryField"
Expand Down Expand Up @@ -174,10 +182,38 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) {
return "", errors.New("no database selected")
}

if len(options.tableNames) == 0 {
options.tableNames, err = schemaFetcher.GetTableNames()
if err != nil {
return "", err
}
}

needImportTime := false
for _, tableName := range options.tableNames {
fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName)
if err != nil {
return "", err
}
for _, fieldDescriptor := range fieldDescriptors {
if !timeAsString && fieldDescriptor.Type == "datetime" || fieldDescriptor.Type == "timestamp" {
needImportTime = true
break
}
}
}

code := "// This file is generated by sqlingo (https://github.com/lqs/sqlingo)\n"
code += "// DO NOT EDIT.\n\n"
code += "package " + ensureIdentifier(dbName) + "_dsl\n"
code += "import \"github.com/lqs/sqlingo\"\n\n"
if needImportTime {
code += "import (\n"
code += "\t\"time\"\n"
code += "\t\"github.com/lqs/sqlingo\"\n"
code += ")\n\n"
} else {
code += "import \"github.com/lqs/sqlingo\"\n\n"
}

code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n"

Expand Down Expand Up @@ -205,12 +241,9 @@ func Generate(driverName string, exampleDataSourceName string) (string, error) {
code += "\tsqlingo.ArrayField\n"
code += "}\n\n"

if len(options.tableNames) == 0 {
options.tableNames, err = schemaFetcher.GetTableNames()
if err != nil {
return "", err
}
}
code += "type dateField interface {\n"
code += "\tsqlingo.DateField\n"
code += "}\n\n"

var wg sync.WaitGroup

Expand Down

0 comments on commit 1327f63

Please sign in to comment.