-
Notifications
You must be signed in to change notification settings - Fork 991
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for driver.NamedValueChecker on driver connection (#887)
* add support for driver.NamedValueChecker on driver connection * added more tests
- Loading branch information
1 parent
b13191f
commit c607c3c
Showing
3 changed files
with
166 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -448,6 +448,41 @@ func main() { | |
} | ||
``` | ||
|
||
### Custom NamedValueChecker | ||
|
||
Golang allows for custom handling of query arguments before they are passed to the driver | ||
with the implementation of a [NamedValueChecker](https://pkg.go.dev/database/sql/driver#NamedValueChecker). By doing a full import of the driver (not by side-effects only), | ||
a custom NamedValueChecker can be implemented. | ||
|
||
```golang | ||
import ( | ||
"database/sql" | ||
|
||
"github.com/go-mysql-org/go-mysql/driver" | ||
) | ||
|
||
func main() { | ||
driver.AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error { | ||
rv := reflect.ValueOf(nv.Value) | ||
if rv.Kind() != reflect.Uint64 { | ||
// fallback to the default value converter when the value is not a uint64 | ||
return sqlDriver.ErrSkip | ||
} | ||
|
||
return nil | ||
}) | ||
|
||
conn, err := sql.Open("mysql", "[email protected]:3306/test") | ||
defer conn.Close() | ||
|
||
stmt, err := conn.Prepare("select * from table where id = ?") | ||
defer stmt.Close() | ||
var val uint64 = math.MaxUint64 | ||
// without the NamedValueChecker this query would fail | ||
result, err := stmt.Query(val) | ||
} | ||
``` | ||
|
||
|
||
We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,11 @@ package driver | |
import ( | ||
"context" | ||
"database/sql" | ||
sqlDriver "database/sql/driver" | ||
"fmt" | ||
"math" | ||
"net" | ||
"reflect" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
@@ -94,16 +97,85 @@ func TestDriverOptions_writeTimeout(t *testing.T) { | |
srv := CreateMockServer(t) | ||
defer srv.Stop() | ||
|
||
// use a writeTimeout that will fail parsing by ParseDuration resulting | ||
// in a conn open error. The Open() won't fail until the ExecContext() | ||
// call though, because that's when golang database/sql package will open | ||
// the actual connection. | ||
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=10") | ||
require.NoError(t, err) | ||
require.NotNil(t, conn) | ||
|
||
result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);") | ||
// here we should fail because of the missing time unit in the duration. | ||
result, err := conn.ExecContext(context.TODO(), "select 1;") | ||
require.Contains(t, err.Error(), "missing unit in duration") | ||
require.Error(t, err) | ||
require.Nil(t, result) | ||
|
||
// use an almost zero (1ns) writeTimeout to ensure the insert statement | ||
// can't write before the timeout. Just want to make sure ExecContext() | ||
// will throw an error. | ||
conn, err = sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns") | ||
require.NoError(t, err) | ||
|
||
// ExecContext() should fail due to the write timeout of 1ns | ||
result, err = conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);") | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "i/o timeout") | ||
require.Nil(t, result) | ||
|
||
conn.Close() | ||
} | ||
|
||
func TestDriverOptions_namedValueChecker(t *testing.T) { | ||
AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error { | ||
rv := reflect.ValueOf(nv.Value) | ||
if rv.Kind() != reflect.Uint64 { | ||
// fallback to the default value converter when the value is not a uint64 | ||
return sqlDriver.ErrSkip | ||
} | ||
|
||
return nil | ||
}) | ||
|
||
log.SetLevel(log.LevelDebug) | ||
srv := CreateMockServer(t) | ||
defer srv.Stop() | ||
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1s") | ||
require.NoError(t, err) | ||
defer conn.Close() | ||
|
||
// the NamedValueChecker will return ErrSkip for types that are NOT uint64, so make | ||
// sure those make it to the server ok first. | ||
int32Stmt, err := conn.Prepare("select ?") | ||
require.NoError(t, err) | ||
defer int32Stmt.Close() | ||
r1, err := int32Stmt.Query(math.MaxInt32) | ||
require.NoError(t, err) | ||
require.NotNil(t, r1) | ||
|
||
var i32 int32 | ||
require.True(t, r1.Next()) | ||
require.NoError(t, r1.Scan(&i32)) | ||
require.True(t, math.MaxInt32 == i32) | ||
|
||
// Now make sure that the uint64 makes it to the server as well, this case will be handled | ||
// by the NamedValueChecker (i.e. it will not return ErrSkip) | ||
stmt, err := conn.Prepare("select a, b from fast where uint64 = ?") | ||
require.NoError(t, err) | ||
defer stmt.Close() | ||
|
||
var val uint64 = math.MaxUint64 | ||
result, err := stmt.Query(val) | ||
require.NoError(t, err) | ||
require.NotNil(t, result) | ||
|
||
var a uint64 | ||
var b string | ||
require.True(t, result.Next()) | ||
require.NoError(t, result.Scan(&a, &b)) | ||
require.True(t, math.MaxUint64 == a) | ||
} | ||
|
||
func CreateMockServer(t *testing.T) *testServer { | ||
inMemProvider := server.NewInMemoryProvider() | ||
inMemProvider.AddUser(*testUser, *testPassword) | ||
|
@@ -151,7 +223,7 @@ func (h *mockHandler) UseDB(dbName string) error { | |
return nil | ||
} | ||
|
||
func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { | ||
func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) { | ||
ss := strings.Split(query, " ") | ||
switch strings.ToLower(ss[0]) { | ||
case "select": | ||
|
@@ -163,13 +235,24 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err | |
{mysql.MaxPayloadLen}, | ||
}, binary) | ||
} else { | ||
if strings.Contains(query, "slow") { | ||
time.Sleep(time.Second * 5) | ||
} | ||
if ss[1] == "?" { | ||
r, err = mysql.BuildSimpleResultset([]string{"a"}, [][]interface{}{ | ||
{args[0].(int64)}, | ||
}, binary) | ||
} else { | ||
if strings.Contains(query, "slow") { | ||
time.Sleep(time.Second * 5) | ||
} | ||
|
||
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ | ||
{1, "hello world"}, | ||
}, binary) | ||
var aValue uint64 = 1 | ||
if strings.Contains(query, "uint64") { | ||
aValue = math.MaxUint64 | ||
} | ||
|
||
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ | ||
{aValue, "hello world"}, | ||
}, binary) | ||
} | ||
} | ||
|
||
if err != nil { | ||
|
@@ -197,7 +280,7 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err | |
} | ||
|
||
func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) { | ||
return h.handleQuery(query, false) | ||
return h.handleQuery(query, false, nil) | ||
} | ||
|
||
func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { | ||
|
@@ -206,13 +289,13 @@ func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*my | |
|
||
func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) { | ||
params = 1 | ||
columns = 0 | ||
columns = 2 | ||
return params, columns, nil, nil | ||
} | ||
|
||
func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { | ||
if strings.HasPrefix(strings.ToLower(query), "select") { | ||
return h.HandleQuery(query) | ||
return h.handleQuery(query, true, args) | ||
} | ||
|
||
return &mysql.Result{ | ||
|