Skip to content

Commit

Permalink
fixing bad connection error when reading large compressed packets (#863)
Browse files Browse the repository at this point in the history
* fixing bad connection error when reading large compressed packets

* fixing linting errors

* minor cleanup and some more comments

* minor cleanup and some more comments

* fixing issue when net_buffer_length=1024

* fixing packet reader lookup condition

* handle possible nil access violation when attempting to read next compressed packet

* removed deprecated linters that no longer exist in golangci-lint 1.58.0

* addressing PR feedback

* addressing PR feedback

* removed compressedReaderActive

---------

Co-authored-by: dvilaverde <[email protected]>
Co-authored-by: lance6716 <[email protected]>
  • Loading branch information
3 people authored May 7, 2024
1 parent 0ad0d03 commit 007f306
Showing 1 changed file with 83 additions and 42 deletions.
125 changes: 83 additions & 42 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/sha1"
"crypto/x509"
"encoding/pem"
goErrors "errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -65,8 +66,6 @@ type Conn struct {

compressedHeader [7]byte

compressedReaderActive bool

compressedReader io.Reader
}

Expand Down Expand Up @@ -107,42 +106,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
}()

if c.Compression != MYSQL_COMPRESS_NONE {
if !c.compressedReaderActive {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

if uncompressedLength > 0 {
var err error
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
c.compressedReader, err = zlib.NewReader(c.reader)
case MYSQL_COMPRESS_ZSTD:
c.compressedReader, err = zstd.NewReader(c.reader)
}
if err != nil {
return nil, err
}
if c.compressedReader == nil {
var err error
c.compressedReader, err = c.newCompressedPacketReader()
if err != nil {
return nil, err
}
c.compressedReaderActive = true
}
}

if c.compressedReader != nil {
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
return nil, errors.Trace(err)
}
} else {
if err := c.ReadPacketTo(buf, c.reader); err != nil {
return nil, errors.Trace(err)
}
if err := c.ReadPacketTo(buf); err != nil {
return nil, errors.Trace(err)
}

readBytes := buf.Bytes()
Expand All @@ -167,22 +141,78 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
return result, nil
}

func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
// newCompressedPacketReader creates a new compressed packet reader.
func (c *Conn) newCompressedPacketReader() (io.Reader, error) {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16)
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if uncompressedLength > 0 {
limitedReader := io.LimitReader(c.reader, int64(compressedLength))
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
return zlib.NewReader(limitedReader)
case MYSQL_COMPRESS_ZSTD:
return zstd.NewReader(limitedReader)
}
}

return nil, nil
}

func (c *Conn) currentPacketReader() io.Reader {
if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil {
return c.reader
} else {
return c.compressedReader
}
}

func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) {
var written int64

for n > 0 {
bcap := cap(c.copyNBuf)
if int64(bcap) > n {
bcap = int(n)
}
buf := c.copyNBuf[:bcap]

rd, err := io.ReadAtLeast(src, buf, bcap)
// Call ReadAtLeast with the currentPacketReader as it may change on every iteration
// of this loop.
rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap)

n -= int64(rd)

// ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
// bytes are read. In this case, and when we have compression then advance
// the sequence number and reset the compressed reader to continue reading
// the remaining bytes in the next compressed packet.
if c.Compression != MYSQL_COMPRESS_NONE &&
(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) {
// we have read to EOF and read an incomplete uncompressed packet
// so advance the compressed sequence number and reset the compressed reader
// to get the remaining unread uncompressed bytes from the next compressed packet.
c.CompressedSequence++
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil {
return written, errors.Trace(err)
}
}

if err != nil {
return written, errors.Trace(err)
}

wr, err := dst.Write(buf)
// careful to only write from the buffer the number of bytes read
wr, err := dst.Write(buf[:rd])
written += int64(wr)
if err != nil {
return written, errors.Trace(err)
Expand All @@ -192,9 +222,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
return written, nil
}

func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
func (c *Conn) ReadPacketTo(w io.Writer) error {
b := utils.BytesBufferGet()
defer func() {
utils.BytesBufferPut(b)
}()

// packets that come in a compressed packet may be partial
// so use the copyN function to read the packet header into a
// buffer, since copyN is capable of getting the next compressed
// packet and updating the Conn state with a new compressedReader.
if _, err := c.copyN(b, 4); err != nil {
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
} else {
// copy was successful so copy the 4 bytes from the buffer to the header
copy(c.header[:4], b.Bytes()[:4])
}

length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
Expand All @@ -211,7 +253,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
buf.Grow(length)
}

if n, err := c.copyN(w, r, int64(length)); err != nil {
if n, err := c.copyN(w, int64(length)); err != nil {
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
} else if n != int64(length) {
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
Expand All @@ -220,7 +262,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
return nil
}

if err = c.ReadPacketTo(w, r); err != nil {
if err = c.ReadPacketTo(w); err != nil {
return errors.Wrap(err, "ReadPacketTo failed")
}
}
Expand Down Expand Up @@ -270,7 +312,6 @@ func (c *Conn) WritePacket(data []byte) error {
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
}
c.compressedReader = nil
c.compressedReaderActive = false
default:
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
}
Expand Down

0 comments on commit 007f306

Please sign in to comment.