Skip to content

Commit

Permalink
Merge branch 'SNOW-870356-increase-coverage' of https://github.com/sn…
Browse files Browse the repository at this point in the history
…owflakedb/gosnowflake into SNOW-870356-increase-coverage
  • Loading branch information
sfc-gh-ext-simba-jl committed Oct 27, 2023
2 parents ae61895 + ff1d08c commit 211d147
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 42 deletions.
4 changes: 3 additions & 1 deletion aaa_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gosnowflake

import "testing"
import (
"testing"
)

func TestShowServerVersion(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
Expand Down
4 changes: 4 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string
fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...))
}

func assertTrueF(t *testing.T, actual bool, descriptions ...string) {
fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...))
}

func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) {
errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...))
}
Expand Down
5 changes: 5 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,5 +960,10 @@ Remember, to encode slashes.
Example:
u:[email protected]/db/s?account=a.r.c&tmpDirPath=%2Fother%2Ftmp
## Using custom configuration for PUT/GET
If you want to override some default configuration options, you can use `WithFileTransferOptions` context.
There are multiple config parameters including progress bars or compression.
*/
package gosnowflake
3 changes: 1 addition & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type SnowflakeFileTransferOptions struct {
compressSourceFromStream bool

/* PUT */
DisablePutOverwrite bool
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
putCallbackOutputStream *io.Writer
Expand Down Expand Up @@ -268,7 +267,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error {
if sfa.data.Parallel != 0 {
sfa.parallel = sfa.data.Parallel
}
sfa.overwrite = !sfa.options.DisablePutOverwrite
sfa.overwrite = sfa.data.Overwrite
sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType))
sfa.stageInfo = &sfa.data.StageInfo
sfa.presignedURLs = make([]string, 0)
Expand Down
3 changes: 0 additions & 3 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1},
},
},
options: &SnowflakeFileTransferOptions{
DisablePutOverwrite: false,
},
}

err = sfa.parseCommand()
Expand Down
2 changes: 1 addition & 1 deletion gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool
logger.Debug("Using GCS downscoped token")
return info.Creds.GcsAccessToken, nil
}
logger.Debug("No access token received from GS, using presigned url")
logger.Debugf("No access token received from GS, using presigned url: %s", info.PresignedURL)
return "", nil
}

Expand Down
78 changes: 50 additions & 28 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,13 @@ func TestPutOverwrite(t *testing.T) {
f.Close()

runDBTest(t, func(dbt *DBTest) {
if _, err = dbt.exec("use role sysadmin"); err != nil {
t.Skip("snowflake admin account not accessible")
}
dbt.mustExec("rm @~/test_put_overwrite")

f, _ = os.Open(testData)
rows := dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "\\\\")))
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
defer dbt.mustExec("rm @~/test_put_overwrite")
Expand All @@ -321,51 +318,77 @@ func TestPutOverwrite(t *testing.T) {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
md5Column := s2

f, _ = os.Open(testData)
ctx := WithFileTransferOptions(context.Background(),
&SnowflakeFileTransferOptions{
DisablePutOverwrite: true,
})
rows = dbt.mustQueryContext(
WithFileStream(ctx, f),
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
strings.ReplaceAll(testData, "\\", "\\\\")))
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != skipped.String() {
if runningOnGCP() && s6 != uploaded.String() {
// when this condition fails it means, that presgined URLs are replaced with downscoped tokens
// when it happens, all clouds should not overwrite by default, so all clouds should pass the `s6 == skipped` test
t.Fatalf("expected UPLOADED as long as presigned URLs are used, got %v", s6)
} else if !runningOnGCP() && s6 != skipped.String() {
t.Fatalf("expected SKIPPED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
assertTrueF(t, rows.Next(), "expected new rows")

if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if runningOnGCP() {
if s2 == md5Column {
// when this condition fails it means, that presgined URLs are replaced with downscoped tokens
// when it happens, all clouds should not overwrite by default, so all clouds should pass the `s2 == md5Column` check
t.Fatal("For GCP and presigned URLs (current on Github Actions) it should be overwritten by default")
}
} else if s2 != md5Column {
t.Fatal("The MD5 column should have stayed the same")
}

f, _ = os.Open(testData)
rows = dbt.mustQueryContext(
WithFileStream(context.Background(), f),
fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
strings.ReplaceAll(testData, "\\", "\\\\")))
strings.ReplaceAll(testData, "\\", "/")))
defer rows.Close()
f.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != uploaded.String() {
t.Fatalf("expected UPLOADED, got %v", s6)
}

rows = dbt.mustQuery("ls @~/test_put_overwrite")
defer rows.Close()
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
t.Fatal(err)
}
if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
}
if s2 == md5Column {
t.Fatalf("file should have been overwritten.")
}
})
}

Expand Down Expand Up @@ -417,10 +440,9 @@ func testPutGet(t *testing.T, isStream bool) {
defer rows.Close()

var s0, s1, s2, s3, s4, s5, s6, s7 string
if rows.Next() {
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
assertTrueF(t, rows.Next(), "expected new rows")
if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
t.Fatal(err)
}
if s6 != uploaded.String() {
t.Fatalf("expected %v, got: %v", uploaded, s6)
Expand Down
1 change: 1 addition & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type execResponseData struct {
Parallel int64 `json:"parallel,omitempty"`
Threshold int64 `json:"threshold,omitempty"`
AutoCompress bool `json:"autoCompress,omitempty"`
Overwrite bool `json:"overwrite,omitempty"`
SourceCompression string `json:"sourceCompression,omitempty"`
ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"`
EncryptionMaterial encryptionWrapper `json:"encryptionMaterial,omitempty"`
Expand Down
6 changes: 1 addition & 5 deletions s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string
if errors.As(err, &ae) {
if ae.ErrorCode() == notFound {
meta.resStatus = notFoundFile
return &fileHeader{
digest: "",
contentLength: 0,
encryptionMetadata: nil,
}, nil
return nil, fmt.Errorf("could not find file")
} else if ae.ErrorCode() == expiredToken {
meta.resStatus = renewToken
return nil, fmt.Errorf("received expired token. renewing")
Expand Down
2 changes: 1 addition & 1 deletion s3_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func TestGetHeaderNotFoundError(t *testing.T) {
}

_, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt")
if err != nil {
if err != nil && err.Error() != "could not find file" {
t.Error(err)
}

Expand Down
7 changes: 6 additions & 1 deletion storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error {
for retry := 0; retry < maxRetry; retry++ {
if !meta.overwrite {
header, err := utilClass.getFileHeader(meta, meta.dstFileName)
if err != nil {
if meta.resStatus == notFoundFile {
err := utilClass.uploadFile(dataFile, meta, encryptMeta, maxConcurrency, meta.options.MultiPartThreshold)
if err != nil {
logger.Warnf("Error uploading %v. err: %v", dataFile, err)
}
} else if err != nil {
return err
}
if header != nil && meta.resStatus == uploaded {
Expand Down

0 comments on commit 211d147

Please sign in to comment.