From 7f346bb5cd0f45ea80ea07f2b30a9a2b0c294fd5 Mon Sep 17 00:00:00 2001 From: Joyce Ling <115662568+sfc-gh-ext-simba-jl@users.noreply.github.com> Date: Wed, 11 Dec 2024 01:54:20 -0800 Subject: [PATCH] fix test failures --- bind_uploader.go | 3 ++- connection_util.go | 28 +++++++++++++++++++++------- file_transfer_agent.go | 42 +++++++++++++++++++++++++++++++++++++++--- file_util.go | 37 +++++++++++++++++-------------------- 4 files changed, 79 insertions(+), 31 deletions(-) diff --git a/bind_uploader.go b/bind_uploader.go index 04a266a8e..b9728372d 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -96,7 +96,8 @@ func (bu *bindUploader) uploadStreamInternal( // prepare context for PUT command ctx := WithFileStream(bu.ctx, inputStream) ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{ - compressSourceFromStream: compressData}) + compressSourceFromStream: compressData, + arrayBindFromStream: true}) return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{}) } diff --git a/connection_util.go b/connection_util.go index 36022e20b..b2b884cdb 100644 --- a/connection_util.go +++ b/connection_util.go @@ -95,6 +95,9 @@ func (sc *snowflakeConn) processFileTransfer( options := &SnowflakeFileTransferOptions{ RaisePutGetError: true, } + var err error + var fs *bytes.Buffer + sfa := snowflakeFileTransferAgent{ ctx: ctx, sc: sc, @@ -103,7 +106,17 @@ func (sc *snowflakeConn) processFileTransfer( options: options, streamBuffer: new(bytes.Buffer), } - fs, err := getFileStream(ctx) + if op := getFileTransferOptions(ctx); op != nil { + sfa.options = op + } + if sfa.options.MultiPartThreshold == 0 { + sfa.options.MultiPartThreshold = dataSizeThreshold + } + if sfa.options.arrayBindFromStream { + fs = getFileStreamAll(ctx) + } else { + fs, err = getFileStream(ctx) + } if err != nil { return nil, err } @@ -113,12 +126,6 @@ func (sc *snowflakeConn) processFileTransfer( sfa.data.AutoCompress = false } } - if op := getFileTransferOptions(ctx); op != nil { - sfa.options = op - } - if sfa.options.MultiPartThreshold == 0 { - sfa.options.MultiPartThreshold = dataSizeThreshold - } if err = sfa.execute(); err != nil { return nil, err } @@ -143,6 +150,13 @@ func getReaderFromContext(ctx context.Context) io.Reader { return r } +func getFileStreamAll(ctx context.Context) *bytes.Buffer { + r := getReaderFromContext(ctx) + buf := new(bytes.Buffer) + buf.ReadFrom(r) + return buf +} + func getFileStream(ctx context.Context) (*bytes.Buffer, error) { r := getReaderFromContext(ctx) if r == nil { diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 825bd239d..4f5d522c2 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -88,6 +88,7 @@ type SnowflakeFileTransferOptions struct { /* streaming PUT */ compressSourceFromStream bool + arrayBindFromStream bool /* streaming GET */ GetFileToStream bool @@ -463,7 +464,17 @@ func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error { if currentFileCompressionType == nil { var mtype *mimetype.MIME var err error - mtype, err = mimetype.DetectFile(fileName) + _, err = os.Stat(fileName) + if os.IsNotExist(err) { + r := getReaderFromBuffer(&meta.srcStream) + mtype, err = mimetype.DetectReader(r) + if err != nil { + return err + } + io.ReadAll(r) // flush out tee buffer + } else { + mtype, err = mimetype.DetectFile(fileName) + } if err != nil { return err } @@ -847,7 +858,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM fileUtil := new(snowflakeFileUtil) if meta.requireCompress { if meta.srcStream != nil { - meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx) + meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx, &meta.srcStream) } else { meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir) } @@ -857,7 +868,32 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM } if meta.srcStream != nil { - meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(sfa.ctx, &meta.realSrcStream, &meta.srcStream) + if meta.realSrcStream != nil { + // the whole file has been read in compressFileWithGzipFromStream + meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream) + } else { + r := getReaderFromContext(sfa.ctx) + if r == nil { + return nil, errors.New("failed to get the reader from context") + } + + var fullSrcStream bytes.Buffer + io.Copy(&fullSrcStream, meta.srcStream) + + // continue reading the rest of the data in chunks + chunk := make([]byte, fileChunkSize) + for { + n, err := r.Read(chunk) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + fullSrcStream.Write(chunk[:n]) + } + io.Copy(meta.srcStream, &fullSrcStream) + meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.srcStream) + } } else { meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForFile(meta.realSrcFileName) } diff --git a/file_util.go b/file_util.go index 401293770..0215e145a 100644 --- a/file_util.go +++ b/file_util.go @@ -8,6 +8,7 @@ import ( "context" "crypto/sha256" "encoding/base64" + "errors" "io" "net/url" "os" @@ -24,26 +25,33 @@ const ( readWriteFileMode os.FileMode = 0666 ) -func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context) (*bytes.Buffer, int, error) { +func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context, srcStream **bytes.Buffer) (*bytes.Buffer, int, error) { var c bytes.Buffer w := gzip.NewWriter(&c) - buf := make([]byte, fileChunkSize) r := getReaderFromContext(ctx) if r == nil { - return nil, -1, nil + return nil, -1, errors.New("failed to get the reader from context") } - // read the whole file in chunks + // compress the first chunk of data which was read before + var streamBuf bytes.Buffer + io.Copy(&streamBuf, *srcStream) + if _, err := w.Write(streamBuf.Bytes()); err != nil { + return nil, -1, err + } + + // continue reading the rest of the data in chunks + chunk := make([]byte, fileChunkSize) for { - n, err := r.Read(buf) + n, err := r.Read(chunk) if err == io.EOF { break } if err != nil { return nil, -1, err } - // write buf to gzip writer - if _, err = w.Write(buf[:n]); err != nil { + // write chunk to gzip writer + if _, err = w.Write(chunk[:n]); err != nil { return nil, -1, err } } @@ -88,21 +96,10 @@ func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir stri return gzipFileName, stat.Size(), err } -func (util *snowflakeFileUtil) getDigestAndSizeForStream(ctx context.Context, realSrcStream **bytes.Buffer, srcStream **bytes.Buffer) (string, int64, error) { - var r io.Reader - var stream **bytes.Buffer - if realSrcStream != nil { - r = getReaderFromBuffer(srcStream) - stream = realSrcStream - } else { - r = getReaderFromContext(ctx) - stream = srcStream - } - if r == nil { - return "", 0, nil - } +func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) { m := sha256.New() + r := getReaderFromBuffer(stream) chunk := make([]byte, fileChunkSize) for { n, err := r.Read(chunk)