Skip to content

Commit

Permalink
support overwrite option in PUT
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl committed Oct 25, 2023
1 parent 8ed5c10 commit f5a0ec6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
3 changes: 1 addition & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type SnowflakeFileTransferOptions struct {
compressSourceFromStream bool

/* PUT */
DisablePutOverwrite bool
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
putCallbackOutputStream *io.Writer
Expand Down Expand Up @@ -268,7 +267,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error {
if sfa.data.Parallel != 0 {
sfa.parallel = sfa.data.Parallel
}
sfa.overwrite = !sfa.options.DisablePutOverwrite
sfa.overwrite = sfa.data.Overwrite
sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType))
sfa.stageInfo = &sfa.data.StageInfo
sfa.presignedURLs = make([]string, 0)
Expand Down
3 changes: 0 additions & 3 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1},
},
},
options: &SnowflakeFileTransferOptions{
DisablePutOverwrite: false,
},
}

err = sfa.parseCommand()
Expand Down
30 changes: 25 additions & 5 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,18 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
}
uploadTime := s3

f, _ = os.Open(testData)
ctx := WithFileTransferOptions(context.Background(),
&SnowflakeFileTransferOptions{
DisablePutOverwrite: true,
})
rows = dbt.mustQueryContext(
WithFileStream(ctx, f),
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
Expand All @@ -341,11 +346,23 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected SKIPPED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
}
if s3 != uploadTime {
t.Fatalf("upload time should have stayed the same, expected: %v, got: %v", uploadTime, s3)
}

f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
Expand All @@ -366,6 +383,9 @@ func TestPutOverwrite(t *testing.T) {
if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
}
if s3 == uploadTime {
t.Fatalf("file should have been overwritten.")
}
})
}

Expand Down
1 change: 1 addition & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type execResponseData struct {
Parallel int64 `json:"parallel,omitempty"`
Threshold int64 `json:"threshold,omitempty"`
AutoCompress bool `json:"autoCompress,omitempty"`
Overwrite bool `json:"overwrite,omitempty"`
SourceCompression string `json:"sourceCompression,omitempty"`
ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"`
EncryptionMaterial encryptionWrapper `json:"encryptionMaterial,omitempty"`
Expand Down
6 changes: 1 addition & 5 deletions s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string
if errors.As(err, &ae) {
if ae.ErrorCode() == notFound {
meta.resStatus = notFoundFile
return &fileHeader{
digest: "",
contentLength: 0,
encryptionMetadata: nil,
}, nil
return nil, fmt.Errorf("could not find file")
} else if ae.ErrorCode() == expiredToken {
meta.resStatus = renewToken
return nil, fmt.Errorf("received expired token. renewing")
Expand Down
7 changes: 6 additions & 1 deletion storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error {
for retry := 0; retry < maxRetry; retry++ {
if !meta.overwrite {
header, err := utilClass.getFileHeader(meta, meta.dstFileName)
if err != nil {
if meta.resStatus == notFoundFile {
err := utilClass.uploadFile(dataFile, meta, encryptMeta, maxConcurrency, meta.options.MultiPartThreshold)
if err != nil {
logger.Debugf("Error uploading %v. err: %v", dataFile, err)
}
} else if err != nil {
return err
}
if header != nil && meta.resStatus == uploaded {
Expand Down

0 comments on commit f5a0ec6

Please sign in to comment.