Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1029631: Optimize streaming PUT memory usage #1266

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ func (bu *bindUploader) uploadStreamInternal(
// prepare context for PUT command
ctx := WithFileStream(bu.ctx, inputStream)
ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{
compressSourceFromStream: compressData})
compressSourceFromStream: compressData,
arrayBindFromStream: true})
return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{})
}

Expand Down
50 changes: 37 additions & 13 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (sc *snowflakeConn) processFileTransfer(
options := &SnowflakeFileTransferOptions{
RaisePutGetError: true,
}
var err error
var fs *bytes.Buffer

sfa := snowflakeFileTransferAgent{
ctx: ctx,
sc: sc,
Expand All @@ -103,7 +106,17 @@ func (sc *snowflakeConn) processFileTransfer(
options: options,
streamBuffer: new(bytes.Buffer),
}
fs, err := getFileStream(ctx)
if op := getFileTransferOptions(ctx); op != nil {
sfa.options = op
}
if sfa.options.MultiPartThreshold == 0 {
sfa.options.MultiPartThreshold = dataSizeThreshold
}
if sfa.options.arrayBindFromStream {
fs, err = getFileStreamAll(ctx)
} else {
fs, err = getFileStream(ctx)
}
if err != nil {
return nil, err
}
Expand All @@ -113,13 +126,7 @@ func (sc *snowflakeConn) processFileTransfer(
sfa.data.AutoCompress = false
}
}
if op := getFileTransferOptions(ctx); op != nil {
sfa.options = op
}
if sfa.options.MultiPartThreshold == 0 {
sfa.options.MultiPartThreshold = dataSizeThreshold
}
if err := sfa.execute(); err != nil {
if err = sfa.execute(); err != nil {
return nil, err
}
data, err = sfa.result()
Expand All @@ -134,20 +141,37 @@ func (sc *snowflakeConn) processFileTransfer(
return data, nil
}

func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
func getReaderFromContext(ctx context.Context) io.Reader {
s := ctx.Value(fileStreamFile)
if s == nil {
return nil, nil
}
r, ok := s.(io.Reader)
if !ok {
return nil, errors.New("incorrect io.Reader")
return nil
}
return r
}

func getFileStreamAll(ctx context.Context) (*bytes.Buffer, error) {
r := getReaderFromContext(ctx)
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(r)
return buf, err
}

func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
r := getReaderFromContext(ctx)
if r == nil {
return nil, nil
}

// read a small amount of data to check if file stream will be used
buf := make([]byte, defaultStringBufferSize)
n, err := r.Read(buf)
if err != nil {
return nil, err
}
return bytes.NewBuffer(buf[:n]), nil
}

func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {
v := ctx.Value(fileTransferOptions)
if v == nil {
Expand Down
37 changes: 32 additions & 5 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type SnowflakeFileTransferOptions struct {

/* streaming PUT */
compressSourceFromStream bool
arrayBindFromStream bool

/* streaming GET */
GetFileToStream bool
Expand Down Expand Up @@ -463,7 +464,8 @@ func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error {
if currentFileCompressionType == nil {
var mtype *mimetype.MIME
var err error
if meta.srcStream != nil {
_, err = os.Stat(fileName)
if os.IsNotExist(err) {
r := getReaderFromBuffer(&meta.srcStream)
mtype, err = mimetype.DetectReader(r)
if err != nil {
Expand All @@ -474,9 +476,9 @@ func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error {
}
} else {
mtype, err = mimetype.DetectFile(fileName)
if err != nil {
return err
}
}
if err != nil {
return err
}
currentFileCompressionType = lookupByExtension(mtype.Extension())
}
Expand Down Expand Up @@ -869,7 +871,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM
fileUtil := new(snowflakeFileUtil)
if meta.requireCompress {
if meta.srcStream != nil {
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(&meta.srcStream)
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx, &meta.srcStream)
} else {
meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir)
}
Expand All @@ -880,8 +882,33 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM

if meta.srcStream != nil {
if meta.realSrcStream != nil {
// the file has been fully read in compressFileWithGzipFromStream
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream)
} else {
r := getReaderFromContext(sfa.ctx)
if r == nil {
return nil, errors.New("failed to get the reader from context")
}

var fullSrcStream bytes.Buffer
if _, err = io.Copy(&fullSrcStream, meta.srcStream); err != nil {
return nil, err
}

// continue reading the rest of the data in chunks
chunk := make([]byte, fileChunkSize)
for {
n, err := r.Read(chunk)
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
fullSrcStream.Write(chunk[:n])
}
if _, err = io.Copy(meta.srcStream, &fullSrcStream); err != nil {
return nil, err
}
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.srcStream)
}
} else {
Expand Down
7 changes: 4 additions & 3 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/service/s3"
"io"
"net/url"
"os"
Expand All @@ -18,6 +17,8 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/service/s3"

"github.com/aws/smithy-go"
)

Expand Down Expand Up @@ -980,8 +981,8 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) {

if isStream {
fileStream, _ := os.Open(uploadFile)
ctx := WithFileStream(context.Background(), fileStream)
uploadMeta.srcStream, err = getFileStream(ctx)
sfa.ctx = WithFileStream(context.Background(), fileStream)
uploadMeta.srcStream, err = getFileStream(sfa.ctx)
assertNilF(t, err)
}

Expand Down
39 changes: 31 additions & 8 deletions file_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ package gosnowflake
import (
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"net/url"
"os"
Expand All @@ -23,17 +25,38 @@ const (
readWriteFileMode os.FileMode = 0666
)

func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes.Buffer) (*bytes.Buffer, int, error) {
r := getReaderFromBuffer(srcStream)
buf, err := io.ReadAll(r)
if err != nil {
return nil, -1, err
}
func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context, srcStream **bytes.Buffer) (*bytes.Buffer, int, error) {
var c bytes.Buffer
w := gzip.NewWriter(&c)
if _, err := w.Write(buf); err != nil { // write buf to gzip writer
r := getReaderFromContext(ctx)
if r == nil {
return nil, -1, errors.New("failed to get the reader from context")
}

// compress the first chunk of data which was read before
var streamBuf bytes.Buffer
if _, err := io.Copy(&streamBuf, *srcStream); err != nil {
return nil, -1, err
}
if _, err := w.Write(streamBuf.Bytes()); err != nil {
return nil, -1, err
}

// continue reading the rest of the data in chunks
chunk := make([]byte, fileChunkSize)
for {
n, err := r.Read(chunk)
if err == io.EOF {
break
}
if err != nil {
return nil, -1, err
}
// write chunk to gzip writer
if _, err = w.Write(chunk[:n]); err != nil {
return nil, -1, err
}
}
if err := w.Close(); err != nil {
return nil, -1, err
}
Expand Down Expand Up @@ -76,10 +99,10 @@ func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir stri
}

func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) {

m := sha256.New()
r := getReaderFromBuffer(stream)
chunk := make([]byte, fileChunkSize)

for {
n, err := r.Read(chunk)
if err == io.EOF {
Expand Down
Loading