From ee48e789c8edea845522cdab60303e63f61f45be Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Wed, 24 Jul 2024 22:12:46 -0400 Subject: [PATCH] allow disabling the default golang database/sql retry behavior (#899) * allow disabling the default golang database retry behavior * fixing comment * fixing comment * fix(canal): handle fake rotate events correctly for MariaDB 11.4 (#894) After upgrading to MariaDB 11.4, the canal module stopped detecting row updates within transactions due to incorrect handling of fake rotate events. MariaDB 11.4 does not set LogPos for certain events, causing these events to be ignored. This fix modifies the handling to consider fake rotate events only for ROTATE_EVENTs with timestamp = 0, aligning with MariaDB and MySQL documentation. * incorporating PR feedback --------- Co-authored-by: Bulat Aikaev --- README.md | 12 +++++++ driver/driver.go | 50 +++++++++++++++++++++++----- driver/driver_options_test.go | 62 +++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d5cd6e4a5..a9ef516fd 100644 --- a/README.md +++ b/README.md @@ -426,6 +426,18 @@ golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. | --------- | --------- | ----------------------------------------------- | | duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s | +#### `retries` + +Allows disabling the golang `database/sql` default behavior to retry errors +when `ErrBadConn` is returned by the driver. When retries are disabled +this driver will not return `ErrBadConn` from the `database/sql` package. + +Valid values are `on` (default) and `off`. + +| Type | Default | Example | +| --------- | --------- | ----------------------------------------------- | +| string | on | user:pass@localhost/mydb?retries=off | + ### Custom Driver Options The driver package exposes the function `SetDSNOptions`, allowing for modification of the diff --git a/driver/driver.go b/driver/driver.go index d2343fdad..94ebabbf0 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -97,7 +97,11 @@ func parseDSN(dsn string) (connInfo, error) { // Open takes a supplied DSN string and opens a connection // See ParseDSN for more information on the form of the DSN func (d driver) Open(dsn string) (sqldriver.Conn, error) { - var c *client.Conn + var ( + c *client.Conn + // by default database/sql driver retries will be enabled + retries = true + ) ci, err := parseDSN(dsn) @@ -134,6 +138,10 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { if timeout, err = time.ParseDuration(value[0]); err != nil { return nil, errors.Wrap(err, "invalid duration value for timeout option") } + } else if key == "retries" && len(value) > 0 { + // by default keep the golang database/sql retry behavior enabled unless + // the retries driver option is explicitly set to 'off' + retries = !strings.EqualFold(value[0], "off") } else { if option, ok := options[key]; ok { opt := func(o DriverOption, v string) client.Option { @@ -161,15 +169,28 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return nil, err } - return &conn{c}, nil + // if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3 + // retries by the database/sql package. If retries are 'off' then we'll return + // the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry. + // In this case the sqldriver.Validator interface is implemented and will return + // false for IsValid() signaling the connection is bad and should be discarded. + return &conn{Conn: c, state: &state{valid: true, useStdLibErrors: retries}}, nil } type CheckNamedValueFunc func(*sqldriver.NamedValue) error var _ sqldriver.NamedValueChecker = &conn{} +var _ sqldriver.Validator = &conn{} + +type state struct { + valid bool + // when true, the driver connection will return ErrBadConn from the golang Standard Library + useStdLibErrors bool +} type conn struct { *client.Conn + state *state } func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { @@ -190,13 +211,17 @@ func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error { return sqldriver.ErrSkip } +func (c *conn) IsValid() bool { + return c.state.valid +} + func (c *conn) Prepare(query string) (sqldriver.Stmt, error) { st, err := c.Conn.Prepare(query) if err != nil { return nil, errors.Trace(err) } - return &stmt{st}, nil + return &stmt{Stmt: st, connectionState: c.state}, nil } func (c *conn) Close() error { @@ -222,10 +247,16 @@ func buildArgs(args []sqldriver.Value) []interface{} { return a } -func replyError(err error) error { - if mysql.ErrorEqual(err, mysql.ErrBadConn) { +func (st *state) replyError(err error) error { + isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn) + + if st.useStdLibErrors && isBadConnection { return sqldriver.ErrBadConn } else { + // if we have a bad connection, this mark the state of this connection as not valid + // do the database/sql package can discard it instead of placing it back in the + // sql.DB pool. + st.valid = !isBadConnection return errors.Trace(err) } } @@ -234,7 +265,7 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err a := buildArgs(args) r, err := c.Conn.Execute(query, a...) if err != nil { - return nil, replyError(err) + return nil, c.state.replyError(err) } return &result{r}, nil } @@ -243,13 +274,14 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro a := buildArgs(args) r, err := c.Conn.Execute(query, a...) if err != nil { - return nil, replyError(err) + return nil, c.state.replyError(err) } return newRows(r.Resultset) } type stmt struct { *client.Stmt + connectionState *state } func (s *stmt) Close() error { @@ -264,7 +296,7 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) if err != nil { - return nil, replyError(err) + return nil, s.connectionState.replyError(err) } return &result{r}, nil } @@ -273,7 +305,7 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) { a := buildArgs(args) r, err := s.Stmt.Execute(a...) if err != nil { - return nil, replyError(err) + return nil, s.connectionState.replyError(err) } return newRows(r.Resultset) } diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index e0a9820a8..eaa863fbc 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -31,6 +31,51 @@ type testServer struct { } type mockHandler struct { + // the number of times a query executed + queryCount int +} + +func TestDriverOptions_SetRetriesOn(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + defer func() { + _ = conn.Close() + }() + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + + // we want to get a golang database/sql/driver ErrBadConn + require.ErrorIs(t, err, sqlDriver.ErrBadConn) + + // here we issue assert that even though we only issued 1 query, that the retries + // remained on and there were 3 calls to the DB. + require.Equal(t, 3, srv.handler.queryCount) +} + +func TestDriverOptions_SetRetriesOff(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s&retries=off") + defer func() { + _ = conn.Close() + }() + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + // we want the native error from this driver implementation + require.ErrorIs(t, err, mysql.ErrBadConn) + + // here we issue assert that even though we only issued 1 query, that the retries + // remained on and there were 3 calls to the DB. + require.Equal(t, 1, srv.handler.queryCount) } func TestDriverOptions_SetCollation(t *testing.T) { @@ -65,6 +110,9 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) { defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?timeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from table;") @@ -88,6 +136,9 @@ func TestDriverOptions_BufferSize(t *testing.T) { }) conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?bufferSize=4096") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from table;") @@ -103,6 +154,9 @@ func TestDriverOptions_ReadTimeout(t *testing.T) { defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) rows, err := conn.QueryContext(context.TODO(), "select * from slow;") @@ -134,11 +188,15 @@ func TestDriverOptions_writeTimeout(t *testing.T) { require.Contains(t, err.Error(), "missing unit in duration") require.Error(t, err) require.Nil(t, result) + require.NoError(t, conn.Close()) // use an almost zero (1ns) writeTimeout to ensure the insert statement // can't write before the timeout. Just want to make sure ExecContext() // will throw an error. conn, err = sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1ns") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) // ExecContext() should fail due to the write timeout of 1ns @@ -165,6 +223,9 @@ func TestDriverOptions_namedValueChecker(t *testing.T) { srv := CreateMockServer(t) defer srv.Stop() conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1s") + defer func() { + _ = conn.Close() + }() require.NoError(t, err) defer conn.Close() @@ -248,6 +309,7 @@ func (h *mockHandler) UseDB(dbName string) error { } func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) { + h.queryCount++ ss := strings.Split(query, " ") switch strings.ToLower(ss[0]) { case "select":