diff --git a/assert_test.go b/assert_test.go index 58d4cc458..8eb34a55f 100644 --- a/assert_test.go +++ b/assert_test.go @@ -25,6 +25,10 @@ func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } +func assertTrueF(t *testing.T, actual bool, descriptions ...string) { + fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...)) +} + func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } diff --git a/doc.go b/doc.go index e2342635c..df4a61042 100644 --- a/doc.go +++ b/doc.go @@ -960,5 +960,10 @@ Remember, to encode slashes. Example: u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Fother%2Ftmp + +## Using custom configuration for PUT/GET + +If you want to override some default configuration options, you can use `WithFileTransferOptions` context. +There are multiple config parameters including progress bars or compression. */ package gosnowflake diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 9e3a64697..ae63f9a86 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -92,7 +92,6 @@ type SnowflakeFileTransferOptions struct { compressSourceFromStream bool /* PUT */ - DisablePutOverwrite bool putCallback *snowflakeProgressPercentage putAzureCallback *snowflakeProgressPercentage putCallbackOutputStream *io.Writer @@ -268,7 +267,7 @@ func (sfa *snowflakeFileTransferAgent) parseCommand() error { if sfa.data.Parallel != 0 { sfa.parallel = sfa.data.Parallel } - sfa.overwrite = !sfa.options.DisablePutOverwrite + sfa.overwrite = sfa.data.Overwrite sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType)) sfa.stageInfo = &sfa.data.StageInfo sfa.presignedURLs = make([]string, 0) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 163fe3ce1..f4ec224c6 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -237,9 +237,6 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) { EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, }, }, - options: &SnowflakeFileTransferOptions{ - DisablePutOverwrite: false, - }, } err = sfa.parseCommand() diff --git a/put_get_test.go b/put_get_test.go index fee0516ba..7eef02fba 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -288,6 +288,9 @@ func TestPutWithAutoCompressFalse(t *testing.T) { } func TestPutOverwrite(t *testing.T) { + if runningOnGCP() { + t.Skip("Overwriting is default as long as presigned URLs are enabled") + } tmpDir := t.TempDir() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) @@ -298,9 +301,6 @@ func TestPutOverwrite(t *testing.T) { f.Close() runDBTest(t, func(dbt *DBTest) { - if _, err = dbt.exec("use role sysadmin"); err != nil { - t.Skip("snowflake admin account not accessible") - } dbt.mustExec("rm @~/test_put_overwrite") f, _ = os.Open(testData) @@ -321,36 +321,50 @@ func TestPutOverwrite(t *testing.T) { t.Fatalf("expected UPLOADED, got %v", s6) } + rows = dbt.mustQuery("ls @~/test_put_overwrite") + defer rows.Close() + assertTrueF(t, rows.Next(), "expected new rows") + if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { + t.Fatal(err) + } + uploadTime := s3 + f, _ = os.Open(testData) - ctx := WithFileTransferOptions(context.Background(), - &SnowflakeFileTransferOptions{ - DisablePutOverwrite: true, - }) rows = dbt.mustQueryContext( - WithFileStream(ctx, f), + WithFileStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @~/test_put_overwrite", strings.ReplaceAll(testData, "\\", "\\\\"))) defer rows.Close() f.Close() - if rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { - t.Fatal(err) - } + assertTrueF(t, rows.Next(), "expected new rows") + if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { + t.Fatal(err) } if s6 != skipped.String() { t.Fatalf("expected SKIPPED, got %v", s6) } + rows = dbt.mustQuery("ls @~/test_put_overwrite") + defer rows.Close() + assertTrueF(t, rows.Next(), "expected new rows") + + if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { + t.Fatal(err) + } + if s3 != uploadTime { + t.Fatalf("upload time should have stayed the same, expected: %v, got: %v", uploadTime, s3) + } + f, _ = os.Open(testData) rows = dbt.mustQueryContext( WithFileStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true", strings.ReplaceAll(testData, "\\", "\\\\"))) + defer rows.Close() f.Close() - if rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { - t.Fatal(err) - } + assertTrueF(t, rows.Next(), "expected new rows") + if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { + t.Fatal(err) } if s6 != uploaded.String() { t.Fatalf("expected UPLOADED, got %v", s6) @@ -358,14 +372,16 @@ func TestPutOverwrite(t *testing.T) { rows = dbt.mustQuery("ls @~/test_put_overwrite") defer rows.Close() - if rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { - t.Fatal(err) - } + assertTrueF(t, rows.Next(), "expected new rows") + if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { + t.Fatal(err) } if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) { t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0) } + if s3 == uploadTime { + t.Fatalf("file should have been overwritten.") + } }) } @@ -417,10 +433,9 @@ func testPutGet(t *testing.T, isStream bool) { defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string - if rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { - t.Fatal(err) - } + assertTrueF(t, rows.Next(), "expected new rows") + if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { + t.Fatal(err) } if s6 != uploaded.String() { t.Fatalf("expected %v, got: %v", uploaded, s6) diff --git a/query.go b/query.go index 300233d0e..162b45236 100644 --- a/query.go +++ b/query.go @@ -127,6 +127,7 @@ type execResponseData struct { Parallel int64 `json:"parallel,omitempty"` Threshold int64 `json:"threshold,omitempty"` AutoCompress bool `json:"autoCompress,omitempty"` + Overwrite bool `json:"overwrite,omitempty"` SourceCompression string `json:"sourceCompression,omitempty"` ShowEncryptionParameter bool `json:"clientShowEncryptionParameter,omitempty"` EncryptionMaterial encryptionWrapper `json:"encryptionMaterial,omitempty"` diff --git a/s3_storage_client.go b/s3_storage_client.go index ed3bca59a..27171e1cd 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -80,11 +80,7 @@ func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string if errors.As(err, &ae) { if ae.ErrorCode() == notFound { meta.resStatus = notFoundFile - return &fileHeader{ - digest: "", - contentLength: 0, - encryptionMetadata: nil, - }, nil + return nil, errors.New("could not find file") } else if ae.ErrorCode() == expiredToken { meta.resStatus = renewToken return nil, fmt.Errorf("received expired token. renewing") diff --git a/s3_storage_client_test.go b/s3_storage_client_test.go index 8843dcdbf..1de79c24c 100644 --- a/s3_storage_client_test.go +++ b/s3_storage_client_test.go @@ -296,7 +296,7 @@ func TestGetHeaderNotFoundError(t *testing.T) { } _, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt") - if err != nil { + if err != nil && err.Error() != "could not find file" { t.Error(err) } diff --git a/storage_client.go b/storage_client.go index a7385c278..ee746a648 100644 --- a/storage_client.go +++ b/storage_client.go @@ -88,7 +88,12 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error { for retry := 0; retry < maxRetry; retry++ { if !meta.overwrite { header, err := utilClass.getFileHeader(meta, meta.dstFileName) - if err != nil { + if meta.resStatus == notFoundFile { + err := utilClass.uploadFile(dataFile, meta, encryptMeta, maxConcurrency, meta.options.MultiPartThreshold) + if err != nil { + logger.Warnf("Error uploading %v. err: %v", dataFile, err) + } + } else if err != nil { return err } if header != nil && meta.resStatus == uploaded {