From 22d07160699ea9b0ca5ef6202cebe86ac4a62fee Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Fri, 7 Jun 2024 07:40:41 -0400 Subject: [PATCH] add support for driver.NamedValueChecker on driver connection --- README.md | 35 +++++++++++++++++++++++ driver/driver.go | 35 +++++++++++++++++++++++ driver/driver_options_test.go | 52 +++++++++++++++++++++++++++++++---- 3 files changed, 117 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e108469e0..b3042e7e7 100644 --- a/README.md +++ b/README.md @@ -448,6 +448,41 @@ func main() { } ``` +### Custom NamedValueChecker + +Golang allows for custom handling of query arguments before they are passed to the driver +with the implementation of a [NamedValueChecker](https://pkg.go.dev/database/sql/driver#NamedValueChecker). By doing a full import of the driver (not by side-effects only), +a custom NamedValueChecker can be implemented. + +```golang +import ( + "database/sql" + + "github.com/go-mysql-org/go-mysql/driver" +) + +func main() { + driver.AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error { + rv := reflect.ValueOf(nv.Value) + if rv.Kind() != reflect.Uint64 { + // fallback to the default value converter when the value is not a uint64 + return sqlDriver.ErrSkip + } + + return nil + }) + + conn, err := sql.Open("mysql", "root@127.0.0.1:3306/test") + defer conn.Close() + + stmt, err := conn.Prepare("select * from table where id = ?") + defer stmt.Close() + var val uint64 = math.MaxUint64 + // without the NamedValueChecker this query would fail + result, err := stmt.Query(val) +} +``` + We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) diff --git a/driver/driver.go b/driver/driver.go index 8f132d2b3..7ecd64419 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "database/sql" sqldriver "database/sql/driver" + goErrors "errors" "fmt" "io" "net/url" @@ -25,6 +26,10 @@ var customTLSMutex sync.Mutex var ( customTLSConfigMap = make(map[string]*tls.Config) options = make(map[string]DriverOption) + + // can be provided by clients to allow more control in handling Go and database + // types beyond the default Value types allowed + namedValueCheckers []CheckNamedValueFunc ) type driver struct { @@ -154,10 +159,32 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return &conn{c}, nil } +type CheckNamedValueFunc func(*sqldriver.NamedValue) error + +var _ sqldriver.NamedValueChecker = &conn{} + type conn struct { *client.Conn } +func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { + for _, nvChecker := range namedValueCheckers { + err := nvChecker(nv) + if err == nil { + // we've found a CheckNamedValueFunc that handled this named value + // no need to keep looking + return nil + } else { + // we've found an error, if the error is driver.ErrSkip then + // keep looking otherwise return the unknown error + if !goErrors.Is(sqldriver.ErrSkip, err) { + return err + } + } + } + return sqldriver.ErrSkip +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { st, err := c.Conn.Prepare(query) if err != nil { @@ -375,3 +402,11 @@ func SetDSNOptions(customOptions map[string]DriverOption) { options[o] = f } } + +// AddNamedValueChecker sets a custom NamedValueChecker for the driver connection which +// allows for more control in handling Go and database types beyond the default Value types. +// See https://pkg.go.dev/database/sql/driver#NamedValueChecker +// Usage requires a full import of the driver (not by side-effects only). +func AddNamedValueChecker(nvCheckFunc ...CheckNamedValueFunc) { + namedValueCheckers = append(namedValueCheckers, nvCheckFunc...) +} diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index 32431932a..90d14f945 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -3,8 +3,11 @@ package driver import ( "context" "database/sql" + sqlDriver "database/sql/driver" "fmt" + "math" "net" + "reflect" "strings" "testing" "time" @@ -94,16 +97,50 @@ func TestDriverOptions_writeTimeout(t *testing.T) { srv := CreateMockServer(t) defer srv.Stop() - conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=10") + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1ns") require.NoError(t, err) result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);") - require.Nil(t, result) require.Error(t, err) + require.Nil(t, result) conn.Close() } +func TestDriverOptions_namedValueChecker(t *testing.T) { + AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error { + rv := reflect.ValueOf(nv.Value) + if rv.Kind() != reflect.Uint64 { + // fallback to the default value converter when the value is not a uint64 + return sqlDriver.ErrSkip + } + + return nil + }) + + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1s") + require.NoError(t, err) + defer conn.Close() + + stmt, err := conn.Prepare("select a, b from fast where uint64 = ?") + require.NoError(t, err) + defer stmt.Close() + + var val uint64 = math.MaxUint64 + result, err := stmt.Query(val) + require.NoError(t, err) + require.NotNil(t, result) + + var a uint64 + var b string + require.True(t, result.Next()) + require.NoError(t, result.Scan(&a, &b)) + require.True(t, math.MaxUint64 == a) +} + func CreateMockServer(t *testing.T) *testServer { inMemProvider := server.NewInMemoryProvider() inMemProvider.AddUser(*testUser, *testPassword) @@ -167,8 +204,13 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err time.Sleep(time.Second * 5) } + var aValue uint64 = 1 + if strings.Contains(query, "uint64") { + aValue = math.MaxUint64 + } + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ - {1, "hello world"}, + {aValue, "hello world"}, }, binary) } @@ -206,13 +248,13 @@ func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*my func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) { params = 1 - columns = 0 + columns = 2 return params, columns, nil, nil } func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { if strings.HasPrefix(strings.ToLower(query), "select") { - return h.HandleQuery(query) + return h.handleQuery(query, true) } return &mysql.Result{