diff --git a/client/client_test.go b/client/client_test.go index 10515e622..aaf72ff42 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -89,6 +89,17 @@ func (s *clientTestSuite) TestConn_Ping() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_Compress() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + conn.SetCapability(mysql.CLIENT_COMPRESS) + }) + require.NoError(s.T(), err) + + _, err = conn.Execute("SELECT VERSION()") + require.NoError(s.T(), err) +} + func (s *clientTestSuite) TestConn_SetCapability() { caps := []uint32{ mysql.CLIENT_LONG_PASSWORD, diff --git a/packet/conn.go b/packet/conn.go index 7250623fa..6096d4f06 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -67,6 +67,8 @@ type Conn struct { compressedHeader [7]byte compressedReader io.Reader + + compressedReaderActive bool } func NewConn(conn net.Conn) *Conn { @@ -106,12 +108,19 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { }() if c.Compression != MYSQL_COMPRESS_NONE { - if c.compressedReader == nil { + // it's possible that we're using compression but the server response with a compressed + // packet with uncompressed length of 0. In this case we leave compressedReader nil. The + // compressedReaderActive flag is important to track the state of the reader, allowing + // for the compressedReader to be reset after a packet write. Without this flag, when a + // compressed packet with uncompressed length of 0 is read, the compressedReader would + // be nil, and we'd incorrectly attempt to read the next packet as compressed. + if !c.compressedReaderActive { var err error c.compressedReader, err = c.newCompressedPacketReader() if err != nil { return nil, err } + c.compressedReaderActive = true } } @@ -312,6 +321,7 @@ func (c *Conn) WritePacket(data []byte) error { return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) } c.compressedReader = nil + c.compressedReaderActive = false default: return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") }