Skip to content

Commit

Permalink
support overwrite option in PUT
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl authored and sfc-gh-pfus committed Oct 25, 2023
1 parent 99921d4 commit afb952c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 14 deletions.
4 changes: 4 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string
fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...))
}

func assertTrueF(t *testing.T, actual bool, descriptions ...string) {
fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...))
}

func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}
Expand Down
5 changes: 5 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,5 +960,10 @@ Remember, to encode slashes.
Example:
u:[email protected]/db/s?account=a.r.c&tmpDirPath=%2Fother%2Ftmp
## Using custom configuration for PUT/GET
If you want to override some default configuration options, you can use `WithFileTransferOptions` context.
There are multiple config parameters including progress bars or compression.
*/
package gosnowflake
3 changes: 1 addition & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type SnowflakeFileTransferOptions struct {
compressSourceFromStream bool

/* PUT */
DisablePutOverwrite bool
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
putCallbackOutputStream *io.Writer
Expand Down Expand Up @@ -268,7 +267,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error {
if sfa.data.Parallel != 0 {
sfa.parallel = sfa.data.Parallel
}
sfa.overwrite = !sfa.options.DisablePutOverwrite
sfa.overwrite = sfa.data.Overwrite
sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType))
sfa.stageInfo = &sfa.data.StageInfo
sfa.presignedURLs = make([]string, 0)
Expand Down
3 changes: 0 additions & 3 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1},
},
},
options: &SnowflakeFileTransferOptions{
DisablePutOverwrite: false,
},
}

err = sfa.parseCommand()
Expand Down
36 changes: 27 additions & 9 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,17 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
uploadTime := s3

f, _ = os.Open(testData)
ctx := WithFileTransferOptions(context.Background(),
&SnowflakeFileTransferOptions{
DisablePutOverwrite: true,
})
rows = dbt.mustQueryContext(
WithFileStream(ctx, f),
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
Expand All @@ -341,11 +345,23 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected SKIPPED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
}
if s3 != uploadTime {
t.Fatalf("upload time should have stayed the same, expected: %v, got: %v", uploadTime, s3)
}

f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
Expand All @@ -358,14 +374,16 @@ func TestPutOverwrite(t *testing.T) {

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
}
if s3 == uploadTime {
t.Fatalf("file should have been overwritten.")
}
})
}

Expand Down
1 change: 1 addition & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type execResponseData struct {
Parallel int64 `json:"parallel,omitempty"`
Threshold int64 `json:"threshold,omitempty"`
AutoCompress bool `json:"autoCompress,omitempty"`
Overwrite bool `json:"overwrite,omitempty"`
SourceCompression string `json:"sourceCompression,omitempty"`
ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"`
EncryptionMaterial encryptionWrapper `json:"encryptionMaterial,omitempty"`
Expand Down
5 changes: 5 additions & 0 deletions storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error {
header, err := utilClass.getFileHeader(meta, meta.dstFileName)
if err != nil {
return err
} else if meta.resStatus == notFoundFile {
err := utilClass.uploadFile(dataFile, meta, encryptMeta, maxConcurrency, meta.options.MultiPartThreshold)
if err != nil {
logger.Warnf("Error uploading %v. err: %v", dataFile, err)
}
}
if header != nil && meta.resStatus == uploaded {
meta.dstFileSize = 0
Expand Down

0 comments on commit afb952c

Please sign in to comment.