Skip to content

Commit

Permalink
add support for driver.NamedValueChecker on driver connection (#887)
Browse files Browse the repository at this point in the history
* add support for driver.NamedValueChecker on driver connection

* added more tests
  • Loading branch information
dvilaverde authored Jun 12, 2024
1 parent b13191f commit c607c3c
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 11 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,41 @@ func main() {
}
```

### Custom NamedValueChecker

Golang allows for custom handling of query arguments before they are passed to the driver
with the implementation of a [NamedValueChecker](https://pkg.go.dev/database/sql/driver#NamedValueChecker). By doing a full import of the driver (not by side-effects only),
a custom NamedValueChecker can be implemented.

```golang
import (
"database/sql"

"github.com/go-mysql-org/go-mysql/driver"
)

func main() {
driver.AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error {
rv := reflect.ValueOf(nv.Value)
if rv.Kind() != reflect.Uint64 {
// fallback to the default value converter when the value is not a uint64
return sqlDriver.ErrSkip
}

return nil
})

conn, err := sql.Open("mysql", "[email protected]:3306/test")
defer conn.Close()

stmt, err := conn.Prepare("select * from table where id = ?")
defer stmt.Close()
var val uint64 = math.MaxUint64
// without the NamedValueChecker this query would fail
result, err := stmt.Query(val)
}
```


We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-)

Expand Down
37 changes: 37 additions & 0 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
goErrors "errors"
"fmt"
"io"
"net/url"
Expand All @@ -25,6 +26,10 @@ var customTLSMutex sync.Mutex
var (
customTLSConfigMap = make(map[string]*tls.Config)
options = make(map[string]DriverOption)

// can be provided by clients to allow more control in handling Go and database
// types beyond the default Value types allowed
namedValueCheckers []CheckNamedValueFunc
)

type driver struct {
Expand Down Expand Up @@ -154,10 +159,32 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
return &conn{c}, nil
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error

var _ sqldriver.NamedValueChecker = &conn{}

type conn struct {
*client.Conn
}

func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
for _, nvChecker := range namedValueCheckers {
err := nvChecker(nv)
if err == nil {
// we've found a CheckNamedValueFunc that handled this named value
// no need to keep looking
return nil
} else {
// we've found an error, if the error is driver.ErrSkip then
// keep looking otherwise return the unknown error
if !goErrors.Is(sqldriver.ErrSkip, err) {
return err
}
}
}
return sqldriver.ErrSkip
}

func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
st, err := c.Conn.Prepare(query)
if err != nil {
Expand Down Expand Up @@ -375,3 +402,13 @@ func SetDSNOptions(customOptions map[string]DriverOption) {
options[o] = f
}
}

// AddNamedValueChecker sets a custom NamedValueChecker for the driver connection which
// allows for more control in handling Go and database types beyond the default Value types.
// See https://pkg.go.dev/database/sql/driver#NamedValueChecker
// Usage requires a full import of the driver (not by side-effects only). Also note that
// this function is not concurrent-safe, and should only be executed while setting up the driver
// before establishing any connections via `sql.Open()`.
func AddNamedValueChecker(nvCheckFunc ...CheckNamedValueFunc) {
namedValueCheckers = append(namedValueCheckers, nvCheckFunc...)
}
105 changes: 94 additions & 11 deletions driver/driver_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package driver
import (
"context"
"database/sql"
sqlDriver "database/sql/driver"
"fmt"
"math"
"net"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -94,16 +97,85 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
srv := CreateMockServer(t)
defer srv.Stop()

// use a writeTimeout that will fail parsing by ParseDuration resulting
// in a conn open error. The Open() won't fail until the ExecContext()
// call though, because that's when golang database/sql package will open
// the actual connection.
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=10")
require.NoError(t, err)
require.NotNil(t, conn)

result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
// here we should fail because of the missing time unit in the duration.
result, err := conn.ExecContext(context.TODO(), "select 1;")
require.Contains(t, err.Error(), "missing unit in duration")
require.Error(t, err)
require.Nil(t, result)

// use an almost zero (1ns) writeTimeout to ensure the insert statement
// can't write before the timeout. Just want to make sure ExecContext()
// will throw an error.
conn, err = sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns")
require.NoError(t, err)

// ExecContext() should fail due to the write timeout of 1ns
result, err = conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
require.Nil(t, result)

conn.Close()
}

func TestDriverOptions_namedValueChecker(t *testing.T) {
AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error {
rv := reflect.ValueOf(nv.Value)
if rv.Kind() != reflect.Uint64 {
// fallback to the default value converter when the value is not a uint64
return sqlDriver.ErrSkip
}

return nil
})

log.SetLevel(log.LevelDebug)
srv := CreateMockServer(t)
defer srv.Stop()
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1s")
require.NoError(t, err)
defer conn.Close()

// the NamedValueChecker will return ErrSkip for types that are NOT uint64, so make
// sure those make it to the server ok first.
int32Stmt, err := conn.Prepare("select ?")
require.NoError(t, err)
defer int32Stmt.Close()
r1, err := int32Stmt.Query(math.MaxInt32)
require.NoError(t, err)
require.NotNil(t, r1)

var i32 int32
require.True(t, r1.Next())
require.NoError(t, r1.Scan(&i32))
require.True(t, math.MaxInt32 == i32)

// Now make sure that the uint64 makes it to the server as well, this case will be handled
// by the NamedValueChecker (i.e. it will not return ErrSkip)
stmt, err := conn.Prepare("select a, b from fast where uint64 = ?")
require.NoError(t, err)
defer stmt.Close()

var val uint64 = math.MaxUint64
result, err := stmt.Query(val)
require.NoError(t, err)
require.NotNil(t, result)

var a uint64
var b string
require.True(t, result.Next())
require.NoError(t, result.Scan(&a, &b))
require.True(t, math.MaxUint64 == a)
}

func CreateMockServer(t *testing.T) *testServer {
inMemProvider := server.NewInMemoryProvider()
inMemProvider.AddUser(*testUser, *testPassword)
Expand Down Expand Up @@ -151,7 +223,7 @@ func (h *mockHandler) UseDB(dbName string) error {
return nil
}

func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) {
ss := strings.Split(query, " ")
switch strings.ToLower(ss[0]) {
case "select":
Expand All @@ -163,13 +235,24 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
{mysql.MaxPayloadLen},
}, binary)
} else {
if strings.Contains(query, "slow") {
time.Sleep(time.Second * 5)
}
if ss[1] == "?" {
r, err = mysql.BuildSimpleResultset([]string{"a"}, [][]interface{}{
{args[0].(int64)},
}, binary)
} else {
if strings.Contains(query, "slow") {
time.Sleep(time.Second * 5)
}

r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
{1, "hello world"},
}, binary)
var aValue uint64 = 1
if strings.Contains(query, "uint64") {
aValue = math.MaxUint64
}

r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
{aValue, "hello world"},
}, binary)
}
}

if err != nil {
Expand Down Expand Up @@ -197,7 +280,7 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
}

func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) {
return h.handleQuery(query, false)
return h.handleQuery(query, false, nil)
}

func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
Expand All @@ -206,13 +289,13 @@ func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*my

func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) {
params = 1
columns = 0
columns = 2
return params, columns, nil, nil
}

func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) {
if strings.HasPrefix(strings.ToLower(query), "select") {
return h.HandleQuery(query)
return h.handleQuery(query, true, args)
}

return &mysql.Result{
Expand Down

0 comments on commit c607c3c

Please sign in to comment.