Skip to content

Commit

Permalink
fix test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl committed Dec 11, 2024
1 parent a841c0e commit 7f346bb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 31 deletions.
3 changes: 2 additions & 1 deletion bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ func (bu *bindUploader) uploadStreamInternal(
// prepare context for PUT command
ctx := WithFileStream(bu.ctx, inputStream)
ctx = WithFileTransferOptions(ctx, &SnowflakeFileTransferOptions{
compressSourceFromStream: compressData})
compressSourceFromStream: compressData,
arrayBindFromStream: true})
return bu.sc.exec(ctx, putCommand, false, true, false, []driver.NamedValue{})
}

Expand Down
28 changes: 21 additions & 7 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ func (sc *snowflakeConn) processFileTransfer(
options := &SnowflakeFileTransferOptions{
RaisePutGetError: true,
}
var err error
var fs *bytes.Buffer

sfa := snowflakeFileTransferAgent{
ctx: ctx,
sc: sc,
Expand All @@ -103,7 +106,17 @@ func (sc *snowflakeConn) processFileTransfer(
options: options,
streamBuffer: new(bytes.Buffer),
}
fs, err := getFileStream(ctx)
if op := getFileTransferOptions(ctx); op != nil {
sfa.options = op
}
if sfa.options.MultiPartThreshold == 0 {
sfa.options.MultiPartThreshold = dataSizeThreshold
}
if sfa.options.arrayBindFromStream {
fs = getFileStreamAll(ctx)
} else {
fs, err = getFileStream(ctx)
}
if err != nil {
return nil, err
}
Expand All @@ -113,12 +126,6 @@ func (sc *snowflakeConn) processFileTransfer(
sfa.data.AutoCompress = false
}
}
if op := getFileTransferOptions(ctx); op != nil {
sfa.options = op
}
if sfa.options.MultiPartThreshold == 0 {
sfa.options.MultiPartThreshold = dataSizeThreshold
}
if err = sfa.execute(); err != nil {
return nil, err
}
Expand All @@ -143,6 +150,13 @@ func getReaderFromContext(ctx context.Context) io.Reader {
return r
}

func getFileStreamAll(ctx context.Context) *bytes.Buffer {
r := getReaderFromContext(ctx)
buf := new(bytes.Buffer)
buf.ReadFrom(r)

Check failure on line 156 in connection_util.go

View workflow job for this annotation

GitHub Actions / Check linter

Error return value of `buf.ReadFrom` is not checked (errcheck)
return buf
}

func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
r := getReaderFromContext(ctx)
if r == nil {
Expand Down
42 changes: 39 additions & 3 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type SnowflakeFileTransferOptions struct {

/* streaming PUT */
compressSourceFromStream bool
arrayBindFromStream bool

/* streaming GET */
GetFileToStream bool
Expand Down Expand Up @@ -463,7 +464,17 @@ func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error {
if currentFileCompressionType == nil {
var mtype *mimetype.MIME
var err error
mtype, err = mimetype.DetectFile(fileName)
_, err = os.Stat(fileName)
if os.IsNotExist(err) {
r := getReaderFromBuffer(&meta.srcStream)
mtype, err = mimetype.DetectReader(r)
if err != nil {
return err
}
io.ReadAll(r) // flush out tee buffer

Check failure on line 474 in file_transfer_agent.go

View workflow job for this annotation

GitHub Actions / Check linter

Error return value of `io.ReadAll` is not checked (errcheck)
} else {
mtype, err = mimetype.DetectFile(fileName)
}
if err != nil {
return err
}
Expand Down Expand Up @@ -847,7 +858,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM
fileUtil := new(snowflakeFileUtil)
if meta.requireCompress {
if meta.srcStream != nil {
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx)
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx, &meta.srcStream)
} else {
meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir)
}
Expand All @@ -857,7 +868,32 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM
}

if meta.srcStream != nil {
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(sfa.ctx, &meta.realSrcStream, &meta.srcStream)
if meta.realSrcStream != nil {
// the whole file has been read in compressFileWithGzipFromStream
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream)
} else {
r := getReaderFromContext(sfa.ctx)
if r == nil {
return nil, errors.New("failed to get the reader from context")
}

var fullSrcStream bytes.Buffer
io.Copy(&fullSrcStream, meta.srcStream)

Check failure on line 881 in file_transfer_agent.go

View workflow job for this annotation

GitHub Actions / Check linter

Error return value of `io.Copy` is not checked (errcheck)

// continue reading the rest of the data in chunks
chunk := make([]byte, fileChunkSize)
for {
n, err := r.Read(chunk)
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
fullSrcStream.Write(chunk[:n])
}
io.Copy(meta.srcStream, &fullSrcStream)

Check failure on line 894 in file_transfer_agent.go

View workflow job for this annotation

GitHub Actions / Check linter

Error return value of `io.Copy` is not checked (errcheck)
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.srcStream)
}
} else {
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForFile(meta.realSrcFileName)
}
Expand Down
37 changes: 17 additions & 20 deletions file_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"net/url"
"os"
Expand All @@ -24,26 +25,33 @@ const (
readWriteFileMode os.FileMode = 0666
)

func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context) (*bytes.Buffer, int, error) {
func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context, srcStream **bytes.Buffer) (*bytes.Buffer, int, error) {
var c bytes.Buffer
w := gzip.NewWriter(&c)
buf := make([]byte, fileChunkSize)
r := getReaderFromContext(ctx)
if r == nil {
return nil, -1, nil
return nil, -1, errors.New("failed to get the reader from context")
}

// read the whole file in chunks
// compress the first chunk of data which was read before
var streamBuf bytes.Buffer
io.Copy(&streamBuf, *srcStream)

Check failure on line 38 in file_util.go

View workflow job for this annotation

GitHub Actions / Check linter

Error return value of `io.Copy` is not checked (errcheck)
if _, err := w.Write(streamBuf.Bytes()); err != nil {
return nil, -1, err
}

// continue reading the rest of the data in chunks
chunk := make([]byte, fileChunkSize)
for {
n, err := r.Read(buf)
n, err := r.Read(chunk)
if err == io.EOF {
break
}
if err != nil {
return nil, -1, err
}
// write buf to gzip writer
if _, err = w.Write(buf[:n]); err != nil {
// write chunk to gzip writer
if _, err = w.Write(chunk[:n]); err != nil {
return nil, -1, err
}
}
Expand Down Expand Up @@ -88,21 +96,10 @@ func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir stri
return gzipFileName, stat.Size(), err
}

func (util *snowflakeFileUtil) getDigestAndSizeForStream(ctx context.Context, realSrcStream **bytes.Buffer, srcStream **bytes.Buffer) (string, int64, error) {
var r io.Reader
var stream **bytes.Buffer
if realSrcStream != nil {
r = getReaderFromBuffer(srcStream)
stream = realSrcStream
} else {
r = getReaderFromContext(ctx)
stream = srcStream
}
if r == nil {
return "", 0, nil
}
func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) {

m := sha256.New()
r := getReaderFromBuffer(stream)
chunk := make([]byte, fileChunkSize)
for {
n, err := r.Read(chunk)
Expand Down

0 comments on commit 7f346bb

Please sign in to comment.