diff --git a/README.md b/README.md index d5cd6e4a5..a9ef516fd 100644 --- a/README.md +++ b/README.md @@ -426,6 +426,18 @@ golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. | --------- | --------- | ----------------------------------------------- | | duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s | +#### `retries` + +Allows disabling the golang `database/sql` default behavior to retry errors +when `ErrBadConn` is returned by the driver. When retries are disabled +this driver will not return `ErrBadConn` from the `database/sql` package. + +Valid values are `on` (default) and `off`. + +| Type | Default | Example | +| --------- | --------- | ----------------------------------------------- | +| string | on | user:pass@localhost/mydb?retries=off | + ### Custom Driver Options The driver package exposes the function `SetDSNOptions`, allowing for modification of the diff --git a/driver/driver.go b/driver/driver.go index d2343fdad..94ebabbf0 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -97,7 +97,11 @@ func parseDSN(dsn string) (connInfo, error) { // Open takes a supplied DSN string and opens a connection // See ParseDSN for more information on the form of the DSN func (d driver) Open(dsn string) (sqldriver.Conn, error) { - var c *client.Conn + var ( + c *client.Conn + // by default database/sql driver retries will be enabled + retries = true + ) ci, err := parseDSN(dsn) @@ -134,6 +138,10 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { if timeout, err = time.ParseDuration(value[0]); err != nil { return nil, errors.Wrap(err, "invalid duration value for timeout option") } + } else if key == "retries" && len(value) > 0 { + // by default keep the golang database/sql retry behavior enabled unless + // the retries driver option is explicitly set to 'off' + retries = !strings.EqualFold(value[0], "off") } else { if option, ok := options[key]; ok { opt := func(o DriverOption, v string) client.Option { @@ -161,15 +169,28 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return nil, err } - return &conn{c}, nil + // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3 + // retries by the database/sql package. If retries are 'off' then we'll return + // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry. + // In this case the sqldriver.Validator interface is implemented and will return + // false for IsValid() signaling the connection is bad and should be discarded. + return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil } type CheckNamedValueFunc func(*sqldriver.NamedValue) error var _ sqldriver.NamedValueChecker = &conn{} +var _ sqldriver.Validator = &conn{} + +type state struct { + valid bool + // when true, the driver connection will return ErrBadConn from the golang Standard Library + useStdLibErrors bool +} type conn struct { *client.Conn + state *state } func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { @@ -190,13 +211,17 @@ func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { return sqldriver.ErrSkip } +func (c *conn) IsValid() bool { + return c.state.valid +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { st, err := c.Conn.Prepare(query) if err != nil { return nil, errors.Trace(err) } - return &stmt{st}, nil + return &stmt{Stmt: st, connectionState: c.state}, nil } func (c *conn) Close() error { @@ -222,10 +247,16 @@ func buildArgs(args []sqldriver.Value) []interface{} { return a } -func replyError(err error) error { - if mysql.ErrorEqual(err, mysql.ErrBadConn) { +func (st *state) replyError(err error) error { + isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn) + + if st.useStdLibErrors && isBadConnection { return sqldriver.ErrBadConn } else { + // if we have a bad connection, this mark the state of this connection as not valid + // do the database/sql package can discard it instead of placing it back in the + // sql.DB pool. + st.valid = !isBadConnection return errors.Trace(err) } } @@ -234,7 +265,7 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err a := buildArgs(args) r, err := c.Conn.Execute(query, a...) if err != nil { - return nil, replyError(err) + return nil, c.state.replyError(err) } return &result{r}, nil } @@ -243,13 +274,14 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro a := buildArgs(args) r, err := c.Conn.Execute(query, a...) if err != nil { - return nil, replyError(err) + return nil, c.state.replyError(err) } return newRows(r.Resultset) } type stmt struct { *client.Stmt + connectionState *state } func (s *stmt) Close() error { @@ -264,7 +296,7 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) if err != nil { - return nil, replyError(err) + return nil, s.connectionState.replyError(err) } return &result{r}, nil } @@ -273,7 +305,7 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) if err != nil { - return nil, replyError(err) + return nil, s.connectionState.replyError(err) } return newRows(r.Resultset) } diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index e0a9820a8..eaa863fbc 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -31,6 +31,51 @@ type testServer struct { } type mockHandler struct { + // the number of times a query executed + queryCount int +} + +func TestDriverOptions_SetRetriesOn(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + defer func() { + _ = conn.Close() + }() + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + + // we want to get a golang database/sql/driver ErrBadConn + require.ErrorIs(t, err, sqlDriver.ErrBadConn) + + // here we issue assert that even though we only issued 1 query, that the retries + // remained on and there were 3 calls to the DB. + require.Equal(t, 3, srv.handler.queryCount) +} + +func TestDriverOptions_SetRetriesOff(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s&retries=off") + defer func() { + _ = conn.Close() + }() + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + // we want the native error from this driver implementation + require.ErrorIs(t, err, mysql.ErrBadConn) + + // here we issue assert that even though we only issued 1 query, that the retries + // remained on and there were 3 calls to the DB. + require.Equal(t, 1, srv.handler.queryCount) } func TestDriverOptions_SetCollation(t *testing.T) { @@ -65,6 +110,9 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) { defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?timeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from table;") @@ -88,6 +136,9 @@ func TestDriverOptions_BufferSize(t *testing.T) { }) conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?bufferSize=4096") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from table;") @@ -103,6 +154,9 @@ func TestDriverOptions_ReadTimeout(t *testing.T) { defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from slow;") @@ -134,11 +188,15 @@ func TestDriverOptions_writeTimeout(t *testing.T) { require.Contains(t, err.Error(), "missing unit in duration") require.Error(t, err) require.Nil(t, result) + require.NoError(t, conn.Close()) // use an almost zero (1ns) writeTimeout to ensure the insert statement // can't write before the timeout. Just want to make sure ExecContext() // will throw an error. conn, err = sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1ns") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) // ExecContext() should fail due to the write timeout of 1ns @@ -165,6 +223,9 @@ func TestDriverOptions_namedValueChecker(t *testing.T) { srv := CreateMockServer(t) defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) defer conn.Close() @@ -248,6 +309,7 @@ func (h *mockHandler) UseDB(dbName string) error { } func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) { + h.queryCount++ ss := strings.Split(query, " ") switch strings.ToLower(ss[0]) { case "select":