diff --git a/mysql/resultset_helper.go b/mysql/resultset_helper.go index baf0a037f..fee929506 100644 --- a/mysql/resultset_helper.go +++ b/mysql/resultset_helper.go @@ -3,6 +3,7 @@ package mysql import ( "bytes" "encoding/binary" + "fmt" "math" "strconv" "time" @@ -43,7 +44,7 @@ func FormatTextValue(value interface{}) ([]byte, error) { case string: return utils.StringToByteSlice(v), nil case time.Time: - return hack.Slice(v.Format(time.DateTime)), nil + return utils.StringToByteSlice(v.Format(time.DateTime)), nil case nil: return nil, nil default: @@ -55,7 +56,7 @@ func toBinaryDateTime(t time.Time) ([]byte, error) { var buf bytes.Buffer if t.IsZero() { - return nil, nil + return nil, fmt.Errorf("zero time") } year, month, day := t.Year(), t.Month(), t.Day() diff --git a/mysql/util_test.go b/mysql/util_test.go index 345817f47..9d326e768 100644 --- a/mysql/util_test.go +++ b/mysql/util_test.go @@ -1,9 +1,6 @@ package mysql import ( - "encoding/binary" - "fmt" - "strings" "testing" "time" @@ -56,92 +53,65 @@ func TestFormatBinaryTime(t *testing.T) { } } -// mysql driver parse binary datetime -func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (time.Time, error) { - switch num { - case 0: - return time.Time{}, nil - case 4: - return time.Date( - int(binary.LittleEndian.Uint16(data[:2])), // year - time.Month(data[2]), // month - int(data[3]), // day - 0, 0, 0, 0, - loc, - ), nil - case 7: - return time.Date( - int(binary.LittleEndian.Uint16(data[:2])), // year - time.Month(data[2]), // month - int(data[3]), // day - int(data[4]), // hour - int(data[5]), // minutes - int(data[6]), // seconds - 0, - loc, - ), nil - case 11: - return time.Date( - int(binary.LittleEndian.Uint16(data[:2])), // year - time.Month(data[2]), // month - int(data[3]), // day - int(data[4]), // hour - int(data[5]), // minutes - int(data[6]), // seconds - int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds - loc, - ), nil - } - return time.Time{}, fmt.Errorf("invalid DATETIME packet length %d", num) -} - func TestToBinaryDateTime(t *testing.T) { + var ( + DateTimeNano = "2006-01-02 15:04:05.000000" + formatBinaryDateTime = func(n int, data []byte) string { + date, err := FormatBinaryDateTime(n, data) + if err != nil { + return "" + } + return string(date) + } + ) + tests := []struct { - name string - input time.Time - expected string + Name string + Data time.Time + Expect func(n int, data []byte) string + Error bool }{ { - name: "Zero time", - input: time.Time{}, - expected: "", + Name: "Zero time", + Data: time.Time{}, + Expect: nil, + Error: true, }, { - name: "Date with nanoseconds", - input: time.Date(2023, 10, 10, 10, 10, 10, 123456000, time.UTC), - expected: "2023-10-10 10:10:10.123456 +0000 UTC", + Name: "Date with nanoseconds", + Data: time.Date(2023, 10, 10, 10, 10, 10, 123456000, time.UTC), + Expect: formatBinaryDateTime, }, { - name: "Date with time", - input: time.Date(2023, 10, 10, 10, 10, 10, 0, time.UTC), - expected: "2023-10-10 10:10:10 +0000 UTC", + Name: "Date with time", + Data: time.Date(2023, 10, 10, 10, 10, 10, 0, time.UTC), + Expect: formatBinaryDateTime, }, { - name: "Date only", - input: time.Date(2023, 10, 10, 0, 0, 0, 0, time.UTC), - expected: "2023-10-10 00:00:00 +0000 UTC", + Name: "Date only", + Data: time.Date(2023, 10, 10, 0, 0, 0, 0, time.UTC), + Expect: formatBinaryDateTime, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := toBinaryDateTime(tt.input) - if err != nil { - t.Fatalf("unexpected error: %v", err) + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + got, err := toBinaryDateTime(test.Data) + if test.Error { + require.Error(t, err) + } else { + require.NoError(t, err) } - if len(result) == 0 { + if len(got) == 0 { return } - num := uint64(result[0]) - data := result[1:] - date, err := parseBinaryDateTime(num, data, time.UTC) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if !strings.EqualFold(date.String(), tt.expected) { - t.Errorf("expected %v, got %v", tt.expected, result) + tmp := test.Expect(int(got[0]), got[1:]) + if int(got[0]) < 11 { + require.Equal(t, tmp, test.Data.Format(time.DateTime), "test case %v", test.Data.String()) + } else { + require.Equal(t, tmp, test.Data.Format(DateTimeNano), "test case %v", test.Data.String()) } }) } + }