diff --git a/.golangci.yml b/.golangci.yml index 67db52444..2919b9b4d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,13 +2,10 @@ linters: disable-all: true enable: # All code is ready for: - - deadcode - errcheck - staticcheck - - structcheck - typecheck - unused - - varcheck - misspell - nolintlint - goimports diff --git a/client/auth.go b/client/auth.go index 1f4d7c1de..5a8f3937a 100644 --- a/client/auth.go +++ b/client/auth.go @@ -284,10 +284,6 @@ func (c *Conn) writeAuthHandshake() error { // the 23 bytes of filler is used to send the right middle 8 bits of the collation id. // see https://github.com/mysql/mysql-server/pull/541 data[12] = byte(collation.ID & 0xff) - // if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of - // padding the filler with a 0. If ID is > 255 then the first byte of filler will contain - // the right middle 8 bits of the collation ID. - data[13] = byte((collation.ID & 0xff00) >> 8) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -309,12 +305,8 @@ func (c *Conn) writeAuthHandshake() error { } // Filler [23 bytes] (all 0x00) - // the filler starts at position 13, but the first byte of the filler - // has been set with the collation id earlier, so position 13 at this point - // will be either 0x00, or the right middle 8 bits of the collation id. - // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. - pos := 14 - for ; pos < 14+22; pos++ { + pos := 13 + for ; pos < 13+23; pos++ { data[pos] = 0 } diff --git a/client/auth_test.go b/client/auth_test.go index 0837f1767..49ae7d2da 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -66,16 +66,16 @@ func TestConnCollation(t *testing.T) { // if the collation ID is <= 255 the collation ID is stored in the 12th byte if collation.ID <= 255 { require.Equal(t, byte(collation.ID), handShakeResponse[12]) - // the 13th byte should always be 0x00 - require.Equal(t, byte(0x00), handShakeResponse[13]) } else { - // if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes + // if the collation ID is > 255 the collation ID should just be the lower-8 bits require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12]) - require.Equal(t, byte(collation.ID>>8), handShakeResponse[13]) } + // the 13th byte should always be 0x00 + require.Equal(t, byte(0x00), handShakeResponse[13]) + // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly - for i := 14; i < 14+22; i++ { + for i := 13; i < 13+23; i++ { require.Equal(t, byte(0x00), handShakeResponse[i]) } diff --git a/client/conn.go b/client/conn.go index 9fc7faf16..bef9b2de9 100644 --- a/client/conn.go +++ b/client/conn.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" @@ -133,6 +134,21 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st c.Conn.Compression = MYSQL_COMPRESS_ZSTD } + // if a collation was set with a ID of > 255, then we need to call SET NAMES ... + // since the auth handshake response only support collations with 1-byte ids + if len(c.collation) != 0 { + collation, err := charset.GetCollationByName(c.collation) + if err != nil { + return nil, errors.Trace(fmt.Errorf("invalid collation name %s", c.collation)) + } + + if collation.ID > 255 { + if _, err := c.exec(fmt.Sprintf("SET NAMES %s COLLATE %s", c.charset, c.collation)); err != nil { + return nil, errors.Trace(err) + } + } + } + return c, nil }