From 891f48aba4e057f75227e39bb65d313a307893c1 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Thu, 12 Oct 2023 09:26:46 +0200 Subject: [PATCH 01/12] SNOW-859636 Add codecov.yml with partials_as_hits parameter (#925) --- codecov.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 codecov.yml diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..d23803439 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +parsers: + go: + partials_as_hits: true \ No newline at end of file From 18d84a8a76a4a36d7276f1515cfe02eba3a113ab Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Thu, 12 Oct 2023 10:57:26 +0200 Subject: [PATCH 02/12] SNOW-859636 Exclude codecov[bot] from Jira Issue comments (#927) --- .github/workflows/jira_comment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa6..33dbc89dd 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -23,7 +23,7 @@ jobs: echo ::set-output name=jira::$jira - name: Comment on issue uses: atlassian/gajira-comment@master - if: startsWith(steps.extract.outputs.jira, 'SNOW-') + if: startsWith(steps.extract.outputs.jira, 'SNOW-') && ${{ github.event.comment.user.login }} != 'codecov[bot]' with: issue: "${{ steps.extract.outputs.jira }}" comment: "${{ github.event.comment.user.login }} commented:\n\n${{ github.event.comment.body }}\n\n${{ github.event.comment.html_url }}" From 46bf2977c1c3454311cb329c12320b8d4aea2234 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Thu, 12 Oct 2023 13:23:13 +0200 Subject: [PATCH 03/12] Snow 859636 codecov jira comment (#929) --- .github/workflows/jira_comment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 33dbc89dd..11769c0a0 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -23,7 +23,7 @@ jobs: echo ::set-output name=jira::$jira - name: Comment on issue uses: atlassian/gajira-comment@master - if: startsWith(steps.extract.outputs.jira, 'SNOW-') && ${{ github.event.comment.user.login }} != 'codecov[bot]' + if: startsWith(steps.extract.outputs.jira, 'SNOW-') && github.event.comment.user.login != 'codecov[bot]' with: issue: "${{ steps.extract.outputs.jira }}" comment: "${{ github.event.comment.user.login }} commented:\n\n${{ github.event.comment.body }}\n\n${{ github.event.comment.html_url }}" From dfb1c18624ebd4023e27839cd32a51955758bd75 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 13 Oct 2023 13:35:49 +0200 Subject: [PATCH 04/12] SNOW-856228 easy logging parser (#924) SNOW-856228 easy logging parser --- client_configuration.go | 87 +++++++++++++++++ client_configuration_test.go | 175 +++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 client_configuration.go create mode 100644 client_configuration_test.go diff --git a/client_configuration.go b/client_configuration.go new file mode 100644 index 000000000..381ed8d86 --- /dev/null +++ b/client_configuration.go @@ -0,0 +1,87 @@ +// Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" +) + +// log levels for easy logging +const ( + levelOff string = "OFF" // log level for logging switched off + levelError string = "ERROR" // error log level + levelWarn string = "WARN" // warn log level + levelInfo string = "INFO" // info log level + levelDebug string = "DEBUG" // debug log level + levelTrace string = "TRACE" // trace log level +) + +// ClientConfig config root +type ClientConfig struct { + Common *ClientConfigCommonProps `json:"common"` +} + +// ClientConfigCommonProps properties from "common" section +type ClientConfigCommonProps struct { + LogLevel string `json:"log_level,omitempty"` + LogPath string `json:"log_path,omitempty"` +} + +func parseClientConfiguration(filePath string) (*ClientConfig, error) { + if filePath == "" { + return nil, nil + } + fileContents, err := os.ReadFile(filePath) + if err != nil { + return nil, parsingClientConfigError(err) + } + var clientConfig ClientConfig + err = json.Unmarshal(fileContents, &clientConfig) + if err != nil { + return nil, parsingClientConfigError(err) + } + err = validateClientConfiguration(&clientConfig) + if err != nil { + return nil, parsingClientConfigError(err) + } + return &clientConfig, nil +} + +func parsingClientConfigError(err error) error { + return fmt.Errorf("parsing client config failed: %w", err) +} + +func validateClientConfiguration(clientConfig *ClientConfig) error { + if clientConfig == nil { + return errors.New("client config not found") + } + if clientConfig.Common == nil { + return errors.New("common section in client config not found") + } + return validateLogLevel(*clientConfig) +} + +func validateLogLevel(clientConfig ClientConfig) error { + var logLevel = clientConfig.Common.LogLevel + if logLevel != "" { + _, error := toLogLevel(logLevel) + if error != nil { + return error + } + } + return nil +} + +func toLogLevel(logLevelString string) (string, error) { + var logLevel = strings.ToUpper(logLevelString) + switch logLevel { + case levelOff, levelError, levelWarn, levelInfo, levelDebug, levelTrace: + return logLevel, nil + default: + return "", errors.New("unknown log level: " + logLevelString) + } +} diff --git a/client_configuration_test.go b/client_configuration_test.go new file mode 100644 index 000000000..b7ecdd2f0 --- /dev/null +++ b/client_configuration_test.go @@ -0,0 +1,175 @@ +// Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "fmt" + "os" + "path" + "strings" + "testing" +) + +func TestParseConfiguration(t *testing.T) { + dir := t.TempDir() + testCases := []struct { + testName string + fileName string + fileContents string + expectedLogLevel string + expectedLogPath string + }{ + { + testName: "TestWithLogLevelUpperCase", + fileName: "config_1.json", + fileContents: `{ + "common": { + "log_level" : "INFO", + "log_path" : "/some-path/some-directory" + } + }`, + expectedLogLevel: "INFO", + expectedLogPath: "/some-path/some-directory", + }, + { + testName: "TestWithLogLevelLowerCase", + fileName: "config_2.json", + fileContents: `{ + "common": { + "log_level" : "info", + "log_path" : "/some-path/some-directory" + } + }`, + expectedLogLevel: "info", + expectedLogPath: "/some-path/some-directory", + }, + { + testName: "TestWithMissingValues", + fileName: "config_3.json", + fileContents: `{ + "common": {} + }`, + expectedLogLevel: "", + expectedLogPath: "", + }, + } + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + fileName := createFile(t, tc.fileName, tc.fileContents, dir) + + config, err := parseClientConfiguration(fileName) + + if err != nil { + t.Fatalf("Error should be nil but was %s", err) + } + if config.Common.LogLevel != tc.expectedLogLevel { + t.Errorf("Log level should be %s but was %s", tc.expectedLogLevel, config.Common.LogLevel) + } + if config.Common.LogPath != tc.expectedLogPath { + t.Errorf("Log path should be %s but was %s", tc.expectedLogPath, config.Common.LogPath) + } + }) + } +} + +func TestParseAllLogLevels(t *testing.T) { + dir := t.TempDir() + for _, logLevel := range []string{"OFF", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"} { + t.Run(logLevel, func(t *testing.T) { + fileContents := fmt.Sprintf(`{ + "common": { + "log_level" : "%s", + "log_path" : "/some-path/some-directory" + } + }`, logLevel) + fileName := createFile(t, fmt.Sprintf("config_%s.json", logLevel), fileContents, dir) + + config, err := parseClientConfiguration(fileName) + + if err != nil { + t.Fatalf("Error should be nil but was: %s", err) + } + if config.Common.LogLevel != logLevel { + t.Errorf("Log level should be %s but was %s", logLevel, config.Common.LogLevel) + } + }) + } +} + +func TestParseConfigurationFails(t *testing.T) { + dir := t.TempDir() + testCases := []struct { + testName string + fileName string + FileContents string + expectedErrorMessageToContain string + }{ + { + testName: "TestWithWrongLogLevel", + fileName: "config_1.json", + FileContents: `{ + "common": { + "log_level" : "something weird", + "log_path" : "/some-path/some-directory" + } + }`, + expectedErrorMessageToContain: "unknown log level", + }, + { + testName: "TestWithWrongTypeOfLogLevel", + fileName: "config_2.json", + FileContents: `{ + "common": { + "log_level" : 15, + "log_path" : "/some-path/some-directory" + } + }`, + expectedErrorMessageToContain: "ClientConfigCommonProps.common.log_level", + }, + { + testName: "TestWithWrongTypeOfLogPath", + fileName: "config_3.json", + FileContents: `{ + "common": { + "log_level" : "INFO", + "log_path" : true + } + }`, + expectedErrorMessageToContain: "ClientConfigCommonProps.common.log_path", + }, + { + testName: "TestWithoutCommon", + fileName: "config_4.json", + FileContents: "{}", + expectedErrorMessageToContain: "common section in client config not found", + }, + } + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + fileName := createFile(t, tc.fileName, tc.FileContents, dir) + + _, err := parseClientConfiguration(fileName) + + if err == nil { + t.Fatal("Error should not be nil but was nil") + } + errMessage := fmt.Sprint(err) + expectedPrefix := "parsing client config failed" + if !strings.HasPrefix(errMessage, expectedPrefix) { + t.Errorf("Error message: \"%s\" should start with prefix: \"%s\"", errMessage, expectedPrefix) + } + if !strings.Contains(errMessage, tc.expectedErrorMessageToContain) { + t.Errorf("Error message: \"%s\" should contain given phrase: \"%s\"", errMessage, tc.expectedErrorMessageToContain) + } + }) + } +} + +func createFile(t *testing.T, fileName string, fileContents string, directory string) string { + fullFileName := path.Join(directory, fileName) + err := os.WriteFile(fullFileName, []byte(fileContents), 0644) + if err != nil { + t.Fatal("Could not create file") + } + return fullFileName +} From fcdd18a7ffbe2a8e94eb28331bc3bb43a0d21d94 Mon Sep 17 00:00:00 2001 From: Eng Zer Jun Date: Mon, 16 Oct 2023 15:32:23 +0800 Subject: [PATCH 05/12] test: use `T.TempDir` to create temporary test directory (#647) --- encrypt_util.go | 3 +++ encrypt_util_test.go | 29 ++++++----------------------- put_get_test.go | 37 ++++++++++--------------------------- put_get_user_stage_test.go | 13 ++++--------- put_get_with_aws_test.go | 20 ++++---------------- 5 files changed, 27 insertions(+), 75 deletions(-) diff --git a/encrypt_util.go b/encrypt_util.go index 435ecbafc..08179891d 100644 --- a/encrypt_util.go +++ b/encrypt_util.go @@ -176,10 +176,13 @@ func encryptFile( if err != nil { return nil, "", err } + defer tmpOutputFile.Close() infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, readWriteFileMode) if err != nil { return nil, "", err } + defer infile.Close() + meta, err := encryptStream(sfe, infile, tmpOutputFile, chunkSize) if err != nil { return nil, "", err diff --git a/encrypt_util_test.go b/encrypt_util_test.go index bfcb6e5b2..03f075d2b 100644 --- a/encrypt_util_test.go +++ b/encrypt_util_test.go @@ -91,11 +91,7 @@ func TestEncryptDecryptFilePadding(t *testing.T) { for _, test := range testcases { t.Run(fmt.Sprintf("%v_%v", test.numberOfBytesInEachRow, test.numberOfLines), func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Error(err) - } - tmpDir, err = generateKLinesOfNByteRows(test.numberOfLines, test.numberOfBytesInEachRow, tmpDir) + tmpDir, err := generateKLinesOfNByteRows(test.numberOfLines, test.numberOfBytesInEachRow, t.TempDir()) if err != nil { t.Error(err) } @@ -114,11 +110,7 @@ func TestEncryptDecryptLargeFile(t *testing.T) { numberOfFiles := 1 numberOfLines := 10000 - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Error(err) - } - tmpDir, err = generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, tmpDir) + tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir()) if err != nil { t.Error(err) } @@ -127,7 +119,6 @@ func TestEncryptDecryptLargeFile(t *testing.T) { } func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected int, tmpDir string) { - defer os.RemoveAll(tmpDir) files, err := filepath.Glob(filepath.Join(tmpDir, "file*")) if err != nil { t.Error(err) @@ -150,6 +141,8 @@ func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected i if err != nil { t.Error(err) } + defer fd.Close() + scanner := bufio.NewScanner(fd) for scanner.Scan() { cnt++ @@ -163,12 +156,6 @@ func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected i } func generateKLinesOfNByteRows(numLines int, numBytes int, tmpDir string) (string, error) { - if tmpDir == "" { - _, err := os.MkdirTemp(tmpDir, "data") - if err != nil { - return "", err - } - } fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(numLines*numBytes), 10)) f, err := os.Create(fname) if err != nil { @@ -185,12 +172,6 @@ func generateKLinesOfNByteRows(numLines int, numBytes int, tmpDir string) (strin } func generateKLinesOfNFiles(k int, n int, compress bool, tmpDir string) (string, error) { - if tmpDir == "" { - _, err := os.MkdirTemp(tmpDir, "data") - if err != nil { - return "", err - } - } for i := 0; i < n; i++ { fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10)) f, err := os.Create(fname) @@ -248,6 +229,8 @@ func generateKLinesOfNFiles(k int, n int, compress bool, tmpDir string) (string, return "", err } w.Close() + fOut.Close() + fIn.Close() } } } diff --git a/put_get_test.go b/put_get_test.go index 730efbcbb..fee0516ba 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -34,19 +34,17 @@ func TestPutError(t *testing.T) { if isWindows { t.Skip("permission model is different") } - tmpDir, err := os.MkdirTemp("", "putfiledir") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() file1 := filepath.Join(tmpDir, "file1") remoteLocation := filepath.Join(tmpDir, "remote_loc") f, err := os.Create(file1) if err != nil { t.Error(err) } + defer f.Close() f.WriteString("test1") os.Chmod(file1, 0000) + defer os.Chmod(file1, 0644) data := &execResponseData{ Command: string(uploadCommand), @@ -253,11 +251,7 @@ func TestPutWithAutoCompressFalse(t *testing.T) { if runningOnGithubAction() && !runningOnAWS() { t.Skip("skipping non aws environment") } - tmpDir, err := os.MkdirTemp("", "put") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { @@ -294,11 +288,7 @@ func TestPutWithAutoCompressFalse(t *testing.T) { } func TestPutOverwrite(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { @@ -388,11 +378,7 @@ func TestPutGetStream(t *testing.T) { } func testPutGet(t *testing.T, isStream bool) { - tmpDir, err := os.MkdirTemp("", "put_get") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" tableName := randomString(5) @@ -401,23 +387,19 @@ func testPutGet(t *testing.T, isStream bool) { gzw := gzip.NewWriter(&b) gzw.Write([]byte(originalContents)) gzw.Close() - if err = os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { + if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table " + tableName + " (a int, b string)") + defer dbt.mustExec("drop table " + tableName) fileStream, err := os.Open(fname) if err != nil { t.Error(err) } - defer func() { - defer dbt.mustExec("drop table " + tableName) - if fileStream != nil { - fileStream.Close() - } - }() + defer fileStream.Close() var sqlText string var rows *RowsExtended @@ -489,6 +471,7 @@ func testPutGet(t *testing.T, isStream bool) { if err != nil { t.Error(err) } + defer gz.Close() var contents string for { c := make([]byte, defaultChunkBufferSize) diff --git a/put_get_user_stage_test.go b/put_get_user_stage_test.go index 1f8021f77..f95926072 100644 --- a/put_get_user_stage_test.go +++ b/put_get_user_stage_test.go @@ -14,29 +14,24 @@ func TestPutGetFileSmallDataViaUserStage(t *testing.T) { if os.Getenv("AWS_ACCESS_KEY_ID") == "" { t.Skip("this test requires to change the internal parameter") } - putGetUserStage(t, "", 5, 1, false) + putGetUserStage(t, 5, 1, false) } func TestPutGetStreamSmallDataViaUserStage(t *testing.T) { if os.Getenv("AWS_ACCESS_KEY_ID") == "" { t.Skip("this test requires to change the internal parameter") } - putGetUserStage(t, "", 1, 1, true) + putGetUserStage(t, 1, 1, true) } -func putGetUserStage(t *testing.T, tmpDir string, numberOfFiles int, numberOfLines int, isStream bool) { +func putGetUserStage(t *testing.T, numberOfFiles int, numberOfLines int, isStream bool) { if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { t.Fatal("no aws secret access key found") } - tmpDir, err := os.MkdirTemp(tmpDir, "data") + tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir()) if err != nil { t.Error(err) } - tmpDir, err = generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, tmpDir) - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) var files string if isStream { list, err := os.ReadDir(tmpDir) diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index b5753d295..1bc499d50 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -91,11 +91,7 @@ func TestPutWithInvalidToken(t *testing.T) { if !runningOnAWS() { t.Skip("skipping non aws environment") } - tmpDir, err := os.MkdirTemp("", "aws_put") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") originalContents := "123,test1\n456,test2\n" @@ -189,11 +185,7 @@ func TestPretendToPutButList(t *testing.T) { if runningOnGithubAction() && !runningOnAWS() { t.Skip("skipping non aws environment") } - tmpDir, err := os.MkdirTemp("", "aws_put") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") originalContents := "123,test1\n456,test2\n" @@ -244,11 +236,7 @@ func TestPutGetAWSStage(t *testing.T) { t.Skip("skipping non aws environment") } - tmpDir, err := os.MkdirTemp("", "put_get") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() name := "test_put_get.txt.gz" fname := filepath.Join(tmpDir, name) originalContents := "123,test1\n456,test2\n" @@ -258,7 +246,7 @@ func TestPutGetAWSStage(t *testing.T) { gzw := gzip.NewWriter(&b) gzw.Write([]byte(originalContents)) gzw.Close() - if err = os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { + if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { t.Fatal("could not write to gzip file") } From 5218d7a669b03282cc0b304243561b3a4d3e8e70 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Mon, 16 Oct 2023 11:17:44 +0200 Subject: [PATCH 06/12] SNOW-894815 Fix access to params from htap (#934) --- htap.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/htap.go b/htap.go index 93806801f..3d1e7c702 100644 --- a/htap.go +++ b/htap.go @@ -79,7 +79,10 @@ func (qcc *queryContextCache) prune(size int) { } func (qcc *queryContextCache) getQueryContextCacheSize(sc *snowflakeConn) int { - if sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName]; ok { + paramsMutex.Lock() + sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName] + paramsMutex.Unlock() + if ok { size, err := strconv.Atoi(*sizeStr) if err != nil { logger.Warnf("cannot parse %v as int as query context cache size: %v", sizeStr, err) From bd79bae87d3c3a5b2f66f2fc8c9ef144741862b6 Mon Sep 17 00:00:00 2001 From: etsheks <71740970+etsheks@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:18:14 +1100 Subject: [PATCH 07/12] add log depencies into doc.go (#592) --- doc.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc.go b/doc.go index 93def4075..e2342635c 100644 --- a/doc.go +++ b/doc.go @@ -7,6 +7,8 @@ Clients can use the database/sql package directly. For example: "database/sql" _ "github.com/snowflakedb/gosnowflake" + + "log" ) func main() { From 973638621b84e3160439e2fc80b5910679243270 Mon Sep 17 00:00:00 2001 From: sivchari Date: Tue, 17 Oct 2023 17:39:25 +0900 Subject: [PATCH 08/12] fix: error format (#649) --- auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth.go b/auth.go index fa2f651c6..e2f426908 100644 --- a/auth.go +++ b/auth.go @@ -235,7 +235,7 @@ func postAuth( var respd authResponse err = json.NewDecoder(resp.Body).Decode(&respd) if err != nil { - logger.Error("failed to decode JSON. err: %v", err) + logger.Errorf("failed to decode JSON. err: %v", err) return nil, err } return &respd, nil From c219d9d0d5bc6223bf689d883ef59cd6ab5ba3f1 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Tue, 17 Oct 2023 12:02:10 +0200 Subject: [PATCH 09/12] SNOW-856228 better assertions in easy logging (#933) * SNOW-856228 better assertions in easy logging --- assert_test.go | 94 ++++++++++++++++++++++++++++++++++++ client_configuration_test.go | 37 ++++---------- 2 files changed, 103 insertions(+), 28 deletions(-) create mode 100644 assert_test.go diff --git a/assert_test.go b/assert_test.go new file mode 100644 index 000000000..0394b4b08 --- /dev/null +++ b/assert_test.go @@ -0,0 +1,94 @@ +// Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +func assertNilF(t *testing.T, actual any, descriptions ...string) { + fatalOnNonEmpty(t, validateNil(actual, descriptions...)) +} + +func assertNotNilF(t *testing.T, actual any, descriptions ...string) { + fatalOnNonEmpty(t, validateNotNil(actual, descriptions...)) +} + +func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string) { + errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) +} + +func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { + errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) +} + +func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) { + errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...)) +} + +func fatalOnNonEmpty(t *testing.T, errMsg string) { + if errMsg != "" { + t.Fatal(errMsg) + } +} + +func errorOnNonEmpty(t *testing.T, errMsg string) { + if errMsg != "" { + t.Error(errMsg) + } +} + +func validateNil(actual any, descriptions ...string) string { + if isNil(actual) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to be nil but was not. %s", actual, desc) +} + +func validateNotNil(actual any, descriptions ...string) string { + if !isNil(actual) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected to be not nil but was not. %s", desc) +} + +func validateEqual(actual any, expected any, descriptions ...string) string { + if expected == actual { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to be equal to \"%s\" but was not. %s", actual, expected, desc) +} + +func validateStringContains(actual string, expectedToContain string, descriptions ...string) string { + if strings.Contains(actual, expectedToContain) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to contain \"%s\" but did not. %s", actual, expectedToContain, desc) +} + +func validateHasPrefix(actual string, expectedPrefix string, descriptions ...string) string { + if strings.HasPrefix(actual, expectedPrefix) { + return "" + } + desc := joinDescriptions(descriptions...) + return fmt.Sprintf("expected \"%s\" to start with \"%s\" but did not. %s", actual, expectedPrefix, desc) +} + +func joinDescriptions(descriptions ...string) string { + return strings.Join(descriptions, " ") +} + +func isNil(value any) bool { + if value == nil { + return true + } + val := reflect.ValueOf(value) + return val.Kind() == reflect.Pointer && val.IsNil() +} diff --git a/client_configuration_test.go b/client_configuration_test.go index b7ecdd2f0..8da3be3ec 100644 --- a/client_configuration_test.go +++ b/client_configuration_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path" - "strings" "testing" ) @@ -59,15 +58,9 @@ func TestParseConfiguration(t *testing.T) { config, err := parseClientConfiguration(fileName) - if err != nil { - t.Fatalf("Error should be nil but was %s", err) - } - if config.Common.LogLevel != tc.expectedLogLevel { - t.Errorf("Log level should be %s but was %s", tc.expectedLogLevel, config.Common.LogLevel) - } - if config.Common.LogPath != tc.expectedLogPath { - t.Errorf("Log path should be %s but was %s", tc.expectedLogPath, config.Common.LogPath) - } + assertNilF(t, err, "parse client configuration error") + assertEqualE(t, config.Common.LogLevel, tc.expectedLogLevel, "log level") + assertEqualE(t, config.Common.LogPath, tc.expectedLogPath, "log path") }) } } @@ -86,12 +79,8 @@ func TestParseAllLogLevels(t *testing.T) { config, err := parseClientConfiguration(fileName) - if err != nil { - t.Fatalf("Error should be nil but was: %s", err) - } - if config.Common.LogLevel != logLevel { - t.Errorf("Log level should be %s but was %s", logLevel, config.Common.LogLevel) - } + assertNilF(t, err, "parse client config error") + assertEqualE(t, config.Common.LogLevel, logLevel, "log level") }) } } @@ -150,17 +139,11 @@ func TestParseConfigurationFails(t *testing.T) { _, err := parseClientConfiguration(fileName) - if err == nil { - t.Fatal("Error should not be nil but was nil") - } + assertNotNilF(t, err, "parse client configuration error") errMessage := fmt.Sprint(err) expectedPrefix := "parsing client config failed" - if !strings.HasPrefix(errMessage, expectedPrefix) { - t.Errorf("Error message: \"%s\" should start with prefix: \"%s\"", errMessage, expectedPrefix) - } - if !strings.Contains(errMessage, tc.expectedErrorMessageToContain) { - t.Errorf("Error message: \"%s\" should contain given phrase: \"%s\"", errMessage, tc.expectedErrorMessageToContain) - } + assertHasPrefixE(t, errMessage, expectedPrefix, "error message") + assertStringContainsE(t, errMessage, tc.expectedErrorMessageToContain, "error message") }) } } @@ -168,8 +151,6 @@ func TestParseConfigurationFails(t *testing.T) { func createFile(t *testing.T, fileName string, fileContents string, directory string) string { fullFileName := path.Join(directory, fileName) err := os.WriteFile(fullFileName, []byte(fileContents), 0644) - if err != nil { - t.Fatal("Could not create file") - } + assertNilF(t, err, "create file error") return fullFileName } From c6c2afdf81761f360648d4270aeb04ab4007892b Mon Sep 17 00:00:00 2001 From: sivchari Date: Tue, 17 Oct 2023 21:47:54 +0900 Subject: [PATCH 10/12] improve: use context.Background instead of context.TODO (#651) * improve: use context.Background instead of context.TODO * refactor --------- Co-authored-by: Piotr Fus --- auth_test.go | 53 +++++++++++++++++++------------------ authexternalbrowser_test.go | 8 +++--- authokta_test.go | 32 +++++++++++----------- connection_test.go | 2 +- ctx_test.go | 4 +-- driver.go | 2 +- ocsp.go | 4 +-- ocsp_test.go | 8 +++--- restful.go | 2 +- retry_test.go | 16 +++++------ 10 files changed, 66 insertions(+), 65 deletions(-) diff --git a/auth_test.go b/auth_test.go index 43123d760..c3a39a707 100644 --- a/auth_test.go +++ b/auth_test.go @@ -29,27 +29,27 @@ func TestUnitPostAuth(t *testing.T) { bodyCreator := func() ([]byte, error) { return []byte{0x12, 0x34}, nil } - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) + _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err != nil { t.Fatalf("err: %v", err) } sr.FuncAuthPost = postAuthTestError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) + _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppBadGatewayError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) + _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppForbiddenError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) + _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } sr.FuncAuthPost = postAuthTestAppUnexpectedError - _, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) + _, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0) if err == nil { t.Fatal("should have failed to auth for unknown reason") } @@ -131,7 +131,8 @@ func postAuthCheckOAuth( _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, - _ time.Duration) (*authResponse, error) { + _ time.Duration, +) (*authResponse, error) { var ar authRequest jsonBody, _ := bodyCreator() if err := json.Unmarshal(jsonBody, &ar); err != nil { @@ -408,7 +409,7 @@ func TestUnitAuthenticateWithTokenAccessor(t *testing.T) { sc.rest = sr // FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth - resp, err := authenticate(context.TODO(), sc, []byte{}, []byte{}) + resp, err := authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("should not have failed, err %v", err) } @@ -449,7 +450,7 @@ func TestUnitAuthenticate(t *testing.T) { } sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } @@ -458,7 +459,7 @@ func TestUnitAuthenticate(t *testing.T) { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } sr.FuncPostAuth = postAuthFailWrongAccount - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } @@ -467,7 +468,7 @@ func TestUnitAuthenticate(t *testing.T) { t.Fatalf("Snowflake error is expected. err: %v", driverErr) } sr.FuncPostAuth = postAuthFailUnknown - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } @@ -477,7 +478,7 @@ func TestUnitAuthenticate(t *testing.T) { } ta.SetTokens("bad-token", "bad-master-token", 1) sr.FuncPostAuth = postAuthSuccessWithErrorCode - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } @@ -491,7 +492,7 @@ func TestUnitAuthenticate(t *testing.T) { } ta.SetTokens("bad-token", "bad-master-token", 1) sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed.") } @@ -501,7 +502,7 @@ func TestUnitAuthenticate(t *testing.T) { } sr.FuncPostAuth = postAuthSuccess var resp *authResponseMain - resp, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + resp, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to auth. err: %v", err) } @@ -533,7 +534,7 @@ func TestUnitAuthenticateSaml(t *testing.T) { Host: "blah.okta.com", } sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte("HTML data in bytes from"), []byte{}) + _, err = authenticate(context.Background(), sc, []byte("HTML data in bytes from"), []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } @@ -550,7 +551,7 @@ func TestUnitAuthenticateOAuth(t *testing.T) { sc.cfg.Token = "oauthToken" sc.cfg.Authenticator = AuthTypeOAuth sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } @@ -566,14 +567,14 @@ func TestUnitAuthenticatePasscode(t *testing.T) { sc.cfg.Passcode = "987654321" sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckPasscodeInPassword sc.rest = sr sc.cfg.PasscodeInPassword = true - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } @@ -594,7 +595,7 @@ func TestUnitAuthenticateJWT(t *testing.T) { sc.rest = sr // A valid JWT token should pass - if _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}); err != nil { + if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil { t.Fatalf("failed to run. err: %v", err) } @@ -604,7 +605,7 @@ func TestUnitAuthenticateJWT(t *testing.T) { t.Error(err) } sc.cfg.PrivateKey = invalidPrivateKey - if _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}); err == nil { + if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil { t.Fatalf("invalid token passed") } } @@ -619,20 +620,20 @@ func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) { sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA sc.cfg.ClientRequestMfaToken = ConfigBoolTrue sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken sc.cfg.MfaToken = "mockedMfaToken" - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed") } @@ -648,7 +649,7 @@ func TestUnitAuthenticateWithConfigMFA(t *testing.T) { sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA sc.cfg.ClientRequestMfaToken = ConfigBoolTrue sc.rest = sr - sc.ctx = context.TODO() + sc.ctx = context.Background() err = authenticateWithConfig(sc) if err != nil { t.Fatalf("failed to run. err: %v", err) @@ -665,20 +666,20 @@ func TestUnitAuthenticateExternalBrowser(t *testing.T) { sc.cfg.Authenticator = AuthTypeExternalBrowser sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue sc.rest = sr - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckExternalBrowserToken sc.cfg.IDToken = "mockedIDToken" - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err != nil { t.Fatalf("failed to run. err: %v", err) } sr.FuncPostAuth = postAuthCheckExternalBrowserFailed - _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}) + _, err = authenticate(context.Background(), sc, []byte{}, []byte{}) if err == nil { t.Fatal("should have failed") } diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go index b8889a1fc..a6650dc78 100644 --- a/authexternalbrowser_test.go +++ b/authexternalbrowser_test.go @@ -91,17 +91,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) { FuncPostAuthSAML: postAuthExternalBrowserError, TokenAccessor: getSimpleTokenAccessor(), } - _, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout) + _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFail - _, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout) + _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode - _, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout) + _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) if err == nil { t.Fatal("should have failed.") } @@ -128,7 +128,7 @@ func TestAuthenticationTimeout(t *testing.T) { FuncPostAuthSAML: postAuthExternalBrowserError, TokenAccessor: getSimpleTokenAccessor(), } - _, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout) + _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) if err.Error() != "authentication timed out" { t.Fatal("should have timed out") } diff --git a/authokta_test.go b/authokta_test.go index dd1cf7af2..a9a4b2772 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -64,17 +64,17 @@ func TestUnitPostAuthSAML(t *testing.T) { TokenAccessor: getSimpleTokenAccessor(), } var err error - _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0) + _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestAppBadGatewayError - _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0) + _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestSuccessButInvalidJSON - _, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, 0) + _, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, 0) if err == nil { t.Fatalf("should have failed to post") } @@ -86,17 +86,17 @@ func TestUnitPostAuthOKTA(t *testing.T) { TokenAccessor: getSimpleTokenAccessor(), } var err error - _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0) + _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestAppBadGatewayError - _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0) + _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncPost = postTestSuccessButInvalidJSON - _, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0) + _, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0) if err == nil { t.Fatal("should have failed to run post request after the renewal") } @@ -108,17 +108,17 @@ func TestUnitGetSSO(t *testing.T) { TokenAccessor: getSimpleTokenAccessor(), } var err error - _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) + _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncGet = getTestAppBadGatewayError - _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) + _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err == nil { t.Fatal("should have failed.") } sr.FuncGet = getTestHTMLSuccess - _, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0) + _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0) if err != nil { t.Fatalf("failed to get HTML content. err: %v", err) } @@ -194,17 +194,17 @@ func TestUnitAuthenticateBySAML(t *testing.T) { TokenAccessor: getSimpleTokenAccessor(), } var err error - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthSAMLAuthFail - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } @@ -217,23 +217,23 @@ func TestUnitAuthenticateBySAML(t *testing.T) { } sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthOKTA = postAuthOKTASuccess sr.FuncGetSSO = getSSOError - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncGetSSO = getSSOSuccessButInvalidURL - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err == nil { t.Fatal("should have failed.") } sr.FuncGetSSO = getSSOSuccess - _, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password) + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) if err != nil { t.Fatalf("failed. err: %v", err) } diff --git a/connection_test.go b/connection_test.go index 0a23684d2..a76cdb7c1 100644 --- a/connection_test.go +++ b/connection_test.go @@ -191,7 +191,7 @@ func TestServiceName(t *testing.T) { expectServiceName := serviceNameStub for i := 0; i < 5; i++ { - sc.exec(context.TODO(), "", false, /* noResult */ + sc.exec(context.Background(), "", false, /* noResult */ false /* isInternal */, false /* describeOnly */, nil) if actualServiceName, ok := sc.cfg.Params[serviceName]; ok { if *actualServiceName != expectServiceName { diff --git a/ctx_test.go b/ctx_test.go index f9bd5adeb..9422c6654 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -40,8 +40,8 @@ func TestCtxVal(t *testing.T) { func TestLogEntryCtx(t *testing.T) { var log = logger - var ctx1 = context.WithValue(context.TODO(), SFSessionIDKey, "sessID1") - var ctx2 = context.WithValue(context.TODO(), SFSessionUserKey, "admin") + var ctx1 = context.WithValue(context.Background(), SFSessionIDKey, "sessID1") + var ctx2 = context.WithValue(context.Background(), SFSessionUserKey, "admin") fs1 := context2Fields(ctx1) fs2 := context2Fields(ctx2) diff --git a/driver.go b/driver.go index af29d66e1..6a565be4e 100644 --- a/driver.go +++ b/driver.go @@ -18,7 +18,7 @@ type SnowflakeDriver struct{} // Open creates a new connection. func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) { logger.Info("Open") - ctx := context.TODO() + ctx := context.Background() cfg, err := ParseDSN(dsn) if err != nil { return nil, err diff --git a/ocsp.go b/ocsp.go index 6feb71692..297a999db 100644 --- a/ocsp.go +++ b/ocsp.go @@ -751,7 +751,7 @@ func downloadOCSPCacheServer() { Timeout: timeout, Transport: snowflakeInsecureTransport, } - ret, ocspStatus := checkOCSPCacheServer(context.TODO(), ocspClient, http.NewRequest, u, timeout) + ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout) if ocspStatus.code != ocspSuccess { return } @@ -788,7 +788,7 @@ func getAllRevocationStatus(ctx context.Context, verifiedChains []*x509.Certific // verifyPeerCertificateSerial verifies the certificate revocation status in serial. func verifyPeerCertificateSerial(_ [][]byte, verifiedChains [][]*x509.Certificate) (err error) { overrideCacheDir() - return verifyPeerCertificate(context.TODO(), verifiedChains) + return verifyPeerCertificate(context.Background(), verifiedChains) } func overrideCacheDir() { diff --git a/ocsp_test.go b/ocsp_test.go index e2deb515f..c89a7a6ee 100644 --- a/ocsp_test.go +++ b/ocsp_test.go @@ -318,7 +318,7 @@ func TestOCSPRetry(t *testing.T) { body: []byte{1, 2, 3}, } res, b, st := retryOCSP( - context.TODO(), + context.Background(), client, emptyRequest, dummyOCSPHost, make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second) @@ -331,7 +331,7 @@ func TestOCSPRetry(t *testing.T) { body: []byte{1, 2, 3}, } res, b, st = retryOCSP( - context.TODO(), + context.Background(), client, fakeRequestFunc, dummyOCSPHost, make(map[string]string), []byte{0}, certs[len(certs)-1], 5*time.Second) @@ -389,7 +389,7 @@ func TestOCSPCacheServerRetry(t *testing.T) { body: []byte{1, 2, 3}, } res, st := checkOCSPCacheServer( - context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) + context.Background(), client, fakeRequestFunc, dummyOCSPHost, 20*time.Second) if st.err == nil { t.Errorf("should fail: %v", res) } @@ -399,7 +399,7 @@ func TestOCSPCacheServerRetry(t *testing.T) { body: []byte{1, 2, 3}, } res, st = checkOCSPCacheServer( - context.TODO(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) + context.Background(), client, fakeRequestFunc, dummyOCSPHost, 10*time.Second) if st.err == nil { t.Errorf("should fail: %v", res) } diff --git a/restful.go b/restful.go index f6948c4c3..6b10dd4b3 100644 --- a/restful.go +++ b/restful.go @@ -217,7 +217,7 @@ func postRestfulQuery( return data, err } - if err = sr.FuncCancelQuery(context.TODO(), sr, requestID, timeout); err != nil { + if err = sr.FuncCancelQuery(context.Background(), sr, requestID, timeout); err != nil { return nil, err } return nil, ctx.Err() diff --git a/retry_test.go b/retry_test.go index ea9ee18b9..397da9edd 100644 --- a/retry_test.go +++ b/retry_test.go @@ -225,7 +225,7 @@ func TestRetryQuerySuccess(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolTrue}).doPost().setBody([]byte{0}).execute() if err != nil { @@ -274,7 +274,7 @@ func TestRetryQuerySuccessWithRetryReasonDisabled(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolFalse}).doPost().setBody([]byte{0}).execute() if err != nil { @@ -320,7 +320,7 @@ func TestRetryQuerySuccessWithTimeout(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), nil).doPost().setBody([]byte{0}).execute() if err != nil { @@ -350,7 +350,7 @@ func TestRetryQueryFail(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { @@ -396,7 +396,7 @@ func TestRetryLoginRequest(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err != nil { @@ -416,7 +416,7 @@ func TestRetryLoginRequest(t *testing.T) { success: false, timeout: true, } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 10*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { @@ -447,7 +447,7 @@ func TestRetryAuthLoginRequest(t *testing.T) { execID++ return []byte(fmt.Sprintf("execID: %d", execID)), nil } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBodyCreator(bodyCreator).execute() if err != nil { @@ -468,7 +468,7 @@ func TestLoginRetry429(t *testing.T) { if err != nil { t.Fatal("failed to parse the test URL") } - _, err = newRetryHTTP(context.TODO(), + _, err = newRetryHTTP(context.Background(), client, emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX if err != nil { From 0e5dfdd07df1e2d326127df0bc43cd4c0267ade2 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Thu, 19 Oct 2023 07:20:20 +0200 Subject: [PATCH 11/12] SNOW-894815 Fix TestConcurrentReadOnParams cleaning up the connection too early (#926) --- .github/workflows/build-test.yml | 1 + connection_test.go | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 291b1ee6a..b61a03c18 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -48,6 +48,7 @@ jobs: env: PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} CLOUD_PROVIDER: ${{ matrix.cloud }} + GORACE: history_size=7 run: ./ci/test.sh - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/connection_test.go b/connection_test.go index a76cdb7c1..c5d508d32 100644 --- a/connection_test.go +++ b/connection_test.go @@ -476,19 +476,19 @@ func TestExecWithServerSideError(t *testing.T) { } func TestConcurrentReadOnParams(t *testing.T) { - t.Skip("Fails randomly") config, err := ParseDSN(dsn) if err != nil { t.Fatal("Failed to parse dsn") } connector := NewConnector(SnowflakeDriver{}, *config) db := sql.OpenDB(connector) + defer db.Close() wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { for c := 0; c < 10; c++ { - stmt, err := db.PrepareContext(context.Background(), "SELECT * FROM information_schema.columns WHERE table_schema = ?") + stmt, err := db.PrepareContext(context.Background(), "SELECT table_schema FROM information_schema.columns WHERE table_schema = ? LIMIT 1") if err != nil { t.Error(err) } @@ -499,13 +499,18 @@ func TestConcurrentReadOnParams(t *testing.T) { if rows == nil { continue } + rows.Next() + var tableName string + err = rows.Scan(&tableName) + if err != nil { + t.Error(err) + } _ = rows.Close() } wg.Done() }() } wg.Wait() - defer db.Close() } func postQueryTest(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { From debf383feae2adad08dd8030dbb74eeb7961d4f3 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 20 Oct 2023 09:47:08 +0200 Subject: [PATCH 12/12] SNOW-856228 easy logging finder (#939) SNOW-856228 easy logging finder --- assert_test.go | 4 ++ client_configuration.go | 63 +++++++++++++++++ client_configuration_test.go | 130 +++++++++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+) diff --git a/assert_test.go b/assert_test.go index 0394b4b08..58d4cc458 100644 --- a/assert_test.go +++ b/assert_test.go @@ -21,6 +21,10 @@ func assertEqualE(t *testing.T, actual any, expected any, descriptions ...string errorOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) } +func assertEqualF(t *testing.T, actual any, expected any, descriptions ...string) { + fatalOnNonEmpty(t, validateEqual(actual, expected, descriptions...)) +} + func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } diff --git a/client_configuration.go b/client_configuration.go index 381ed8d86..52d55de73 100644 --- a/client_configuration.go +++ b/client_configuration.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "os" + "path" "strings" ) @@ -20,6 +21,68 @@ const ( levelTrace string = "TRACE" // trace log level ) +const ( + defaultConfigName = "sf_client_config.json" + clientConfEnvName = "SF_CLIENT_CONFIG_FILE" +) + +func getClientConfig(filePathFromConnectionString string) (*ClientConfig, error) { + configPredefinedFilePaths := clientConfigPredefinedDirs() + filePath, err := findClientConfigFilePath(filePathFromConnectionString, configPredefinedFilePaths) + if err != nil { + return nil, err + } + if filePath == "" { // we did not find a config file + return nil, nil + } + return parseClientConfiguration(filePath) +} + +func findClientConfigFilePath(filePathFromConnectionString string, configPredefinedDirs []string) (string, error) { + if filePathFromConnectionString != "" { + return filePathFromConnectionString, nil + } + envConfigFilePath := os.Getenv(clientConfEnvName) + if envConfigFilePath != "" { + return envConfigFilePath, nil + } + return searchForConfigFile(configPredefinedDirs) +} + +func searchForConfigFile(directories []string) (string, error) { + for _, dir := range directories { + filePath := path.Join(dir, defaultConfigName) + exists, err := existsFile(filePath) + if err != nil { + return "", err + } + if exists { + return filePath, nil + } + } + return "", nil +} + +func existsFile(filePath string) (bool, error) { + _, err := os.Stat(filePath) + if err == nil { + return true, nil + } + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, err +} + +func clientConfigPredefinedDirs() []string { + homeDir, err := os.UserHomeDir() + if err != nil { + logger.Warnf("Home dir could not be determined: %w", err) + return []string{".", os.TempDir()} + } + return []string{".", homeDir, os.TempDir()} +} + // ClientConfig config root type ClientConfig struct { Common *ClientConfigCommonProps `json:"common"` diff --git a/client_configuration_test.go b/client_configuration_test.go index 8da3be3ec..a63eb6753 100644 --- a/client_configuration_test.go +++ b/client_configuration_test.go @@ -9,6 +9,105 @@ import ( "testing" ) +func TestFindConfigFileFromConnectionParameters(t *testing.T) { + dirs := createTestDirectories(t) + connParameterConfigPath := createFile(t, "conn_parameters_config.json", "random content", dirs.dir) + envConfigPath := createFile(t, "env_var_config.json", "random content", dirs.dir) + t.Setenv(clientConfEnvName, envConfigPath) + createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) + createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) + + clientConfigFilePath, err := findClientConfigFilePath(connParameterConfigPath, predefinedTestDirs(dirs)) + + assertNilF(t, err, "get client config error") + assertEqualE(t, clientConfigFilePath, connParameterConfigPath, "config file path") +} + +func TestFindConfigFileFromEnvVariable(t *testing.T) { + dirs := createTestDirectories(t) + envConfigPath := createFile(t, "env_var_config.json", "random content", dirs.dir) + t.Setenv(clientConfEnvName, envConfigPath) + createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) + createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) + + clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) + + assertNilF(t, err, "get client config error") + assertEqualE(t, clientConfigFilePath, envConfigPath, "config file path") +} + +func TestFindConfigFileFromFirstPredefinedDir(t *testing.T) { + dirs := createTestDirectories(t) + configPath := createFile(t, defaultConfigName, "random content", dirs.predefinedDir1) + createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) + + clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) + + assertNilF(t, err, "get client config error") + assertEqualE(t, clientConfigFilePath, configPath, "config file path") +} + +func TestFindConfigFileFromSubsequentDirectoryIfNotFoundInPreviousOne(t *testing.T) { + dirs := createTestDirectories(t) + createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir1) + configPath := createFile(t, defaultConfigName, "random content", dirs.predefinedDir2) + + clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) + + assertNilF(t, err, "get client config error") + assertEqualE(t, clientConfigFilePath, configPath, "config file path") +} + +func TestNotFindConfigFileWhenNotDefined(t *testing.T) { + dirs := createTestDirectories(t) + createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir1) + createFile(t, "wrong_file_name.json", "random content", dirs.predefinedDir2) + + clientConfigFilePath, err := findClientConfigFilePath("", predefinedTestDirs(dirs)) + + assertNilF(t, err, "get client config error") + assertEqualE(t, clientConfigFilePath, "", "config file path") +} + +func TestCreatePredefinedDirs(t *testing.T) { + homeDir, err := os.UserHomeDir() + assertNilF(t, err, "get home dir error") + + locations := clientConfigPredefinedDirs() + + assertEqualF(t, len(locations), 3, "size") + assertEqualE(t, locations[0], ".", "driver directory") + assertEqualE(t, locations[1], homeDir, "home directory") + assertEqualE(t, locations[2], os.TempDir(), "temp directory") +} + +func TestGetClientConfig(t *testing.T) { + dir := t.TempDir() + fileName := "config.json" + configContents := `{ + "common": { + "log_level" : "INFO", + "log_path" : "/some-path/some-directory" + } + }` + createFile(t, fileName, configContents, dir) + filePath := path.Join(dir, fileName) + + clientConfigFilePath, err := getClientConfig(filePath) + + assertNilF(t, err) + assertNotNilF(t, clientConfigFilePath) + assertEqualE(t, clientConfigFilePath.Common.LogLevel, "INFO", "log level") + assertEqualE(t, clientConfigFilePath.Common.LogPath, "/some-path/some-directory", "log path") +} + +func TestNoResultForGetClientConfigWhenNoFileFound(t *testing.T) { + clientConfigFilePath, err := getClientConfig("") + + assertNilF(t, err) + assertNilF(t, clientConfigFilePath) +} + func TestParseConfiguration(t *testing.T) { dir := t.TempDir() testCases := []struct { @@ -154,3 +253,34 @@ func createFile(t *testing.T, fileName string, fileContents string, directory st assertNilF(t, err, "create file error") return fullFileName } + +func createTestDirectories(t *testing.T) struct { + dir string + predefinedDir1 string + predefinedDir2 string +} { + dir := t.TempDir() + predefinedDir1 := path.Join(dir, "dir1") + err := os.Mkdir(predefinedDir1, 0755) + assertNilF(t, err, "predefined dir1 error") + predefinedDir2 := path.Join(dir, "dir2") + err = os.Mkdir(predefinedDir2, 0755) + assertNilF(t, err, "predefined dir2 error") + return struct { + dir string + predefinedDir1 string + predefinedDir2 string + }{ + dir: dir, + predefinedDir1: predefinedDir1, + predefinedDir2: predefinedDir2, + } +} + +func predefinedTestDirs(dirs struct { + dir string + predefinedDir1 string + predefinedDir2 string +}) []string { + return []string{dirs.predefinedDir1, dirs.predefinedDir2} +}