Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dvilaverde committed Jun 11, 2024
1 parent 22d0716 commit 199a13e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
4 changes: 3 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ func SetDSNOptions(customOptions map[string]DriverOption) {
// AddNamedValueChecker sets a custom NamedValueChecker for the driver connection which
// allows for more control in handling Go and database types beyond the default Value types.
// See https://pkg.go.dev/database/sql/driver#NamedValueChecker
// Usage requires a full import of the driver (not by side-effects only).
// Usage requires a full import of the driver (not by side-effects only). Also note that
// this function is not concurrent-safe, and should only be executed while setting up the driver
// before establishing any connections via `sql.Open()`.
func AddNamedValueChecker(nvCheckFunc ...CheckNamedValueFunc) {
namedValueCheckers = append(namedValueCheckers, nvCheckFunc...)
}
71 changes: 56 additions & 15 deletions driver/driver_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,32 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
srv := CreateMockServer(t)
defer srv.Stop()

conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns")
// use a writeTimeout that will fail parsing by ParseDuration resulting
// in a conn open error. The Open() won't fail until the ExecContext()
// call though, because that's when golang database/sql package will open
// the actual connection.
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=10")
require.NoError(t, err)
require.NotNil(t, conn)

result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
// here we should fail because of the missing time unit in the duration.
result, err := conn.ExecContext(context.TODO(), "select 1;")
require.Contains(t, err.Error(), "missing unit in duration")
require.Error(t, err)
require.Nil(t, result)

// use an almost zero (1ns) writeTimeout to ensure the insert statement
// can't write before the timeout. Just want to make sure ExecContext()
// will throw an error.
conn, err = sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns")
require.NoError(t, err)

// ExecContext() should fail due to the write timeout of 1ns
result, err = conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
require.Nil(t, result)

conn.Close()
}

Expand All @@ -125,6 +144,22 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
require.NoError(t, err)
defer conn.Close()

// the NamedValueChecker will return ErrSkip for types that are NOT uint64, so make
// sure those make it to the server ok first.
int32Stmt, err := conn.Prepare("select ?")
require.NoError(t, err)
defer int32Stmt.Close()
r1, err := int32Stmt.Query(math.MaxInt32)
require.NoError(t, err)
require.NotNil(t, r1)

var i32 int32
require.True(t, r1.Next())
require.NoError(t, r1.Scan(&i32))
require.True(t, math.MaxInt32 == i32)

// Now make sure that the uint64 makes it to the server as well, this case will be handled
// by the NamedValueChecker (i.e. it will not return ErrSkip)
stmt, err := conn.Prepare("select a, b from fast where uint64 = ?")
require.NoError(t, err)
defer stmt.Close()
Expand Down Expand Up @@ -188,7 +223,7 @@ func (h *mockHandler) UseDB(dbName string) error {
return nil
}

func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) {
ss := strings.Split(query, " ")
switch strings.ToLower(ss[0]) {
case "select":
Expand All @@ -200,18 +235,24 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
{mysql.MaxPayloadLen},
}, binary)
} else {
if strings.Contains(query, "slow") {
time.Sleep(time.Second * 5)
}
if ss[1] == "?" {
r, err = mysql.BuildSimpleResultset([]string{"a"}, [][]interface{}{
{args[0].(int64)},
}, binary)
} else {
if strings.Contains(query, "slow") {
time.Sleep(time.Second * 5)
}

var aValue uint64 = 1
if strings.Contains(query, "uint64") {
aValue = math.MaxUint64
}
var aValue uint64 = 1
if strings.Contains(query, "uint64") {
aValue = math.MaxUint64
}

r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
{aValue, "hello world"},
}, binary)
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
{aValue, "hello world"},
}, binary)
}
}

if err != nil {
Expand Down Expand Up @@ -239,7 +280,7 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
}

func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) {
return h.handleQuery(query, false)
return h.handleQuery(query, false, nil)
}

func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
Expand All @@ -254,7 +295,7 @@ func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int,

func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) {
if strings.HasPrefix(strings.ToLower(query), "select") {
return h.handleQuery(query, true)
return h.handleQuery(query, true, args)
}

return &mysql.Result{
Expand Down

0 comments on commit 199a13e

Please sign in to comment.