Skip to content

Commit

Permalink
support overwrite option in PUT
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jl authored and sfc-gh-pfus committed Oct 25, 2023
1 parent 99921d4 commit ae5fb54
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 36 deletions.
4 changes: 4 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string
fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...))
}

func assertTrueF(t *testing.T, actual bool, descriptions ...string) {
fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...))
}

func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}
Expand Down
5 changes: 5 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,5 +960,10 @@ Remember, to encode slashes.
Example:
u:[email protected]/db/s?account=a.r.c&tmpDirPath=%2Fother%2Ftmp
## Using custom configuration for PUT/GET
If you want to override some default configuration options, you can use `WithFileTransferOptions` context.
There are multiple config parameters including progress bars or compression.
*/
package gosnowflake
3 changes: 1 addition & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type SnowflakeFileTransferOptions struct {
compressSourceFromStream bool

/* PUT */
DisablePutOverwrite bool
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
putCallbackOutputStream *io.Writer
Expand Down Expand Up @@ -268,7 +267,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error {
if sfa.data.Parallel != 0 {
sfa.parallel = sfa.data.Parallel
}
sfa.overwrite = !sfa.options.DisablePutOverwrite
sfa.overwrite = sfa.data.Overwrite
sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType))
sfa.stageInfo = &sfa.data.StageInfo
sfa.presignedURLs = make([]string, 0)
Expand Down
3 changes: 0 additions & 3 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1},
},
},
options: &SnowflakeFileTransferOptions{
DisablePutOverwrite: false,
},
}

err = sfa.parseCommand()
Expand Down
63 changes: 39 additions & 24 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ func TestPutWithAutoCompressFalse(t *testing.T) {
}

func TestPutOverwrite(t *testing.T) {
if runningOnGCP() {
t.Skip("Overwriting is default as long as presigned URLs are enabled")
}
tmpDir := t.TempDir()
testData := filepath.Join(tmpDir, "data.txt")
f, err := os.Create(testData)
Expand All @@ -298,9 +301,6 @@ func TestPutOverwrite(t *testing.T) {
f.Close()

runDBTest(t, func(dbt *DBTest) {
if _, err = dbt.exec("use role sysadmin"); err != nil {
t.Skip("snowflake admin account not accessible")
}
dbt.mustExec("rm @~/test_put_overwrite")

f, _ = os.Open(testData)
Expand All @@ -321,51 +321,67 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
uploadTime := s3

f, _ = os.Open(testData)
ctx := WithFileTransferOptions(context.Background(),
&SnowflakeFileTransferOptions{
DisablePutOverwrite: true,
})
rows = dbt.mustQueryContext(
WithFileStream(ctx, f),
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != skipped.String() {
t.Fatalf("expected SKIPPED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")

if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if s3 != uploadTime {
t.Fatalf("upload time should have stayed the same, expected: %v, got: %v", uploadTime, s3)
}

f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
strings.ReplaceAll(testData, "\\", "\\\\")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != uploaded.String() {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
}
if s3 == uploadTime {
t.Fatalf("file should have been overwritten.")
}
})
}

Expand Down Expand Up @@ -417,10 +433,9 @@ func testPutGet(t *testing.T, isStream bool) {
defer rows.Close()

var s0, s1, s2, s3, s4, s5, s6, s7 string
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != uploaded.String() {
t.Fatalf("expected %v, got: %v", uploaded, s6)
Expand Down
1 change: 1 addition & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type execResponseData struct {
Parallel int64 `json:"parallel,omitempty"`
Threshold int64 `json:"threshold,omitempty"`
AutoCompress bool `json:"autoCompress,omitempty"`
Overwrite bool `json:"overwrite,omitempty"`
SourceCompression string `json:"sourceCompression,omitempty"`
ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"`
EncryptionMaterial encryptionWrapper `json:"encryptionMaterial,omitempty"`
Expand Down
6 changes: 1 addition & 5 deletions s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string
if errors.As(err, &ae) {
if ae.ErrorCode() == notFound {
meta.resStatus = notFoundFile
return &fileHeader{
digest: "",
contentLength: 0,
encryptionMetadata: nil,
}, nil
return nil, errors.New("could not find file")
} else if ae.ErrorCode() == expiredToken {
meta.resStatus = renewToken
return nil, fmt.Errorf("received expired token. renewing")
Expand Down
2 changes: 1 addition & 1 deletion s3_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func TestGetHeaderNotFoundError(t *testing.T) {
}

_, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt")
if err != nil {
if err != nil && err.Error() != "could not find file" {
t.Error(err)
}

Expand Down
7 changes: 6 additions & 1 deletion storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error {
for retry := 0; retry < maxRetry; retry++ {
if !meta.overwrite {
header, err := utilClass.getFileHeader(meta, meta.dstFileName)
if err != nil {
if meta.resStatus == notFoundFile {
err := utilClass.uploadFile(dataFile, meta, encryptMeta, maxConcurrency, meta.options.MultiPartThreshold)
if err != nil {
logger.Warnf("Error uploading %v. err: %v", dataFile, err)
}

Check warning on line 95 in storage_client.go

View check run for this annotation

Codecov / codecov/patch

storage_client.go#L94-L95

Added lines #L94 - L95 were not covered by tests
} else if err != nil {
return err
}
if header != nil && meta.resStatus == uploaded {
Expand Down

0 comments on commit ae5fb54

Please sign in to comment.