diff --git a/AUTHORS b/AUTHORS index 63ee516e5..4021b96cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -132,6 +132,7 @@ GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Microsoft Corp. Multiplay Ltd. Percona LLC PingCAP Inc. diff --git a/connector.go b/connector.go index 3cef7963f..a0ee62839 100644 --- a/connector.go +++ b/connector.go @@ -66,12 +66,22 @@ func newConnector(cfg *Config) *connector { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var err error + // Invoke beforeConnect if present, with a copy of the configuration + cfg := c.cfg + if c.cfg.beforeConnect != nil { + cfg = c.cfg.Clone() + err = c.cfg.beforeConnect(ctx, cfg) + if err != nil { + return nil, err + } + } + // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), - cfg: c.cfg, + cfg: cfg, connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/driver_test.go b/driver_test.go index 5934caab6..001957244 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2044,6 +2044,40 @@ func TestCustomDial(t *testing.T) { } } +func TestBeforeConnect(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // dbname is set in the BeforeConnect handle + cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_")) + if err != nil { + t.Fatalf("error parsing DSN: %v", err) + } + + cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error { + c.DBName = dbname + return nil + })) + + connector, err := NewConnector(cfg) + if err != nil { + t.Fatalf("error creating connector: %v", err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var connectedDb string + err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb) + if err != nil { + t.Fatalf("error executing query: %v", err) + } + if connectedDb != dbname { + t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb) + } +} + func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { diff --git a/dsn.go b/dsn.go index d0fbf3bd9..65f5a0242 100644 --- a/dsn.go +++ b/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/rsa" "crypto/tls" "errors" @@ -71,8 +72,9 @@ type Config struct { // unexported fields. new options should be come here - pubKey *rsa.PublicKey // Server public key - timeTruncate time.Duration // Truncate time.Time values to the specified duration + beforeConnect func(context.Context, *Config) error // Invoked before a connection is established + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration } // Functional Options Pattern @@ -112,6 +114,14 @@ func TimeTruncate(d time.Duration) Option { } } +// BeforeConnect sets the function to be invoked before a connection is established. +func BeforeConnect(fn func(context.Context, *Config) error) Option { + return func(cfg *Config) error { + cfg.beforeConnect = fn + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil {